Compare commits
11 Commits
55261c0b72
...
90c8b427fc
| Author | SHA1 | Date | |
|---|---|---|---|
| 90c8b427fc | |||
| 077f26d9d6 | |||
| f446e5766e | |||
| 728d1b0c0e | |||
| 57420807ca | |||
| 9ca0e0844c | |||
| 54668536c1 | |||
| c4bf8c371f | |||
| 2f4540f1d3 | |||
| 50971c0946 | |||
| 3b2f219d66 |
@ -189,6 +189,18 @@
|
|||||||
/* Full blend transition time: MANUAL→AUTO takes this many ms */
|
/* Full blend transition time: MANUAL→AUTO takes this many ms */
|
||||||
#define MODE_BLEND_MS 500
|
#define MODE_BLEND_MS 500
|
||||||
|
|
||||||
|
// --- Power Management (STOP mode, Issue #178) ---
|
||||||
|
#define PM_IDLE_TIMEOUT_MS 30000u // 30s no activity → PM_SLEEP_PENDING
|
||||||
|
#define PM_FADE_MS 3000u // LED fade-out duration before STOP entry
|
||||||
|
#define PM_LED_PERIOD_MS 2000u // sleep-pending triangle-wave period (ms)
|
||||||
|
// Estimated per-subsystem currents (mA) — used for JLINK_TLM_POWER telemetry
|
||||||
|
#define PM_CURRENT_BASE_MA 30 // SPI1(IMU)+UART4(CRSF)+USART1(JLink)+core
|
||||||
|
#define PM_CURRENT_AUDIO_MA 8 // I2S3 + amplifier quiescent
|
||||||
|
#define PM_CURRENT_OSD_MA 5 // SPI2 OSD (MAX7456)
|
||||||
|
#define PM_CURRENT_DEBUG_MA 1 // UART5 + USART6
|
||||||
|
#define PM_CURRENT_STOP_MA 1 // MCU in STOP mode (< 1 mA)
|
||||||
|
#define PM_TLM_HZ 1 // JLINK_TLM_POWER transmit rate (Hz)
|
||||||
|
|
||||||
// --- Audio Amplifier (I2S3, Issue #143) ---
|
// --- Audio Amplifier (I2S3, Issue #143) ---
|
||||||
// SPI3 repurposed as I2S3; blackbox flash unused on balance bot
|
// SPI3 repurposed as I2S3; blackbox flash unused on balance bot
|
||||||
#define AUDIO_BCLK_PORT GPIOC
|
#define AUDIO_BCLK_PORT GPIOC
|
||||||
|
|||||||
@ -54,9 +54,11 @@
|
|||||||
#define JLINK_CMD_DFU_ENTER 0x06u
|
#define JLINK_CMD_DFU_ENTER 0x06u
|
||||||
#define JLINK_CMD_ESTOP 0x07u
|
#define JLINK_CMD_ESTOP 0x07u
|
||||||
#define JLINK_CMD_AUDIO 0x08u /* PCM audio chunk: int16 samples, up to 126 */
|
#define JLINK_CMD_AUDIO 0x08u /* PCM audio chunk: int16 samples, up to 126 */
|
||||||
|
#define JLINK_CMD_SLEEP 0x09u /* no payload; request STOP-mode sleep */
|
||||||
|
|
||||||
/* ---- Telemetry IDs (STM32 → Jetson) ---- */
|
/* ---- Telemetry IDs (STM32 → Jetson) ---- */
|
||||||
#define JLINK_TLM_STATUS 0x80u
|
#define JLINK_TLM_STATUS 0x80u
|
||||||
|
#define JLINK_TLM_POWER 0x81u /* jlink_tlm_power_t (11 bytes) */
|
||||||
|
|
||||||
/* ---- Telemetry STATUS payload (20 bytes, packed) ---- */
|
/* ---- Telemetry STATUS payload (20 bytes, packed) ---- */
|
||||||
typedef struct __attribute__((packed)) {
|
typedef struct __attribute__((packed)) {
|
||||||
@ -77,6 +79,15 @@ typedef struct __attribute__((packed)) {
|
|||||||
uint8_t fw_patch;
|
uint8_t fw_patch;
|
||||||
} jlink_tlm_status_t; /* 20 bytes */
|
} jlink_tlm_status_t; /* 20 bytes */
|
||||||
|
|
||||||
|
/* ---- Telemetry POWER payload (11 bytes, packed) ---- */
|
||||||
|
typedef struct __attribute__((packed)) {
|
||||||
|
uint8_t power_state; /* PowerState: 0=ACTIVE,1=SLEEP_PENDING,2=SLEEPING,3=WAKING */
|
||||||
|
uint16_t est_total_ma; /* estimated total current draw (mA) */
|
||||||
|
uint16_t est_audio_ma; /* estimated I2S3+amp current (mA); 0 if gated */
|
||||||
|
uint16_t est_osd_ma; /* estimated OSD SPI2 current (mA); 0 if gated */
|
||||||
|
uint32_t idle_ms; /* ms since last cmd_vel activity */
|
||||||
|
} jlink_tlm_power_t; /* 11 bytes */
|
||||||
|
|
||||||
/* ---- Volatile state (read from main loop) ---- */
|
/* ---- Volatile state (read from main loop) ---- */
|
||||||
typedef struct {
|
typedef struct {
|
||||||
/* Drive command — updated on JLINK_CMD_DRIVE */
|
/* Drive command — updated on JLINK_CMD_DRIVE */
|
||||||
@ -99,6 +110,8 @@ typedef struct {
|
|||||||
|
|
||||||
/* DFU reboot request — set by parser, cleared by main loop */
|
/* DFU reboot request — set by parser, cleared by main loop */
|
||||||
volatile uint8_t dfu_req;
|
volatile uint8_t dfu_req;
|
||||||
|
/* Sleep request — set by JLINK_CMD_SLEEP, cleared by main loop */
|
||||||
|
volatile uint8_t sleep_req;
|
||||||
} JLinkState;
|
} JLinkState;
|
||||||
|
|
||||||
extern volatile JLinkState jlink_state;
|
extern volatile JLinkState jlink_state;
|
||||||
@ -130,4 +143,10 @@ void jlink_send_telemetry(const jlink_tlm_status_t *status);
|
|||||||
*/
|
*/
|
||||||
void jlink_process(void);
|
void jlink_process(void);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* jlink_send_power_telemetry(power) — build and transmit a JLINK_TLM_POWER
|
||||||
|
* frame (17 bytes) at PM_TLM_HZ. Call from main loop when not in STOP mode.
|
||||||
|
*/
|
||||||
|
void jlink_send_power_telemetry(const jlink_tlm_power_t *power);
|
||||||
|
|
||||||
#endif /* JLINK_H */
|
#endif /* JLINK_H */
|
||||||
|
|||||||
96
include/power_mgmt.h
Normal file
96
include/power_mgmt.h
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
#ifndef POWER_MGMT_H
|
||||||
|
#define POWER_MGMT_H
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
|
||||||
|
/*
|
||||||
|
* power_mgmt — STM32F7 STOP-mode sleep/wake manager (Issue #178).
|
||||||
|
*
|
||||||
|
* State machine:
|
||||||
|
* PM_ACTIVE ──(idle ≥ PM_IDLE_TIMEOUT_MS or sleep cmd)──► PM_SLEEP_PENDING
|
||||||
|
* PM_SLEEP_PENDING ──(fade complete, ≥ PM_FADE_MS)──► PM_SLEEPING (WFI)
|
||||||
|
* PM_SLEEPING ──(EXTI wake)──► PM_WAKING ──(clocks restored)──► PM_ACTIVE
|
||||||
|
*
|
||||||
|
* Any call to power_mgmt_activity() during SLEEP_PENDING or SLEEPING
|
||||||
|
* immediately transitions back toward PM_ACTIVE.
|
||||||
|
*
|
||||||
|
* Wake sources (EXTI, falling edge on UART idle-high RX pin or IMU INT):
|
||||||
|
* EXTI1 PA1 UART4_RX — CRSF/ELRS start bit
|
||||||
|
* EXTI7 PB7 USART1_RX — JLink start bit
|
||||||
|
* EXTI4 PC4 MPU6000 INT — IMU motion (handler owned by mpu6000.c)
|
||||||
|
*
|
||||||
|
* Peripheral gating on sleep entry (clock disable, state preserved):
|
||||||
|
* Disabled: SPI3/I2S3 (audio amp), SPI2 (OSD), USART6, UART5 (debug)
|
||||||
|
* Active: SPI1 (IMU), UART4 (CRSF), USART1 (JLink), I2C1 (baro/mag)
|
||||||
|
*
|
||||||
|
* Sleep LED (LED1, active-low PC15):
|
||||||
|
* PM_SLEEP_PENDING: triangle-wave pulse, period PM_LED_PERIOD_MS
|
||||||
|
* All other states: 0 (caller uses normal LED logic)
|
||||||
|
*
|
||||||
|
* IWDG:
|
||||||
|
* Fed immediately before WFI. STOP wakeup <10 ms typical — well within
|
||||||
|
* WATCHDOG_TIMEOUT_MS (50 ms).
|
||||||
|
*
|
||||||
|
* Safety interlock:
|
||||||
|
* Caller MUST NOT call power_mgmt_tick() while armed; call
|
||||||
|
* power_mgmt_activity() instead to keep the idle timer reset.
|
||||||
|
*
|
||||||
|
* JLink integration:
|
||||||
|
* JLINK_CMD_SLEEP (0x09) → power_mgmt_request_sleep()
|
||||||
|
* Any valid JLink frame → power_mgmt_activity() (handled in main loop)
|
||||||
|
*/
|
||||||
|
|
||||||
|
typedef enum {
|
||||||
|
PM_ACTIVE = 0, /* Normal, all peripherals running */
|
||||||
|
PM_SLEEP_PENDING = 1, /* Idle timeout reached; LED fade-out in progress */
|
||||||
|
PM_SLEEPING = 2, /* In STOP mode (WFI); execution blocked in tick() */
|
||||||
|
PM_WAKING = 3, /* Transitional; clocks/peripherals being restored */
|
||||||
|
} PowerState;
|
||||||
|
|
||||||
|
/* ---- API ---- */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* power_mgmt_init() — configure wake EXTI lines (EXTI1, EXTI7).
|
||||||
|
* Call after crsf_init() and jlink_init().
|
||||||
|
*/
|
||||||
|
void power_mgmt_init(void);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* power_mgmt_activity() — record cmd_vel event (CRSF frame, JLink frame).
|
||||||
|
* Resets idle timer; aborts any pending/active sleep.
|
||||||
|
*/
|
||||||
|
void power_mgmt_activity(void);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* power_mgmt_request_sleep() — force sleep regardless of idle timer
|
||||||
|
* (called on JLINK_CMD_SLEEP). Next tick() enters PM_SLEEP_PENDING.
|
||||||
|
*/
|
||||||
|
void power_mgmt_request_sleep(void);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* power_mgmt_tick(now_ms) — drive state machine. May block in WFI during
|
||||||
|
* STOP mode. Returns state after this tick.
|
||||||
|
* MUST NOT be called while balance_state == BALANCE_ARMED.
|
||||||
|
*/
|
||||||
|
PowerState power_mgmt_tick(uint32_t now_ms);
|
||||||
|
|
||||||
|
/* power_mgmt_state() — non-blocking read of current state. */
|
||||||
|
PowerState power_mgmt_state(void);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* power_mgmt_led_brightness() — 0-255 brightness for sleep-pending pulse.
|
||||||
|
* Returns 0 when not in PM_SLEEP_PENDING; caller uses normal LED logic.
|
||||||
|
*/
|
||||||
|
uint8_t power_mgmt_led_brightness(void);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* power_mgmt_current_ma() — estimated total current draw (mA) based on
|
||||||
|
* gating state; populated in JLINK_TLM_POWER telemetry.
|
||||||
|
*/
|
||||||
|
uint16_t power_mgmt_current_ma(void);
|
||||||
|
|
||||||
|
/* power_mgmt_idle_ms() — ms elapsed since last power_mgmt_activity() call. */
|
||||||
|
uint32_t power_mgmt_idle_ms(void);
|
||||||
|
|
||||||
|
#endif /* POWER_MGMT_H */
|
||||||
@ -40,6 +40,11 @@ rosbridge_websocket:
|
|||||||
"/person/target",
|
"/person/target",
|
||||||
"/person/detections",
|
"/person/detections",
|
||||||
"/camera/*/image_raw/compressed",
|
"/camera/*/image_raw/compressed",
|
||||||
|
"/camera/depth/image_rect_raw/compressed",
|
||||||
|
"/camera/panoramic/compressed",
|
||||||
|
"/social/faces/detections",
|
||||||
|
"/social/gestures",
|
||||||
|
"/social/scene/objects",
|
||||||
"/scan",
|
"/scan",
|
||||||
"/cmd_vel",
|
"/cmd_vel",
|
||||||
"/saltybot/imu",
|
"/saltybot/imu",
|
||||||
|
|||||||
@ -94,4 +94,33 @@ def generate_launch_description():
|
|||||||
for name in _CAMERAS
|
for name in _CAMERAS
|
||||||
]
|
]
|
||||||
|
|
||||||
return LaunchDescription([rosbridge] + republishers)
|
# ── D435i colour republisher (Issue #177) ────────────────────────────────
|
||||||
|
d435i_color = Node(
|
||||||
|
package='image_transport',
|
||||||
|
executable='republish',
|
||||||
|
name='compress_d435i_color',
|
||||||
|
arguments=['raw', 'compressed'],
|
||||||
|
remappings=[
|
||||||
|
('in', '/camera/color/image_raw'),
|
||||||
|
('out/compressed', '/camera/color/image_raw/compressed'),
|
||||||
|
],
|
||||||
|
parameters=[{'compressed.jpeg_quality': _JPEG_QUALITY}],
|
||||||
|
output='screen',
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── D435i depth republisher (Issue #177) ─────────────────────────────────
|
||||||
|
# Depth stream as compressedDepth (PNG16) — preserves uint16 depth values.
|
||||||
|
# Browser displays as greyscale PNG (darker = closer).
|
||||||
|
d435i_depth = Node(
|
||||||
|
package='image_transport',
|
||||||
|
executable='republish',
|
||||||
|
name='compress_d435i_depth',
|
||||||
|
arguments=['raw', 'compressedDepth'],
|
||||||
|
remappings=[
|
||||||
|
('in', '/camera/depth/image_rect_raw'),
|
||||||
|
('out/compressedDepth', '/camera/depth/image_rect_raw/compressed'),
|
||||||
|
],
|
||||||
|
output='screen',
|
||||||
|
)
|
||||||
|
|
||||||
|
return LaunchDescription([rosbridge] + republishers + [d435i_color, d435i_depth])
|
||||||
|
|||||||
16
jetson/ros2_ws/src/saltybot_dynamic_obs_msgs/CMakeLists.txt
Normal file
16
jetson/ros2_ws/src/saltybot_dynamic_obs_msgs/CMakeLists.txt
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.8)
|
||||||
|
project(saltybot_dynamic_obs_msgs)
|
||||||
|
|
||||||
|
find_package(ament_cmake REQUIRED)
|
||||||
|
find_package(rosidl_default_generators REQUIRED)
|
||||||
|
find_package(std_msgs REQUIRED)
|
||||||
|
find_package(geometry_msgs REQUIRED)
|
||||||
|
|
||||||
|
rosidl_generate_interfaces(${PROJECT_NAME}
|
||||||
|
"msg/TrackedObject.msg"
|
||||||
|
"msg/MovingObjectArray.msg"
|
||||||
|
DEPENDENCIES std_msgs geometry_msgs
|
||||||
|
)
|
||||||
|
|
||||||
|
ament_export_dependencies(rosidl_default_runtime)
|
||||||
|
ament_package()
|
||||||
@ -0,0 +1,12 @@
|
|||||||
|
# MovingObjectArray — all currently tracked moving obstacles.
|
||||||
|
#
|
||||||
|
# Published at ~10 Hz on /saltybot/moving_objects.
|
||||||
|
# Only confirmed tracks (hits >= confirm_frames) appear here.
|
||||||
|
|
||||||
|
std_msgs/Header header
|
||||||
|
|
||||||
|
saltybot_dynamic_obs_msgs/TrackedObject[] objects
|
||||||
|
|
||||||
|
uint32 active_count # number of confirmed tracks
|
||||||
|
uint32 tentative_count # tracks not yet confirmed
|
||||||
|
float32 detector_latency_ms # pipeline latency hint
|
||||||
@ -0,0 +1,21 @@
|
|||||||
|
# TrackedObject — a single tracked moving obstacle.
|
||||||
|
#
|
||||||
|
# predicted_path[i] is the estimated pose at predicted_times[i] seconds from now.
|
||||||
|
# All poses are in the same frame as header.frame_id (typically 'odom').
|
||||||
|
|
||||||
|
std_msgs/Header header
|
||||||
|
|
||||||
|
uint32 object_id # stable ID across frames (monotonically increasing)
|
||||||
|
|
||||||
|
geometry_msgs/PoseWithCovariance pose # current best-estimate pose (x, y, yaw)
|
||||||
|
geometry_msgs/Vector3 velocity # vx, vy in m/s (vz = 0 for ground objects)
|
||||||
|
|
||||||
|
geometry_msgs/Pose[] predicted_path # future positions at predicted_times
|
||||||
|
float32[] predicted_times # seconds from header.stamp for each pose
|
||||||
|
|
||||||
|
float32 speed_mps # scalar |v|
|
||||||
|
float32 confidence # 0.0–1.0 (higher after more confirmed frames)
|
||||||
|
|
||||||
|
uint32 age_frames # frames since first detection
|
||||||
|
uint32 hits # number of successful associations
|
||||||
|
bool is_valid # false if in tentative / just-created state
|
||||||
23
jetson/ros2_ws/src/saltybot_dynamic_obs_msgs/package.xml
Normal file
23
jetson/ros2_ws/src/saltybot_dynamic_obs_msgs/package.xml
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
<?xml version="1.0"?>
|
||||||
|
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||||
|
<package format="3">
|
||||||
|
<name>saltybot_dynamic_obs_msgs</name>
|
||||||
|
<version>0.1.0</version>
|
||||||
|
<description>Custom message types for dynamic obstacle tracking.</description>
|
||||||
|
<maintainer email="robot@saltylab.local">SaltyLab</maintainer>
|
||||||
|
<license>MIT</license>
|
||||||
|
|
||||||
|
<buildtool_depend>ament_cmake</buildtool_depend>
|
||||||
|
<buildtool_depend>rosidl_default_generators</buildtool_depend>
|
||||||
|
|
||||||
|
<depend>std_msgs</depend>
|
||||||
|
<depend>geometry_msgs</depend>
|
||||||
|
|
||||||
|
<exec_depend>rosidl_default_runtime</exec_depend>
|
||||||
|
|
||||||
|
<member_of_group>rosidl_interface_packages</member_of_group>
|
||||||
|
|
||||||
|
<export>
|
||||||
|
<build_type>ament_cmake</build_type>
|
||||||
|
</export>
|
||||||
|
</package>
|
||||||
@ -0,0 +1,52 @@
|
|||||||
|
# saltybot_dynamic_obstacles — runtime parameters
|
||||||
|
#
|
||||||
|
# Requires:
|
||||||
|
# /scan (sensor_msgs/LaserScan) — RPLIDAR A1M8 at ~5.5 Hz
|
||||||
|
#
|
||||||
|
# LIDAR scan is published by rplidar_ros node.
|
||||||
|
# Make sure RPLIDAR is running before starting this stack.
|
||||||
|
|
||||||
|
dynamic_obs_tracker:
|
||||||
|
ros__parameters:
|
||||||
|
max_tracks: 20 # max simultaneous tracked objects
|
||||||
|
confirm_frames: 3 # hits before a track is published
|
||||||
|
max_missed_frames: 6 # missed frames before track deletion
|
||||||
|
assoc_dist_m: 1.5 # max assignment distance (Hungarian)
|
||||||
|
prediction_hz: 10.0 # tracker update + publish rate
|
||||||
|
horizon_s: 2.5 # prediction look-ahead
|
||||||
|
pred_step_s: 0.5 # time between predicted waypoints
|
||||||
|
odom_frame: 'odom'
|
||||||
|
min_speed_mps: 0.05 # suppress near-stationary tracks
|
||||||
|
max_range_m: 8.0 # ignore detections beyond this
|
||||||
|
|
||||||
|
dynamic_obs_costmap:
|
||||||
|
ros__parameters:
|
||||||
|
inflation_radius_m: 0.35 # safety bubble around each predicted point
|
||||||
|
ring_points: 8 # polygon points for inflation circle
|
||||||
|
clear_on_empty: true # push empty cloud to clear stale Nav2 markings
|
||||||
|
|
||||||
|
|
||||||
|
# ── Nav2 costmap integration ───────────────────────────────────────────────────
|
||||||
|
# In your nav2_params.yaml, under local_costmap or global_costmap > plugins, add
|
||||||
|
# an ObstacleLayer with:
|
||||||
|
#
|
||||||
|
# obstacle_layer:
|
||||||
|
# plugin: "nav2_costmap_2d::ObstacleLayer"
|
||||||
|
# enabled: true
|
||||||
|
# observation_sources: static_scan dynamic_obs
|
||||||
|
# static_scan:
|
||||||
|
# topic: /scan
|
||||||
|
# data_type: LaserScan
|
||||||
|
# ...
|
||||||
|
# dynamic_obs:
|
||||||
|
# topic: /saltybot/dynamic_obs_cloud
|
||||||
|
# data_type: PointCloud2
|
||||||
|
# sensor_frame: odom
|
||||||
|
# obstacle_max_range: 10.0
|
||||||
|
# raytrace_max_range: 10.0
|
||||||
|
# marking: true
|
||||||
|
# clearing: false
|
||||||
|
#
|
||||||
|
# This feeds the predicted trajectory smear directly into Nav2's obstacle
|
||||||
|
# inflation, forcing the planner to route around the predicted future path
|
||||||
|
# of every tracked moving object.
|
||||||
@ -0,0 +1,75 @@
|
|||||||
|
"""
|
||||||
|
dynamic_obstacles.launch.py — Dynamic obstacle tracking + Nav2 costmap layer.
|
||||||
|
|
||||||
|
Starts:
|
||||||
|
dynamic_obs_tracker — LIDAR motion detection + Kalman tracking @10 Hz
|
||||||
|
dynamic_obs_costmap — Predicted-trajectory → PointCloud2 for Nav2
|
||||||
|
|
||||||
|
Launch args:
|
||||||
|
max_tracks int '20'
|
||||||
|
assoc_dist_m float '1.5'
|
||||||
|
horizon_s float '2.5'
|
||||||
|
inflation_radius_m float '0.35'
|
||||||
|
|
||||||
|
Verify:
|
||||||
|
ros2 topic hz /saltybot/moving_objects # ~10 Hz
|
||||||
|
ros2 topic echo /saltybot/moving_objects # TrackedObject list
|
||||||
|
ros2 topic hz /saltybot/dynamic_obs_cloud # ~10 Hz (when objects present)
|
||||||
|
rviz2 → add MarkerArray → /saltybot/moving_objects_viz
|
||||||
|
|
||||||
|
Nav2 costmap integration:
|
||||||
|
In your costmap_params.yaml ObstacleLayer observation_sources, add:
|
||||||
|
dynamic_obs:
|
||||||
|
topic: /saltybot/dynamic_obs_cloud
|
||||||
|
data_type: PointCloud2
|
||||||
|
marking: true
|
||||||
|
clearing: false
|
||||||
|
"""
|
||||||
|
|
||||||
|
from launch import LaunchDescription
|
||||||
|
from launch.actions import DeclareLaunchArgument
|
||||||
|
from launch.substitutions import LaunchConfiguration
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launch_description():
|
||||||
|
args = [
|
||||||
|
DeclareLaunchArgument('max_tracks', default_value='20'),
|
||||||
|
DeclareLaunchArgument('assoc_dist_m', default_value='1.5'),
|
||||||
|
DeclareLaunchArgument('horizon_s', default_value='2.5'),
|
||||||
|
DeclareLaunchArgument('inflation_radius_m', default_value='0.35'),
|
||||||
|
DeclareLaunchArgument('min_speed_mps', default_value='0.05'),
|
||||||
|
]
|
||||||
|
|
||||||
|
tracker = Node(
|
||||||
|
package='saltybot_dynamic_obstacles',
|
||||||
|
executable='dynamic_obs_tracker',
|
||||||
|
name='dynamic_obs_tracker',
|
||||||
|
output='screen',
|
||||||
|
parameters=[{
|
||||||
|
'max_tracks': LaunchConfiguration('max_tracks'),
|
||||||
|
'assoc_dist_m': LaunchConfiguration('assoc_dist_m'),
|
||||||
|
'horizon_s': LaunchConfiguration('horizon_s'),
|
||||||
|
'min_speed_mps': LaunchConfiguration('min_speed_mps'),
|
||||||
|
'prediction_hz': 10.0,
|
||||||
|
'confirm_frames': 3,
|
||||||
|
'max_missed_frames': 6,
|
||||||
|
'pred_step_s': 0.5,
|
||||||
|
'odom_frame': 'odom',
|
||||||
|
'max_range_m': 8.0,
|
||||||
|
}],
|
||||||
|
)
|
||||||
|
|
||||||
|
costmap = Node(
|
||||||
|
package='saltybot_dynamic_obstacles',
|
||||||
|
executable='dynamic_obs_costmap',
|
||||||
|
name='dynamic_obs_costmap',
|
||||||
|
output='screen',
|
||||||
|
parameters=[{
|
||||||
|
'inflation_radius_m': LaunchConfiguration('inflation_radius_m'),
|
||||||
|
'ring_points': 8,
|
||||||
|
'clear_on_empty': True,
|
||||||
|
}],
|
||||||
|
)
|
||||||
|
|
||||||
|
return LaunchDescription(args + [tracker, costmap])
|
||||||
29
jetson/ros2_ws/src/saltybot_dynamic_obstacles/package.xml
Normal file
29
jetson/ros2_ws/src/saltybot_dynamic_obstacles/package.xml
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
<?xml version="1.0"?>
|
||||||
|
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||||
|
<package format="3">
|
||||||
|
<name>saltybot_dynamic_obstacles</name>
|
||||||
|
<version>0.1.0</version>
|
||||||
|
<description>
|
||||||
|
Dynamic obstacle detection, multi-object Kalman tracking, trajectory
|
||||||
|
prediction, and Nav2 costmap layer integration for SaltyBot.
|
||||||
|
</description>
|
||||||
|
<maintainer email="robot@saltylab.local">SaltyLab</maintainer>
|
||||||
|
<license>MIT</license>
|
||||||
|
|
||||||
|
<depend>rclpy</depend>
|
||||||
|
<depend>std_msgs</depend>
|
||||||
|
<depend>sensor_msgs</depend>
|
||||||
|
<depend>geometry_msgs</depend>
|
||||||
|
<depend>nav_msgs</depend>
|
||||||
|
<depend>visualization_msgs</depend>
|
||||||
|
<depend>saltybot_dynamic_obs_msgs</depend>
|
||||||
|
|
||||||
|
<exec_depend>python3-numpy</exec_depend>
|
||||||
|
<exec_depend>python3-scipy</exec_depend>
|
||||||
|
|
||||||
|
<test_depend>pytest</test_depend>
|
||||||
|
|
||||||
|
<export>
|
||||||
|
<build_type>ament_python</build_type>
|
||||||
|
</export>
|
||||||
|
</package>
|
||||||
@ -0,0 +1,178 @@
|
|||||||
|
"""
|
||||||
|
costmap_layer_node.py — Nav2 costmap integration for dynamic obstacles.
|
||||||
|
|
||||||
|
Converts predicted trajectories from /saltybot/moving_objects into a
|
||||||
|
PointCloud2 fed into Nav2's ObstacleLayer. Each predicted future position
|
||||||
|
is added as a point, creating a "smeared" dynamic obstacle zone that
|
||||||
|
covers the full 2-3 s prediction horizon.
|
||||||
|
|
||||||
|
Nav2 ObstacleLayer config (in costmap_params.yaml):
|
||||||
|
obstacle_layer:
|
||||||
|
enabled: true
|
||||||
|
observation_sources: dynamic_obs
|
||||||
|
dynamic_obs:
|
||||||
|
topic: /saltybot/dynamic_obs_cloud
|
||||||
|
sensor_frame: odom
|
||||||
|
data_type: PointCloud2
|
||||||
|
obstacle_max_range: 12.0
|
||||||
|
obstacle_min_range: 0.0
|
||||||
|
raytrace_max_range: 12.0
|
||||||
|
marking: true
|
||||||
|
clearing: false # let the tracker handle clearing
|
||||||
|
|
||||||
|
The node also clears old obstacle points when tracks are dropped, by
|
||||||
|
publishing a clearing cloud to /saltybot/dynamic_obs_clear.
|
||||||
|
|
||||||
|
Subscribes:
|
||||||
|
/saltybot/moving_objects saltybot_dynamic_obs_msgs/MovingObjectArray
|
||||||
|
|
||||||
|
Publishes:
|
||||||
|
/saltybot/dynamic_obs_cloud sensor_msgs/PointCloud2 marking cloud
|
||||||
|
/saltybot/dynamic_obs_clear sensor_msgs/PointCloud2 clearing cloud
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
inflation_radius_m float 0.35 (each predicted point inflated by this radius)
|
||||||
|
ring_points int 8 (polygon approximation of inflation circle)
|
||||||
|
clear_on_empty bool true (publish clear cloud when no objects tracked)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import struct
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||||
|
|
||||||
|
from sensor_msgs.msg import PointCloud2, PointField
|
||||||
|
from std_msgs.msg import Header
|
||||||
|
|
||||||
|
try:
|
||||||
|
from saltybot_dynamic_obs_msgs.msg import MovingObjectArray
|
||||||
|
_MSGS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
_MSGS_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
_RELIABLE_QOS = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.RELIABLE,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pc2(header: Header, points_xyz: List[tuple]) -> PointCloud2:
|
||||||
|
"""Pack a list of (x, y, z) into a PointCloud2 message."""
|
||||||
|
fields = [
|
||||||
|
PointField(name='x', offset=0, datatype=PointField.FLOAT32, count=1),
|
||||||
|
PointField(name='y', offset=4, datatype=PointField.FLOAT32, count=1),
|
||||||
|
PointField(name='z', offset=8, datatype=PointField.FLOAT32, count=1),
|
||||||
|
]
|
||||||
|
point_step = 12 # 3 × float32
|
||||||
|
data = bytearray(len(points_xyz) * point_step)
|
||||||
|
for i, (x, y, z) in enumerate(points_xyz):
|
||||||
|
struct.pack_into('<fff', data, i * point_step, x, y, z)
|
||||||
|
|
||||||
|
pc = PointCloud2()
|
||||||
|
pc.header = header
|
||||||
|
pc.height = 1
|
||||||
|
pc.width = len(points_xyz)
|
||||||
|
pc.fields = fields
|
||||||
|
pc.is_bigendian = False
|
||||||
|
pc.point_step = point_step
|
||||||
|
pc.row_step = point_step * len(points_xyz)
|
||||||
|
pc.data = bytes(data)
|
||||||
|
pc.is_dense = True
|
||||||
|
return pc
|
||||||
|
|
||||||
|
|
||||||
|
class CostmapLayerNode(Node):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('dynamic_obs_costmap')
|
||||||
|
|
||||||
|
self.declare_parameter('inflation_radius_m', 0.35)
|
||||||
|
self.declare_parameter('ring_points', 8)
|
||||||
|
self.declare_parameter('clear_on_empty', True)
|
||||||
|
|
||||||
|
self._infl_r = self.get_parameter('inflation_radius_m').value
|
||||||
|
self._ring_n = self.get_parameter('ring_points').value
|
||||||
|
self._clear_empty = self.get_parameter('clear_on_empty').value
|
||||||
|
|
||||||
|
# Pre-compute ring offsets for inflation
|
||||||
|
self._ring_offsets = [
|
||||||
|
(self._infl_r * math.cos(2 * math.pi * i / self._ring_n),
|
||||||
|
self._infl_r * math.sin(2 * math.pi * i / self._ring_n))
|
||||||
|
for i in range(self._ring_n)
|
||||||
|
]
|
||||||
|
|
||||||
|
if _MSGS_AVAILABLE:
|
||||||
|
self.create_subscription(
|
||||||
|
MovingObjectArray,
|
||||||
|
'/saltybot/moving_objects',
|
||||||
|
self._on_objects,
|
||||||
|
_RELIABLE_QOS,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.get_logger().warning(
|
||||||
|
'[costmap_layer] saltybot_dynamic_obs_msgs not built — '
|
||||||
|
'will not subscribe to MovingObjectArray'
|
||||||
|
)
|
||||||
|
|
||||||
|
self._mark_pub = self.create_publisher(
|
||||||
|
PointCloud2, '/saltybot/dynamic_obs_cloud', 10
|
||||||
|
)
|
||||||
|
self._clear_pub = self.create_publisher(
|
||||||
|
PointCloud2, '/saltybot/dynamic_obs_clear', 10
|
||||||
|
)
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f'dynamic_obs_costmap ready — '
|
||||||
|
f'inflation={self._infl_r}m ring_pts={self._ring_n}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Callback ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _on_objects(self, msg: 'MovingObjectArray') -> None:
|
||||||
|
hdr = msg.header
|
||||||
|
mark_pts: List[tuple] = []
|
||||||
|
|
||||||
|
for obj in msg.objects:
|
||||||
|
if not obj.is_valid:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Current position
|
||||||
|
self._add_inflated(obj.pose.pose.position.x,
|
||||||
|
obj.pose.pose.position.y, mark_pts)
|
||||||
|
|
||||||
|
# Predicted future positions
|
||||||
|
for pose in obj.predicted_path:
|
||||||
|
self._add_inflated(pose.position.x, pose.position.y, mark_pts)
|
||||||
|
|
||||||
|
if mark_pts:
|
||||||
|
self._mark_pub.publish(_make_pc2(hdr, mark_pts))
|
||||||
|
elif self._clear_empty:
|
||||||
|
# Publish tiny clear cloud so Nav2 clears stale markings
|
||||||
|
self._clear_pub.publish(_make_pc2(hdr, []))
|
||||||
|
|
||||||
|
def _add_inflated(self, cx: float, cy: float, pts: List[tuple]) -> None:
|
||||||
|
"""Add the centre + ring of inflation points at height 0.5 m."""
|
||||||
|
pts.append((cx, cy, 0.5))
|
||||||
|
for ox, oy in self._ring_offsets:
|
||||||
|
pts.append((cx + ox, cy + oy, 0.5))
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = CostmapLayerNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -0,0 +1,319 @@
|
|||||||
|
"""
|
||||||
|
dynamic_obs_node.py — ROS2 node: LIDAR moving-object detection + Kalman tracking.
|
||||||
|
|
||||||
|
Pipeline:
|
||||||
|
1. Subscribe /scan (RPLIDAR LaserScan, ~5.5 Hz).
|
||||||
|
2. ObjectDetector performs background subtraction → moving blobs.
|
||||||
|
3. TrackerManager runs Hungarian assignment + Kalman predict/update at 10 Hz.
|
||||||
|
4. Publish /saltybot/moving_objects (MovingObjectArray).
|
||||||
|
5. Publish /saltybot/moving_objects_viz (MarkerArray) for RViz.
|
||||||
|
|
||||||
|
The 10 Hz timer drives the tracker regardless of scan rate, so prediction
|
||||||
|
continues between scans (pure-predict steps).
|
||||||
|
|
||||||
|
Subscribes:
|
||||||
|
/scan sensor_msgs/LaserScan RPLIDAR A1M8
|
||||||
|
|
||||||
|
Publishes:
|
||||||
|
/saltybot/moving_objects saltybot_dynamic_obs_msgs/MovingObjectArray
|
||||||
|
/saltybot/moving_objects_viz visualization_msgs/MarkerArray
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
max_tracks int 20
|
||||||
|
confirm_frames int 3
|
||||||
|
max_missed_frames int 6
|
||||||
|
assoc_dist_m float 1.5
|
||||||
|
prediction_hz float 10.0 (tracker + publish rate)
|
||||||
|
horizon_s float 2.5
|
||||||
|
pred_step_s float 0.5
|
||||||
|
odom_frame str 'odom'
|
||||||
|
min_speed_mps float 0.05 (suppress near-zero velocity tracks)
|
||||||
|
max_range_m float 8.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||||
|
|
||||||
|
from sensor_msgs.msg import LaserScan
|
||||||
|
from geometry_msgs.msg import Pose, Point, Quaternion, Vector3
|
||||||
|
from std_msgs.msg import Header, ColorRGBA
|
||||||
|
from visualization_msgs.msg import Marker, MarkerArray
|
||||||
|
|
||||||
|
try:
|
||||||
|
from saltybot_dynamic_obs_msgs.msg import TrackedObject, MovingObjectArray
|
||||||
|
_MSGS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
_MSGS_AVAILABLE = False
|
||||||
|
|
||||||
|
from .object_detector import ObjectDetector
|
||||||
|
from .tracker_manager import TrackerManager, Track
|
||||||
|
|
||||||
|
|
||||||
|
_SENSOR_QOS = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _yaw_quat(yaw: float) -> Quaternion:
|
||||||
|
q = Quaternion()
|
||||||
|
q.w = math.cos(yaw * 0.5)
|
||||||
|
q.z = math.sin(yaw * 0.5)
|
||||||
|
return q
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicObsNode(Node):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('dynamic_obs_tracker')
|
||||||
|
|
||||||
|
# ── Parameters ──────────────────────────────────────────────────────
|
||||||
|
self.declare_parameter('max_tracks', 20)
|
||||||
|
self.declare_parameter('confirm_frames', 3)
|
||||||
|
self.declare_parameter('max_missed_frames', 6)
|
||||||
|
self.declare_parameter('assoc_dist_m', 1.5)
|
||||||
|
self.declare_parameter('prediction_hz', 10.0)
|
||||||
|
self.declare_parameter('horizon_s', 2.5)
|
||||||
|
self.declare_parameter('pred_step_s', 0.5)
|
||||||
|
self.declare_parameter('odom_frame', 'odom')
|
||||||
|
self.declare_parameter('min_speed_mps', 0.05)
|
||||||
|
self.declare_parameter('max_range_m', 8.0)
|
||||||
|
|
||||||
|
max_tracks = self.get_parameter('max_tracks').value
|
||||||
|
confirm_f = self.get_parameter('confirm_frames').value
|
||||||
|
max_missed = self.get_parameter('max_missed_frames').value
|
||||||
|
assoc_dist = self.get_parameter('assoc_dist_m').value
|
||||||
|
pred_hz = self.get_parameter('prediction_hz').value
|
||||||
|
horizon_s = self.get_parameter('horizon_s').value
|
||||||
|
pred_step = self.get_parameter('pred_step_s').value
|
||||||
|
self._frame = self.get_parameter('odom_frame').value
|
||||||
|
self._min_spd = self.get_parameter('min_speed_mps').value
|
||||||
|
self._max_rng = self.get_parameter('max_range_m').value
|
||||||
|
|
||||||
|
# ── Core modules ────────────────────────────────────────────────────
|
||||||
|
self._detector = ObjectDetector(
|
||||||
|
grid_radius_m=min(self._max_rng + 2.0, 12.0),
|
||||||
|
max_cluster=int((self._max_rng / 0.1) ** 2 * 0.5),
|
||||||
|
)
|
||||||
|
self._tracker = TrackerManager(
|
||||||
|
max_tracks=max_tracks,
|
||||||
|
confirm_frames=confirm_f,
|
||||||
|
max_missed=max_missed,
|
||||||
|
assoc_dist_m=assoc_dist,
|
||||||
|
horizon_s=horizon_s,
|
||||||
|
pred_step_s=pred_step,
|
||||||
|
)
|
||||||
|
self._horizon_s = horizon_s
|
||||||
|
self._pred_step = pred_step
|
||||||
|
|
||||||
|
# ── State ────────────────────────────────────────────────────────────
|
||||||
|
self._latest_scan: LaserScan | None = None
|
||||||
|
self._last_track_t: float = time.monotonic()
|
||||||
|
self._scan_processed_stamp: float | None = None
|
||||||
|
|
||||||
|
# ── Subscriptions ────────────────────────────────────────────────────
|
||||||
|
self.create_subscription(LaserScan, '/scan', self._on_scan, _SENSOR_QOS)
|
||||||
|
|
||||||
|
# ── Publishers ───────────────────────────────────────────────────────
|
||||||
|
if _MSGS_AVAILABLE:
|
||||||
|
self._obj_pub = self.create_publisher(
|
||||||
|
MovingObjectArray, '/saltybot/moving_objects', 10
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._obj_pub = None
|
||||||
|
self.get_logger().warning(
|
||||||
|
'[dyn_obs] saltybot_dynamic_obs_msgs not built — '
|
||||||
|
'MovingObjectArray will not be published'
|
||||||
|
)
|
||||||
|
|
||||||
|
self._viz_pub = self.create_publisher(
|
||||||
|
MarkerArray, '/saltybot/moving_objects_viz', 10
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Timer ────────────────────────────────────────────────────────────
|
||||||
|
self.create_timer(1.0 / pred_hz, self._track_tick)
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f'dynamic_obs_tracker ready — '
|
||||||
|
f'max_tracks={max_tracks} horizon={horizon_s}s assoc={assoc_dist}m'
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Scan callback ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _on_scan(self, msg: LaserScan) -> None:
|
||||||
|
self._latest_scan = msg
|
||||||
|
|
||||||
|
# ── 10 Hz tracker tick ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _track_tick(self) -> None:
|
||||||
|
t0 = time.monotonic()
|
||||||
|
|
||||||
|
now_mono = t0
|
||||||
|
dt = now_mono - self._last_track_t
|
||||||
|
dt = max(1e-3, min(dt, 0.5))
|
||||||
|
self._last_track_t = now_mono
|
||||||
|
|
||||||
|
scan = self._latest_scan
|
||||||
|
detections = []
|
||||||
|
|
||||||
|
if scan is not None:
|
||||||
|
stamp_sec = scan.header.stamp.sec + scan.header.stamp.nanosec * 1e-9
|
||||||
|
if stamp_sec != self._scan_processed_stamp:
|
||||||
|
self._scan_processed_stamp = stamp_sec
|
||||||
|
ranges = np.asarray(scan.ranges, dtype=np.float32)
|
||||||
|
detections = self._detector.update(
|
||||||
|
ranges,
|
||||||
|
scan.angle_min,
|
||||||
|
scan.angle_increment,
|
||||||
|
min(scan.range_max, self._max_rng),
|
||||||
|
)
|
||||||
|
|
||||||
|
confirmed = self._tracker.update(detections, dt)
|
||||||
|
|
||||||
|
latency_ms = (time.monotonic() - t0) * 1000.0
|
||||||
|
stamp = self.get_clock().now().to_msg()
|
||||||
|
|
||||||
|
if _MSGS_AVAILABLE and self._obj_pub is not None:
|
||||||
|
self._publish_objects(confirmed, stamp, latency_ms)
|
||||||
|
self._publish_viz(confirmed, stamp)
|
||||||
|
|
||||||
|
# ── Publish helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _publish_objects(self, confirmed: list, stamp, latency_ms: float) -> None:
|
||||||
|
arr = MovingObjectArray()
|
||||||
|
arr.header.stamp = stamp
|
||||||
|
arr.header.frame_id = self._frame
|
||||||
|
arr.active_count = len(confirmed)
|
||||||
|
arr.tentative_count = self._tracker.tentative_count
|
||||||
|
arr.detector_latency_ms = float(latency_ms)
|
||||||
|
|
||||||
|
for tr in confirmed:
|
||||||
|
px, py = tr.kalman.position
|
||||||
|
vx, vy = tr.kalman.velocity
|
||||||
|
speed = tr.kalman.speed
|
||||||
|
if speed < self._min_spd:
|
||||||
|
continue
|
||||||
|
|
||||||
|
obj = TrackedObject()
|
||||||
|
obj.header = arr.header
|
||||||
|
obj.object_id = tr.track_id
|
||||||
|
obj.pose.pose.position.x = px
|
||||||
|
obj.pose.pose.position.y = py
|
||||||
|
obj.pose.pose.orientation = _yaw_quat(math.atan2(vy, vx))
|
||||||
|
cov = tr.kalman.covariance_2x2
|
||||||
|
obj.pose.covariance[0] = float(cov[0, 0])
|
||||||
|
obj.pose.covariance[1] = float(cov[0, 1])
|
||||||
|
obj.pose.covariance[6] = float(cov[1, 0])
|
||||||
|
obj.pose.covariance[7] = float(cov[1, 1])
|
||||||
|
obj.velocity.x = vx
|
||||||
|
obj.velocity.y = vy
|
||||||
|
obj.speed_mps = speed
|
||||||
|
obj.confidence = min(1.0, tr.hits / (self._tracker._confirm_frames * 3))
|
||||||
|
obj.age_frames = tr.age
|
||||||
|
obj.hits = tr.hits
|
||||||
|
obj.is_valid = True
|
||||||
|
|
||||||
|
# Predicted path
|
||||||
|
for px_f, py_f, t_f in tr.kalman.predict_horizon(
|
||||||
|
self._horizon_s, self._pred_step
|
||||||
|
):
|
||||||
|
p = Pose()
|
||||||
|
p.position.x = px_f
|
||||||
|
p.position.y = py_f
|
||||||
|
p.orientation.w = 1.0
|
||||||
|
obj.predicted_path.append(p)
|
||||||
|
obj.predicted_times.append(float(t_f))
|
||||||
|
|
||||||
|
arr.objects.append(obj)
|
||||||
|
|
||||||
|
self._obj_pub.publish(arr)
|
||||||
|
|
||||||
|
def _publish_viz(self, confirmed: list, stamp) -> None:
|
||||||
|
markers = MarkerArray()
|
||||||
|
|
||||||
|
# Delete old markers
|
||||||
|
del_marker = Marker()
|
||||||
|
del_marker.header.stamp = stamp
|
||||||
|
del_marker.header.frame_id = self._frame
|
||||||
|
del_marker.action = Marker.DELETEALL
|
||||||
|
markers.markers.append(del_marker)
|
||||||
|
|
||||||
|
for tr in confirmed:
|
||||||
|
px, py = tr.kalman.position
|
||||||
|
vx, vy = tr.kalman.velocity
|
||||||
|
speed = tr.kalman.speed
|
||||||
|
if speed < self._min_spd:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Cylinder at current position
|
||||||
|
m = Marker()
|
||||||
|
m.header.stamp = stamp
|
||||||
|
m.header.frame_id = self._frame
|
||||||
|
m.ns = 'dyn_obs'
|
||||||
|
m.id = tr.track_id
|
||||||
|
m.type = Marker.CYLINDER
|
||||||
|
m.action = Marker.ADD
|
||||||
|
m.pose.position.x = px
|
||||||
|
m.pose.position.y = py
|
||||||
|
m.pose.position.z = 0.5
|
||||||
|
m.pose.orientation.w = 1.0
|
||||||
|
m.scale.x = 0.4
|
||||||
|
m.scale.y = 0.4
|
||||||
|
m.scale.z = 1.0
|
||||||
|
m.color = ColorRGBA(r=1.0, g=0.2, b=0.0, a=0.7)
|
||||||
|
m.lifetime.sec = 1
|
||||||
|
markers.markers.append(m)
|
||||||
|
|
||||||
|
# Arrow for velocity
|
||||||
|
vel_m = Marker()
|
||||||
|
vel_m.header = m.header
|
||||||
|
vel_m.ns = 'dyn_obs_vel'
|
||||||
|
vel_m.id = tr.track_id
|
||||||
|
vel_m.type = Marker.ARROW
|
||||||
|
vel_m.action = Marker.ADD
|
||||||
|
from geometry_msgs.msg import Point as GPoint
|
||||||
|
p_start = GPoint(x=px, y=py, z=1.0)
|
||||||
|
p_end = GPoint(x=px + vx, y=py + vy, z=1.0)
|
||||||
|
vel_m.points = [p_start, p_end]
|
||||||
|
vel_m.scale.x = 0.05
|
||||||
|
vel_m.scale.y = 0.10
|
||||||
|
vel_m.color = ColorRGBA(r=1.0, g=1.0, b=0.0, a=0.9)
|
||||||
|
vel_m.lifetime.sec = 1
|
||||||
|
markers.markers.append(vel_m)
|
||||||
|
|
||||||
|
# Line strip for predicted path
|
||||||
|
path_m = Marker()
|
||||||
|
path_m.header = m.header
|
||||||
|
path_m.ns = 'dyn_obs_path'
|
||||||
|
path_m.id = tr.track_id
|
||||||
|
path_m.type = Marker.LINE_STRIP
|
||||||
|
path_m.action = Marker.ADD
|
||||||
|
path_m.scale.x = 0.04
|
||||||
|
path_m.color = ColorRGBA(r=1.0, g=0.5, b=0.0, a=0.5)
|
||||||
|
path_m.lifetime.sec = 1
|
||||||
|
path_m.points.append(GPoint(x=px, y=py, z=0.5))
|
||||||
|
for fx, fy, _ in tr.kalman.predict_horizon(self._horizon_s, self._pred_step):
|
||||||
|
path_m.points.append(GPoint(x=fx, y=fy, z=0.5))
|
||||||
|
markers.markers.append(path_m)
|
||||||
|
|
||||||
|
self._viz_pub.publish(markers)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = DynamicObsNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -0,0 +1,132 @@
|
|||||||
|
"""
|
||||||
|
kalman_tracker.py — Single-object Kalman filter for 2-D ground-plane tracking.
|
||||||
|
|
||||||
|
State vector: x = [px, py, vx, vy] (position + velocity)
|
||||||
|
Motion model: constant-velocity (CV) with process noise on acceleration
|
||||||
|
|
||||||
|
Predict step:
|
||||||
|
F = | 1 0 dt 0 | x_{k|k-1} = F @ x_{k-1|k-1}
|
||||||
|
| 0 1 0 dt | P_{k|k-1} = F @ P @ F^T + Q
|
||||||
|
| 0 0 1 0 |
|
||||||
|
| 0 0 0 1 |
|
||||||
|
|
||||||
|
Update step (position observation only):
|
||||||
|
H = | 1 0 0 0 | y = z - H @ x
|
||||||
|
| 0 1 0 0 | S = H @ P @ H^T + R
|
||||||
|
K = P @ H^T @ inv(S)
|
||||||
|
x = x + K @ y
|
||||||
|
P = (I - K @ H) @ P (Joseph form for stability)
|
||||||
|
|
||||||
|
Trajectory prediction: unrolls the CV model forward at fixed time steps.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# ── Default noise matrices ────────────────────────────────────────────────────
|
||||||
|
# Process noise: models uncertainty in acceleration between frames
|
||||||
|
_Q_BASE = np.diag([0.02, 0.02, 0.8, 0.8]).astype(np.float64)
|
||||||
|
|
||||||
|
# Measurement noise: LIDAR centroid uncertainty (~0.15 m std)
|
||||||
|
_R = np.diag([0.025, 0.025]).astype(np.float64) # 0.16 m sigma each axis
|
||||||
|
|
||||||
|
# Observation matrix
|
||||||
|
_H = np.array([[1, 0, 0, 0],
|
||||||
|
[0, 1, 0, 0]], dtype=np.float64)
|
||||||
|
_I4 = np.eye(4, dtype=np.float64)
|
||||||
|
|
||||||
|
|
||||||
|
class KalmanTracker:
|
||||||
|
"""
|
||||||
|
Kalman filter tracking one object.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x0, y0 : initial position (metres, odom frame)
|
||||||
|
process_noise : scalar multiplier on _Q_BASE
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, x0: float, y0: float, process_noise: float = 1.0):
|
||||||
|
self._x = np.array([x0, y0, 0.0, 0.0], dtype=np.float64)
|
||||||
|
self._P = np.eye(4, dtype=np.float64) * 0.5
|
||||||
|
self._Q = _Q_BASE * process_noise
|
||||||
|
|
||||||
|
# ── Core filter ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def predict(self, dt: float) -> None:
|
||||||
|
"""Propagate state by dt seconds."""
|
||||||
|
F = np.array([
|
||||||
|
[1, 0, dt, 0],
|
||||||
|
[0, 1, 0, dt],
|
||||||
|
[0, 0, 1, 0],
|
||||||
|
[0, 0, 0, 1],
|
||||||
|
], dtype=np.float64)
|
||||||
|
self._x = F @ self._x
|
||||||
|
self._P = F @ self._P @ F.T + self._Q
|
||||||
|
|
||||||
|
def update(self, z: np.ndarray) -> None:
|
||||||
|
"""
|
||||||
|
Incorporate a position measurement z = [x, y].
|
||||||
|
Uses Joseph-form covariance update for numerical stability.
|
||||||
|
"""
|
||||||
|
y = z.astype(np.float64) - _H @ self._x
|
||||||
|
S = _H @ self._P @ _H.T + _R
|
||||||
|
K = self._P @ _H.T @ np.linalg.inv(S)
|
||||||
|
self._x = self._x + K @ y
|
||||||
|
IKH = _I4 - K @ _H
|
||||||
|
# Joseph form: (I-KH) P (I-KH)^T + K R K^T
|
||||||
|
self._P = IKH @ self._P @ IKH.T + K @ _R @ K.T
|
||||||
|
|
||||||
|
# ── Prediction horizon ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def predict_horizon(
|
||||||
|
self,
|
||||||
|
horizon_s: float = 2.5,
|
||||||
|
step_s: float = 0.5,
|
||||||
|
) -> List[Tuple[float, float, float]]:
|
||||||
|
"""
|
||||||
|
Return [(x, y, t), ...] at equally-spaced future times.
|
||||||
|
|
||||||
|
Does NOT modify internal filter state.
|
||||||
|
"""
|
||||||
|
predictions: List[Tuple[float, float, float]] = []
|
||||||
|
state = self._x.copy()
|
||||||
|
t = 0.0
|
||||||
|
F_step = np.array([
|
||||||
|
[1, 0, step_s, 0],
|
||||||
|
[0, 1, 0, step_s],
|
||||||
|
[0, 0, 1, 0],
|
||||||
|
[0, 0, 0, 1],
|
||||||
|
], dtype=np.float64)
|
||||||
|
while t < horizon_s - 1e-6:
|
||||||
|
state = F_step @ state
|
||||||
|
t += step_s
|
||||||
|
predictions.append((float(state[0]), float(state[1]), t))
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
# ── Properties ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@property
|
||||||
|
def position(self) -> Tuple[float, float]:
|
||||||
|
return float(self._x[0]), float(self._x[1])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def velocity(self) -> Tuple[float, float]:
|
||||||
|
return float(self._x[2]), float(self._x[3])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def speed(self) -> float:
|
||||||
|
return float(np.hypot(self._x[2], self._x[3]))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def covariance_2x2(self) -> np.ndarray:
|
||||||
|
"""Position covariance (top-left 2×2 of P)."""
|
||||||
|
return self._P[:2, :2].copy()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self) -> np.ndarray:
|
||||||
|
return self._x.copy()
|
||||||
@ -0,0 +1,168 @@
|
|||||||
|
"""
|
||||||
|
object_detector.py — LIDAR-based moving object detector.
|
||||||
|
|
||||||
|
Algorithm:
|
||||||
|
1. Convert each LaserScan to a 2-D occupancy grid (robot-centred, fixed size).
|
||||||
|
2. Maintain a background model via exponential moving average (EMA):
|
||||||
|
bg_t = α * current + (1-α) * bg_{t-1} (only for non-moving cells)
|
||||||
|
3. Foreground = cells whose occupancy significantly exceeds the background.
|
||||||
|
4. Cluster foreground cells with scipy connected-components → Detection list.
|
||||||
|
|
||||||
|
The grid is robot-relative (origin at robot centre) so it naturally tracks
|
||||||
|
the robot's motion without needing TF at this stage. The caller is responsible
|
||||||
|
for transforming detections into a stable frame (odom) before passing to the
|
||||||
|
tracker.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from scipy import ndimage
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Detection:
|
||||||
|
"""A clustered moving foreground blob from one scan."""
|
||||||
|
x: float # centroid x in sensor frame (m)
|
||||||
|
y: float # centroid y in sensor frame (m)
|
||||||
|
size_m2: float # approximate area of the cluster (m²)
|
||||||
|
range_m: float # distance from robot (m)
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectDetector:
|
||||||
|
"""
|
||||||
|
Detects moving objects in consecutive 2-D LIDAR scans.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
grid_radius_m : half-size of the occupancy grid (grid covers ±radius)
|
||||||
|
resolution : metres per cell
|
||||||
|
bg_alpha : EMA update rate for background (small = slow forgetting)
|
||||||
|
motion_thr : occupancy delta above background to count as moving
|
||||||
|
min_cluster : minimum cells to keep a cluster
|
||||||
|
max_cluster : maximum cells before a cluster is considered static wall
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
grid_radius_m: float = 10.0,
|
||||||
|
resolution: float = 0.10,
|
||||||
|
bg_alpha: float = 0.04,
|
||||||
|
motion_thr: float = 0.45,
|
||||||
|
min_cluster: int = 3,
|
||||||
|
max_cluster: int = 200,
|
||||||
|
):
|
||||||
|
cells = int(2 * grid_radius_m / resolution)
|
||||||
|
self._cells = cells
|
||||||
|
self._res = resolution
|
||||||
|
self._origin = -grid_radius_m # world x/y at grid index 0
|
||||||
|
self._bg_alpha = bg_alpha
|
||||||
|
self._motion_thr = motion_thr
|
||||||
|
self._min_clust = min_cluster
|
||||||
|
self._max_clust = max_cluster
|
||||||
|
|
||||||
|
self._bg = np.zeros((cells, cells), dtype=np.float32)
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
ranges: np.ndarray,
|
||||||
|
angle_min: float,
|
||||||
|
angle_increment: float,
|
||||||
|
range_max: float,
|
||||||
|
) -> List[Detection]:
|
||||||
|
"""
|
||||||
|
Process one LaserScan and return detected moving blobs.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
ranges : 1-D array of range readings (metres)
|
||||||
|
angle_min : angle of first beam (radians)
|
||||||
|
angle_increment : angular step between beams (radians)
|
||||||
|
range_max : maximum valid range (metres)
|
||||||
|
"""
|
||||||
|
curr = self._scan_to_grid(ranges, angle_min, angle_increment, range_max)
|
||||||
|
|
||||||
|
if not self._initialized:
|
||||||
|
self._bg = curr.copy()
|
||||||
|
self._initialized = True
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Foreground mask
|
||||||
|
motion_mask = (curr - self._bg) > self._motion_thr
|
||||||
|
|
||||||
|
# Update background only on non-moving cells
|
||||||
|
static = ~motion_mask
|
||||||
|
self._bg[static] = (
|
||||||
|
self._bg[static] * (1.0 - self._bg_alpha)
|
||||||
|
+ curr[static] * self._bg_alpha
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._cluster(motion_mask)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self._bg[:] = 0.0
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
# ── Internals ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _scan_to_grid(
|
||||||
|
self,
|
||||||
|
ranges: np.ndarray,
|
||||||
|
angle_min: float,
|
||||||
|
angle_increment: float,
|
||||||
|
range_max: float,
|
||||||
|
) -> np.ndarray:
|
||||||
|
grid = np.zeros((self._cells, self._cells), dtype=np.float32)
|
||||||
|
n = len(ranges)
|
||||||
|
angles = angle_min + np.arange(n) * angle_increment
|
||||||
|
r = np.asarray(ranges, dtype=np.float32)
|
||||||
|
|
||||||
|
valid = (r > 0.05) & (r < range_max) & np.isfinite(r)
|
||||||
|
r, a = r[valid], angles[valid]
|
||||||
|
|
||||||
|
x = r * np.cos(a)
|
||||||
|
y = r * np.sin(a)
|
||||||
|
|
||||||
|
ix = np.clip(
|
||||||
|
((x - self._origin) / self._res).astype(np.int32), 0, self._cells - 1
|
||||||
|
)
|
||||||
|
iy = np.clip(
|
||||||
|
((y - self._origin) / self._res).astype(np.int32), 0, self._cells - 1
|
||||||
|
)
|
||||||
|
grid[iy, ix] = 1.0
|
||||||
|
return grid
|
||||||
|
|
||||||
|
def _cluster(self, mask: np.ndarray) -> List[Detection]:
|
||||||
|
# Dilate slightly to connect nearby hit cells into one blob
|
||||||
|
struct = ndimage.generate_binary_structure(2, 2)
|
||||||
|
dilated = ndimage.binary_dilation(mask, structure=struct, iterations=1)
|
||||||
|
|
||||||
|
labeled, n_labels = ndimage.label(dilated)
|
||||||
|
detections: List[Detection] = []
|
||||||
|
|
||||||
|
for label_id in range(1, n_labels + 1):
|
||||||
|
coords = np.argwhere(labeled == label_id)
|
||||||
|
n_cells = len(coords)
|
||||||
|
if n_cells < self._min_clust or n_cells > self._max_clust:
|
||||||
|
continue
|
||||||
|
|
||||||
|
ys, xs = coords[:, 0], coords[:, 1]
|
||||||
|
cx_grid = float(np.mean(xs))
|
||||||
|
cy_grid = float(np.mean(ys))
|
||||||
|
cx = cx_grid * self._res + self._origin
|
||||||
|
cy = cy_grid * self._res + self._origin
|
||||||
|
|
||||||
|
detections.append(Detection(
|
||||||
|
x=cx,
|
||||||
|
y=cy,
|
||||||
|
size_m2=n_cells * self._res ** 2,
|
||||||
|
range_m=float(np.hypot(cx, cy)),
|
||||||
|
))
|
||||||
|
|
||||||
|
return detections
|
||||||
@ -0,0 +1,206 @@
|
|||||||
|
"""
|
||||||
|
tracker_manager.py — Multi-object tracker with Hungarian data association.
|
||||||
|
|
||||||
|
Track lifecycle:
|
||||||
|
TENTATIVE → confirmed after `confirm_frames` consecutive hits
|
||||||
|
CONFIRMED → normal tracked state
|
||||||
|
LOST → missed for 1..max_missed frames (still predicts, not published)
|
||||||
|
DEAD → missed > max_missed → removed
|
||||||
|
|
||||||
|
Association:
|
||||||
|
Uses scipy.optimize.linear_sum_assignment (Hungarian algorithm) on a cost
|
||||||
|
matrix of Euclidean distances between predicted track positions and new
|
||||||
|
detections. Assignments with cost > assoc_dist_m are rejected.
|
||||||
|
|
||||||
|
Up to `max_tracks` simultaneous live tracks (tentative + confirmed).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import IntEnum
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from scipy.optimize import linear_sum_assignment
|
||||||
|
|
||||||
|
from .kalman_tracker import KalmanTracker
|
||||||
|
from .object_detector import Detection
|
||||||
|
|
||||||
|
|
||||||
|
class TrackState(IntEnum):
|
||||||
|
TENTATIVE = 0
|
||||||
|
CONFIRMED = 1
|
||||||
|
LOST = 2
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Track:
|
||||||
|
track_id: int
|
||||||
|
kalman: KalmanTracker
|
||||||
|
state: TrackState = TrackState.TENTATIVE
|
||||||
|
hits: int = 1
|
||||||
|
age: int = 1 # frames since creation
|
||||||
|
missed: int = 0 # consecutive missed frames
|
||||||
|
|
||||||
|
|
||||||
|
class TrackerManager:
|
||||||
|
"""
|
||||||
|
Manages a pool of Kalman tracks.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
max_tracks : hard cap on simultaneously alive tracks
|
||||||
|
confirm_frames : hits needed before a track is CONFIRMED
|
||||||
|
max_missed : consecutive missed frames before a track is DEAD
|
||||||
|
assoc_dist_m : max allowed distance (m) for a valid assignment
|
||||||
|
horizon_s : prediction horizon for trajectory output (seconds)
|
||||||
|
pred_step_s : time step between predicted waypoints
|
||||||
|
process_noise : KalmanTracker process-noise multiplier
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_tracks: int = 20,
|
||||||
|
confirm_frames: int = 3,
|
||||||
|
max_missed: int = 6,
|
||||||
|
assoc_dist_m: float = 1.5,
|
||||||
|
horizon_s: float = 2.5,
|
||||||
|
pred_step_s: float = 0.5,
|
||||||
|
process_noise: float = 1.0,
|
||||||
|
):
|
||||||
|
self._max_tracks = max_tracks
|
||||||
|
self._confirm_frames = confirm_frames
|
||||||
|
self._max_missed = max_missed
|
||||||
|
self._assoc_dist = assoc_dist_m
|
||||||
|
self._horizon_s = horizon_s
|
||||||
|
self._pred_step = pred_step_s
|
||||||
|
self._proc_noise = process_noise
|
||||||
|
|
||||||
|
self._tracks: Dict[int, Track] = {}
|
||||||
|
self._next_id: int = 1
|
||||||
|
|
||||||
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def update(self, detections: List[Detection], dt: float) -> List[Track]:
|
||||||
|
"""
|
||||||
|
Process one frame of detections.
|
||||||
|
|
||||||
|
1. Predict all tracks by dt.
|
||||||
|
2. Hungarian assignment of predictions → detections.
|
||||||
|
3. Update matched tracks; mark unmatched tracks as LOST.
|
||||||
|
4. Promote tracks crossing `confirm_frames`.
|
||||||
|
5. Create new tracks for unmatched detections (if room).
|
||||||
|
6. Remove DEAD tracks.
|
||||||
|
|
||||||
|
Returns confirmed tracks only.
|
||||||
|
"""
|
||||||
|
# 1. Predict
|
||||||
|
for tr in self._tracks.values():
|
||||||
|
tr.kalman.predict(dt)
|
||||||
|
tr.age += 1
|
||||||
|
|
||||||
|
# 2. Assign
|
||||||
|
matched, unmatched_tracks, unmatched_dets = self._assign(detections)
|
||||||
|
|
||||||
|
# 3a. Update matched
|
||||||
|
for tid, did in matched:
|
||||||
|
tr = self._tracks[tid]
|
||||||
|
det = detections[did]
|
||||||
|
tr.kalman.update(np.array([det.x, det.y]))
|
||||||
|
tr.hits += 1
|
||||||
|
tr.missed = 0
|
||||||
|
if tr.state == TrackState.LOST:
|
||||||
|
tr.state = TrackState.CONFIRMED
|
||||||
|
elif tr.state == TrackState.TENTATIVE and tr.hits >= self._confirm_frames:
|
||||||
|
tr.state = TrackState.CONFIRMED
|
||||||
|
|
||||||
|
# 3b. Unmatched tracks
|
||||||
|
for tid in unmatched_tracks:
|
||||||
|
tr = self._tracks[tid]
|
||||||
|
tr.missed += 1
|
||||||
|
if tr.missed > 1:
|
||||||
|
tr.state = TrackState.LOST
|
||||||
|
|
||||||
|
# 4. New tracks for unmatched detections
|
||||||
|
live = sum(1 for t in self._tracks.values() if t.state != TrackState.LOST
|
||||||
|
or t.missed <= self._max_missed)
|
||||||
|
for did in unmatched_dets:
|
||||||
|
if live >= self._max_tracks:
|
||||||
|
break
|
||||||
|
det = detections[did]
|
||||||
|
init_state = (TrackState.CONFIRMED
|
||||||
|
if self._confirm_frames <= 1
|
||||||
|
else TrackState.TENTATIVE)
|
||||||
|
new_tr = Track(
|
||||||
|
track_id=self._next_id,
|
||||||
|
kalman=KalmanTracker(det.x, det.y, self._proc_noise),
|
||||||
|
state=init_state,
|
||||||
|
)
|
||||||
|
self._tracks[self._next_id] = new_tr
|
||||||
|
self._next_id += 1
|
||||||
|
live += 1
|
||||||
|
|
||||||
|
# 5. Prune dead
|
||||||
|
dead = [tid for tid, t in self._tracks.items() if t.missed > self._max_missed]
|
||||||
|
for tid in dead:
|
||||||
|
del self._tracks[tid]
|
||||||
|
|
||||||
|
return [t for t in self._tracks.values() if t.state == TrackState.CONFIRMED]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_tracks(self) -> List[Track]:
|
||||||
|
return list(self._tracks.values())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tentative_count(self) -> int:
|
||||||
|
return sum(1 for t in self._tracks.values()
|
||||||
|
if t.state == TrackState.TENTATIVE)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self._tracks.clear()
|
||||||
|
self._next_id = 1
|
||||||
|
|
||||||
|
# ── Hungarian assignment ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _assign(
|
||||||
|
self,
|
||||||
|
detections: List[Detection],
|
||||||
|
) -> Tuple[List[Tuple[int, int]], List[int], List[int]]:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
matched — list of (track_id, det_index)
|
||||||
|
unmatched_tids — track IDs with no detection assigned
|
||||||
|
unmatched_dids — detection indices with no track assigned
|
||||||
|
"""
|
||||||
|
track_ids = list(self._tracks.keys())
|
||||||
|
|
||||||
|
if not track_ids or not detections:
|
||||||
|
return [], track_ids, list(range(len(detections)))
|
||||||
|
|
||||||
|
# Build cost matrix: rows=tracks, cols=detections
|
||||||
|
cost = np.full((len(track_ids), len(detections)), fill_value=np.inf)
|
||||||
|
for r, tid in enumerate(track_ids):
|
||||||
|
tx, ty = self._tracks[tid].kalman.position
|
||||||
|
for c, det in enumerate(detections):
|
||||||
|
cost[r, c] = np.hypot(tx - det.x, ty - det.y)
|
||||||
|
|
||||||
|
row_ind, col_ind = linear_sum_assignment(cost)
|
||||||
|
|
||||||
|
matched: List[Tuple[int, int]] = []
|
||||||
|
matched_track_idx: set = set()
|
||||||
|
matched_det_idx: set = set()
|
||||||
|
|
||||||
|
for r, c in zip(row_ind, col_ind):
|
||||||
|
if cost[r, c] > self._assoc_dist:
|
||||||
|
continue
|
||||||
|
matched.append((track_ids[r], c))
|
||||||
|
matched_track_idx.add(r)
|
||||||
|
matched_det_idx.add(c)
|
||||||
|
|
||||||
|
unmatched_tids = [track_ids[r] for r in range(len(track_ids))
|
||||||
|
if r not in matched_track_idx]
|
||||||
|
unmatched_dids = [c for c in range(len(detections))
|
||||||
|
if c not in matched_det_idx]
|
||||||
|
|
||||||
|
return matched, unmatched_tids, unmatched_dids
|
||||||
4
jetson/ros2_ws/src/saltybot_dynamic_obstacles/setup.cfg
Normal file
4
jetson/ros2_ws/src/saltybot_dynamic_obstacles/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
[develop]
|
||||||
|
script_dir=$base/lib/saltybot_dynamic_obstacles
|
||||||
|
[install]
|
||||||
|
install_scripts=$base/lib/saltybot_dynamic_obstacles
|
||||||
32
jetson/ros2_ws/src/saltybot_dynamic_obstacles/setup.py
Normal file
32
jetson/ros2_ws/src/saltybot_dynamic_obstacles/setup.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from setuptools import setup, find_packages
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
package_name = 'saltybot_dynamic_obstacles'
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name=package_name,
|
||||||
|
version='0.1.0',
|
||||||
|
packages=find_packages(exclude=['test']),
|
||||||
|
data_files=[
|
||||||
|
('share/ament_index/resource_index/packages',
|
||||||
|
['resource/' + package_name]),
|
||||||
|
('share/' + package_name, ['package.xml']),
|
||||||
|
('share/' + package_name + '/launch',
|
||||||
|
glob('launch/*.launch.py')),
|
||||||
|
('share/' + package_name + '/config',
|
||||||
|
glob('config/*.yaml')),
|
||||||
|
],
|
||||||
|
install_requires=['setuptools'],
|
||||||
|
zip_safe=True,
|
||||||
|
maintainer='SaltyLab',
|
||||||
|
maintainer_email='robot@saltylab.local',
|
||||||
|
description='Dynamic obstacle tracking: LIDAR motion detection, Kalman tracking, Nav2 costmap',
|
||||||
|
license='MIT',
|
||||||
|
tests_require=['pytest'],
|
||||||
|
entry_points={
|
||||||
|
'console_scripts': [
|
||||||
|
'dynamic_obs_tracker = saltybot_dynamic_obstacles.dynamic_obs_node:main',
|
||||||
|
'dynamic_obs_costmap = saltybot_dynamic_obstacles.costmap_layer_node:main',
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
@ -0,0 +1,262 @@
|
|||||||
|
"""
|
||||||
|
test_dynamic_obstacles.py — Unit tests for KalmanTracker, TrackerManager,
|
||||||
|
and ObjectDetector.
|
||||||
|
|
||||||
|
Runs without ROS2 / hardware (no rclpy imports).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||||
|
|
||||||
|
from saltybot_dynamic_obstacles.kalman_tracker import KalmanTracker
|
||||||
|
from saltybot_dynamic_obstacles.tracker_manager import TrackerManager, TrackState
|
||||||
|
from saltybot_dynamic_obstacles.object_detector import ObjectDetector, Detection
|
||||||
|
|
||||||
|
|
||||||
|
# ── KalmanTracker ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestKalmanTracker:
|
||||||
|
|
||||||
|
def test_initial_position(self):
|
||||||
|
kt = KalmanTracker(3.0, 4.0)
|
||||||
|
px, py = kt.position
|
||||||
|
assert px == pytest.approx(3.0)
|
||||||
|
assert py == pytest.approx(4.0)
|
||||||
|
|
||||||
|
def test_initial_velocity_zero(self):
|
||||||
|
kt = KalmanTracker(0.0, 0.0)
|
||||||
|
vx, vy = kt.velocity
|
||||||
|
assert vx == pytest.approx(0.0)
|
||||||
|
assert vy == pytest.approx(0.0)
|
||||||
|
|
||||||
|
def test_predict_moves_position(self):
|
||||||
|
kt = KalmanTracker(0.0, 0.0)
|
||||||
|
# Give it some velocity via update sequence
|
||||||
|
kt.update(np.array([0.1, 0.0]))
|
||||||
|
kt.update(np.array([0.2, 0.0]))
|
||||||
|
kt.predict(0.1)
|
||||||
|
px, _ = kt.position
|
||||||
|
assert px > 0.0 # should have moved forward
|
||||||
|
|
||||||
|
def test_pure_predict_constant_velocity(self):
|
||||||
|
"""After velocity is established, predict() should move linearly."""
|
||||||
|
kt = KalmanTracker(0.0, 0.0)
|
||||||
|
# Force velocity by repeated updates
|
||||||
|
for i in range(10):
|
||||||
|
kt.update(np.array([i * 0.1, 0.0]))
|
||||||
|
kt.predict(0.1)
|
||||||
|
vx, _ = kt.velocity
|
||||||
|
px0, _ = kt.position
|
||||||
|
kt.predict(1.0)
|
||||||
|
px1, _ = kt.position
|
||||||
|
# Should advance roughly vx * 1.0 metres
|
||||||
|
assert px1 == pytest.approx(px0 + vx * 1.0, abs=0.3)
|
||||||
|
|
||||||
|
def test_update_corrects_position(self):
|
||||||
|
kt = KalmanTracker(0.0, 0.0)
|
||||||
|
# Predict way off
|
||||||
|
kt.predict(10.0)
|
||||||
|
# Then update to ground truth
|
||||||
|
kt.update(np.array([1.0, 2.0]))
|
||||||
|
px, py = kt.position
|
||||||
|
# Should move toward (1, 2)
|
||||||
|
assert px == pytest.approx(1.0, abs=0.5)
|
||||||
|
assert py == pytest.approx(2.0, abs=0.5)
|
||||||
|
|
||||||
|
def test_predict_horizon_length(self):
|
||||||
|
kt = KalmanTracker(0.0, 0.0)
|
||||||
|
preds = kt.predict_horizon(horizon_s=2.5, step_s=0.5)
|
||||||
|
assert len(preds) == 5 # 0.5, 1.0, 1.5, 2.0, 2.5
|
||||||
|
|
||||||
|
def test_predict_horizon_times(self):
|
||||||
|
kt = KalmanTracker(0.0, 0.0)
|
||||||
|
preds = kt.predict_horizon(horizon_s=2.0, step_s=0.5)
|
||||||
|
times = [t for _, _, t in preds]
|
||||||
|
assert times == pytest.approx([0.5, 1.0, 1.5, 2.0], abs=0.01)
|
||||||
|
|
||||||
|
def test_predict_horizon_does_not_mutate_state(self):
|
||||||
|
kt = KalmanTracker(1.0, 2.0)
|
||||||
|
kt.predict_horizon(horizon_s=3.0, step_s=0.5)
|
||||||
|
px, py = kt.position
|
||||||
|
assert px == pytest.approx(1.0)
|
||||||
|
assert py == pytest.approx(2.0)
|
||||||
|
|
||||||
|
def test_speed_zero_at_init(self):
|
||||||
|
kt = KalmanTracker(5.0, 5.0)
|
||||||
|
assert kt.speed == pytest.approx(0.0)
|
||||||
|
|
||||||
|
def test_covariance_shape(self):
|
||||||
|
kt = KalmanTracker(0.0, 0.0)
|
||||||
|
cov = kt.covariance_2x2
|
||||||
|
assert cov.shape == (2, 2)
|
||||||
|
|
||||||
|
def test_covariance_positive_definite(self):
|
||||||
|
kt = KalmanTracker(0.0, 0.0)
|
||||||
|
for _ in range(5):
|
||||||
|
kt.predict(0.1)
|
||||||
|
kt.update(np.array([0.1, 0.0]))
|
||||||
|
eigvals = np.linalg.eigvalsh(kt.covariance_2x2)
|
||||||
|
assert np.all(eigvals > 0)
|
||||||
|
|
||||||
|
def test_joseph_form_stays_symmetric(self):
|
||||||
|
"""Covariance should remain symmetric after many updates."""
|
||||||
|
kt = KalmanTracker(0.0, 0.0)
|
||||||
|
for i in range(50):
|
||||||
|
kt.predict(0.1)
|
||||||
|
kt.update(np.array([i * 0.01, 0.0]))
|
||||||
|
P = kt._P
|
||||||
|
assert np.allclose(P, P.T, atol=1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
# ── TrackerManager ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestTrackerManager:
|
||||||
|
|
||||||
|
def _det(self, x, y):
|
||||||
|
return Detection(x=x, y=y, size_m2=0.1, range_m=math.hypot(x, y))
|
||||||
|
|
||||||
|
def test_empty_detections_no_tracks(self):
|
||||||
|
tm = TrackerManager()
|
||||||
|
confirmed = tm.update([], 0.1)
|
||||||
|
assert confirmed == []
|
||||||
|
|
||||||
|
def test_track_created_on_detection(self):
|
||||||
|
tm = TrackerManager(confirm_frames=1)
|
||||||
|
confirmed = tm.update([self._det(1.0, 0.0)], 0.1)
|
||||||
|
assert len(confirmed) == 1
|
||||||
|
|
||||||
|
def test_track_tentative_before_confirm(self):
|
||||||
|
tm = TrackerManager(confirm_frames=3)
|
||||||
|
tm.update([self._det(1.0, 0.0)], 0.1)
|
||||||
|
# Only 1 hit — should still be tentative
|
||||||
|
assert tm.tentative_count == 1
|
||||||
|
|
||||||
|
def test_track_confirmed_after_N_hits(self):
|
||||||
|
tm = TrackerManager(confirm_frames=3, assoc_dist_m=2.0)
|
||||||
|
for _ in range(4):
|
||||||
|
confirmed = tm.update([self._det(1.0, 0.0)], 0.1)
|
||||||
|
assert len(confirmed) == 1
|
||||||
|
|
||||||
|
def test_track_deleted_after_max_missed(self):
|
||||||
|
tm = TrackerManager(confirm_frames=1, max_missed=3, assoc_dist_m=2.0)
|
||||||
|
tm.update([self._det(1.0, 0.0)], 0.1) # create
|
||||||
|
for _ in range(5):
|
||||||
|
tm.update([], 0.1) # no detections → missed++
|
||||||
|
assert len(tm.all_tracks) == 0
|
||||||
|
|
||||||
|
def test_max_tracks_cap(self):
|
||||||
|
tm = TrackerManager(max_tracks=5, confirm_frames=1)
|
||||||
|
dets = [self._det(float(i), 0.0) for i in range(10)]
|
||||||
|
tm.update(dets, 0.1)
|
||||||
|
assert len(tm.all_tracks) <= 5
|
||||||
|
|
||||||
|
def test_consistent_track_id(self):
|
||||||
|
tm = TrackerManager(confirm_frames=3, assoc_dist_m=1.5)
|
||||||
|
for i in range(5):
|
||||||
|
confirmed = tm.update([self._det(1.0 + i * 0.01, 0.0)], 0.1)
|
||||||
|
assert len(confirmed) == 1
|
||||||
|
track_id = confirmed[0].track_id
|
||||||
|
# One more tick — ID should be stable
|
||||||
|
confirmed2 = tm.update([self._det(1.06, 0.0)], 0.1)
|
||||||
|
assert confirmed2[0].track_id == track_id
|
||||||
|
|
||||||
|
def test_two_independent_tracks(self):
|
||||||
|
tm = TrackerManager(confirm_frames=3, assoc_dist_m=0.8)
|
||||||
|
for _ in range(5):
|
||||||
|
confirmed = tm.update([self._det(1.0, 0.0), self._det(5.0, 0.0)], 0.1)
|
||||||
|
assert len(confirmed) == 2
|
||||||
|
|
||||||
|
def test_reset_clears_all(self):
|
||||||
|
tm = TrackerManager(confirm_frames=1)
|
||||||
|
tm.update([self._det(1.0, 0.0)], 0.1)
|
||||||
|
tm.reset()
|
||||||
|
assert len(tm.all_tracks) == 0
|
||||||
|
|
||||||
|
def test_far_detection_not_assigned(self):
|
||||||
|
tm = TrackerManager(confirm_frames=1, assoc_dist_m=0.5)
|
||||||
|
tm.update([self._det(1.0, 0.0)], 0.1) # create track at (1,0)
|
||||||
|
# Detection 3 m away → new track, not update
|
||||||
|
tm.update([self._det(4.0, 0.0)], 0.1)
|
||||||
|
assert len(tm.all_tracks) == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ── ObjectDetector ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestObjectDetector:
|
||||||
|
|
||||||
|
def _empty_scan(self, n=360, rmax=8.0) -> tuple:
|
||||||
|
"""All readings at max range (static background)."""
|
||||||
|
ranges = np.full(n, rmax - 0.1, dtype=np.float32)
|
||||||
|
return ranges, -math.pi, 2 * math.pi / n, rmax
|
||||||
|
|
||||||
|
def _scan_with_blob(self, blob_r=2.0, blob_theta=0.0, n=360, rmax=8.0) -> tuple:
|
||||||
|
"""Background scan + a short-range cluster at blob_theta."""
|
||||||
|
ranges = np.full(n, rmax - 0.1, dtype=np.float32)
|
||||||
|
angle_inc = 2 * math.pi / n
|
||||||
|
angle_min = -math.pi
|
||||||
|
# Put a cluster of ~10 beams at blob_r
|
||||||
|
center_idx = int((blob_theta - angle_min) / angle_inc) % n
|
||||||
|
for di in range(-5, 6):
|
||||||
|
idx = (center_idx + di) % n
|
||||||
|
ranges[idx] = blob_r
|
||||||
|
return ranges, angle_min, angle_inc, rmax
|
||||||
|
|
||||||
|
def test_empty_scan_no_detections_after_warmup(self):
|
||||||
|
od = ObjectDetector()
|
||||||
|
r, a_min, a_inc, rmax = self._empty_scan()
|
||||||
|
od.update(r, a_min, a_inc, rmax) # init background
|
||||||
|
for _ in range(3):
|
||||||
|
dets = od.update(r, a_min, a_inc, rmax)
|
||||||
|
assert len(dets) == 0
|
||||||
|
|
||||||
|
def test_moving_blob_detected(self):
|
||||||
|
od = ObjectDetector()
|
||||||
|
bg_r, a_min, a_inc, rmax = self._empty_scan()
|
||||||
|
od.update(bg_r, a_min, a_inc, rmax) # seed background
|
||||||
|
for _ in range(5):
|
||||||
|
od.update(bg_r, a_min, a_inc, rmax)
|
||||||
|
# Now inject a foreground blob
|
||||||
|
fg_r, _, _, _ = self._scan_with_blob(blob_r=2.0, blob_theta=0.0)
|
||||||
|
dets = od.update(fg_r, a_min, a_inc, rmax)
|
||||||
|
assert len(dets) >= 1
|
||||||
|
|
||||||
|
def test_detection_centroid_approximate(self):
|
||||||
|
od = ObjectDetector()
|
||||||
|
bg_r, a_min, a_inc, rmax = self._empty_scan()
|
||||||
|
for _ in range(8):
|
||||||
|
od.update(bg_r, a_min, a_inc, rmax)
|
||||||
|
fg_r, _, _, _ = self._scan_with_blob(blob_r=3.0, blob_theta=0.0)
|
||||||
|
dets = od.update(fg_r, a_min, a_inc, rmax)
|
||||||
|
assert len(dets) >= 1
|
||||||
|
# Blob is at ~3 m along x-axis (theta=0)
|
||||||
|
cx = dets[0].x
|
||||||
|
cy = dets[0].y
|
||||||
|
assert abs(cx - 3.0) < 0.5
|
||||||
|
assert abs(cy) < 0.5
|
||||||
|
|
||||||
|
def test_reset_clears_background(self):
|
||||||
|
od = ObjectDetector()
|
||||||
|
bg_r, a_min, a_inc, rmax = self._empty_scan()
|
||||||
|
for _ in range(5):
|
||||||
|
od.update(bg_r, a_min, a_inc, rmax)
|
||||||
|
od.reset()
|
||||||
|
assert not od._initialized
|
||||||
|
|
||||||
|
def test_no_inf_nan_ranges(self):
|
||||||
|
od = ObjectDetector()
|
||||||
|
r = np.array([np.inf, np.nan, 5.0, -1.0, 0.0] * 72, dtype=np.float32)
|
||||||
|
a_min = -math.pi
|
||||||
|
a_inc = 2 * math.pi / len(r)
|
||||||
|
od.update(r, a_min, a_inc, 8.0) # should not raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main([__file__, '-v'])
|
||||||
@ -0,0 +1,43 @@
|
|||||||
|
/**:
|
||||||
|
ros__parameters:
|
||||||
|
# Control loop rate (Hz)
|
||||||
|
control_rate: 20.0
|
||||||
|
|
||||||
|
# Odometry topic for stuck detection
|
||||||
|
odom_topic: "/saltybot/rover_odom"
|
||||||
|
|
||||||
|
# ── LaserScan forward sector ───────────────────────────────────────────────
|
||||||
|
forward_scan_angle_rad: 0.785 # ±45° forward sector
|
||||||
|
|
||||||
|
# ── Obstacle proximity ────────────────────────────────────────────────────
|
||||||
|
stop_distance_m: 0.30 # MAJOR threshold (spec: <30 cm)
|
||||||
|
critical_distance_m: 0.10 # CRITICAL threshold
|
||||||
|
min_cmd_speed_ms: 0.05 # ignore obstacle when nearly stopped
|
||||||
|
|
||||||
|
# ── Fall detection (IMU tilt) ─────────────────────────────────────────────
|
||||||
|
minor_tilt_rad: 0.20 # advisory
|
||||||
|
major_tilt_rad: 0.35 # stop + recover
|
||||||
|
critical_tilt_rad: 0.52 # ~30° — full shutdown
|
||||||
|
floor_drop_m: 0.15 # depth discontinuity triggering MAJOR
|
||||||
|
|
||||||
|
# ── Stuck detection ───────────────────────────────────────────────────────
|
||||||
|
stuck_timeout_s: 3.0 # (spec: 3 s wheel stall)
|
||||||
|
|
||||||
|
# ── Bump / jerk detection ─────────────────────────────────────────────────
|
||||||
|
jerk_threshold_ms3: 8.0
|
||||||
|
critical_jerk_threshold_ms3: 25.0
|
||||||
|
|
||||||
|
# ── FSM / recovery ────────────────────────────────────────────────────────
|
||||||
|
stopped_ms: 0.03 # speed below which robot is "stopped" (m/s)
|
||||||
|
major_count_threshold: 3 # MAJOR alerts before escalation to CRITICAL
|
||||||
|
escalation_window_s: 10.0 # sliding window for escalation counter (s)
|
||||||
|
suppression_s: 1.0 # de-bounce period for duplicate alerts (s)
|
||||||
|
|
||||||
|
# Recovery sequence
|
||||||
|
reverse_speed_ms: -0.15 # back-up speed (m/s; must be negative)
|
||||||
|
reverse_distance_m: 0.30 # distance to reverse each cycle (m)
|
||||||
|
angular_speed_rads: 0.60 # turn speed (rad/s)
|
||||||
|
turn_angle_rad: 1.5708 # ~90° turn (rad)
|
||||||
|
retry_timeout_s: 3.0 # time in RETRYING per attempt (s)
|
||||||
|
clear_hold_s: 0.50 # consecutive clear time to declare success (s)
|
||||||
|
max_retries: 3 # maximum reverse+turn attempts before GAVE_UP
|
||||||
@ -0,0 +1,53 @@
|
|||||||
|
"""
|
||||||
|
emergency.launch.py — Launch the emergency behavior system (Issue #169).
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
ros2 launch saltybot_emergency emergency.launch.py
|
||||||
|
ros2 launch saltybot_emergency emergency.launch.py \
|
||||||
|
stop_distance_m:=0.30 max_retries:=3
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from ament_index_python.packages import get_package_share_directory
|
||||||
|
from launch import LaunchDescription
|
||||||
|
from launch.actions import DeclareLaunchArgument
|
||||||
|
from launch.substitutions import LaunchConfiguration
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launch_description():
|
||||||
|
pkg_share = get_package_share_directory("saltybot_emergency")
|
||||||
|
default_params = os.path.join(pkg_share, "config", "emergency_params.yaml")
|
||||||
|
|
||||||
|
return LaunchDescription([
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
"params_file",
|
||||||
|
default_value=default_params,
|
||||||
|
description="Path to emergency_params.yaml",
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
"stop_distance_m",
|
||||||
|
default_value="0.30",
|
||||||
|
description="Obstacle distance triggering MAJOR stop (m)",
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
"max_retries",
|
||||||
|
default_value="3",
|
||||||
|
description="Maximum recovery cycles before ESCALATED",
|
||||||
|
),
|
||||||
|
Node(
|
||||||
|
package="saltybot_emergency",
|
||||||
|
executable="emergency_node",
|
||||||
|
name="emergency",
|
||||||
|
output="screen",
|
||||||
|
parameters=[
|
||||||
|
LaunchConfiguration("params_file"),
|
||||||
|
{
|
||||||
|
"stop_distance_m": LaunchConfiguration("stop_distance_m"),
|
||||||
|
"max_retries": LaunchConfiguration("max_retries"),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
])
|
||||||
24
jetson/ros2_ws/src/saltybot_emergency/package.xml
Normal file
24
jetson/ros2_ws/src/saltybot_emergency/package.xml
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
<?xml version="1.0"?>
|
||||||
|
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||||
|
<package format="3">
|
||||||
|
<name>saltybot_emergency</name>
|
||||||
|
<version>0.1.0</version>
|
||||||
|
<description>Emergency behavior system — collision avoidance, fall prevention, stuck detection, recovery (Issue #169)</description>
|
||||||
|
<maintainer email="sl-controls@saltylab.local">sl-controls</maintainer>
|
||||||
|
<license>MIT</license>
|
||||||
|
|
||||||
|
<depend>rclpy</depend>
|
||||||
|
<depend>sensor_msgs</depend>
|
||||||
|
<depend>nav_msgs</depend>
|
||||||
|
<depend>geometry_msgs</depend>
|
||||||
|
<depend>std_msgs</depend>
|
||||||
|
|
||||||
|
<test_depend>ament_copyright</test_depend>
|
||||||
|
<test_depend>ament_flake8</test_depend>
|
||||||
|
<test_depend>ament_pep257</test_depend>
|
||||||
|
<test_depend>pytest</test_depend>
|
||||||
|
|
||||||
|
<export>
|
||||||
|
<build_type>ament_python</build_type>
|
||||||
|
</export>
|
||||||
|
</package>
|
||||||
@ -0,0 +1,139 @@
|
|||||||
|
"""
|
||||||
|
alert_manager.py — Alert severity escalation for emergency behavior (Issue #169).
|
||||||
|
|
||||||
|
Alert levels
|
||||||
|
────────────
|
||||||
|
NONE : no action
|
||||||
|
MINOR : advisory beep → publish to /saltybot/alert_beep
|
||||||
|
MAJOR : stop + LED flash → publish to /saltybot/alert_flash; cmd_vel override
|
||||||
|
CRITICAL : full shutdown + MQTT → publish to /saltybot/e_stop + /saltybot/critical_alert
|
||||||
|
|
||||||
|
Escalation
|
||||||
|
──────────
|
||||||
|
If major_count_threshold MAJOR alerts occur within escalation_window_s, the
|
||||||
|
next MAJOR is promoted to CRITICAL. This catches persistent stuck / repeated
|
||||||
|
collision scenarios.
|
||||||
|
|
||||||
|
Suppression
|
||||||
|
───────────
|
||||||
|
Identical (type, level) alerts are suppressed within suppression_s to avoid
|
||||||
|
flooding downstream topics.
|
||||||
|
|
||||||
|
Pure module — no ROS2 dependency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from saltybot_emergency.threat_detector import ThreatEvent, ThreatLevel, ThreatType
|
||||||
|
|
||||||
|
|
||||||
|
# ── Alert level ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class AlertLevel(Enum):
|
||||||
|
NONE = 0
|
||||||
|
MINOR = 1 # beep
|
||||||
|
MAJOR = 2 # stop + flash
|
||||||
|
CRITICAL = 3 # shutdown + MQTT
|
||||||
|
|
||||||
|
|
||||||
|
# ── Alert ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Alert:
|
||||||
|
level: AlertLevel
|
||||||
|
source: str # ThreatType value string
|
||||||
|
message: str
|
||||||
|
timestamp_s: float
|
||||||
|
|
||||||
|
|
||||||
|
# ── AlertManager ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class AlertManager:
|
||||||
|
"""
|
||||||
|
Converts ThreatEvents to Alerts with escalation and suppression logic.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
major_count_threshold : number of MAJOR alerts within window to escalate
|
||||||
|
escalation_window_s : sliding window for escalation counting (s)
|
||||||
|
suppression_s : suppress duplicate (type, level) alerts within this period
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
major_count_threshold: int = 3,
|
||||||
|
escalation_window_s: float = 10.0,
|
||||||
|
suppression_s: float = 1.0,
|
||||||
|
):
|
||||||
|
self._major_threshold = max(1, int(major_count_threshold))
|
||||||
|
self._esc_window = float(escalation_window_s)
|
||||||
|
self._suppress = float(suppression_s)
|
||||||
|
self._major_times: deque = deque() # timestamps of MAJOR alerts
|
||||||
|
self._last_seen: dict = {} # (type, level) → timestamp
|
||||||
|
|
||||||
|
# ── Update ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def update(self, threat: ThreatEvent) -> Optional[Alert]:
|
||||||
|
"""
|
||||||
|
Convert one ThreatEvent to an Alert, applying escalation and suppression.
|
||||||
|
|
||||||
|
Returns None if threat is CLEAR or the alert is suppressed.
|
||||||
|
"""
|
||||||
|
if threat.level == ThreatLevel.CLEAR:
|
||||||
|
return None
|
||||||
|
|
||||||
|
now = threat.timestamp_s
|
||||||
|
alert_level = _threat_to_alert(threat.level)
|
||||||
|
|
||||||
|
# ── Suppression ───────────────────────────────────────────────────────
|
||||||
|
key = (threat.threat_type, threat.level)
|
||||||
|
last = self._last_seen.get(key)
|
||||||
|
if last is not None and (now - last) < self._suppress:
|
||||||
|
return None
|
||||||
|
self._last_seen[key] = now
|
||||||
|
|
||||||
|
# ── Escalation ────────────────────────────────────────────────────────
|
||||||
|
if alert_level == AlertLevel.MAJOR:
|
||||||
|
# Prune old timestamps
|
||||||
|
while self._major_times and (now - self._major_times[0]) > self._esc_window:
|
||||||
|
self._major_times.popleft()
|
||||||
|
self._major_times.append(now)
|
||||||
|
if len(self._major_times) >= self._major_threshold:
|
||||||
|
alert_level = AlertLevel.CRITICAL
|
||||||
|
|
||||||
|
msg = _build_message(alert_level, threat)
|
||||||
|
return Alert(
|
||||||
|
level=alert_level,
|
||||||
|
source=threat.threat_type.value,
|
||||||
|
message=msg,
|
||||||
|
timestamp_s=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Clear escalation history and suppression state."""
|
||||||
|
self._major_times.clear()
|
||||||
|
self._last_seen.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _threat_to_alert(level: ThreatLevel) -> AlertLevel:
|
||||||
|
return {
|
||||||
|
ThreatLevel.MINOR: AlertLevel.MINOR,
|
||||||
|
ThreatLevel.MAJOR: AlertLevel.MAJOR,
|
||||||
|
ThreatLevel.CRITICAL: AlertLevel.CRITICAL,
|
||||||
|
}.get(level, AlertLevel.NONE)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_message(level: AlertLevel, threat: ThreatEvent) -> str:
|
||||||
|
prefix = {
|
||||||
|
AlertLevel.MINOR: "[MINOR]",
|
||||||
|
AlertLevel.MAJOR: "[MAJOR]",
|
||||||
|
AlertLevel.CRITICAL: "[CRITICAL]",
|
||||||
|
}.get(level, "[?]")
|
||||||
|
return f"{prefix} {threat.threat_type.value}: {threat.detail}"
|
||||||
@ -0,0 +1,232 @@
|
|||||||
|
"""
|
||||||
|
emergency_fsm.py — Master emergency FSM integrating all detectors (Issue #169).
|
||||||
|
|
||||||
|
States
|
||||||
|
──────
|
||||||
|
NOMINAL : normal operation; minor alerts pass through; major/critical → STOPPING.
|
||||||
|
STOPPING : commanding zero velocity until robot speed drops below stopped_ms.
|
||||||
|
Critical threats skip RECOVERING → ESCALATED immediately.
|
||||||
|
RECOVERING : RecoverySequencer executing reverse+turn sequence.
|
||||||
|
Success → NOMINAL; gave-up → ESCALATED.
|
||||||
|
ESCALATED : full stop; critical alert emitted once; stays until acknowledge.
|
||||||
|
|
||||||
|
Alert actions produced by state
|
||||||
|
────────────────────────────────
|
||||||
|
NOMINAL : emit MINOR alert (beep only); no velocity override.
|
||||||
|
STOPPING : suppress nav, publish zero; emit MAJOR alert once.
|
||||||
|
RECOVERING : suppress nav, publish recovery cmds; no new alerts.
|
||||||
|
ESCALATED : suppress nav, publish zero; emit CRITICAL alert once per entry.
|
||||||
|
|
||||||
|
Pure module — no ROS2 dependency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from saltybot_emergency.alert_manager import Alert, AlertLevel, AlertManager
|
||||||
|
from saltybot_emergency.recovery_sequencer import RecoveryInputs, RecoverySequencer
|
||||||
|
from saltybot_emergency.threat_detector import ThreatEvent, ThreatLevel
|
||||||
|
|
||||||
|
|
||||||
|
# ── States ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class EmergencyState(Enum):
|
||||||
|
NOMINAL = "NOMINAL"
|
||||||
|
STOPPING = "STOPPING"
|
||||||
|
RECOVERING = "RECOVERING"
|
||||||
|
ESCALATED = "ESCALATED"
|
||||||
|
|
||||||
|
|
||||||
|
# ── I/O ───────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmergencyInputs:
|
||||||
|
threat: ThreatEvent # highest-severity threat this tick
|
||||||
|
robot_speed_ms: float = 0.0 # actual speed from odometry (m/s)
|
||||||
|
acknowledge: bool = False # operator cleared the escalation
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmergencyOutputs:
|
||||||
|
state: EmergencyState = EmergencyState.NOMINAL
|
||||||
|
cmd_override: bool = False # True = emergency owns cmd_vel
|
||||||
|
cmd_linear: float = 0.0
|
||||||
|
cmd_angular: float = 0.0
|
||||||
|
alert: Optional[Alert] = None
|
||||||
|
e_stop: bool = False # assert /saltybot/e_stop
|
||||||
|
state_changed: bool = False
|
||||||
|
recovery_progress: float = 0.0
|
||||||
|
recovery_retry_count: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── EmergencyFSM ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class EmergencyFSM:
|
||||||
|
"""
|
||||||
|
Master emergency FSM.
|
||||||
|
|
||||||
|
Owns an AlertManager and a RecoverySequencer; coordinates them each tick.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
stopped_ms : speed below which robot is considered stopped (m/s)
|
||||||
|
major_count_threshold : MAJOR events within window before escalation
|
||||||
|
escalation_window_s : sliding window for escalation (s)
|
||||||
|
suppression_s : alert de-bounce period (s)
|
||||||
|
reverse_speed_ms : reverse speed during recovery (m/s)
|
||||||
|
reverse_distance_m : reverse travel per cycle (m)
|
||||||
|
angular_speed_rads : turn speed during recovery (rad/s)
|
||||||
|
turn_angle_rad : turn per cycle (rad)
|
||||||
|
retry_timeout_s : time in RETRYING before next cycle (s)
|
||||||
|
clear_hold_s : clear duration required to declare success (s)
|
||||||
|
max_retries : recovery cycles before GAVE_UP
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stopped_ms: float = 0.03,
|
||||||
|
major_count_threshold: int = 3,
|
||||||
|
escalation_window_s: float = 10.0,
|
||||||
|
suppression_s: float = 1.0,
|
||||||
|
reverse_speed_ms: float = -0.15,
|
||||||
|
reverse_distance_m: float = 0.30,
|
||||||
|
angular_speed_rads: float = 0.60,
|
||||||
|
turn_angle_rad: float = 1.5708,
|
||||||
|
retry_timeout_s: float = 3.0,
|
||||||
|
clear_hold_s: float = 0.5,
|
||||||
|
max_retries: int = 3,
|
||||||
|
):
|
||||||
|
self._stopped_ms = max(0.0, stopped_ms)
|
||||||
|
self._alert_mgr = AlertManager(
|
||||||
|
major_count_threshold=major_count_threshold,
|
||||||
|
escalation_window_s=escalation_window_s,
|
||||||
|
suppression_s=suppression_s,
|
||||||
|
)
|
||||||
|
self._recovery = RecoverySequencer(
|
||||||
|
reverse_speed_ms=reverse_speed_ms,
|
||||||
|
reverse_distance_m=reverse_distance_m,
|
||||||
|
angular_speed_rads=angular_speed_rads,
|
||||||
|
turn_angle_rad=turn_angle_rad,
|
||||||
|
retry_timeout_s=retry_timeout_s,
|
||||||
|
clear_hold_s=clear_hold_s,
|
||||||
|
max_retries=max_retries,
|
||||||
|
)
|
||||||
|
self._state = EmergencyState.NOMINAL
|
||||||
|
self._critical_pending = False # STOPPING → ESCALATED (not RECOVERING)
|
||||||
|
self._escalation_alerted = False # CRITICAL alert emitted once per ESCALATED entry
|
||||||
|
|
||||||
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self) -> EmergencyState:
|
||||||
|
return self._state
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self._state = EmergencyState.NOMINAL
|
||||||
|
self._critical_pending = False
|
||||||
|
self._escalation_alerted = False
|
||||||
|
self._alert_mgr.reset()
|
||||||
|
self._recovery.reset()
|
||||||
|
|
||||||
|
def tick(self, inputs: EmergencyInputs) -> EmergencyOutputs:
|
||||||
|
prev = self._state
|
||||||
|
out = self._step(inputs)
|
||||||
|
out.state = self._state
|
||||||
|
out.state_changed = (self._state != prev)
|
||||||
|
return out
|
||||||
|
|
||||||
|
# ── Step ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _step(self, inp: EmergencyInputs) -> EmergencyOutputs:
|
||||||
|
out = EmergencyOutputs(state=self._state)
|
||||||
|
|
||||||
|
# Run alert manager for this threat
|
||||||
|
alert = self._alert_mgr.update(inp.threat)
|
||||||
|
|
||||||
|
# ── NOMINAL ───────────────────────────────────────────────────────────
|
||||||
|
if self._state == EmergencyState.NOMINAL:
|
||||||
|
if inp.threat.level == ThreatLevel.CRITICAL:
|
||||||
|
self._state = EmergencyState.STOPPING
|
||||||
|
self._critical_pending = True
|
||||||
|
out.alert = alert
|
||||||
|
out.cmd_override = True # start overriding on entry tick
|
||||||
|
elif inp.threat.level == ThreatLevel.MAJOR:
|
||||||
|
self._state = EmergencyState.STOPPING
|
||||||
|
self._critical_pending = False
|
||||||
|
out.alert = alert
|
||||||
|
out.cmd_override = True # start overriding on entry tick
|
||||||
|
elif inp.threat.level == ThreatLevel.MINOR:
|
||||||
|
# Advisory only — no override
|
||||||
|
out.alert = alert
|
||||||
|
|
||||||
|
# ── STOPPING ──────────────────────────────────────────────────────────
|
||||||
|
elif self._state == EmergencyState.STOPPING:
|
||||||
|
out.cmd_override = True
|
||||||
|
out.cmd_linear = 0.0
|
||||||
|
out.cmd_angular = 0.0
|
||||||
|
# Upgrade to critical if new critical arrives
|
||||||
|
if inp.threat.level == ThreatLevel.CRITICAL:
|
||||||
|
self._critical_pending = True
|
||||||
|
if abs(inp.robot_speed_ms) <= self._stopped_ms:
|
||||||
|
if self._critical_pending:
|
||||||
|
self._state = EmergencyState.ESCALATED
|
||||||
|
self._escalation_alerted = False
|
||||||
|
else:
|
||||||
|
self._state = EmergencyState.RECOVERING
|
||||||
|
self._recovery.reset()
|
||||||
|
self._recovery.tick(RecoveryInputs(trigger=True, dt=0.0))
|
||||||
|
|
||||||
|
# ── RECOVERING ────────────────────────────────────────────────────────
|
||||||
|
elif self._state == EmergencyState.RECOVERING:
|
||||||
|
threat_cleared = inp.threat.level == ThreatLevel.CLEAR
|
||||||
|
rec_inp = RecoveryInputs(
|
||||||
|
trigger=False,
|
||||||
|
threat_cleared=threat_cleared,
|
||||||
|
dt=0.02, # nominal dt; node should pass actual dt
|
||||||
|
)
|
||||||
|
rec_out = self._recovery.tick(rec_inp)
|
||||||
|
|
||||||
|
out.cmd_override = True
|
||||||
|
out.cmd_linear = rec_out.cmd_linear
|
||||||
|
out.cmd_angular = rec_out.cmd_angular
|
||||||
|
out.recovery_progress = rec_out.progress
|
||||||
|
out.recovery_retry_count = rec_out.retry_count
|
||||||
|
|
||||||
|
if rec_out.gave_up:
|
||||||
|
self._state = EmergencyState.ESCALATED
|
||||||
|
self._escalation_alerted = False
|
||||||
|
elif rec_out.state.value == "IDLE" and not inp.trigger if hasattr(inp, "trigger") else True:
|
||||||
|
# RecoverySequencer returned to IDLE = success
|
||||||
|
from saltybot_emergency.recovery_sequencer import RecoveryState
|
||||||
|
if self._recovery.state == RecoveryState.IDLE and not rec_out.gave_up:
|
||||||
|
self._state = EmergencyState.NOMINAL
|
||||||
|
self._recovery.reset()
|
||||||
|
|
||||||
|
# ── ESCALATED ─────────────────────────────────────────────────────────
|
||||||
|
elif self._state == EmergencyState.ESCALATED:
|
||||||
|
out.cmd_override = True
|
||||||
|
out.cmd_linear = 0.0
|
||||||
|
out.cmd_angular = 0.0
|
||||||
|
out.e_stop = True
|
||||||
|
if not self._escalation_alerted:
|
||||||
|
# Force a CRITICAL alert regardless of suppression
|
||||||
|
from saltybot_emergency.alert_manager import Alert, AlertLevel
|
||||||
|
out.alert = Alert(
|
||||||
|
level=AlertLevel.CRITICAL,
|
||||||
|
source=inp.threat.threat_type.value,
|
||||||
|
message=f"[CRITICAL] ESCALATED: {inp.threat.detail or 'Recovery gave up'}",
|
||||||
|
timestamp_s=inp.threat.timestamp_s,
|
||||||
|
)
|
||||||
|
self._escalation_alerted = True
|
||||||
|
if inp.acknowledge:
|
||||||
|
self._state = EmergencyState.NOMINAL
|
||||||
|
self._critical_pending = False
|
||||||
|
self._escalation_alerted = False
|
||||||
|
out.e_stop = False
|
||||||
|
self._alert_mgr.reset()
|
||||||
|
self._recovery.reset()
|
||||||
|
|
||||||
|
return out
|
||||||
@ -0,0 +1,383 @@
|
|||||||
|
"""
|
||||||
|
emergency_node.py — Emergency behavior system orchestration (Issue #169).
|
||||||
|
|
||||||
|
Overview
|
||||||
|
────────
|
||||||
|
Aggregates threats from four independent detectors and drives the
|
||||||
|
EmergencyFSM. Overrides /cmd_vel when an emergency is active. Escalates
|
||||||
|
via /saltybot/e_stop and /saltybot/critical_alert for CRITICAL events.
|
||||||
|
|
||||||
|
Pipeline (20 Hz)
|
||||||
|
────────────────
|
||||||
|
1. LaserScan callback → ObstacleDetector → ThreatEvent
|
||||||
|
2. IMU callback → FallDetector + BumpDetector → ThreatEvent (×2)
|
||||||
|
3. Odom callback → StuckDetector (fed in timer) → ThreatEvent
|
||||||
|
4. 20 Hz timer → highest_threat() → EmergencyFSM.tick()
|
||||||
|
→ publish overriding cmd_vel or pass-through
|
||||||
|
→ publish /saltybot/emergency + /saltybot/recovery_action
|
||||||
|
|
||||||
|
Subscribes
|
||||||
|
──────────
|
||||||
|
/scan sensor_msgs/LaserScan — obstacle detection
|
||||||
|
/saltybot/imu sensor_msgs/Imu — fall + bump detection
|
||||||
|
<odom_topic> nav_msgs/Odometry — stuck + speed tracking
|
||||||
|
/cmd_vel geometry_msgs/Twist — nav commands (pass-through)
|
||||||
|
|
||||||
|
Publishes
|
||||||
|
─────────
|
||||||
|
/saltybot/cmd_vel_out geometry_msgs/Twist — muxed cmd_vel (to drive nodes)
|
||||||
|
/saltybot/e_stop std_msgs/Bool — emergency stop flag
|
||||||
|
/saltybot/alert_beep std_msgs/Empty — beep trigger (MINOR)
|
||||||
|
/saltybot/alert_flash std_msgs/Empty — LED flash trigger (MAJOR)
|
||||||
|
/saltybot/critical_alert std_msgs/String (JSON) — CRITICAL event for MQTT bridge
|
||||||
|
/saltybot/emergency saltybot_emergency_msgs/EmergencyEvent
|
||||||
|
/saltybot/recovery_action saltybot_emergency_msgs/RecoveryAction
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
──────────
|
||||||
|
control_rate 20.0 Hz
|
||||||
|
odom_topic /saltybot/rover_odom
|
||||||
|
forward_scan_angle_rad 0.785 rad (±45° forward sector for obstacle check)
|
||||||
|
stop_distance_m 0.30 m
|
||||||
|
critical_distance_m 0.10 m
|
||||||
|
min_cmd_speed_ms 0.05 m/s
|
||||||
|
minor_tilt_rad 0.20 rad
|
||||||
|
major_tilt_rad 0.35 rad
|
||||||
|
critical_tilt_rad 0.52 rad
|
||||||
|
floor_drop_m 0.15 m
|
||||||
|
stuck_timeout_s 3.0 s
|
||||||
|
jerk_threshold_ms3 8.0 m/s³
|
||||||
|
critical_jerk_threshold_ms3 25.0 m/s³
|
||||||
|
stopped_ms 0.03 m/s
|
||||||
|
major_count_threshold 3
|
||||||
|
escalation_window_s 10.0 s
|
||||||
|
suppression_s 1.0 s
|
||||||
|
reverse_speed_ms -0.15 m/s
|
||||||
|
reverse_distance_m 0.30 m
|
||||||
|
angular_speed_rads 0.60 rad/s
|
||||||
|
turn_angle_rad 1.5708 rad
|
||||||
|
retry_timeout_s 3.0 s
|
||||||
|
clear_hold_s 0.50 s
|
||||||
|
max_retries 3
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import HistoryPolicy, QoSProfile, ReliabilityPolicy
|
||||||
|
|
||||||
|
from geometry_msgs.msg import Twist
|
||||||
|
from nav_msgs.msg import Odometry
|
||||||
|
from sensor_msgs.msg import Imu, LaserScan
|
||||||
|
from std_msgs.msg import Bool, Empty, String
|
||||||
|
|
||||||
|
from saltybot_emergency.alert_manager import AlertLevel
|
||||||
|
from saltybot_emergency.emergency_fsm import EmergencyFSM, EmergencyInputs, EmergencyState
|
||||||
|
from saltybot_emergency.threat_detector import (
|
||||||
|
BumpDetector,
|
||||||
|
FallDetector,
|
||||||
|
ObstacleDetector,
|
||||||
|
StuckDetector,
|
||||||
|
ThreatEvent,
|
||||||
|
ThreatType,
|
||||||
|
highest_threat,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from saltybot_emergency_msgs.msg import EmergencyEvent, RecoveryAction
|
||||||
|
_MSGS_OK = True
|
||||||
|
except ImportError:
|
||||||
|
_MSGS_OK = False
|
||||||
|
|
||||||
|
|
||||||
|
def _quaternion_to_pitch_roll(qx, qy, qz, qw):
|
||||||
|
pitch = math.asin(max(-1.0, min(1.0, 2.0 * (qw * qy - qz * qx))))
|
||||||
|
roll = math.atan2(2.0 * (qw * qx + qy * qz), 1.0 - 2.0 * (qx * qx + qy * qy))
|
||||||
|
return pitch, roll
|
||||||
|
|
||||||
|
|
||||||
|
class EmergencyNode(Node):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("emergency")
|
||||||
|
|
||||||
|
self._declare_params()
|
||||||
|
p = self._load_params()
|
||||||
|
|
||||||
|
# ── Detectors ────────────────────────────────────────────────────────
|
||||||
|
self._obstacle = ObstacleDetector(
|
||||||
|
stop_distance_m=p["stop_distance_m"],
|
||||||
|
critical_distance_m=p["critical_distance_m"],
|
||||||
|
min_speed_ms=p["min_cmd_speed_ms"],
|
||||||
|
)
|
||||||
|
self._fall = FallDetector(
|
||||||
|
minor_tilt_rad=p["minor_tilt_rad"],
|
||||||
|
major_tilt_rad=p["major_tilt_rad"],
|
||||||
|
critical_tilt_rad=p["critical_tilt_rad"],
|
||||||
|
floor_drop_m=p["floor_drop_m"],
|
||||||
|
)
|
||||||
|
self._stuck = StuckDetector(
|
||||||
|
stuck_timeout_s=p["stuck_timeout_s"],
|
||||||
|
min_cmd_ms=p["min_cmd_speed_ms"],
|
||||||
|
)
|
||||||
|
self._bump = BumpDetector(
|
||||||
|
jerk_threshold_ms3=p["jerk_threshold_ms3"],
|
||||||
|
critical_jerk_threshold_ms3=p["critical_jerk_threshold_ms3"],
|
||||||
|
)
|
||||||
|
self._fsm = EmergencyFSM(
|
||||||
|
stopped_ms=p["stopped_ms"],
|
||||||
|
major_count_threshold=p["major_count_threshold"],
|
||||||
|
escalation_window_s=p["escalation_window_s"],
|
||||||
|
suppression_s=p["suppression_s"],
|
||||||
|
reverse_speed_ms=p["reverse_speed_ms"],
|
||||||
|
reverse_distance_m=p["reverse_distance_m"],
|
||||||
|
angular_speed_rads=p["angular_speed_rads"],
|
||||||
|
turn_angle_rad=p["turn_angle_rad"],
|
||||||
|
retry_timeout_s=p["retry_timeout_s"],
|
||||||
|
clear_hold_s=p["clear_hold_s"],
|
||||||
|
max_retries=p["max_retries"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── State ────────────────────────────────────────────────────────────
|
||||||
|
self._latest_obstacle_threat = ThreatEvent()
|
||||||
|
self._latest_fall_threat = ThreatEvent()
|
||||||
|
self._latest_bump_threat = ThreatEvent()
|
||||||
|
self._cmd_speed_ms = 0.0
|
||||||
|
self._actual_speed_ms = 0.0
|
||||||
|
self._last_ctrl_t = time.monotonic()
|
||||||
|
self._scan_forward_angle = p["forward_scan_angle_rad"]
|
||||||
|
self._acknowledge_flag = False
|
||||||
|
|
||||||
|
# ── QoS ──────────────────────────────────────────────────────────────
|
||||||
|
reliable = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.RELIABLE,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=10,
|
||||||
|
)
|
||||||
|
best_effort = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Subscriptions ────────────────────────────────────────────────────
|
||||||
|
self.create_subscription(LaserScan, "/scan", self._scan_cb, best_effort)
|
||||||
|
self.create_subscription(Imu, "/saltybot/imu", self._imu_cb, best_effort)
|
||||||
|
self.create_subscription(Odometry, p["odom_topic"], self._odom_cb, reliable)
|
||||||
|
self.create_subscription(Twist, "/cmd_vel", self._cmd_vel_cb, reliable)
|
||||||
|
self.create_subscription(Bool, "/saltybot/emergency_ack", self._ack_cb, reliable)
|
||||||
|
|
||||||
|
# ── Publishers ───────────────────────────────────────────────────────
|
||||||
|
self._cmd_out_pub = self.create_publisher(Twist, "/saltybot/cmd_vel_out", reliable)
|
||||||
|
self._estop_pub = self.create_publisher(Bool, "/saltybot/e_stop", reliable)
|
||||||
|
self._beep_pub = self.create_publisher(Empty, "/saltybot/alert_beep", reliable)
|
||||||
|
self._flash_pub = self.create_publisher(Empty, "/saltybot/alert_flash", reliable)
|
||||||
|
self._critical_pub = self.create_publisher(String, "/saltybot/critical_alert", reliable)
|
||||||
|
|
||||||
|
self._event_pub = None
|
||||||
|
self._recovery_pub = None
|
||||||
|
if _MSGS_OK:
|
||||||
|
self._event_pub = self.create_publisher(EmergencyEvent, "/saltybot/emergency", reliable)
|
||||||
|
self._recovery_pub = self.create_publisher(RecoveryAction, "/saltybot/recovery_action", reliable)
|
||||||
|
|
||||||
|
# ── Timer ────────────────────────────────────────────────────────────
|
||||||
|
rate = p["control_rate"]
|
||||||
|
self._timer = self.create_timer(1.0 / rate, self._control_cb)
|
||||||
|
self.get_logger().info(f"EmergencyNode ready rate={rate}Hz")
|
||||||
|
|
||||||
|
# ── Parameters ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _declare_params(self) -> None:
|
||||||
|
self.declare_parameter("control_rate", 20.0)
|
||||||
|
self.declare_parameter("odom_topic", "/saltybot/rover_odom")
|
||||||
|
self.declare_parameter("forward_scan_angle_rad", 0.785)
|
||||||
|
self.declare_parameter("stop_distance_m", 0.30)
|
||||||
|
self.declare_parameter("critical_distance_m", 0.10)
|
||||||
|
self.declare_parameter("min_cmd_speed_ms", 0.05)
|
||||||
|
self.declare_parameter("minor_tilt_rad", 0.20)
|
||||||
|
self.declare_parameter("major_tilt_rad", 0.35)
|
||||||
|
self.declare_parameter("critical_tilt_rad", 0.52)
|
||||||
|
self.declare_parameter("floor_drop_m", 0.15)
|
||||||
|
self.declare_parameter("stuck_timeout_s", 3.0)
|
||||||
|
self.declare_parameter("jerk_threshold_ms3", 8.0)
|
||||||
|
self.declare_parameter("critical_jerk_threshold_ms3", 25.0)
|
||||||
|
self.declare_parameter("stopped_ms", 0.03)
|
||||||
|
self.declare_parameter("major_count_threshold", 3)
|
||||||
|
self.declare_parameter("escalation_window_s", 10.0)
|
||||||
|
self.declare_parameter("suppression_s", 1.0)
|
||||||
|
self.declare_parameter("reverse_speed_ms", -0.15)
|
||||||
|
self.declare_parameter("reverse_distance_m", 0.30)
|
||||||
|
self.declare_parameter("angular_speed_rads", 0.60)
|
||||||
|
self.declare_parameter("turn_angle_rad", 1.5708)
|
||||||
|
self.declare_parameter("retry_timeout_s", 3.0)
|
||||||
|
self.declare_parameter("clear_hold_s", 0.50)
|
||||||
|
self.declare_parameter("max_retries", 3)
|
||||||
|
|
||||||
|
def _load_params(self) -> dict:
|
||||||
|
g = self.get_parameter
|
||||||
|
return {k: g(k).value for k in [
|
||||||
|
"control_rate", "odom_topic",
|
||||||
|
"forward_scan_angle_rad",
|
||||||
|
"stop_distance_m", "critical_distance_m", "min_cmd_speed_ms",
|
||||||
|
"minor_tilt_rad", "major_tilt_rad", "critical_tilt_rad", "floor_drop_m",
|
||||||
|
"stuck_timeout_s", "jerk_threshold_ms3", "critical_jerk_threshold_ms3",
|
||||||
|
"stopped_ms",
|
||||||
|
"major_count_threshold", "escalation_window_s", "suppression_s",
|
||||||
|
"reverse_speed_ms", "reverse_distance_m",
|
||||||
|
"angular_speed_rads", "turn_angle_rad",
|
||||||
|
"retry_timeout_s", "clear_hold_s", "max_retries",
|
||||||
|
]}
|
||||||
|
|
||||||
|
# ── Callbacks ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _scan_cb(self, msg: LaserScan) -> None:
|
||||||
|
# Extract minimum range within forward sector (±forward_scan_angle_rad)
|
||||||
|
half = self._scan_forward_angle
|
||||||
|
ranges = []
|
||||||
|
for i, r in enumerate(msg.ranges):
|
||||||
|
angle = msg.angle_min + i * msg.angle_increment
|
||||||
|
if abs(angle) <= half and msg.range_min < r < msg.range_max:
|
||||||
|
ranges.append(r)
|
||||||
|
min_r = min(ranges) if ranges else float("inf")
|
||||||
|
self._latest_obstacle_threat = self._obstacle.update(
|
||||||
|
min_r, self._cmd_speed_ms, time.monotonic()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _imu_cb(self, msg: Imu) -> None:
|
||||||
|
now = time.monotonic()
|
||||||
|
ax = msg.linear_acceleration.x
|
||||||
|
ay = msg.linear_acceleration.y
|
||||||
|
az = msg.linear_acceleration.z
|
||||||
|
pitch, roll = _quaternion_to_pitch_roll(
|
||||||
|
msg.orientation.x, msg.orientation.y,
|
||||||
|
msg.orientation.z, msg.orientation.w,
|
||||||
|
)
|
||||||
|
# dt for jerk is approximated; bump detector handles None on first call
|
||||||
|
dt = 0.02 # nominal 20 Hz
|
||||||
|
self._latest_fall_threat = self._fall.update(pitch, roll, 0.0, now)
|
||||||
|
self._latest_bump_threat = self._bump.update(ax, ay, az, dt, now)
|
||||||
|
|
||||||
|
def _odom_cb(self, msg: Odometry) -> None:
|
||||||
|
self._actual_speed_ms = msg.twist.twist.linear.x
|
||||||
|
|
||||||
|
def _cmd_vel_cb(self, msg: Twist) -> None:
|
||||||
|
self._cmd_speed_ms = msg.linear.x
|
||||||
|
|
||||||
|
def _ack_cb(self, msg: Bool) -> None:
|
||||||
|
if msg.data:
|
||||||
|
self._acknowledge_flag = True
|
||||||
|
|
||||||
|
# ── 20 Hz control loop ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _control_cb(self) -> None:
|
||||||
|
now = time.monotonic()
|
||||||
|
dt = now - self._last_ctrl_t
|
||||||
|
self._last_ctrl_t = now
|
||||||
|
|
||||||
|
stuck_threat = self._stuck.update(
|
||||||
|
self._cmd_speed_ms, self._actual_speed_ms, dt, now
|
||||||
|
)
|
||||||
|
|
||||||
|
threat = highest_threat([
|
||||||
|
self._latest_obstacle_threat,
|
||||||
|
self._latest_fall_threat,
|
||||||
|
self._latest_bump_threat,
|
||||||
|
stuck_threat,
|
||||||
|
])
|
||||||
|
|
||||||
|
inp = EmergencyInputs(
|
||||||
|
threat=threat,
|
||||||
|
robot_speed_ms=self._actual_speed_ms,
|
||||||
|
acknowledge=self._acknowledge_flag,
|
||||||
|
)
|
||||||
|
self._acknowledge_flag = False
|
||||||
|
|
||||||
|
out = self._fsm.tick(inp)
|
||||||
|
|
||||||
|
if out.state_changed:
|
||||||
|
self.get_logger().info(f"Emergency FSM → {out.state.value}")
|
||||||
|
|
||||||
|
# ── Alert dispatch ────────────────────────────────────────────────────
|
||||||
|
if out.alert is not None:
|
||||||
|
lvl = out.alert.level
|
||||||
|
self.get_logger().warn(out.alert.message)
|
||||||
|
if lvl == AlertLevel.MINOR:
|
||||||
|
self._beep_pub.publish(Empty())
|
||||||
|
elif lvl == AlertLevel.MAJOR:
|
||||||
|
self._flash_pub.publish(Empty())
|
||||||
|
elif lvl == AlertLevel.CRITICAL:
|
||||||
|
self._flash_pub.publish(Empty())
|
||||||
|
self._publish_critical_alert(out.alert.message, threat)
|
||||||
|
|
||||||
|
# ── E-stop ───────────────────────────────────────────────────────────
|
||||||
|
estop_msg = Bool()
|
||||||
|
estop_msg.data = out.e_stop
|
||||||
|
self._estop_pub.publish(estop_msg)
|
||||||
|
|
||||||
|
# ── cmd_vel mux ───────────────────────────────────────────────────────
|
||||||
|
twist = Twist()
|
||||||
|
if out.cmd_override:
|
||||||
|
twist.linear.x = out.cmd_linear
|
||||||
|
twist.angular.z = out.cmd_angular
|
||||||
|
else:
|
||||||
|
twist.linear.x = self._cmd_speed_ms
|
||||||
|
self._cmd_out_pub.publish(twist)
|
||||||
|
|
||||||
|
# ── Status topics ─────────────────────────────────────────────────────
|
||||||
|
if self._event_pub is not None:
|
||||||
|
self._publish_event(out, threat)
|
||||||
|
if self._recovery_pub is not None:
|
||||||
|
self._publish_recovery(out)
|
||||||
|
|
||||||
|
# ── Publishers ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _publish_critical_alert(self, message: str, threat: ThreatEvent) -> None:
|
||||||
|
msg = String()
|
||||||
|
msg.data = json.dumps({
|
||||||
|
"severity": "CRITICAL",
|
||||||
|
"threat": threat.threat_type.value,
|
||||||
|
"value": round(threat.value, 3),
|
||||||
|
"detail": threat.detail,
|
||||||
|
"message": message,
|
||||||
|
})
|
||||||
|
self._critical_pub.publish(msg)
|
||||||
|
|
||||||
|
def _publish_event(self, out, threat: ThreatEvent) -> None:
|
||||||
|
msg = EmergencyEvent()
|
||||||
|
msg.stamp = self.get_clock().now().to_msg()
|
||||||
|
msg.state = out.state.value
|
||||||
|
msg.threat_type = threat.threat_type.value
|
||||||
|
msg.severity = threat.level.name
|
||||||
|
msg.threat_value = float(threat.value)
|
||||||
|
msg.detail = threat.detail
|
||||||
|
msg.cmd_override = out.cmd_override
|
||||||
|
self._event_pub.publish(msg)
|
||||||
|
|
||||||
|
def _publish_recovery(self, out) -> None:
|
||||||
|
msg = RecoveryAction()
|
||||||
|
msg.stamp = self.get_clock().now().to_msg()
|
||||||
|
msg.action = self._fsm._recovery.state.value
|
||||||
|
msg.retry_count = out.recovery_retry_count
|
||||||
|
msg.progress = float(out.recovery_progress)
|
||||||
|
self._recovery_pub.publish(msg)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Entry point ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = EmergencyNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.try_shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -0,0 +1,193 @@
|
|||||||
|
"""
|
||||||
|
recovery_sequencer.py — Reverse + turn recovery FSM for emergency behavior (Issue #169).
|
||||||
|
|
||||||
|
Sequence
|
||||||
|
────────
|
||||||
|
IDLE → REVERSING → TURNING → RETRYING → (IDLE on success)
|
||||||
|
→ (REVERSING on re-threat, retry loop)
|
||||||
|
→ (GAVE_UP after max_retries)
|
||||||
|
|
||||||
|
REVERSING : command reverse at reverse_speed_ms until reverse_distance_m covered.
|
||||||
|
TURNING : command angular_speed_rads until turn_angle_rad covered (90°).
|
||||||
|
RETRYING : zero velocity; wait up to retry_timeout_s for threat to clear.
|
||||||
|
If threat clears within clear_hold_s → back to IDLE (success).
|
||||||
|
If timeout without clearance → start another REVERSING cycle.
|
||||||
|
If retry_count >= max_retries → GAVE_UP.
|
||||||
|
|
||||||
|
Pure module — no ROS2 dependency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
# ── States ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class RecoveryState(Enum):
|
||||||
|
IDLE = "IDLE"
|
||||||
|
REVERSING = "REVERSING"
|
||||||
|
TURNING = "TURNING"
|
||||||
|
RETRYING = "RETRYING"
|
||||||
|
GAVE_UP = "GAVE_UP"
|
||||||
|
|
||||||
|
|
||||||
|
# ── I/O ───────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RecoveryInputs:
|
||||||
|
trigger: bool = False # True to start (or restart) recovery
|
||||||
|
threat_cleared: bool = False # True when all threats are CLEAR
|
||||||
|
dt: float = 0.02 # time step (s)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RecoveryOutputs:
|
||||||
|
state: RecoveryState = RecoveryState.IDLE
|
||||||
|
cmd_linear: float = 0.0 # m/s
|
||||||
|
cmd_angular: float = 0.0 # rad/s
|
||||||
|
progress: float = 0.0 # [0, 1] completion of current phase
|
||||||
|
retry_count: int = 0
|
||||||
|
gave_up: bool = False
|
||||||
|
state_changed: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
# ── RecoverySequencer ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class RecoverySequencer:
|
||||||
|
"""
|
||||||
|
Tick-based FSM for executing reverse + turn recovery sequences.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
reverse_speed_ms : backward speed during REVERSING (m/s; stored as negative)
|
||||||
|
reverse_distance_m: total reverse travel before turning (m)
|
||||||
|
angular_speed_rads: yaw rate during TURNING (rad/s; positive = left)
|
||||||
|
turn_angle_rad : total turn angle before RETRYING (rad; default π/2)
|
||||||
|
retry_timeout_s : max RETRYING time per attempt before next reverse cycle
|
||||||
|
clear_hold_s : consecutive clear time needed to declare success
|
||||||
|
max_retries : maximum reverse+turn attempts before GAVE_UP
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
reverse_speed_ms: float = -0.15,
|
||||||
|
reverse_distance_m: float = 0.30,
|
||||||
|
angular_speed_rads: float = 0.60,
|
||||||
|
turn_angle_rad: float = 1.5708, # π/2
|
||||||
|
retry_timeout_s: float = 3.0,
|
||||||
|
clear_hold_s: float = 0.5,
|
||||||
|
max_retries: int = 3,
|
||||||
|
):
|
||||||
|
self._rev_speed = min(0.0, float(reverse_speed_ms)) # ensure negative
|
||||||
|
self._rev_dist = max(0.01, float(reverse_distance_m))
|
||||||
|
self._ang_speed = abs(float(angular_speed_rads))
|
||||||
|
self._turn_angle = max(0.01, float(turn_angle_rad))
|
||||||
|
self._retry_tout = max(0.1, float(retry_timeout_s))
|
||||||
|
self._clear_hold = max(0.0, float(clear_hold_s))
|
||||||
|
self._max_retry = max(1, int(max_retries))
|
||||||
|
|
||||||
|
self._state = RecoveryState.IDLE
|
||||||
|
self._rev_done = 0.0 # distance reversed so far
|
||||||
|
self._turn_done = 0.0 # angle turned so far
|
||||||
|
self._retry_count = 0
|
||||||
|
self._retry_timer = 0.0 # time spent in RETRYING
|
||||||
|
self._clear_timer = 0.0 # consecutive clear time in RETRYING
|
||||||
|
|
||||||
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self) -> RecoveryState:
|
||||||
|
return self._state
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retry_count(self) -> int:
|
||||||
|
return self._retry_count
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Return to IDLE and clear all counters."""
|
||||||
|
self._state = RecoveryState.IDLE
|
||||||
|
self._rev_done = 0.0
|
||||||
|
self._turn_done = 0.0
|
||||||
|
self._retry_count = 0
|
||||||
|
self._retry_timer = 0.0
|
||||||
|
self._clear_timer = 0.0
|
||||||
|
|
||||||
|
def tick(self, inputs: RecoveryInputs) -> RecoveryOutputs:
|
||||||
|
prev = self._state
|
||||||
|
out = self._step(inputs)
|
||||||
|
out.state = self._state
|
||||||
|
out.retry_count = self._retry_count
|
||||||
|
out.state_changed = (self._state != prev)
|
||||||
|
if out.state_changed:
|
||||||
|
self._on_enter(self._state)
|
||||||
|
return out
|
||||||
|
|
||||||
|
# ── Internal step ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _step(self, inp: RecoveryInputs) -> RecoveryOutputs:
|
||||||
|
out = RecoveryOutputs(state=self._state)
|
||||||
|
dt = max(0.0, inp.dt)
|
||||||
|
|
||||||
|
# ── IDLE ──────────────────────────────────────────────────────────────
|
||||||
|
if self._state == RecoveryState.IDLE:
|
||||||
|
if inp.trigger:
|
||||||
|
self._state = RecoveryState.REVERSING
|
||||||
|
|
||||||
|
# ── REVERSING ─────────────────────────────────────────────────────────
|
||||||
|
elif self._state == RecoveryState.REVERSING:
|
||||||
|
step = abs(self._rev_speed) * dt
|
||||||
|
self._rev_done += step
|
||||||
|
out.cmd_linear = self._rev_speed
|
||||||
|
out.progress = min(1.0, self._rev_done / self._rev_dist)
|
||||||
|
if self._rev_done >= self._rev_dist:
|
||||||
|
self._state = RecoveryState.TURNING
|
||||||
|
|
||||||
|
# ── TURNING ───────────────────────────────────────────────────────────
|
||||||
|
elif self._state == RecoveryState.TURNING:
|
||||||
|
step = self._ang_speed * dt
|
||||||
|
self._turn_done += step
|
||||||
|
out.cmd_angular = self._ang_speed
|
||||||
|
out.progress = min(1.0, self._turn_done / self._turn_angle)
|
||||||
|
if self._turn_done >= self._turn_angle:
|
||||||
|
self._retry_count += 1
|
||||||
|
self._state = RecoveryState.RETRYING
|
||||||
|
|
||||||
|
# ── RETRYING ──────────────────────────────────────────────────────────
|
||||||
|
elif self._state == RecoveryState.RETRYING:
|
||||||
|
self._retry_timer += dt
|
||||||
|
|
||||||
|
if inp.threat_cleared:
|
||||||
|
self._clear_timer += dt
|
||||||
|
if self._clear_timer >= self._clear_hold:
|
||||||
|
# Success — threat has cleared
|
||||||
|
self._state = RecoveryState.IDLE
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
self._clear_timer = 0.0
|
||||||
|
|
||||||
|
if self._retry_timer >= self._retry_tout:
|
||||||
|
if self._retry_count >= self._max_retry:
|
||||||
|
self._state = RecoveryState.GAVE_UP
|
||||||
|
out.gave_up = True
|
||||||
|
else:
|
||||||
|
self._state = RecoveryState.REVERSING
|
||||||
|
|
||||||
|
# ── GAVE_UP ───────────────────────────────────────────────────────────
|
||||||
|
elif self._state == RecoveryState.GAVE_UP:
|
||||||
|
# Stay in GAVE_UP until external reset()
|
||||||
|
out.gave_up = True
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
# ── Entry actions ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _on_enter(self, state: RecoveryState) -> None:
|
||||||
|
if state == RecoveryState.REVERSING:
|
||||||
|
self._rev_done = 0.0
|
||||||
|
elif state == RecoveryState.TURNING:
|
||||||
|
self._turn_done = 0.0
|
||||||
|
elif state == RecoveryState.RETRYING:
|
||||||
|
self._retry_timer = 0.0
|
||||||
|
self._clear_timer = 0.0
|
||||||
@ -0,0 +1,354 @@
|
|||||||
|
"""
|
||||||
|
threat_detector.py — Multi-source threat detection for emergency behavior (Issue #169).
|
||||||
|
|
||||||
|
Detectors
|
||||||
|
─────────
|
||||||
|
ObstacleDetector : Forward-sector minimum range < stop thresholds at speed.
|
||||||
|
Inputs: min_range_m (pre-filtered from LaserScan forward
|
||||||
|
sector), cmd_speed_ms.
|
||||||
|
|
||||||
|
FallDetector : Extreme pitch/roll from IMU, or depth floor-drop ahead.
|
||||||
|
Inputs: pitch_rad, roll_rad, floor_drop_m (depth-derived;
|
||||||
|
0.0 if depth unavailable).
|
||||||
|
|
||||||
|
StuckDetector : Commanded speed vs actual speed mismatch sustained for
|
||||||
|
stuck_timeout_s. Tracks elapsed time with dt argument.
|
||||||
|
|
||||||
|
BumpDetector : IMU acceleration jerk (|Δ|a||/dt) above threshold.
|
||||||
|
MAJOR at jerk_threshold_ms3, CRITICAL at
|
||||||
|
critical_jerk_threshold_ms3.
|
||||||
|
|
||||||
|
ThreatLevel
|
||||||
|
───────────
|
||||||
|
CLEAR : no threat; normal operation
|
||||||
|
MINOR : advisory; log/beep only
|
||||||
|
MAJOR : stop and execute recovery
|
||||||
|
CRITICAL : full shutdown + MQTT escalation
|
||||||
|
|
||||||
|
Pure module — no ROS2 dependency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
# ── Enumerations ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class ThreatLevel(Enum):
|
||||||
|
CLEAR = 0
|
||||||
|
MINOR = 1
|
||||||
|
MAJOR = 2
|
||||||
|
CRITICAL = 3
|
||||||
|
|
||||||
|
|
||||||
|
class ThreatType(Enum):
|
||||||
|
NONE = "NONE"
|
||||||
|
OBSTACLE_PROXIMITY = "OBSTACLE_PROXIMITY"
|
||||||
|
FALL_RISK = "FALL_RISK"
|
||||||
|
WHEEL_STUCK = "WHEEL_STUCK"
|
||||||
|
BUMP = "BUMP"
|
||||||
|
|
||||||
|
|
||||||
|
# ── ThreatEvent ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ThreatEvent:
|
||||||
|
"""Snapshot of a single detected threat."""
|
||||||
|
threat_type: ThreatType = ThreatType.NONE
|
||||||
|
level: ThreatLevel = ThreatLevel.CLEAR
|
||||||
|
value: float = 0.0 # triggering metric
|
||||||
|
detail: str = ""
|
||||||
|
timestamp_s: float = 0.0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def clear(timestamp_s: float = 0.0) -> "ThreatEvent":
|
||||||
|
return ThreatEvent(timestamp_s=timestamp_s)
|
||||||
|
|
||||||
|
|
||||||
|
_CLEAR = ThreatEvent()
|
||||||
|
|
||||||
|
|
||||||
|
# ── ObstacleDetector ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class ObstacleDetector:
|
||||||
|
"""
|
||||||
|
Obstacle proximity threat from forward-sector minimum range.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
stop_distance_m : range below which MAJOR is raised (default 0.30 m)
|
||||||
|
critical_distance_m : range below which CRITICAL is raised (default 0.10 m)
|
||||||
|
min_speed_ms : only active above this commanded speed (default 0.05 m/s)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stop_distance_m: float = 0.30,
|
||||||
|
critical_distance_m: float = 0.10,
|
||||||
|
min_speed_ms: float = 0.05,
|
||||||
|
):
|
||||||
|
self._stop = max(1e-3, stop_distance_m)
|
||||||
|
self._critical = max(1e-3, min(self._stop, critical_distance_m))
|
||||||
|
self._min_spd = abs(min_speed_ms)
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
min_range_m: float,
|
||||||
|
cmd_speed_ms: float,
|
||||||
|
timestamp_s: float = 0.0,
|
||||||
|
) -> ThreatEvent:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
min_range_m : minimum obstacle range in forward sector (m)
|
||||||
|
cmd_speed_ms : signed commanded forward speed (m/s)
|
||||||
|
"""
|
||||||
|
if abs(cmd_speed_ms) < self._min_spd:
|
||||||
|
return ThreatEvent.clear(timestamp_s)
|
||||||
|
|
||||||
|
if min_range_m <= self._critical:
|
||||||
|
return ThreatEvent(
|
||||||
|
threat_type=ThreatType.OBSTACLE_PROXIMITY,
|
||||||
|
level=ThreatLevel.CRITICAL,
|
||||||
|
value=min_range_m,
|
||||||
|
detail=f"Obstacle {min_range_m:.2f} m (critical zone)",
|
||||||
|
timestamp_s=timestamp_s,
|
||||||
|
)
|
||||||
|
if min_range_m <= self._stop:
|
||||||
|
return ThreatEvent(
|
||||||
|
threat_type=ThreatType.OBSTACLE_PROXIMITY,
|
||||||
|
level=ThreatLevel.MAJOR,
|
||||||
|
value=min_range_m,
|
||||||
|
detail=f"Obstacle {min_range_m:.2f} m ahead",
|
||||||
|
timestamp_s=timestamp_s,
|
||||||
|
)
|
||||||
|
return ThreatEvent.clear(timestamp_s)
|
||||||
|
|
||||||
|
|
||||||
|
# ── FallDetector ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class FallDetector:
|
||||||
|
"""
|
||||||
|
Fall / tipping risk from IMU pitch/roll and optional depth floor-drop.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
minor_tilt_rad : |pitch| or |roll| above which MINOR fires (default 0.20 rad)
|
||||||
|
major_tilt_rad : above which MAJOR fires (default 0.35 rad)
|
||||||
|
critical_tilt_rad : above which CRITICAL fires (default 0.52 rad ≈ 30°)
|
||||||
|
floor_drop_m : depth discontinuity (m) triggering MAJOR (default 0.15 m)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
minor_tilt_rad: float = 0.20,
|
||||||
|
major_tilt_rad: float = 0.35,
|
||||||
|
critical_tilt_rad: float = 0.52,
|
||||||
|
floor_drop_m: float = 0.15,
|
||||||
|
):
|
||||||
|
self._minor = float(minor_tilt_rad)
|
||||||
|
self._major = float(major_tilt_rad)
|
||||||
|
self._critical = float(critical_tilt_rad)
|
||||||
|
self._drop = float(floor_drop_m)
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
pitch_rad: float,
|
||||||
|
roll_rad: float,
|
||||||
|
floor_drop_m: float = 0.0,
|
||||||
|
timestamp_s: float = 0.0,
|
||||||
|
) -> ThreatEvent:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
pitch_rad : forward tilt (rad); +ve = nose up
|
||||||
|
roll_rad : lateral tilt (rad); +ve = left side up
|
||||||
|
floor_drop_m : depth discontinuity ahead of robot (m); 0 = not measured
|
||||||
|
"""
|
||||||
|
tilt = max(abs(pitch_rad), abs(roll_rad))
|
||||||
|
|
||||||
|
if tilt >= self._critical:
|
||||||
|
return ThreatEvent(
|
||||||
|
threat_type=ThreatType.FALL_RISK,
|
||||||
|
level=ThreatLevel.CRITICAL,
|
||||||
|
value=tilt,
|
||||||
|
detail=f"Critical tilt {math.degrees(tilt):.1f}°",
|
||||||
|
timestamp_s=timestamp_s,
|
||||||
|
)
|
||||||
|
if tilt >= self._major or floor_drop_m >= self._drop:
|
||||||
|
value = max(tilt, floor_drop_m)
|
||||||
|
detail = (
|
||||||
|
f"Floor drop {floor_drop_m:.2f} m" if floor_drop_m >= self._drop
|
||||||
|
else f"Major tilt {math.degrees(tilt):.1f}°"
|
||||||
|
)
|
||||||
|
return ThreatEvent(
|
||||||
|
threat_type=ThreatType.FALL_RISK,
|
||||||
|
level=ThreatLevel.MAJOR,
|
||||||
|
value=value,
|
||||||
|
detail=detail,
|
||||||
|
timestamp_s=timestamp_s,
|
||||||
|
)
|
||||||
|
if tilt >= self._minor:
|
||||||
|
return ThreatEvent(
|
||||||
|
threat_type=ThreatType.FALL_RISK,
|
||||||
|
level=ThreatLevel.MINOR,
|
||||||
|
value=tilt,
|
||||||
|
detail=f"Tilt advisory {math.degrees(tilt):.1f}°",
|
||||||
|
timestamp_s=timestamp_s,
|
||||||
|
)
|
||||||
|
return ThreatEvent.clear(timestamp_s)
|
||||||
|
|
||||||
|
|
||||||
|
# ── StuckDetector ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class StuckDetector:
|
||||||
|
"""
|
||||||
|
Wheel stall / stuck detection from cmd_vel vs odometry mismatch.
|
||||||
|
|
||||||
|
Accumulates stuck time while |cmd| > min_cmd_ms AND |actual| < moving_ms.
|
||||||
|
Resets when motion resumes or commanded speed drops below min_cmd_ms.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
stuck_timeout_s : accumulated stuck time before MAJOR fires (default 3.0 s)
|
||||||
|
min_cmd_ms : minimum commanded speed to consider stalling (0.05 m/s)
|
||||||
|
moving_threshold_ms : actual speed above which robot is considered moving
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stuck_timeout_s: float = 3.0,
|
||||||
|
min_cmd_ms: float = 0.05,
|
||||||
|
moving_threshold_ms: float = 0.05,
|
||||||
|
):
|
||||||
|
self._timeout = max(0.1, stuck_timeout_s)
|
||||||
|
self._min_cmd = abs(min_cmd_ms)
|
||||||
|
self._moving = abs(moving_threshold_ms)
|
||||||
|
self._stuck_time: float = 0.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stuck_time(self) -> float:
|
||||||
|
"""Accumulated stuck duration (s)."""
|
||||||
|
return self._stuck_time
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
cmd_speed_ms: float,
|
||||||
|
actual_speed_ms: float,
|
||||||
|
dt: float,
|
||||||
|
timestamp_s: float = 0.0,
|
||||||
|
) -> ThreatEvent:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
cmd_speed_ms : commanded forward speed (m/s)
|
||||||
|
actual_speed_ms : measured forward speed from odometry (m/s)
|
||||||
|
dt : elapsed time since last call (s)
|
||||||
|
"""
|
||||||
|
commanding = abs(cmd_speed_ms) >= self._min_cmd
|
||||||
|
moving = abs(actual_speed_ms) >= self._moving
|
||||||
|
|
||||||
|
if not commanding or moving:
|
||||||
|
self._stuck_time = 0.0
|
||||||
|
return ThreatEvent.clear(timestamp_s)
|
||||||
|
|
||||||
|
self._stuck_time += max(0.0, dt)
|
||||||
|
|
||||||
|
if self._stuck_time >= self._timeout:
|
||||||
|
return ThreatEvent(
|
||||||
|
threat_type=ThreatType.WHEEL_STUCK,
|
||||||
|
level=ThreatLevel.MAJOR,
|
||||||
|
value=self._stuck_time,
|
||||||
|
detail=f"Wheels stuck for {self._stuck_time:.1f} s",
|
||||||
|
timestamp_s=timestamp_s,
|
||||||
|
)
|
||||||
|
return ThreatEvent.clear(timestamp_s)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self._stuck_time = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ── BumpDetector ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class BumpDetector:
|
||||||
|
"""
|
||||||
|
Collision / bump detection via IMU acceleration jerk.
|
||||||
|
|
||||||
|
Jerk = |Δ|a|| / dt where |a| = sqrt(ax²+ay²+az²) − g (gravity removed)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
jerk_threshold_ms3 : MAJOR at jerk above this (m/s³, default 8.0)
|
||||||
|
critical_jerk_threshold_ms3 : CRITICAL at jerk above this (m/s³, default 25.0)
|
||||||
|
gravity_ms2 : gravity magnitude to subtract (default 9.81)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
jerk_threshold_ms3: float = 8.0,
|
||||||
|
critical_jerk_threshold_ms3: float = 25.0,
|
||||||
|
gravity_ms2: float = 9.81,
|
||||||
|
):
|
||||||
|
self._jerk_major = float(jerk_threshold_ms3)
|
||||||
|
self._jerk_critical = float(critical_jerk_threshold_ms3)
|
||||||
|
self._gravity = float(gravity_ms2)
|
||||||
|
self._prev_dyn_mag: Optional[float] = None # previous |a_dynamic|
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
ax: float,
|
||||||
|
ay: float,
|
||||||
|
az: float,
|
||||||
|
dt: float,
|
||||||
|
timestamp_s: float = 0.0,
|
||||||
|
) -> ThreatEvent:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
ax, ay, az : IMU linear acceleration (m/s²)
|
||||||
|
dt : elapsed time since last call (s)
|
||||||
|
"""
|
||||||
|
raw_mag = math.sqrt(ax * ax + ay * ay + az * az)
|
||||||
|
dyn_mag = abs(raw_mag - self._gravity) # remove gravity component
|
||||||
|
|
||||||
|
if self._prev_dyn_mag is None or dt <= 0.0:
|
||||||
|
self._prev_dyn_mag = dyn_mag
|
||||||
|
return ThreatEvent.clear(timestamp_s)
|
||||||
|
|
||||||
|
jerk = abs(dyn_mag - self._prev_dyn_mag) / dt
|
||||||
|
self._prev_dyn_mag = dyn_mag
|
||||||
|
|
||||||
|
if jerk >= self._jerk_critical:
|
||||||
|
return ThreatEvent(
|
||||||
|
threat_type=ThreatType.BUMP,
|
||||||
|
level=ThreatLevel.CRITICAL,
|
||||||
|
value=jerk,
|
||||||
|
detail=f"Critical impact: jerk {jerk:.1f} m/s³",
|
||||||
|
timestamp_s=timestamp_s,
|
||||||
|
)
|
||||||
|
if jerk >= self._jerk_major:
|
||||||
|
return ThreatEvent(
|
||||||
|
threat_type=ThreatType.BUMP,
|
||||||
|
level=ThreatLevel.MAJOR,
|
||||||
|
value=jerk,
|
||||||
|
detail=f"Bump detected: jerk {jerk:.1f} m/s³",
|
||||||
|
timestamp_s=timestamp_s,
|
||||||
|
)
|
||||||
|
return ThreatEvent.clear(timestamp_s)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self._prev_dyn_mag = None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Utility: pick highest-severity threat ─────────────────────────────────────
|
||||||
|
|
||||||
|
def highest_threat(events: list[ThreatEvent]) -> ThreatEvent:
|
||||||
|
"""Return the ThreatEvent with the highest ThreatLevel from a list."""
|
||||||
|
if not events:
|
||||||
|
return _CLEAR
|
||||||
|
return max(events, key=lambda e: e.level.value)
|
||||||
5
jetson/ros2_ws/src/saltybot_emergency/setup.cfg
Normal file
5
jetson/ros2_ws/src/saltybot_emergency/setup.cfg
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
[develop]
|
||||||
|
script_dir=$base/lib/saltybot_emergency
|
||||||
|
|
||||||
|
[install]
|
||||||
|
install_scripts=$base/lib/saltybot_emergency
|
||||||
32
jetson/ros2_ws/src/saltybot_emergency/setup.py
Normal file
32
jetson/ros2_ws/src/saltybot_emergency/setup.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from setuptools import setup, find_packages
|
||||||
|
import os
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
package_name = "saltybot_emergency"
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name=package_name,
|
||||||
|
version="0.1.0",
|
||||||
|
packages=find_packages(exclude=["test"]),
|
||||||
|
data_files=[
|
||||||
|
("share/ament_index/resource_index/packages",
|
||||||
|
[f"resource/{package_name}"]),
|
||||||
|
(f"share/{package_name}", ["package.xml"]),
|
||||||
|
(os.path.join("share", package_name, "config"),
|
||||||
|
glob("config/*.yaml")),
|
||||||
|
(os.path.join("share", package_name, "launch"),
|
||||||
|
glob("launch/*.py")),
|
||||||
|
],
|
||||||
|
install_requires=["setuptools"],
|
||||||
|
zip_safe=True,
|
||||||
|
maintainer="sl-controls",
|
||||||
|
maintainer_email="sl-controls@saltylab.local",
|
||||||
|
description="Emergency behavior system — collision avoidance, fall prevention, stuck detection, recovery",
|
||||||
|
license="MIT",
|
||||||
|
tests_require=["pytest"],
|
||||||
|
entry_points={
|
||||||
|
"console_scripts": [
|
||||||
|
f"emergency_node = {package_name}.emergency_node:main",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
560
jetson/ros2_ws/src/saltybot_emergency/test/test_emergency.py
Normal file
560
jetson/ros2_ws/src/saltybot_emergency/test/test_emergency.py
Normal file
@ -0,0 +1,560 @@
|
|||||||
|
"""
|
||||||
|
test_emergency.py — Unit tests for Issue #169 emergency behavior modules.
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
ObstacleDetector — proximity thresholds, speed gate
|
||||||
|
FallDetector — tilt levels, floor drop
|
||||||
|
StuckDetector — timeout accumulation, reset on motion
|
||||||
|
BumpDetector — jerk thresholds, first-call safety
|
||||||
|
AlertManager — severity mapping, escalation, suppression
|
||||||
|
RecoverySequencer — full sequence, retry, give-up
|
||||||
|
EmergencyFSM — all state transitions and guard conditions
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from saltybot_emergency.threat_detector import (
|
||||||
|
BumpDetector,
|
||||||
|
FallDetector,
|
||||||
|
ObstacleDetector,
|
||||||
|
StuckDetector,
|
||||||
|
ThreatEvent,
|
||||||
|
ThreatLevel,
|
||||||
|
ThreatType,
|
||||||
|
highest_threat,
|
||||||
|
)
|
||||||
|
from saltybot_emergency.alert_manager import Alert, AlertLevel, AlertManager
|
||||||
|
from saltybot_emergency.recovery_sequencer import (
|
||||||
|
RecoveryInputs,
|
||||||
|
RecoverySequencer,
|
||||||
|
RecoveryState,
|
||||||
|
)
|
||||||
|
from saltybot_emergency.emergency_fsm import (
|
||||||
|
EmergencyFSM,
|
||||||
|
EmergencyInputs,
|
||||||
|
EmergencyState,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _obs(**kw):
|
||||||
|
d = dict(stop_distance_m=0.30, critical_distance_m=0.10, min_speed_ms=0.05)
|
||||||
|
d.update(kw)
|
||||||
|
return ObstacleDetector(**d)
|
||||||
|
|
||||||
|
def _fall(**kw):
|
||||||
|
d = dict(minor_tilt_rad=0.20, major_tilt_rad=0.35,
|
||||||
|
critical_tilt_rad=0.52, floor_drop_m=0.15)
|
||||||
|
d.update(kw)
|
||||||
|
return FallDetector(**d)
|
||||||
|
|
||||||
|
def _stuck(**kw):
|
||||||
|
d = dict(stuck_timeout_s=3.0, min_cmd_ms=0.05, moving_threshold_ms=0.05)
|
||||||
|
d.update(kw)
|
||||||
|
return StuckDetector(**d)
|
||||||
|
|
||||||
|
def _bump(**kw):
|
||||||
|
d = dict(jerk_threshold_ms3=8.0, critical_jerk_threshold_ms3=25.0)
|
||||||
|
d.update(kw)
|
||||||
|
return BumpDetector(**d)
|
||||||
|
|
||||||
|
def _alert_mgr(**kw):
|
||||||
|
d = dict(major_count_threshold=3, escalation_window_s=10.0, suppression_s=1.0)
|
||||||
|
d.update(kw)
|
||||||
|
return AlertManager(**d)
|
||||||
|
|
||||||
|
def _seq(**kw):
|
||||||
|
d = dict(
|
||||||
|
reverse_speed_ms=-0.15, reverse_distance_m=0.30,
|
||||||
|
angular_speed_rads=0.60, turn_angle_rad=1.5708,
|
||||||
|
retry_timeout_s=3.0, clear_hold_s=0.5, max_retries=3,
|
||||||
|
)
|
||||||
|
d.update(kw)
|
||||||
|
return RecoverySequencer(**d)
|
||||||
|
|
||||||
|
def _fsm(**kw):
|
||||||
|
d = dict(
|
||||||
|
stopped_ms=0.03, major_count_threshold=3, escalation_window_s=10.0,
|
||||||
|
suppression_s=0.0, # disable suppression for cleaner tests
|
||||||
|
reverse_speed_ms=-0.15, reverse_distance_m=0.30,
|
||||||
|
angular_speed_rads=0.60, turn_angle_rad=1.5708,
|
||||||
|
retry_timeout_s=3.0, clear_hold_s=0.5, max_retries=3,
|
||||||
|
)
|
||||||
|
d.update(kw)
|
||||||
|
return EmergencyFSM(**d)
|
||||||
|
|
||||||
|
def _major_threat(threat_type=ThreatType.OBSTACLE_PROXIMITY, ts=0.0):
|
||||||
|
return ThreatEvent(threat_type=threat_type, level=ThreatLevel.MAJOR,
|
||||||
|
value=1.0, detail="test", timestamp_s=ts)
|
||||||
|
|
||||||
|
def _critical_threat(ts=0.0):
|
||||||
|
return ThreatEvent(threat_type=ThreatType.OBSTACLE_PROXIMITY,
|
||||||
|
level=ThreatLevel.CRITICAL, value=0.05,
|
||||||
|
detail="critical test", timestamp_s=ts)
|
||||||
|
|
||||||
|
def _minor_threat(ts=0.0):
|
||||||
|
return ThreatEvent(threat_type=ThreatType.FALL_RISK, level=ThreatLevel.MINOR,
|
||||||
|
value=0.21, detail="tilt", timestamp_s=ts)
|
||||||
|
|
||||||
|
def _clear_threat():
|
||||||
|
return ThreatEvent.clear()
|
||||||
|
|
||||||
|
def _inp(threat=None, speed=0.0, ack=False):
|
||||||
|
return EmergencyInputs(
|
||||||
|
threat=threat or _clear_threat(),
|
||||||
|
robot_speed_ms=speed,
|
||||||
|
acknowledge=ack,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
# ObstacleDetector
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestObstacleDetector:
|
||||||
|
|
||||||
|
def test_clear_when_far(self):
|
||||||
|
ev = _obs().update(0.5, 0.3)
|
||||||
|
assert ev.level == ThreatLevel.CLEAR
|
||||||
|
|
||||||
|
def test_major_within_stop_distance(self):
|
||||||
|
ev = _obs(stop_distance_m=0.30).update(0.25, 0.3)
|
||||||
|
assert ev.level == ThreatLevel.MAJOR
|
||||||
|
assert ev.threat_type == ThreatType.OBSTACLE_PROXIMITY
|
||||||
|
|
||||||
|
def test_critical_within_critical_distance(self):
|
||||||
|
ev = _obs(critical_distance_m=0.10).update(0.05, 0.3)
|
||||||
|
assert ev.level == ThreatLevel.CRITICAL
|
||||||
|
|
||||||
|
def test_clear_when_stopped(self):
|
||||||
|
"""Obstacle detection suppressed when not moving."""
|
||||||
|
ev = _obs(min_speed_ms=0.05).update(0.05, 0.01)
|
||||||
|
assert ev.level == ThreatLevel.CLEAR
|
||||||
|
|
||||||
|
def test_active_at_min_speed(self):
|
||||||
|
ev = _obs(min_speed_ms=0.05).update(0.20, 0.06)
|
||||||
|
assert ev.level == ThreatLevel.MAJOR
|
||||||
|
|
||||||
|
def test_value_is_distance(self):
|
||||||
|
ev = _obs().update(0.20, 0.3)
|
||||||
|
assert ev.value == pytest.approx(0.20, abs=1e-9)
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
# FallDetector
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestFallDetector:
|
||||||
|
|
||||||
|
def test_clear_on_flat(self):
|
||||||
|
ev = _fall().update(0.0, 0.0)
|
||||||
|
assert ev.level == ThreatLevel.CLEAR
|
||||||
|
|
||||||
|
def test_minor_moderate_tilt(self):
|
||||||
|
ev = _fall(minor_tilt_rad=0.20, major_tilt_rad=0.35).update(0.25, 0.0)
|
||||||
|
assert ev.level == ThreatLevel.MINOR
|
||||||
|
|
||||||
|
def test_major_high_tilt(self):
|
||||||
|
ev = _fall(major_tilt_rad=0.35, critical_tilt_rad=0.52).update(0.40, 0.0)
|
||||||
|
assert ev.level == ThreatLevel.MAJOR
|
||||||
|
|
||||||
|
def test_critical_extreme_tilt(self):
|
||||||
|
ev = _fall(critical_tilt_rad=0.52).update(0.60, 0.0)
|
||||||
|
assert ev.level == ThreatLevel.CRITICAL
|
||||||
|
|
||||||
|
def test_major_on_floor_drop(self):
|
||||||
|
ev = _fall(floor_drop_m=0.15).update(0.0, 0.0, floor_drop_m=0.20)
|
||||||
|
assert ev.level == ThreatLevel.MAJOR
|
||||||
|
assert "drop" in ev.detail.lower()
|
||||||
|
|
||||||
|
def test_roll_triggers_same_as_pitch(self):
|
||||||
|
"""Roll beyond minor threshold also fires."""
|
||||||
|
ev = _fall(minor_tilt_rad=0.20).update(0.0, 0.25)
|
||||||
|
assert ev.level == ThreatLevel.MINOR
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
# StuckDetector
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestStuckDetector:
|
||||||
|
|
||||||
|
def test_clear_when_not_commanded(self):
|
||||||
|
s = _stuck(stuck_timeout_s=1.0, min_cmd_ms=0.05)
|
||||||
|
ev = s.update(0.01, 0.0, dt=1.0) # cmd below threshold
|
||||||
|
assert ev.level == ThreatLevel.CLEAR
|
||||||
|
|
||||||
|
def test_clear_when_moving(self):
|
||||||
|
s = _stuck(stuck_timeout_s=1.0)
|
||||||
|
ev = s.update(0.2, 0.2, dt=1.0) # actually moving
|
||||||
|
assert ev.level == ThreatLevel.CLEAR
|
||||||
|
|
||||||
|
def test_major_after_timeout(self):
|
||||||
|
s = _stuck(stuck_timeout_s=3.0, min_cmd_ms=0.05, moving_threshold_ms=0.05)
|
||||||
|
for _ in range(6):
|
||||||
|
ev = s.update(0.2, 0.0, dt=0.5) # cmd=0.2, actual=0 → stuck
|
||||||
|
assert ev.level == ThreatLevel.MAJOR
|
||||||
|
|
||||||
|
def test_no_major_before_timeout(self):
|
||||||
|
s = _stuck(stuck_timeout_s=3.0)
|
||||||
|
ev = s.update(0.2, 0.0, dt=1.0) # only 1s — not yet
|
||||||
|
assert ev.level == ThreatLevel.CLEAR
|
||||||
|
|
||||||
|
def test_reset_on_motion_resume(self):
|
||||||
|
s = _stuck(stuck_timeout_s=1.0)
|
||||||
|
s.update(0.2, 0.0, dt=0.8) # accumulate stuck time
|
||||||
|
s.update(0.2, 0.3, dt=0.1) # motion resumes → reset
|
||||||
|
ev = s.update(0.2, 0.0, dt=0.3) # only 0.3s since reset → still clear
|
||||||
|
assert ev.level == ThreatLevel.CLEAR
|
||||||
|
|
||||||
|
def test_stuck_time_property(self):
|
||||||
|
s = _stuck(stuck_timeout_s=5.0)
|
||||||
|
s.update(0.2, 0.0, dt=1.0)
|
||||||
|
s.update(0.2, 0.0, dt=1.0)
|
||||||
|
assert s.stuck_time == pytest.approx(2.0, abs=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
# BumpDetector
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestBumpDetector:
|
||||||
|
|
||||||
|
def test_clear_on_first_call(self):
|
||||||
|
"""No jerk on first sample (no previous value)."""
|
||||||
|
ev = _bump().update(0.0, 0.0, 9.81, dt=0.05)
|
||||||
|
assert ev.level == ThreatLevel.CLEAR
|
||||||
|
|
||||||
|
def test_major_on_jerk(self):
|
||||||
|
b = _bump(jerk_threshold_ms3=5.0, critical_jerk_threshold_ms3=20.0)
|
||||||
|
b.update(0.0, 0.0, 9.81, dt=0.05) # seed → dyn_mag = 0
|
||||||
|
# ax=4.5: raw≈10.79, dyn≈0.98, jerk≈9.8 m/s³ → MAJOR (5.0 < 9.8 < 20.0)
|
||||||
|
ev = b.update(4.5, 0.0, 9.81, dt=0.1)
|
||||||
|
assert ev.level == ThreatLevel.MAJOR
|
||||||
|
|
||||||
|
def test_critical_on_severe_jerk(self):
|
||||||
|
b = _bump(jerk_threshold_ms3=5.0, critical_jerk_threshold_ms3=20.0)
|
||||||
|
b.update(0.0, 0.0, 9.81, dt=0.05)
|
||||||
|
# Very large spike
|
||||||
|
ev = b.update(50.0, 0.0, 9.81, dt=0.1)
|
||||||
|
assert ev.level == ThreatLevel.CRITICAL
|
||||||
|
|
||||||
|
def test_clear_on_gentle_acceleration(self):
|
||||||
|
b = _bump(jerk_threshold_ms3=8.0)
|
||||||
|
b.update(0.0, 0.0, 9.81, dt=0.05)
|
||||||
|
ev = b.update(0.1, 0.0, 9.81, dt=0.05) # tiny change
|
||||||
|
assert ev.level == ThreatLevel.CLEAR
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
# highest_threat
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestHighestThreat:
|
||||||
|
|
||||||
|
def test_empty_returns_clear(self):
|
||||||
|
assert highest_threat([]).level == ThreatLevel.CLEAR
|
||||||
|
|
||||||
|
def test_picks_highest(self):
|
||||||
|
a = ThreatEvent(level=ThreatLevel.MINOR)
|
||||||
|
b = ThreatEvent(level=ThreatLevel.CRITICAL)
|
||||||
|
c = ThreatEvent(level=ThreatLevel.MAJOR)
|
||||||
|
assert highest_threat([a, b, c]).level == ThreatLevel.CRITICAL
|
||||||
|
|
||||||
|
def test_single_item(self):
|
||||||
|
ev = ThreatEvent(level=ThreatLevel.MAJOR)
|
||||||
|
assert highest_threat([ev]) is ev
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
# AlertManager
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestAlertManager:
|
||||||
|
|
||||||
|
def test_clear_returns_none(self):
|
||||||
|
am = _alert_mgr()
|
||||||
|
assert am.update(_clear_threat()) is None
|
||||||
|
|
||||||
|
def test_minor_threat_gives_minor_alert(self):
|
||||||
|
am = _alert_mgr(suppression_s=0.0)
|
||||||
|
alert = am.update(_minor_threat(ts=0.0))
|
||||||
|
assert alert is not None
|
||||||
|
assert alert.level == AlertLevel.MINOR
|
||||||
|
|
||||||
|
def test_major_threat_gives_major_alert(self):
|
||||||
|
am = _alert_mgr(suppression_s=0.0)
|
||||||
|
alert = am.update(_major_threat(ts=0.0))
|
||||||
|
assert alert is not None
|
||||||
|
assert alert.level == AlertLevel.MAJOR
|
||||||
|
|
||||||
|
def test_critical_threat_gives_critical_alert(self):
|
||||||
|
am = _alert_mgr(suppression_s=0.0)
|
||||||
|
alert = am.update(_critical_threat(ts=0.0))
|
||||||
|
assert alert is not None
|
||||||
|
assert alert.level == AlertLevel.CRITICAL
|
||||||
|
|
||||||
|
def test_suppression_blocks_duplicate(self):
|
||||||
|
am = _alert_mgr(suppression_s=5.0)
|
||||||
|
am.update(_major_threat(ts=0.0))
|
||||||
|
alert = am.update(_major_threat(ts=1.0)) # within 5s window
|
||||||
|
assert alert is None
|
||||||
|
|
||||||
|
def test_suppression_expires(self):
|
||||||
|
am = _alert_mgr(suppression_s=2.0)
|
||||||
|
am.update(_major_threat(ts=0.0))
|
||||||
|
alert = am.update(_major_threat(ts=3.0)) # after 2s window
|
||||||
|
assert alert is not None
|
||||||
|
|
||||||
|
def test_escalation_major_to_critical(self):
|
||||||
|
"""After major_count_threshold major alerts, next one becomes CRITICAL."""
|
||||||
|
am = _alert_mgr(major_count_threshold=3, escalation_window_s=60.0,
|
||||||
|
suppression_s=0.0)
|
||||||
|
for i in range(3):
|
||||||
|
am.update(_major_threat(ts=float(i)))
|
||||||
|
# 4th should be escalated
|
||||||
|
alert = am.update(_major_threat(ts=4.0))
|
||||||
|
assert alert is not None
|
||||||
|
assert alert.level == AlertLevel.CRITICAL
|
||||||
|
|
||||||
|
def test_escalation_resets_after_window(self):
|
||||||
|
"""Major alerts outside the window don't contribute to escalation."""
|
||||||
|
am = _alert_mgr(major_count_threshold=3, escalation_window_s=5.0,
|
||||||
|
suppression_s=0.0)
|
||||||
|
am.update(_major_threat(ts=0.0))
|
||||||
|
am.update(_major_threat(ts=1.0))
|
||||||
|
am.update(_major_threat(ts=2.0))
|
||||||
|
# All 3 are old; new one at t=10 is outside window
|
||||||
|
alert = am.update(_major_threat(ts=10.0))
|
||||||
|
assert alert is not None
|
||||||
|
assert alert.level == AlertLevel.MAJOR # not escalated
|
||||||
|
|
||||||
|
def test_reset_clears_escalation_state(self):
|
||||||
|
am = _alert_mgr(major_count_threshold=2, suppression_s=0.0)
|
||||||
|
am.update(_major_threat(ts=0.0))
|
||||||
|
am.update(_major_threat(ts=1.0)) # now at threshold
|
||||||
|
am.reset()
|
||||||
|
alert = am.update(_major_threat(ts=2.0))
|
||||||
|
assert alert.level == AlertLevel.MAJOR # back to major after reset
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
# RecoverySequencer
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestRecoverySequencer:
|
||||||
|
|
||||||
|
def _trigger(self, seq):
|
||||||
|
return seq.tick(RecoveryInputs(trigger=True, dt=0.02))
|
||||||
|
|
||||||
|
def test_idle_on_init(self):
|
||||||
|
seq = _seq()
|
||||||
|
assert seq.state == RecoveryState.IDLE
|
||||||
|
|
||||||
|
def test_trigger_starts_reversing(self):
|
||||||
|
seq = _seq()
|
||||||
|
out = self._trigger(seq)
|
||||||
|
assert seq.state == RecoveryState.REVERSING
|
||||||
|
|
||||||
|
def test_reversing_backward_velocity(self):
|
||||||
|
seq = _seq(reverse_speed_ms=-0.15)
|
||||||
|
self._trigger(seq)
|
||||||
|
out = seq.tick(RecoveryInputs(dt=0.02))
|
||||||
|
assert out.cmd_linear < 0.0
|
||||||
|
|
||||||
|
def test_reversing_completes_to_turning(self):
|
||||||
|
seq = _seq(reverse_speed_ms=-1.0, reverse_distance_m=0.5)
|
||||||
|
self._trigger(seq)
|
||||||
|
for _ in range(30):
|
||||||
|
out = seq.tick(RecoveryInputs(dt=0.02))
|
||||||
|
assert seq.state == RecoveryState.TURNING
|
||||||
|
|
||||||
|
def test_turning_positive_angular(self):
|
||||||
|
seq = _seq(reverse_speed_ms=-1.0, reverse_distance_m=0.1,
|
||||||
|
angular_speed_rads=1.0)
|
||||||
|
self._trigger(seq)
|
||||||
|
# Skip through reversing quickly
|
||||||
|
for _ in range(20):
|
||||||
|
seq.tick(RecoveryInputs(dt=0.02))
|
||||||
|
if seq.state == RecoveryState.TURNING:
|
||||||
|
out = seq.tick(RecoveryInputs(dt=0.02))
|
||||||
|
assert out.cmd_angular > 0.0
|
||||||
|
|
||||||
|
def test_retrying_increments_count(self):
|
||||||
|
seq = _seq(reverse_speed_ms=-1.0, reverse_distance_m=0.05,
|
||||||
|
angular_speed_rads=10.0, turn_angle_rad=0.1)
|
||||||
|
self._trigger(seq)
|
||||||
|
for _ in range(100):
|
||||||
|
seq.tick(RecoveryInputs(dt=0.02))
|
||||||
|
assert seq.state == RecoveryState.RETRYING
|
||||||
|
assert seq.retry_count == 1
|
||||||
|
|
||||||
|
def test_threat_cleared_returns_idle(self):
|
||||||
|
seq = _seq(reverse_speed_ms=-1.0, reverse_distance_m=0.05,
|
||||||
|
angular_speed_rads=10.0, turn_angle_rad=0.1,
|
||||||
|
clear_hold_s=0.1)
|
||||||
|
self._trigger(seq)
|
||||||
|
# Fast-forward to RETRYING
|
||||||
|
for _ in range(100):
|
||||||
|
seq.tick(RecoveryInputs(dt=0.02))
|
||||||
|
assert seq.state == RecoveryState.RETRYING
|
||||||
|
# Feed cleared ticks until clear_hold met
|
||||||
|
for _ in range(20):
|
||||||
|
seq.tick(RecoveryInputs(threat_cleared=True, dt=0.02))
|
||||||
|
assert seq.state == RecoveryState.IDLE
|
||||||
|
|
||||||
|
def test_max_retries_gives_up(self):
|
||||||
|
seq = _seq(reverse_speed_ms=-1.0, reverse_distance_m=0.05,
|
||||||
|
angular_speed_rads=10.0, turn_angle_rad=0.1,
|
||||||
|
retry_timeout_s=0.1, max_retries=2)
|
||||||
|
self._trigger(seq)
|
||||||
|
for _ in range(500):
|
||||||
|
out = seq.tick(RecoveryInputs(threat_cleared=False, dt=0.05))
|
||||||
|
if seq.state == RecoveryState.GAVE_UP:
|
||||||
|
break
|
||||||
|
assert seq.state == RecoveryState.GAVE_UP
|
||||||
|
|
||||||
|
def test_reset_returns_to_idle(self):
|
||||||
|
seq = _seq()
|
||||||
|
self._trigger(seq)
|
||||||
|
seq.reset()
|
||||||
|
assert seq.state == RecoveryState.IDLE
|
||||||
|
assert seq.retry_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
# EmergencyFSM
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestEmergencyFSMBasic:
|
||||||
|
|
||||||
|
def test_initial_state_nominal(self):
|
||||||
|
fsm = _fsm()
|
||||||
|
assert fsm.state == EmergencyState.NOMINAL
|
||||||
|
|
||||||
|
def test_nominal_stays_on_clear(self):
|
||||||
|
fsm = _fsm()
|
||||||
|
out = fsm.tick(_inp())
|
||||||
|
assert fsm.state == EmergencyState.NOMINAL
|
||||||
|
assert out.cmd_override is False
|
||||||
|
|
||||||
|
def test_minor_alert_no_override(self):
|
||||||
|
fsm = _fsm()
|
||||||
|
out = fsm.tick(_inp(_minor_threat(ts=0.0)))
|
||||||
|
assert fsm.state == EmergencyState.NOMINAL
|
||||||
|
assert out.cmd_override is False
|
||||||
|
assert out.alert is not None
|
||||||
|
assert out.alert.level == AlertLevel.MINOR
|
||||||
|
|
||||||
|
def test_major_threat_enters_stopping(self):
|
||||||
|
fsm = _fsm()
|
||||||
|
out = fsm.tick(_inp(_major_threat()))
|
||||||
|
assert fsm.state == EmergencyState.STOPPING
|
||||||
|
assert out.cmd_override is True
|
||||||
|
|
||||||
|
def test_critical_threat_enters_stopping_critical_pending(self):
|
||||||
|
fsm = _fsm()
|
||||||
|
fsm.tick(_inp(_critical_threat()))
|
||||||
|
assert fsm.state == EmergencyState.STOPPING
|
||||||
|
assert fsm._critical_pending is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmergencyFSMStopping:
|
||||||
|
|
||||||
|
def test_stopping_commands_zero(self):
|
||||||
|
fsm = _fsm()
|
||||||
|
fsm.tick(_inp(_major_threat()))
|
||||||
|
out = fsm.tick(_inp(_major_threat(), speed=0.5))
|
||||||
|
assert out.cmd_linear == pytest.approx(0.0, abs=1e-9)
|
||||||
|
assert out.cmd_angular == pytest.approx(0.0, abs=1e-9)
|
||||||
|
|
||||||
|
def test_stopped_enters_recovering(self):
|
||||||
|
fsm = _fsm(stopped_ms=0.03)
|
||||||
|
fsm.tick(_inp(_major_threat()))
|
||||||
|
out = fsm.tick(_inp(_major_threat(), speed=0.01)) # below stopped_ms
|
||||||
|
assert fsm.state == EmergencyState.RECOVERING
|
||||||
|
|
||||||
|
def test_critical_pending_enters_escalated(self):
|
||||||
|
fsm = _fsm(stopped_ms=0.03)
|
||||||
|
fsm.tick(_inp(_critical_threat()))
|
||||||
|
fsm.tick(_inp(_critical_threat(), speed=0.01)) # stopped → ESCALATED
|
||||||
|
assert fsm.state == EmergencyState.ESCALATED
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmergencyFSMRecovering:
|
||||||
|
|
||||||
|
def _reach_recovering(self, fsm):
|
||||||
|
fsm.tick(_inp(_major_threat()))
|
||||||
|
fsm.tick(_inp(_major_threat(), speed=0.0)) # stopped → RECOVERING
|
||||||
|
assert fsm.state == EmergencyState.RECOVERING
|
||||||
|
|
||||||
|
def test_recovering_has_cmd_override(self):
|
||||||
|
fsm = _fsm()
|
||||||
|
self._reach_recovering(fsm)
|
||||||
|
out = fsm.tick(_inp(_clear_threat()))
|
||||||
|
assert out.cmd_override is True
|
||||||
|
|
||||||
|
def test_recovering_gave_up_escalates(self):
|
||||||
|
fsm = _fsm(max_retries=1, retry_timeout_s=0.05)
|
||||||
|
self._reach_recovering(fsm)
|
||||||
|
# Drive recovery to GAVE_UP by feeding many non-clearing ticks
|
||||||
|
for _ in range(500):
|
||||||
|
out = fsm.tick(_inp(_major_threat()))
|
||||||
|
if fsm.state == EmergencyState.ESCALATED:
|
||||||
|
break
|
||||||
|
assert fsm.state == EmergencyState.ESCALATED
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmergencyFSMEscalated:
|
||||||
|
|
||||||
|
def _reach_escalated(self, fsm):
|
||||||
|
fsm.tick(_inp(_critical_threat()))
|
||||||
|
fsm.tick(_inp(_critical_threat(), speed=0.0))
|
||||||
|
assert fsm.state == EmergencyState.ESCALATED
|
||||||
|
|
||||||
|
def test_escalated_emits_critical_alert_once(self):
|
||||||
|
fsm = _fsm()
|
||||||
|
self._reach_escalated(fsm)
|
||||||
|
out1 = fsm.tick(_inp())
|
||||||
|
out2 = fsm.tick(_inp())
|
||||||
|
assert out1.alert is not None
|
||||||
|
assert out1.alert.level == AlertLevel.CRITICAL
|
||||||
|
assert out2.alert is None # suppressed after first emission
|
||||||
|
|
||||||
|
def test_escalated_e_stop_asserted(self):
|
||||||
|
fsm = _fsm()
|
||||||
|
self._reach_escalated(fsm)
|
||||||
|
out = fsm.tick(_inp())
|
||||||
|
assert out.e_stop is True
|
||||||
|
|
||||||
|
def test_escalated_stays_without_ack(self):
|
||||||
|
fsm = _fsm()
|
||||||
|
self._reach_escalated(fsm)
|
||||||
|
for _ in range(5):
|
||||||
|
fsm.tick(_inp())
|
||||||
|
assert fsm.state == EmergencyState.ESCALATED
|
||||||
|
|
||||||
|
def test_acknowledge_returns_to_nominal(self):
|
||||||
|
fsm = _fsm()
|
||||||
|
self._reach_escalated(fsm)
|
||||||
|
fsm.tick(_inp(ack=True))
|
||||||
|
assert fsm.state == EmergencyState.NOMINAL
|
||||||
|
|
||||||
|
def test_reset_returns_to_nominal(self):
|
||||||
|
fsm = _fsm()
|
||||||
|
self._reach_escalated(fsm)
|
||||||
|
fsm.reset()
|
||||||
|
assert fsm.state == EmergencyState.NOMINAL
|
||||||
|
|
||||||
|
def test_e_stop_cleared_on_ack(self):
|
||||||
|
fsm = _fsm()
|
||||||
|
self._reach_escalated(fsm)
|
||||||
|
out = fsm.tick(_inp(ack=True))
|
||||||
|
assert out.e_stop is False
|
||||||
15
jetson/ros2_ws/src/saltybot_emergency_msgs/CMakeLists.txt
Normal file
15
jetson/ros2_ws/src/saltybot_emergency_msgs/CMakeLists.txt
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.8)
|
||||||
|
project(saltybot_emergency_msgs)
|
||||||
|
|
||||||
|
find_package(ament_cmake REQUIRED)
|
||||||
|
find_package(rosidl_default_generators REQUIRED)
|
||||||
|
find_package(builtin_interfaces REQUIRED)
|
||||||
|
|
||||||
|
rosidl_generate_interfaces(${PROJECT_NAME}
|
||||||
|
"msg/EmergencyEvent.msg"
|
||||||
|
"msg/RecoveryAction.msg"
|
||||||
|
DEPENDENCIES builtin_interfaces
|
||||||
|
)
|
||||||
|
|
||||||
|
ament_export_dependencies(rosidl_default_runtime)
|
||||||
|
ament_package()
|
||||||
@ -0,0 +1,25 @@
|
|||||||
|
# EmergencyEvent.msg — Real-time emergency system state snapshot (Issue #169)
|
||||||
|
# Published by: /saltybot/emergency_node
|
||||||
|
# Topic: /saltybot/emergency
|
||||||
|
|
||||||
|
builtin_interfaces/Time stamp
|
||||||
|
|
||||||
|
# Overall FSM state
|
||||||
|
# Values: "NOMINAL" | "STOPPING" | "RECOVERING" | "ESCALATED"
|
||||||
|
string state
|
||||||
|
|
||||||
|
# Active threat (highest severity across all detectors)
|
||||||
|
# threat_type values: "NONE" | "OBSTACLE_PROXIMITY" | "FALL_RISK" | "WHEEL_STUCK" | "BUMP"
|
||||||
|
string threat_type
|
||||||
|
|
||||||
|
# Severity: "CLEAR" | "MINOR" | "MAJOR" | "CRITICAL"
|
||||||
|
string severity
|
||||||
|
|
||||||
|
# Triggering metric value (e.g. distance in m, jerk in m/s³, stuck seconds)
|
||||||
|
float32 threat_value
|
||||||
|
|
||||||
|
# Human-readable description of the active threat
|
||||||
|
string detail
|
||||||
|
|
||||||
|
# True when emergency system is overriding normal cmd_vel with its own commands
|
||||||
|
bool cmd_override
|
||||||
@ -0,0 +1,15 @@
|
|||||||
|
# RecoveryAction.msg — Recovery sequencer state (Issue #169)
|
||||||
|
# Published by: /saltybot/emergency_node
|
||||||
|
# Topic: /saltybot/recovery_action
|
||||||
|
|
||||||
|
builtin_interfaces/Time stamp
|
||||||
|
|
||||||
|
# Current recovery action
|
||||||
|
# Values: "IDLE" | "REVERSING" | "TURNING" | "RETRYING" | "GAVE_UP"
|
||||||
|
string action
|
||||||
|
|
||||||
|
# Number of reverse+turn attempts completed so far
|
||||||
|
int32 retry_count
|
||||||
|
|
||||||
|
# Progress through current phase [0.0 – 1.0]
|
||||||
|
float32 progress
|
||||||
22
jetson/ros2_ws/src/saltybot_emergency_msgs/package.xml
Normal file
22
jetson/ros2_ws/src/saltybot_emergency_msgs/package.xml
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
<?xml version="1.0"?>
|
||||||
|
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||||
|
<package format="3">
|
||||||
|
<name>saltybot_emergency_msgs</name>
|
||||||
|
<version>0.1.0</version>
|
||||||
|
<description>Emergency behavior message definitions for SaltyBot (Issue #169)</description>
|
||||||
|
<maintainer email="sl-controls@saltylab.local">sl-controls</maintainer>
|
||||||
|
<license>MIT</license>
|
||||||
|
|
||||||
|
<buildtool_depend>ament_cmake</buildtool_depend>
|
||||||
|
<buildtool_depend>rosidl_default_generators</buildtool_depend>
|
||||||
|
|
||||||
|
<depend>builtin_interfaces</depend>
|
||||||
|
|
||||||
|
<exec_depend>rosidl_default_runtime</exec_depend>
|
||||||
|
|
||||||
|
<member_of_group>rosidl_interface_packages</member_of_group>
|
||||||
|
|
||||||
|
<export>
|
||||||
|
<build_type>ament_cmake</build_type>
|
||||||
|
</export>
|
||||||
|
</package>
|
||||||
@ -0,0 +1,38 @@
|
|||||||
|
emotion_node:
|
||||||
|
ros__parameters:
|
||||||
|
# Path to the TRT FP16 emotion engine (built from emotion_cnn.onnx).
|
||||||
|
# Build with:
|
||||||
|
# trtexec --onnx=emotion_cnn.onnx --fp16 --saveEngine=emotion_fp16.trt
|
||||||
|
# Leave empty to use landmark heuristic only (no GPU required).
|
||||||
|
engine_path: "/models/emotion_fp16.trt"
|
||||||
|
|
||||||
|
# Minimum smoothed confidence to publish an expression.
|
||||||
|
# Lower = more detections but more noise; 0.40 is a good production default.
|
||||||
|
min_confidence: 0.40
|
||||||
|
|
||||||
|
# EMA smoothing weight applied to each new observation.
|
||||||
|
# 0.0 = frozen (never updates) 1.0 = no smoothing (raw output)
|
||||||
|
# 0.30 gives stable ~3-frame moving average at 10 Hz face detection rate.
|
||||||
|
smoothing_alpha: 0.30
|
||||||
|
|
||||||
|
# Comma-separated person_ids that have opted out of emotion tracking.
|
||||||
|
# Empty string = everyone is tracked by default.
|
||||||
|
# Example: "1,3,7"
|
||||||
|
opt_out_persons: ""
|
||||||
|
|
||||||
|
# Skip face crops smaller than this side length (pixels).
|
||||||
|
# Very small crops produce unreliable emotion predictions.
|
||||||
|
face_min_size: 24
|
||||||
|
|
||||||
|
# If true and TRT engine is unavailable, fall back to the 5-point
|
||||||
|
# landmark heuristic. Set false to suppress output entirely without TRT.
|
||||||
|
landmark_fallback: true
|
||||||
|
|
||||||
|
# Camera names matching saltybot_cameras topic naming convention.
|
||||||
|
camera_names: "front,left,rear,right"
|
||||||
|
|
||||||
|
# Number of active CSI camera streams to subscribe to.
|
||||||
|
n_cameras: 4
|
||||||
|
|
||||||
|
# Publish /social/emotion/context (JSON string) for conversation_node.
|
||||||
|
publish_context: true
|
||||||
67
jetson/ros2_ws/src/saltybot_social/launch/emotion.launch.py
Normal file
67
jetson/ros2_ws/src/saltybot_social/launch/emotion.launch.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
"""emotion.launch.py — Launch emotion_node for facial expression recognition (Issue #161)."""
|
||||||
|
|
||||||
|
from launch import LaunchDescription
|
||||||
|
from launch.actions import DeclareLaunchArgument
|
||||||
|
from launch.substitutions import LaunchConfiguration, PathJoinSubstitution
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
from launch_ros.substitutions import FindPackageShare
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launch_description() -> LaunchDescription:
|
||||||
|
pkg = FindPackageShare("saltybot_social")
|
||||||
|
params_file = PathJoinSubstitution([pkg, "config", "emotion_params.yaml"])
|
||||||
|
|
||||||
|
return LaunchDescription([
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
"params_file",
|
||||||
|
default_value=params_file,
|
||||||
|
description="Path to emotion_node parameter YAML",
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
"engine_path",
|
||||||
|
default_value="/models/emotion_fp16.trt",
|
||||||
|
description="Path to TensorRT FP16 emotion engine (empty = landmark heuristic)",
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
"min_confidence",
|
||||||
|
default_value="0.40",
|
||||||
|
description="Minimum detection confidence to publish an expression",
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
"smoothing_alpha",
|
||||||
|
default_value="0.30",
|
||||||
|
description="EMA smoothing weight (0=frozen, 1=no smoothing)",
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
"opt_out_persons",
|
||||||
|
default_value="",
|
||||||
|
description="Comma-separated person_ids that opted out",
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
"landmark_fallback",
|
||||||
|
default_value="true",
|
||||||
|
description="Use landmark heuristic when TRT engine unavailable",
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
"publish_context",
|
||||||
|
default_value="true",
|
||||||
|
description="Publish /social/emotion/context JSON for LLM context",
|
||||||
|
),
|
||||||
|
Node(
|
||||||
|
package="saltybot_social",
|
||||||
|
executable="emotion_node",
|
||||||
|
name="emotion_node",
|
||||||
|
output="screen",
|
||||||
|
parameters=[
|
||||||
|
LaunchConfiguration("params_file"),
|
||||||
|
{
|
||||||
|
"engine_path": LaunchConfiguration("engine_path"),
|
||||||
|
"min_confidence": LaunchConfiguration("min_confidence"),
|
||||||
|
"smoothing_alpha": LaunchConfiguration("smoothing_alpha"),
|
||||||
|
"opt_out_persons": LaunchConfiguration("opt_out_persons"),
|
||||||
|
"landmark_fallback": LaunchConfiguration("landmark_fallback"),
|
||||||
|
"publish_context": LaunchConfiguration("publish_context"),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
])
|
||||||
@ -59,7 +59,10 @@ class ConversationNode(Node):
|
|||||||
self.create_subscription(String, "/social/emotion/context", self._on_emotion_context, 10)
|
self.create_subscription(String, "/social/emotion/context", self._on_emotion_context, 10)
|
||||||
threading.Thread(target=self._load_llm, daemon=True).start()
|
threading.Thread(target=self._load_llm, daemon=True).start()
|
||||||
self._save_timer = self.create_timer(self._save_interval, self._save_context)
|
self._save_timer = self.create_timer(self._save_interval, self._save_context)
|
||||||
self.get_logger().info(f"ConversationNode init (model={self._model_path}, gpu_layers={self._n_gpu})")
|
self.get_logger().info(
|
||||||
|
f"ConversationNode init (model={self._model_path}, "
|
||||||
|
f"gpu_layers={self._n_gpu}, ctx={self._n_ctx})"
|
||||||
|
)
|
||||||
|
|
||||||
def _load_llm(self) -> None:
|
def _load_llm(self) -> None:
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|||||||
@ -0,0 +1,414 @@
|
|||||||
|
"""emotion_classifier.py — 7-class facial expression classifier (Issue #161).
|
||||||
|
|
||||||
|
Pure Python, no ROS2 / TensorRT / OpenCV dependencies.
|
||||||
|
Wraps a TensorRT FP16 emotion CNN and provides:
|
||||||
|
- On-device TRT inference on 48×48 grayscale face crops
|
||||||
|
- Heuristic fallback from 5-point SCRFD facial landmarks
|
||||||
|
- Per-person EMA temporal smoothing for stable outputs
|
||||||
|
- Per-person opt-out registry
|
||||||
|
|
||||||
|
Emotion classes (index order matches CNN output layer)
|
||||||
|
------------------------------------------------------
|
||||||
|
0 = happy
|
||||||
|
1 = sad
|
||||||
|
2 = angry
|
||||||
|
3 = surprised
|
||||||
|
4 = fearful
|
||||||
|
5 = disgusted
|
||||||
|
6 = neutral
|
||||||
|
|
||||||
|
Coordinate convention
|
||||||
|
---------------------
|
||||||
|
Face crop: BGR uint8 ndarray, any size (resized to INPUT_SIZE internally).
|
||||||
|
Landmarks (lm10): 10 floats from FaceDetection.landmarks —
|
||||||
|
[left_eye_x, left_eye_y, right_eye_x, right_eye_y,
|
||||||
|
nose_x, nose_y, left_mouth_x, left_mouth_y,
|
||||||
|
right_mouth_x, right_mouth_y]
|
||||||
|
All coordinates are normalised image-space (0.0–1.0).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
# ── Constants ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
EMOTIONS: List[str] = [
|
||||||
|
"happy", "sad", "angry", "surprised", "fearful", "disgusted", "neutral"
|
||||||
|
]
|
||||||
|
N_CLASSES: int = len(EMOTIONS)
|
||||||
|
|
||||||
|
# Landmark slot indices within the 10-value array
|
||||||
|
_LEX, _LEY = 0, 1 # left eye
|
||||||
|
_REX, _REY = 2, 3 # right eye
|
||||||
|
_NX, _NY = 4, 5 # nose
|
||||||
|
_LMX, _LMY = 6, 7 # left mouth corner
|
||||||
|
_RMX, _RMY = 8, 9 # right mouth corner
|
||||||
|
|
||||||
|
# CNN input size
|
||||||
|
INPUT_SIZE: int = 48
|
||||||
|
|
||||||
|
|
||||||
|
# ── Result type ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClassifiedEmotion:
|
||||||
|
emotion: str # dominant emotion label
|
||||||
|
confidence: float # smoothed softmax probability (0.0–1.0)
|
||||||
|
scores: List[float] # per-class scores, len == N_CLASSES
|
||||||
|
source: str = "cnn_trt" # "cnn_trt" | "landmark_heuristic" | "opt_out"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Math helpers ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _softmax(logits: List[float]) -> List[float]:
|
||||||
|
"""Numerically stable softmax."""
|
||||||
|
m = max(logits)
|
||||||
|
exps = [math.exp(x - m) for x in logits]
|
||||||
|
total = sum(exps)
|
||||||
|
return [e / total for e in exps]
|
||||||
|
|
||||||
|
|
||||||
|
def _argmax(values: List[float]) -> int:
|
||||||
|
best = 0
|
||||||
|
for i in range(1, len(values)):
|
||||||
|
if values[i] > values[best]:
|
||||||
|
best = i
|
||||||
|
return best
|
||||||
|
|
||||||
|
|
||||||
|
def _uniform_scores() -> List[float]:
|
||||||
|
return [1.0 / N_CLASSES] * N_CLASSES
|
||||||
|
|
||||||
|
|
||||||
|
# ── Per-person temporal smoother + opt-out registry ───────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class PersonEmotionTracker:
|
||||||
|
"""Exponential moving-average smoother + opt-out registry per person.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alpha: EMA weight for the newest observation (0.0 = frozen, 1.0 = no smooth).
|
||||||
|
max_age: Frames without update before the EMA is expired and reset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, alpha: float = 0.3, max_age: int = 16) -> None:
|
||||||
|
self._alpha = alpha
|
||||||
|
self._max_age = max_age
|
||||||
|
self._ema: Dict[int, List[float]] = {} # person_id → EMA scores
|
||||||
|
self._age: Dict[int, int] = {} # person_id → frames since last update
|
||||||
|
self._opt_out: set = set() # person_ids that opted out
|
||||||
|
|
||||||
|
# ── Opt-out management ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def set_opt_out(self, person_id: int, value: bool) -> None:
|
||||||
|
if value:
|
||||||
|
self._opt_out.add(person_id)
|
||||||
|
self._ema.pop(person_id, None)
|
||||||
|
self._age.pop(person_id, None)
|
||||||
|
else:
|
||||||
|
self._opt_out.discard(person_id)
|
||||||
|
|
||||||
|
def is_opt_out(self, person_id: int) -> bool:
|
||||||
|
return person_id in self._opt_out
|
||||||
|
|
||||||
|
# ── Smoothing ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def smooth(self, person_id: int, scores: List[float]) -> List[float]:
|
||||||
|
"""Apply EMA to raw scores. Returns smoothed scores (length N_CLASSES)."""
|
||||||
|
if person_id < 0:
|
||||||
|
return scores[:]
|
||||||
|
|
||||||
|
# Reset stale entries
|
||||||
|
if person_id in self._age and self._age[person_id] > self._max_age:
|
||||||
|
self.reset(person_id)
|
||||||
|
|
||||||
|
if person_id not in self._ema:
|
||||||
|
self._ema[person_id] = scores[:]
|
||||||
|
self._age[person_id] = 0
|
||||||
|
return scores[:]
|
||||||
|
|
||||||
|
ema = self._ema[person_id]
|
||||||
|
a = self._alpha
|
||||||
|
smoothed = [a * s + (1.0 - a) * e for s, e in zip(scores, ema)]
|
||||||
|
self._ema[person_id] = smoothed
|
||||||
|
self._age[person_id] = 0
|
||||||
|
return smoothed
|
||||||
|
|
||||||
|
def tick(self) -> None:
|
||||||
|
"""Advance age counter for all tracked persons (call once per frame)."""
|
||||||
|
for pid in list(self._age.keys()):
|
||||||
|
self._age[pid] += 1
|
||||||
|
|
||||||
|
def reset(self, person_id: int) -> None:
|
||||||
|
self._ema.pop(person_id, None)
|
||||||
|
self._age.pop(person_id, None)
|
||||||
|
|
||||||
|
def reset_all(self) -> None:
|
||||||
|
self._ema.clear()
|
||||||
|
self._age.clear()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tracked_ids(self) -> List[int]:
|
||||||
|
return list(self._ema.keys())
|
||||||
|
|
||||||
|
|
||||||
|
# ── TensorRT engine loader ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _TrtEngine:
|
||||||
|
"""Thin wrapper around a TensorRT FP16 emotion CNN engine.
|
||||||
|
|
||||||
|
Expected engine:
|
||||||
|
- Input binding: name "input" shape (1, 1, 48, 48) float32
|
||||||
|
- Output binding: name "output" shape (1, 7) float32 (softmax logits)
|
||||||
|
|
||||||
|
The engine is built offline from a MobileNetV2-based 48×48 grayscale CNN
|
||||||
|
(FER+ or AffectNet trained) via:
|
||||||
|
trtexec --onnx=emotion_cnn.onnx --fp16 --saveEngine=emotion_fp16.trt
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._context = None
|
||||||
|
self._h_input = None
|
||||||
|
self._h_output = None
|
||||||
|
self._d_input = None
|
||||||
|
self._d_output = None
|
||||||
|
self._stream = None
|
||||||
|
self._bindings: List = []
|
||||||
|
|
||||||
|
def load(self, engine_path: str) -> bool:
|
||||||
|
"""Load TRT engine from file. Returns True on success."""
|
||||||
|
try:
|
||||||
|
import tensorrt as trt # type: ignore
|
||||||
|
import pycuda.autoinit # type: ignore # noqa: F401
|
||||||
|
import pycuda.driver as cuda # type: ignore
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
||||||
|
with open(engine_path, "rb") as f:
|
||||||
|
engine_data = f.read()
|
||||||
|
|
||||||
|
runtime = trt.Runtime(TRT_LOGGER)
|
||||||
|
engine = runtime.deserialize_cuda_engine(engine_data)
|
||||||
|
self._context = engine.create_execution_context()
|
||||||
|
|
||||||
|
# Allocate host + device buffers
|
||||||
|
self._h_input = cuda.pagelocked_empty((1, 1, INPUT_SIZE, INPUT_SIZE),
|
||||||
|
dtype=np.float32)
|
||||||
|
self._h_output = cuda.pagelocked_empty((1, N_CLASSES), dtype=np.float32)
|
||||||
|
self._d_input = cuda.mem_alloc(self._h_input.nbytes)
|
||||||
|
self._d_output = cuda.mem_alloc(self._h_output.nbytes)
|
||||||
|
self._stream = cuda.Stream()
|
||||||
|
self._bindings = [int(self._d_input), int(self._d_output)]
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
self._context = None
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ready(self) -> bool:
|
||||||
|
return self._context is not None
|
||||||
|
|
||||||
|
def infer(self, face_bgr) -> List[float]:
|
||||||
|
"""Run TRT inference on a BGR face crop. Returns 7 softmax scores."""
|
||||||
|
import pycuda.driver as cuda # type: ignore
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
img = _preprocess(face_bgr) # (1, 1, 48, 48) float32
|
||||||
|
np.copyto(self._h_input, img)
|
||||||
|
cuda.memcpy_htod_async(self._d_input, self._h_input, self._stream)
|
||||||
|
self._context.execute_async_v2(self._bindings, self._stream.handle, None)
|
||||||
|
cuda.memcpy_dtoh_async(self._h_output, self._d_output, self._stream)
|
||||||
|
self._stream.synchronize()
|
||||||
|
logits = list(self._h_output[0])
|
||||||
|
return _softmax(logits)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Image pre-processing helper (importable without cv2 in test mode) ─────────
|
||||||
|
|
||||||
|
|
||||||
|
def _preprocess(face_bgr) -> "np.ndarray": # type: ignore
|
||||||
|
"""Resize BGR crop to 48×48 grayscale, normalise to [-1, 1].
|
||||||
|
|
||||||
|
Returns ndarray shape (1, 1, 48, 48) float32.
|
||||||
|
"""
|
||||||
|
import numpy as np # type: ignore
|
||||||
|
import cv2 # type: ignore
|
||||||
|
|
||||||
|
gray = cv2.cvtColor(face_bgr, cv2.COLOR_BGR2GRAY)
|
||||||
|
resized = cv2.resize(gray, (INPUT_SIZE, INPUT_SIZE),
|
||||||
|
interpolation=cv2.INTER_LINEAR)
|
||||||
|
norm = resized.astype(np.float32) / 127.5 - 1.0
|
||||||
|
return norm.reshape(1, 1, INPUT_SIZE, INPUT_SIZE)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Landmark heuristic ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def classify_from_landmarks(lm10: List[float]) -> ClassifiedEmotion:
|
||||||
|
"""Estimate emotion from 5-point SCRFD landmarks (10 floats).
|
||||||
|
|
||||||
|
Uses geometric ratios between eyes, nose, and mouth corners.
|
||||||
|
Accuracy is limited — treats this as a soft prior, not a definitive label.
|
||||||
|
|
||||||
|
Returns ClassifiedEmotion with source="landmark_heuristic".
|
||||||
|
"""
|
||||||
|
if len(lm10) < 10:
|
||||||
|
scores = _uniform_scores()
|
||||||
|
scores[6] = 0.5 # bias neutral
|
||||||
|
scores = _renorm(scores)
|
||||||
|
return ClassifiedEmotion("neutral", scores[6], scores, "landmark_heuristic")
|
||||||
|
|
||||||
|
eye_y = (lm10[_LEY] + lm10[_REY]) / 2.0
|
||||||
|
nose_y = lm10[_NY]
|
||||||
|
mouth_y = (lm10[_LMY] + lm10[_RMY]) / 2.0
|
||||||
|
eye_span = max(abs(lm10[_REX] - lm10[_LEX]), 1e-4)
|
||||||
|
mouth_width = abs(lm10[_RMX] - lm10[_LMX])
|
||||||
|
mouth_asymm = abs(lm10[_LMY] - lm10[_RMY])
|
||||||
|
|
||||||
|
face_h = max(mouth_y - eye_y, 1e-4)
|
||||||
|
|
||||||
|
# Ratio of mouth span to interocular distance
|
||||||
|
width_ratio = mouth_width / eye_span # >1.0 = wide open / happy smile
|
||||||
|
# How far mouth is below nose, relative to face height
|
||||||
|
mouth_below_nose = (mouth_y - nose_y) / face_h # ~0.3–0.6 typical
|
||||||
|
# Relative asymmetry of mouth corners
|
||||||
|
asym_ratio = mouth_asymm / face_h # >0.05 = notable asymmetry
|
||||||
|
|
||||||
|
# Build soft scores
|
||||||
|
scores = _uniform_scores() # start uniform
|
||||||
|
|
||||||
|
if width_ratio > 0.85 and mouth_below_nose > 0.35:
|
||||||
|
# Wide mouth, normal vertical position → happy
|
||||||
|
scores[0] = 0.55 + 0.25 * min(1.0, (width_ratio - 0.85) / 0.5)
|
||||||
|
elif mouth_below_nose < 0.20 and width_ratio < 0.7:
|
||||||
|
# Tight, compressed mouth high up → surprised OR angry
|
||||||
|
scores[3] = 0.35 # surprised
|
||||||
|
scores[2] = 0.30 # angry
|
||||||
|
elif asym_ratio > 0.06:
|
||||||
|
# Asymmetric mouth → disgust or sadness
|
||||||
|
scores[5] = 0.30 # disgusted
|
||||||
|
scores[1] = 0.25 # sad
|
||||||
|
elif width_ratio < 0.65 and mouth_below_nose < 0.30:
|
||||||
|
# Tight and compressed → sad/angry
|
||||||
|
scores[1] = 0.35 # sad
|
||||||
|
scores[2] = 0.25 # angry
|
||||||
|
else:
|
||||||
|
# Default to neutral
|
||||||
|
scores[6] = 0.45
|
||||||
|
|
||||||
|
scores = _renorm(scores)
|
||||||
|
top_idx = _argmax(scores)
|
||||||
|
return ClassifiedEmotion(
|
||||||
|
emotion=EMOTIONS[top_idx],
|
||||||
|
confidence=round(scores[top_idx], 3),
|
||||||
|
scores=[round(s, 4) for s in scores],
|
||||||
|
source="landmark_heuristic",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _renorm(scores: List[float]) -> List[float]:
|
||||||
|
"""Re-normalise scores so they sum to 1.0."""
|
||||||
|
total = sum(scores)
|
||||||
|
if total <= 0:
|
||||||
|
return _uniform_scores()
|
||||||
|
return [s / total for s in scores]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Public classifier ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class EmotionClassifier:
|
||||||
|
"""Facade combining TRT inference + landmark fallback + per-person smoothing.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
>>> clf = EmotionClassifier(engine_path="/models/emotion_fp16.trt")
|
||||||
|
>>> clf.load()
|
||||||
|
True
|
||||||
|
>>> result = clf.classify_crop(face_bgr, person_id=42, tracker=tracker)
|
||||||
|
>>> result.emotion, result.confidence
|
||||||
|
('happy', 0.87)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
engine_path: str = "",
|
||||||
|
alpha: float = 0.3,
|
||||||
|
) -> None:
|
||||||
|
self._engine_path = engine_path
|
||||||
|
self._alpha = alpha
|
||||||
|
self._engine = _TrtEngine()
|
||||||
|
|
||||||
|
def load(self) -> bool:
|
||||||
|
"""Load TRT engine. Returns True if engine is ready."""
|
||||||
|
if not self._engine_path:
|
||||||
|
return False
|
||||||
|
return self._engine.load(self._engine_path)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ready(self) -> bool:
|
||||||
|
return self._engine.ready
|
||||||
|
|
||||||
|
def classify_crop(
|
||||||
|
self,
|
||||||
|
face_bgr,
|
||||||
|
person_id: int = -1,
|
||||||
|
tracker: Optional[PersonEmotionTracker] = None,
|
||||||
|
) -> ClassifiedEmotion:
|
||||||
|
"""Classify a BGR face crop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
face_bgr: BGR ndarray from cv_bridge / direct crop.
|
||||||
|
person_id: For temporal smoothing. -1 = no smoothing.
|
||||||
|
tracker: Optional PersonEmotionTracker for EMA smoothing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ClassifiedEmotion with smoothed scores.
|
||||||
|
"""
|
||||||
|
if self._engine.ready:
|
||||||
|
raw = self._engine.infer(face_bgr)
|
||||||
|
source = "cnn_trt"
|
||||||
|
else:
|
||||||
|
# Uniform fallback — no inference without engine
|
||||||
|
raw = _uniform_scores()
|
||||||
|
raw[6] = 0.40 # mild neutral bias
|
||||||
|
raw = _renorm(raw)
|
||||||
|
source = "landmark_heuristic"
|
||||||
|
|
||||||
|
smoothed = raw
|
||||||
|
if tracker is not None and person_id >= 0:
|
||||||
|
smoothed = tracker.smooth(person_id, raw)
|
||||||
|
|
||||||
|
top_idx = _argmax(smoothed)
|
||||||
|
return ClassifiedEmotion(
|
||||||
|
emotion=EMOTIONS[top_idx],
|
||||||
|
confidence=round(smoothed[top_idx], 3),
|
||||||
|
scores=[round(s, 4) for s in smoothed],
|
||||||
|
source=source,
|
||||||
|
)
|
||||||
|
|
||||||
|
def classify_from_landmarks(
|
||||||
|
self,
|
||||||
|
lm10: List[float],
|
||||||
|
person_id: int = -1,
|
||||||
|
tracker: Optional[PersonEmotionTracker] = None,
|
||||||
|
) -> ClassifiedEmotion:
|
||||||
|
"""Classify using landmark geometry only (no crop required)."""
|
||||||
|
result = classify_from_landmarks(lm10)
|
||||||
|
if tracker is not None and person_id >= 0:
|
||||||
|
smoothed = tracker.smooth(person_id, result.scores)
|
||||||
|
top_idx = _argmax(smoothed)
|
||||||
|
result = ClassifiedEmotion(
|
||||||
|
emotion=EMOTIONS[top_idx],
|
||||||
|
confidence=round(smoothed[top_idx], 3),
|
||||||
|
scores=[round(s, 4) for s in smoothed],
|
||||||
|
source=result.source,
|
||||||
|
)
|
||||||
|
return result
|
||||||
@ -0,0 +1,380 @@
|
|||||||
|
"""emotion_node.py — Facial expression recognition node (Issue #161).
|
||||||
|
|
||||||
|
Piggybacks on the face detection pipeline: subscribes to
|
||||||
|
/social/faces/detections (FaceDetectionArray), extracts face crops from the
|
||||||
|
latest camera frames, runs a TensorRT FP16 emotion CNN, applies per-person
|
||||||
|
EMA temporal smoothing, and publishes results on /social/faces/expressions.
|
||||||
|
|
||||||
|
Architecture
|
||||||
|
------------
|
||||||
|
FaceDetectionArray → crop extraction → TRT FP16 inference (< 5 ms) →
|
||||||
|
EMA smoothing → ExpressionArray publish
|
||||||
|
|
||||||
|
If TRT engine is not available the node falls back to the 5-point landmark
|
||||||
|
heuristic (classify_from_landmarks) which requires no GPU and adds < 0.1 ms.
|
||||||
|
|
||||||
|
ROS2 topics
|
||||||
|
-----------
|
||||||
|
Subscribe:
|
||||||
|
/social/faces/detections (saltybot_social_msgs/FaceDetectionArray)
|
||||||
|
/camera/{name}/image_raw (sensor_msgs/Image) × 4
|
||||||
|
/social/persons (saltybot_social_msgs/PersonStateArray)
|
||||||
|
Publish:
|
||||||
|
/social/faces/expressions (saltybot_social_msgs/ExpressionArray)
|
||||||
|
/social/emotion/context (std_msgs/String) — JSON for LLM context
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
engine_path (str) "" — path to emotion_fp16.trt; empty = landmark only
|
||||||
|
min_confidence (float) 0.40 — suppress results below this threshold
|
||||||
|
smoothing_alpha (float) 0.30 — EMA weight (higher = faster, less stable)
|
||||||
|
opt_out_persons (str) "" — comma-separated person_ids that opted out
|
||||||
|
face_min_size (int) 24 — skip faces whose bbox is smaller (px side)
|
||||||
|
landmark_fallback (bool) true — use landmark heuristic when TRT unavailable
|
||||||
|
camera_names (str) "front,left,rear,right"
|
||||||
|
n_cameras (int) 4
|
||||||
|
publish_context (bool) true — publish /social/emotion/context JSON
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||||
|
|
||||||
|
from std_msgs.msg import String
|
||||||
|
from sensor_msgs.msg import Image
|
||||||
|
|
||||||
|
from saltybot_social_msgs.msg import (
|
||||||
|
FaceDetectionArray,
|
||||||
|
PersonStateArray,
|
||||||
|
ExpressionArray,
|
||||||
|
Expression,
|
||||||
|
)
|
||||||
|
from saltybot_social.emotion_classifier import (
|
||||||
|
EmotionClassifier,
|
||||||
|
PersonEmotionTracker,
|
||||||
|
EMOTIONS,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from cv_bridge import CvBridge
|
||||||
|
_CV_BRIDGE_OK = True
|
||||||
|
except ImportError:
|
||||||
|
_CV_BRIDGE_OK = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
_CV2_OK = True
|
||||||
|
except ImportError:
|
||||||
|
_CV2_OK = False
|
||||||
|
|
||||||
|
|
||||||
|
# ── Per-camera latest-frame buffer ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _FrameBuffer:
|
||||||
|
"""Thread-safe store of the most recent image per camera."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._frames: Dict[int, object] = {} # camera_id → cv2 BGR image
|
||||||
|
|
||||||
|
def put(self, camera_id: int, img) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self._frames[camera_id] = img
|
||||||
|
|
||||||
|
def get(self, camera_id: int):
|
||||||
|
with self._lock:
|
||||||
|
return self._frames.get(camera_id)
|
||||||
|
|
||||||
|
|
||||||
|
# ── ROS2 Node ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class EmotionNode(Node):
|
||||||
|
"""Facial expression recognition — TRT FP16 + landmark fallback."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__("emotion_node")
|
||||||
|
|
||||||
|
# ── Parameters ────────────────────────────────────────────────────────
|
||||||
|
self.declare_parameter("engine_path", "")
|
||||||
|
self.declare_parameter("min_confidence", 0.40)
|
||||||
|
self.declare_parameter("smoothing_alpha", 0.30)
|
||||||
|
self.declare_parameter("opt_out_persons", "")
|
||||||
|
self.declare_parameter("face_min_size", 24)
|
||||||
|
self.declare_parameter("landmark_fallback", True)
|
||||||
|
self.declare_parameter("camera_names", "front,left,rear,right")
|
||||||
|
self.declare_parameter("n_cameras", 4)
|
||||||
|
self.declare_parameter("publish_context", True)
|
||||||
|
|
||||||
|
engine_path = self.get_parameter("engine_path").value
|
||||||
|
self._min_conf = self.get_parameter("min_confidence").value
|
||||||
|
alpha = self.get_parameter("smoothing_alpha").value
|
||||||
|
opt_out_str = self.get_parameter("opt_out_persons").value
|
||||||
|
self._face_min = self.get_parameter("face_min_size").value
|
||||||
|
self._lm_fallback = self.get_parameter("landmark_fallback").value
|
||||||
|
cam_names_str = self.get_parameter("camera_names").value
|
||||||
|
n_cameras = self.get_parameter("n_cameras").value
|
||||||
|
self._pub_ctx = self.get_parameter("publish_context").value
|
||||||
|
|
||||||
|
# ── Classifier + tracker ───────────────────────────────────────────
|
||||||
|
self._classifier = EmotionClassifier(engine_path=engine_path, alpha=alpha)
|
||||||
|
self._tracker = PersonEmotionTracker(alpha=alpha)
|
||||||
|
|
||||||
|
# Parse opt-out list
|
||||||
|
for pid_str in opt_out_str.split(","):
|
||||||
|
pid_str = pid_str.strip()
|
||||||
|
if pid_str.isdigit():
|
||||||
|
self._tracker.set_opt_out(int(pid_str), True)
|
||||||
|
|
||||||
|
# ── Camera frame buffer + cv_bridge ───────────────────────────────
|
||||||
|
self._frame_buf = _FrameBuffer()
|
||||||
|
self._bridge = CvBridge() if _CV_BRIDGE_OK else None
|
||||||
|
self._cam_names = [n.strip() for n in cam_names_str.split(",")]
|
||||||
|
self._cam_id_map: Dict[str, int] = {
|
||||||
|
name: idx for idx, name in enumerate(self._cam_names)
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Latest persons for person_id ↔ face_id correlation ────────────
|
||||||
|
self._persons_lock = threading.Lock()
|
||||||
|
self._face_to_person: Dict[int, int] = {} # face_id → person_id
|
||||||
|
|
||||||
|
# ── Context for LLM (latest emotion per person) ───────────────────
|
||||||
|
self._emotion_context: Dict[int, str] = {} # person_id → emotion label
|
||||||
|
|
||||||
|
# ── QoS profiles ──────────────────────────────────────────────────
|
||||||
|
be_qos = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Subscriptions ──────────────────────────────────────────────────
|
||||||
|
self.create_subscription(
|
||||||
|
FaceDetectionArray,
|
||||||
|
"/social/faces/detections",
|
||||||
|
self._on_faces,
|
||||||
|
be_qos,
|
||||||
|
)
|
||||||
|
self.create_subscription(
|
||||||
|
PersonStateArray,
|
||||||
|
"/social/persons",
|
||||||
|
self._on_persons,
|
||||||
|
be_qos,
|
||||||
|
)
|
||||||
|
for name in self._cam_names[:n_cameras]:
|
||||||
|
cam_id = self._cam_id_map.get(name, 0)
|
||||||
|
self.create_subscription(
|
||||||
|
Image,
|
||||||
|
f"/camera/{name}/image_raw",
|
||||||
|
lambda msg, cid=cam_id: self._on_image(msg, cid),
|
||||||
|
be_qos,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Publishers ─────────────────────────────────────────────────────
|
||||||
|
self._expr_pub = self.create_publisher(
|
||||||
|
ExpressionArray, "/social/faces/expressions", be_qos
|
||||||
|
)
|
||||||
|
if self._pub_ctx:
|
||||||
|
self._ctx_pub = self.create_publisher(
|
||||||
|
String, "/social/emotion/context", 10
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._ctx_pub = None
|
||||||
|
|
||||||
|
# ── Load TRT engine in background ──────────────────────────────────
|
||||||
|
threading.Thread(target=self._load_engine, daemon=True).start()
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f"EmotionNode ready — engine='{engine_path}', "
|
||||||
|
f"alpha={alpha:.2f}, min_conf={self._min_conf:.2f}, "
|
||||||
|
f"cameras={self._cam_names[:n_cameras]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Engine loading ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _load_engine(self) -> None:
|
||||||
|
if not self._classifier._engine_path:
|
||||||
|
if self._lm_fallback:
|
||||||
|
self.get_logger().info(
|
||||||
|
"No TRT engine path — using landmark heuristic fallback"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.get_logger().warn(
|
||||||
|
"No TRT engine and landmark_fallback=false — "
|
||||||
|
"emotion_node will not classify"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
ok = self._classifier.load()
|
||||||
|
if ok:
|
||||||
|
self.get_logger().info("TRT emotion engine loaded ✓")
|
||||||
|
else:
|
||||||
|
self.get_logger().warn(
|
||||||
|
"TRT engine load failed — falling back to landmark heuristic"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Camera frame ingestion ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _on_image(self, msg: Image, camera_id: int) -> None:
|
||||||
|
if not _CV_BRIDGE_OK or self._bridge is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
bgr = self._bridge.imgmsg_to_cv2(msg, desired_encoding="bgr8")
|
||||||
|
self._frame_buf.put(camera_id, bgr)
|
||||||
|
except Exception as exc:
|
||||||
|
self.get_logger().debug(f"cv_bridge error cam{camera_id}: {exc}")
|
||||||
|
|
||||||
|
# ── Person-state for face_id → person_id mapping ─────────────────────────
|
||||||
|
|
||||||
|
def _on_persons(self, msg: PersonStateArray) -> None:
|
||||||
|
mapping: Dict[int, int] = {}
|
||||||
|
for ps in msg.persons:
|
||||||
|
if ps.face_id >= 0:
|
||||||
|
mapping[ps.face_id] = ps.person_id
|
||||||
|
with self._persons_lock:
|
||||||
|
self._face_to_person = mapping
|
||||||
|
|
||||||
|
def _get_person_id(self, face_id: int) -> int:
|
||||||
|
with self._persons_lock:
|
||||||
|
return self._face_to_person.get(face_id, -1)
|
||||||
|
|
||||||
|
# ── Main face-detection callback ──────────────────────────────────────────
|
||||||
|
|
||||||
|
def _on_faces(self, msg: FaceDetectionArray) -> None:
|
||||||
|
if not msg.faces:
|
||||||
|
return
|
||||||
|
|
||||||
|
expressions: List[Expression] = []
|
||||||
|
|
||||||
|
for face in msg.faces:
|
||||||
|
person_id = self._get_person_id(face.face_id)
|
||||||
|
|
||||||
|
# Opt-out check
|
||||||
|
if person_id >= 0 and self._tracker.is_opt_out(person_id):
|
||||||
|
expr = Expression()
|
||||||
|
expr.header = msg.header
|
||||||
|
expr.person_id = person_id
|
||||||
|
expr.face_id = face.face_id
|
||||||
|
expr.emotion = ""
|
||||||
|
expr.confidence = 0.0
|
||||||
|
expr.scores = [0.0] * 7
|
||||||
|
expr.is_opt_out = True
|
||||||
|
expr.source = "opt_out"
|
||||||
|
expressions.append(expr)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Try TRT crop classification first
|
||||||
|
result = None
|
||||||
|
if self._classifier.ready and _CV2_OK:
|
||||||
|
result = self._classify_crop(face, person_id)
|
||||||
|
|
||||||
|
# Fallback to landmark heuristic
|
||||||
|
if result is None and self._lm_fallback:
|
||||||
|
result = self._classifier.classify_from_landmarks(
|
||||||
|
list(face.landmarks),
|
||||||
|
person_id=person_id,
|
||||||
|
tracker=self._tracker,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if result.confidence < self._min_conf:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Build ROS message
|
||||||
|
expr = Expression()
|
||||||
|
expr.header = msg.header
|
||||||
|
expr.person_id = person_id
|
||||||
|
expr.face_id = face.face_id
|
||||||
|
expr.emotion = result.emotion
|
||||||
|
expr.confidence = result.confidence
|
||||||
|
# Pad/trim scores to exactly 7
|
||||||
|
sc = result.scores
|
||||||
|
expr.scores = (sc + [0.0] * 7)[:7]
|
||||||
|
expr.is_opt_out = False
|
||||||
|
expr.source = result.source
|
||||||
|
expressions.append(expr)
|
||||||
|
|
||||||
|
# Update LLM context cache
|
||||||
|
if person_id >= 0:
|
||||||
|
self._emotion_context[person_id] = result.emotion
|
||||||
|
|
||||||
|
if not expressions:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Publish ExpressionArray
|
||||||
|
arr = ExpressionArray()
|
||||||
|
arr.header = msg.header
|
||||||
|
arr.expressions = expressions
|
||||||
|
self._expr_pub.publish(arr)
|
||||||
|
|
||||||
|
# Publish JSON context for conversation node
|
||||||
|
if self._ctx_pub is not None:
|
||||||
|
self._publish_context()
|
||||||
|
|
||||||
|
self._tracker.tick()
|
||||||
|
|
||||||
|
def _classify_crop(self, face, person_id: int):
|
||||||
|
"""Extract crop from frame buffer and run TRT inference."""
|
||||||
|
# Resolve camera_id from face header frame_id
|
||||||
|
frame_id = face.header.frame_id if face.header.frame_id else "front"
|
||||||
|
cam_name = frame_id.split("/")[-1].split("_")[0] # "front", "left", etc.
|
||||||
|
camera_id = self._cam_id_map.get(cam_name, 0)
|
||||||
|
|
||||||
|
frame = self._frame_buf.get(camera_id)
|
||||||
|
if frame is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
h, w = frame.shape[:2]
|
||||||
|
x1 = max(0, int(face.bbox_x * w))
|
||||||
|
y1 = max(0, int(face.bbox_y * h))
|
||||||
|
x2 = min(w, int((face.bbox_x + face.bbox_w) * w))
|
||||||
|
y2 = min(h, int((face.bbox_y + face.bbox_h) * h))
|
||||||
|
|
||||||
|
if (x2 - x1) < self._face_min or (y2 - y1) < self._face_min:
|
||||||
|
return None
|
||||||
|
|
||||||
|
crop = frame[y1:y2, x1:x2]
|
||||||
|
if crop.size == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self._classifier.classify_crop(
|
||||||
|
crop, person_id=person_id, tracker=self._tracker
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── LLM context publisher ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _publish_context(self) -> None:
|
||||||
|
"""Publish latest-emotions dict as JSON for conversation_node."""
|
||||||
|
ctx = {str(pid): emo for pid, emo in self._emotion_context.items()}
|
||||||
|
msg = String()
|
||||||
|
msg.data = json.dumps({"emotions": ctx, "ts": time.time()})
|
||||||
|
self._ctx_pub.publish(msg)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Entry point ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = EmotionNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.try_shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -37,6 +37,8 @@ setup(
|
|||||||
'voice_command_node = saltybot_social.voice_command_node:main',
|
'voice_command_node = saltybot_social.voice_command_node:main',
|
||||||
# Multi-camera gesture recognition (Issue #140)
|
# Multi-camera gesture recognition (Issue #140)
|
||||||
'gesture_node = saltybot_social.gesture_node:main',
|
'gesture_node = saltybot_social.gesture_node:main',
|
||||||
|
# Facial expression recognition (Issue #161)
|
||||||
|
'emotion_node = saltybot_social.emotion_node:main',
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@ -0,0 +1,528 @@
|
|||||||
|
"""test_emotion_classifier.py — Unit tests for emotion_classifier (Issue #161).
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- Emotion constant definitions
|
||||||
|
- Softmax normalisation
|
||||||
|
- PersonEmotionTracker: EMA smoothing, age expiry, opt-out, reset
|
||||||
|
- Landmark heuristic classifier: geometric edge cases
|
||||||
|
- EmotionClassifier: classify_crop fallback, classify_from_landmarks
|
||||||
|
- Score renormalisation, argmax, utility helpers
|
||||||
|
- Integration: full classify pipeline with mock engine
|
||||||
|
|
||||||
|
No ROS2, TensorRT, or OpenCV runtime required.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from saltybot_social.emotion_classifier import (
|
||||||
|
EMOTIONS,
|
||||||
|
N_CLASSES,
|
||||||
|
ClassifiedEmotion,
|
||||||
|
PersonEmotionTracker,
|
||||||
|
EmotionClassifier,
|
||||||
|
classify_from_landmarks,
|
||||||
|
_softmax,
|
||||||
|
_argmax,
|
||||||
|
_uniform_scores,
|
||||||
|
_renorm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _lm_neutral() -> list:
|
||||||
|
"""5-point SCRFD landmarks for a frontal neutral face (normalised)."""
|
||||||
|
return [
|
||||||
|
0.35, 0.38, # left_eye
|
||||||
|
0.65, 0.38, # right_eye
|
||||||
|
0.50, 0.52, # nose
|
||||||
|
0.38, 0.72, # left_mouth
|
||||||
|
0.62, 0.72, # right_mouth
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _lm_happy() -> list:
|
||||||
|
"""Wide mouth, symmetric corners, mouth well below nose."""
|
||||||
|
return [
|
||||||
|
0.35, 0.38,
|
||||||
|
0.65, 0.38,
|
||||||
|
0.50, 0.52,
|
||||||
|
0.30, 0.74, # wide mouth
|
||||||
|
0.70, 0.74,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _lm_sad() -> list:
|
||||||
|
"""Compressed mouth, tight width."""
|
||||||
|
return [
|
||||||
|
0.35, 0.38,
|
||||||
|
0.65, 0.38,
|
||||||
|
0.50, 0.52,
|
||||||
|
0.44, 0.62, # tight, close to nose
|
||||||
|
0.56, 0.62,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _lm_asymmetric() -> list:
|
||||||
|
"""Asymmetric mouth corners → disgust/sad signal."""
|
||||||
|
return [
|
||||||
|
0.35, 0.38,
|
||||||
|
0.65, 0.38,
|
||||||
|
0.50, 0.52,
|
||||||
|
0.38, 0.65,
|
||||||
|
0.62, 0.74, # right corner lower → asymmetric
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestEmotionConstants ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmotionConstants:
|
||||||
|
def test_emotions_length(self):
|
||||||
|
assert len(EMOTIONS) == 7
|
||||||
|
|
||||||
|
def test_n_classes(self):
|
||||||
|
assert N_CLASSES == 7
|
||||||
|
|
||||||
|
def test_emotions_labels(self):
|
||||||
|
expected = ["happy", "sad", "angry", "surprised", "fearful", "disgusted", "neutral"]
|
||||||
|
assert EMOTIONS == expected
|
||||||
|
|
||||||
|
def test_happy_index(self):
|
||||||
|
assert EMOTIONS[0] == "happy"
|
||||||
|
|
||||||
|
def test_neutral_index(self):
|
||||||
|
assert EMOTIONS[6] == "neutral"
|
||||||
|
|
||||||
|
def test_all_lowercase(self):
|
||||||
|
for e in EMOTIONS:
|
||||||
|
assert e == e.lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestSoftmax ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestSoftmax:
|
||||||
|
def test_sums_to_one(self):
|
||||||
|
logits = [1.0, 2.0, 0.5, -1.0, 3.0, 0.0, 1.5]
|
||||||
|
result = _softmax(logits)
|
||||||
|
assert abs(sum(result) - 1.0) < 1e-6
|
||||||
|
|
||||||
|
def test_all_positive(self):
|
||||||
|
logits = [1.0, 2.0, 3.0, 0.0, -1.0, -2.0, 0.5]
|
||||||
|
result = _softmax(logits)
|
||||||
|
assert all(s > 0 for s in result)
|
||||||
|
|
||||||
|
def test_max_logit_gives_max_prob(self):
|
||||||
|
logits = [0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
result = _softmax(logits)
|
||||||
|
assert _argmax(result) == 2
|
||||||
|
|
||||||
|
def test_uniform_logits_uniform_probs(self):
|
||||||
|
logits = [1.0] * 7
|
||||||
|
result = _softmax(logits)
|
||||||
|
for p in result:
|
||||||
|
assert abs(p - 1.0 / 7.0) < 1e-6
|
||||||
|
|
||||||
|
def test_length_preserved(self):
|
||||||
|
logits = [0.0] * 7
|
||||||
|
assert len(_softmax(logits)) == 7
|
||||||
|
|
||||||
|
def test_numerically_stable_large_values(self):
|
||||||
|
logits = [1000.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
result = _softmax(logits)
|
||||||
|
assert abs(result[0] - 1.0) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestArgmax ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestArgmax:
|
||||||
|
def test_finds_max(self):
|
||||||
|
assert _argmax([0.1, 0.2, 0.9, 0.3, 0.1, 0.1, 0.1]) == 2
|
||||||
|
|
||||||
|
def test_single_element(self):
|
||||||
|
assert _argmax([0.5]) == 0
|
||||||
|
|
||||||
|
def test_first_when_all_equal(self):
|
||||||
|
result = _argmax([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
|
||||||
|
assert result == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestUniformScores ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestUniformScores:
|
||||||
|
def test_length(self):
|
||||||
|
assert len(_uniform_scores()) == N_CLASSES
|
||||||
|
|
||||||
|
def test_sums_to_one(self):
|
||||||
|
assert abs(sum(_uniform_scores()) - 1.0) < 1e-9
|
||||||
|
|
||||||
|
def test_all_equal(self):
|
||||||
|
s = _uniform_scores()
|
||||||
|
assert all(abs(x - s[0]) < 1e-9 for x in s)
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestRenorm ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRenorm:
|
||||||
|
def test_sums_to_one(self):
|
||||||
|
s = [0.1, 0.2, 0.3, 0.1, 0.1, 0.1, 0.1]
|
||||||
|
r = _renorm(s)
|
||||||
|
assert abs(sum(r) - 1.0) < 1e-9
|
||||||
|
|
||||||
|
def test_all_zero_returns_uniform(self):
|
||||||
|
r = _renorm([0.0] * 7)
|
||||||
|
assert len(r) == 7
|
||||||
|
assert abs(sum(r) - 1.0) < 1e-9
|
||||||
|
|
||||||
|
def test_preserves_order(self):
|
||||||
|
s = [0.5, 0.3, 0.0, 0.0, 0.0, 0.0, 0.2]
|
||||||
|
r = _renorm(s)
|
||||||
|
assert r[0] > r[1] > r[6]
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestPersonEmotionTracker ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestPersonEmotionTracker:
|
||||||
|
def _scores(self, idx: int, val: float = 0.8) -> list:
|
||||||
|
s = [0.0] * N_CLASSES
|
||||||
|
s[idx] = val
|
||||||
|
total = val + (N_CLASSES - 1) * 0.02
|
||||||
|
s = [v if i == idx else 0.02 for i, v in enumerate(s)]
|
||||||
|
s[idx] = val
|
||||||
|
t = sum(s)
|
||||||
|
return [x / t for x in s]
|
||||||
|
|
||||||
|
def test_first_call_returns_input(self):
|
||||||
|
tracker = PersonEmotionTracker(alpha=0.5)
|
||||||
|
scores = self._scores(0)
|
||||||
|
result = tracker.smooth(1, scores)
|
||||||
|
assert abs(result[0] - scores[0]) < 1e-9
|
||||||
|
|
||||||
|
def test_ema_converges_toward_new_dominant(self):
|
||||||
|
tracker = PersonEmotionTracker(alpha=0.5)
|
||||||
|
happy = self._scores(0)
|
||||||
|
sad = self._scores(1)
|
||||||
|
tracker.smooth(1, happy)
|
||||||
|
# Push sad repeatedly
|
||||||
|
prev_sad = tracker.smooth(1, sad)[1]
|
||||||
|
for _ in range(20):
|
||||||
|
result = tracker.smooth(1, sad)
|
||||||
|
assert result[1] > prev_sad # sad score increased
|
||||||
|
|
||||||
|
def test_alpha_1_no_smoothing(self):
|
||||||
|
tracker = PersonEmotionTracker(alpha=1.0)
|
||||||
|
s1 = self._scores(0)
|
||||||
|
s2 = self._scores(1)
|
||||||
|
tracker.smooth(1, s1)
|
||||||
|
result = tracker.smooth(1, s2)
|
||||||
|
for a, b in zip(result, s2):
|
||||||
|
assert abs(a - b) < 1e-9
|
||||||
|
|
||||||
|
def test_alpha_0_frozen(self):
|
||||||
|
tracker = PersonEmotionTracker(alpha=0.0)
|
||||||
|
s1 = self._scores(0)
|
||||||
|
s2 = self._scores(1)
|
||||||
|
tracker.smooth(1, s1)
|
||||||
|
result = tracker.smooth(1, s2)
|
||||||
|
for a, b in zip(result, s1):
|
||||||
|
assert abs(a - b) < 1e-9
|
||||||
|
|
||||||
|
def test_different_persons_independent(self):
|
||||||
|
tracker = PersonEmotionTracker(alpha=0.5)
|
||||||
|
happy = self._scores(0)
|
||||||
|
sad = self._scores(1)
|
||||||
|
tracker.smooth(1, happy)
|
||||||
|
tracker.smooth(2, sad)
|
||||||
|
r1 = tracker.smooth(1, happy)
|
||||||
|
r2 = tracker.smooth(2, sad)
|
||||||
|
assert r1[0] > r2[0] # person 1 more happy
|
||||||
|
assert r2[1] > r1[1] # person 2 more sad
|
||||||
|
|
||||||
|
def test_negative_person_id_no_tracking(self):
|
||||||
|
tracker = PersonEmotionTracker()
|
||||||
|
scores = self._scores(2)
|
||||||
|
result = tracker.smooth(-1, scores)
|
||||||
|
assert result == scores # unchanged, not stored
|
||||||
|
|
||||||
|
def test_reset_clears_ema(self):
|
||||||
|
tracker = PersonEmotionTracker(alpha=0.5)
|
||||||
|
s1 = self._scores(0)
|
||||||
|
tracker.smooth(1, s1)
|
||||||
|
tracker.reset(1)
|
||||||
|
assert 1 not in tracker.tracked_ids
|
||||||
|
|
||||||
|
def test_reset_all_clears_all(self):
|
||||||
|
tracker = PersonEmotionTracker(alpha=0.5)
|
||||||
|
for pid in range(5):
|
||||||
|
tracker.smooth(pid, self._scores(pid % N_CLASSES))
|
||||||
|
tracker.reset_all()
|
||||||
|
assert tracker.tracked_ids == []
|
||||||
|
|
||||||
|
def test_tracked_ids_populated(self):
|
||||||
|
tracker = PersonEmotionTracker()
|
||||||
|
tracker.smooth(10, self._scores(0))
|
||||||
|
tracker.smooth(20, self._scores(1))
|
||||||
|
assert set(tracker.tracked_ids) == {10, 20}
|
||||||
|
|
||||||
|
def test_age_expiry_resets_ema(self):
|
||||||
|
tracker = PersonEmotionTracker(alpha=0.5, max_age=3)
|
||||||
|
tracker.smooth(1, self._scores(0))
|
||||||
|
# Advance age beyond max_age
|
||||||
|
for _ in range(4):
|
||||||
|
tracker.tick()
|
||||||
|
# Next smooth after expiry should reset (first call returns input unchanged)
|
||||||
|
fresh_scores = self._scores(1)
|
||||||
|
result = tracker.smooth(1, fresh_scores)
|
||||||
|
# After reset, result should equal fresh_scores exactly
|
||||||
|
for a, b in zip(result, fresh_scores):
|
||||||
|
assert abs(a - b) < 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestOptOut ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestOptOut:
|
||||||
|
def test_set_opt_out_true(self):
|
||||||
|
tracker = PersonEmotionTracker()
|
||||||
|
tracker.set_opt_out(42, True)
|
||||||
|
assert tracker.is_opt_out(42)
|
||||||
|
|
||||||
|
def test_set_opt_out_false(self):
|
||||||
|
tracker = PersonEmotionTracker()
|
||||||
|
tracker.set_opt_out(42, True)
|
||||||
|
tracker.set_opt_out(42, False)
|
||||||
|
assert not tracker.is_opt_out(42)
|
||||||
|
|
||||||
|
def test_opt_out_clears_ema(self):
|
||||||
|
tracker = PersonEmotionTracker()
|
||||||
|
tracker.smooth(42, [0.5, 0.1, 0.1, 0.1, 0.1, 0.0, 0.1])
|
||||||
|
assert 42 in tracker.tracked_ids
|
||||||
|
tracker.set_opt_out(42, True)
|
||||||
|
assert 42 not in tracker.tracked_ids
|
||||||
|
|
||||||
|
def test_unknown_person_not_opted_out(self):
|
||||||
|
tracker = PersonEmotionTracker()
|
||||||
|
assert not tracker.is_opt_out(99)
|
||||||
|
|
||||||
|
def test_multiple_opt_outs(self):
|
||||||
|
tracker = PersonEmotionTracker()
|
||||||
|
for pid in [1, 2, 3]:
|
||||||
|
tracker.set_opt_out(pid, True)
|
||||||
|
for pid in [1, 2, 3]:
|
||||||
|
assert tracker.is_opt_out(pid)
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestClassifyFromLandmarks ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestClassifyFromLandmarks:
|
||||||
|
def test_returns_classified_emotion(self):
|
||||||
|
result = classify_from_landmarks(_lm_neutral())
|
||||||
|
assert isinstance(result, ClassifiedEmotion)
|
||||||
|
|
||||||
|
def test_emotion_is_valid_label(self):
|
||||||
|
result = classify_from_landmarks(_lm_neutral())
|
||||||
|
assert result.emotion in EMOTIONS
|
||||||
|
|
||||||
|
def test_scores_length(self):
|
||||||
|
result = classify_from_landmarks(_lm_neutral())
|
||||||
|
assert len(result.scores) == N_CLASSES
|
||||||
|
|
||||||
|
def test_scores_sum_to_one(self):
|
||||||
|
result = classify_from_landmarks(_lm_neutral())
|
||||||
|
# Scores are rounded to 4 dp; allow 1e-3 accumulation across 7 terms
|
||||||
|
assert abs(sum(result.scores) - 1.0) < 1e-3
|
||||||
|
|
||||||
|
def test_confidence_matches_top_score(self):
|
||||||
|
result = classify_from_landmarks(_lm_neutral())
|
||||||
|
# confidence is round(score, 3) and max score is round(s, 4) → ≤0.5e-3 diff
|
||||||
|
assert abs(result.confidence - max(result.scores)) < 5e-3
|
||||||
|
|
||||||
|
def test_source_is_landmark_heuristic(self):
|
||||||
|
result = classify_from_landmarks(_lm_neutral())
|
||||||
|
assert result.source == "landmark_heuristic"
|
||||||
|
|
||||||
|
def test_happy_landmarks_boost_happy(self):
|
||||||
|
happy = classify_from_landmarks(_lm_happy())
|
||||||
|
neutral = classify_from_landmarks(_lm_neutral())
|
||||||
|
# Happy landmarks should give relatively higher happy score
|
||||||
|
assert happy.scores[0] >= neutral.scores[0]
|
||||||
|
|
||||||
|
def test_sad_landmarks_suppress_happy(self):
|
||||||
|
sad_result = classify_from_landmarks(_lm_sad())
|
||||||
|
# Happy score should be relatively low for sad landmarks
|
||||||
|
assert sad_result.scores[0] < 0.5
|
||||||
|
|
||||||
|
def test_asymmetric_mouth_non_neutral(self):
|
||||||
|
asym = classify_from_landmarks(_lm_asymmetric())
|
||||||
|
# Asymmetric → disgust or sad should be elevated
|
||||||
|
assert asym.scores[5] > 0.10 or asym.scores[1] > 0.10
|
||||||
|
|
||||||
|
def test_empty_landmarks_returns_neutral(self):
|
||||||
|
result = classify_from_landmarks([])
|
||||||
|
assert result.emotion == "neutral"
|
||||||
|
|
||||||
|
def test_short_landmarks_returns_neutral(self):
|
||||||
|
result = classify_from_landmarks([0.5, 0.5])
|
||||||
|
assert result.emotion == "neutral"
|
||||||
|
|
||||||
|
def test_all_positive_scores(self):
|
||||||
|
result = classify_from_landmarks(_lm_happy())
|
||||||
|
assert all(s >= 0.0 for s in result.scores)
|
||||||
|
|
||||||
|
def test_confidence_in_range(self):
|
||||||
|
result = classify_from_landmarks(_lm_neutral())
|
||||||
|
assert 0.0 <= result.confidence <= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestEmotionClassifier ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmotionClassifier:
|
||||||
|
def test_init_no_engine_not_ready(self):
|
||||||
|
clf = EmotionClassifier(engine_path="")
|
||||||
|
assert not clf.ready
|
||||||
|
|
||||||
|
def test_load_empty_path_returns_false(self):
|
||||||
|
clf = EmotionClassifier(engine_path="")
|
||||||
|
assert clf.load() is False
|
||||||
|
|
||||||
|
def test_load_nonexistent_path_returns_false(self):
|
||||||
|
clf = EmotionClassifier(engine_path="/nonexistent/engine.trt")
|
||||||
|
result = clf.load()
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_classify_crop_fallback_without_engine(self):
|
||||||
|
"""Without TRT engine, classify_crop should return a valid result
|
||||||
|
with landmark heuristic source."""
|
||||||
|
clf = EmotionClassifier(engine_path="")
|
||||||
|
# Build a minimal synthetic 48x48 BGR image
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
fake_crop = np.zeros((48, 48, 3), dtype="uint8")
|
||||||
|
result = clf.classify_crop(fake_crop, person_id=-1, tracker=None)
|
||||||
|
assert isinstance(result, ClassifiedEmotion)
|
||||||
|
assert result.emotion in EMOTIONS
|
||||||
|
assert 0.0 <= result.confidence <= 1.0
|
||||||
|
assert len(result.scores) == N_CLASSES
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("numpy not available")
|
||||||
|
|
||||||
|
def test_classify_from_landmarks_delegates(self):
|
||||||
|
clf = EmotionClassifier(engine_path="")
|
||||||
|
result = clf.classify_from_landmarks(_lm_happy())
|
||||||
|
assert isinstance(result, ClassifiedEmotion)
|
||||||
|
assert result.source == "landmark_heuristic"
|
||||||
|
|
||||||
|
def test_classify_from_landmarks_with_tracker_smooths(self):
|
||||||
|
clf = EmotionClassifier(engine_path="", alpha=0.5)
|
||||||
|
tracker = PersonEmotionTracker(alpha=0.5)
|
||||||
|
r1 = clf.classify_from_landmarks(_lm_happy(), person_id=1, tracker=tracker)
|
||||||
|
r2 = clf.classify_from_landmarks(_lm_sad(), person_id=1, tracker=tracker)
|
||||||
|
# After smoothing, r2's top score should not be identical to raw sad scores
|
||||||
|
# (EMA blends r1 history into r2)
|
||||||
|
raw_sad = classify_from_landmarks(_lm_sad())
|
||||||
|
assert r2.scores != raw_sad.scores
|
||||||
|
|
||||||
|
def test_classify_from_landmarks_no_tracker_no_smooth(self):
|
||||||
|
clf = EmotionClassifier(engine_path="")
|
||||||
|
r1 = clf.classify_from_landmarks(_lm_happy(), person_id=1, tracker=None)
|
||||||
|
r2 = clf.classify_from_landmarks(_lm_happy(), person_id=1, tracker=None)
|
||||||
|
# Without tracker, same input → same output
|
||||||
|
assert r1.scores == r2.scores
|
||||||
|
|
||||||
|
def test_source_fallback_when_no_engine(self):
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("numpy not available")
|
||||||
|
clf = EmotionClassifier(engine_path="")
|
||||||
|
crop = __import__("numpy").zeros((48, 48, 3), dtype="uint8")
|
||||||
|
result = clf.classify_crop(crop)
|
||||||
|
assert result.source == "landmark_heuristic"
|
||||||
|
|
||||||
|
def test_classifier_alpha_propagated_to_tracker(self):
|
||||||
|
clf = EmotionClassifier(engine_path="", alpha=0.1)
|
||||||
|
assert clf._alpha == 0.1
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestClassifiedEmotionDataclass ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestClassifiedEmotionDataclass:
|
||||||
|
def test_fields(self):
|
||||||
|
ce = ClassifiedEmotion(
|
||||||
|
emotion="happy",
|
||||||
|
confidence=0.85,
|
||||||
|
scores=[0.85, 0.02, 0.02, 0.03, 0.02, 0.02, 0.04],
|
||||||
|
source="cnn_trt",
|
||||||
|
)
|
||||||
|
assert ce.emotion == "happy"
|
||||||
|
assert ce.confidence == 0.85
|
||||||
|
assert len(ce.scores) == 7
|
||||||
|
assert ce.source == "cnn_trt"
|
||||||
|
|
||||||
|
def test_default_source(self):
|
||||||
|
ce = ClassifiedEmotion("neutral", 0.6, [0.0] * 7)
|
||||||
|
assert ce.source == "cnn_trt"
|
||||||
|
|
||||||
|
def test_emotion_is_mutable(self):
|
||||||
|
ce = ClassifiedEmotion("neutral", 0.6, [0.0] * 7)
|
||||||
|
ce.emotion = "happy"
|
||||||
|
assert ce.emotion == "happy"
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestEdgeCases ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
def test_softmax_single_element(self):
|
||||||
|
result = _softmax([5.0])
|
||||||
|
assert abs(result[0] - 1.0) < 1e-9
|
||||||
|
|
||||||
|
def test_tracker_negative_id_no_stored_state(self):
|
||||||
|
tracker = PersonEmotionTracker()
|
||||||
|
scores = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.4]
|
||||||
|
tracker.smooth(-1, scores)
|
||||||
|
assert tracker.tracked_ids == []
|
||||||
|
|
||||||
|
def test_tick_increments_age(self):
|
||||||
|
tracker = PersonEmotionTracker(max_age=2)
|
||||||
|
tracker.smooth(1, [0.1] * 7)
|
||||||
|
tracker.tick()
|
||||||
|
tracker.tick()
|
||||||
|
tracker.tick()
|
||||||
|
# Age should exceed max_age → next smooth resets
|
||||||
|
fresh = [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]
|
||||||
|
result = tracker.smooth(1, fresh)
|
||||||
|
assert abs(result[2] - 1.0) < 1e-9
|
||||||
|
|
||||||
|
def test_renorm_negative_scores_safe(self):
|
||||||
|
# Negative scores shouldn't crash (though unusual in practice)
|
||||||
|
scores = [0.1, 0.2, 0.0, 0.0, 0.0, 0.0, 0.3]
|
||||||
|
r = _renorm(scores)
|
||||||
|
assert abs(sum(r) - 1.0) < 1e-9
|
||||||
|
|
||||||
|
def test_landmark_heuristic_very_close_eye_mouth(self):
|
||||||
|
# Degenerate face where everything is at same y → should not crash
|
||||||
|
lm = [0.3, 0.5, 0.7, 0.5, 0.5, 0.5, 0.4, 0.5, 0.6, 0.5]
|
||||||
|
result = classify_from_landmarks(lm)
|
||||||
|
assert result.emotion in EMOTIONS
|
||||||
|
|
||||||
|
def test_opt_out_then_re_enable(self):
|
||||||
|
tracker = PersonEmotionTracker()
|
||||||
|
tracker.smooth(5, [0.1] * 7)
|
||||||
|
tracker.set_opt_out(5, True)
|
||||||
|
assert tracker.is_opt_out(5)
|
||||||
|
tracker.set_opt_out(5, False)
|
||||||
|
assert not tracker.is_opt_out(5)
|
||||||
|
# Should be able to smooth again after re-enable
|
||||||
|
result = tracker.smooth(5, [0.2] * 7)
|
||||||
|
assert len(result) == N_CLASSES
|
||||||
@ -38,6 +38,9 @@ rosidl_generate_interfaces(${PROJECT_NAME}
|
|||||||
# Issue #140 — gesture recognition
|
# Issue #140 — gesture recognition
|
||||||
"msg/Gesture.msg"
|
"msg/Gesture.msg"
|
||||||
"msg/GestureArray.msg"
|
"msg/GestureArray.msg"
|
||||||
|
# Issue #161 — emotion detection
|
||||||
|
"msg/Expression.msg"
|
||||||
|
"msg/ExpressionArray.msg"
|
||||||
DEPENDENCIES std_msgs geometry_msgs builtin_interfaces
|
DEPENDENCIES std_msgs geometry_msgs builtin_interfaces
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
17
jetson/ros2_ws/src/saltybot_social_msgs/msg/Expression.msg
Normal file
17
jetson/ros2_ws/src/saltybot_social_msgs/msg/Expression.msg
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# Expression.msg — Detected facial expression for one person (Issue #161).
|
||||||
|
# Published by emotion_node on /social/faces/expressions
|
||||||
|
|
||||||
|
std_msgs/Header header
|
||||||
|
|
||||||
|
int32 person_id # -1 = unidentified; matches PersonState.person_id
|
||||||
|
int32 face_id # matches FaceDetection.face_id
|
||||||
|
|
||||||
|
string emotion # one of: happy sad angry surprised fearful disgusted neutral
|
||||||
|
float32 confidence # smoothed confidence of the top emotion (0.0–1.0)
|
||||||
|
|
||||||
|
float32[7] scores # per-class softmax scores, order:
|
||||||
|
# [0]=happy [1]=sad [2]=angry [3]=surprised
|
||||||
|
# [4]=fearful [5]=disgusted [6]=neutral
|
||||||
|
|
||||||
|
bool is_opt_out # true = this person opted out; no emotion data published
|
||||||
|
string source # "cnn_trt" | "landmark_heuristic" | "opt_out"
|
||||||
@ -0,0 +1,5 @@
|
|||||||
|
# ExpressionArray.msg — Batch of detected facial expressions (Issue #161).
|
||||||
|
# Published by emotion_node on /social/faces/expressions
|
||||||
|
|
||||||
|
std_msgs/Header header
|
||||||
|
Expression[] expressions
|
||||||
30
src/jlink.c
30
src/jlink.c
@ -176,6 +176,11 @@ static void dispatch(const uint8_t *payload, uint8_t cmd, uint8_t plen)
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
case JLINK_CMD_SLEEP:
|
||||||
|
/* Payload-less; main loop calls power_mgmt_request_sleep() */
|
||||||
|
jlink_state.sleep_req = 1u;
|
||||||
|
break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -290,3 +295,28 @@ void jlink_send_telemetry(const jlink_tlm_status_t *status)
|
|||||||
|
|
||||||
HAL_UART_Transmit(&s_uart, frame, sizeof(frame), 5u);
|
HAL_UART_Transmit(&s_uart, frame, sizeof(frame), 5u);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ---- jlink_send_power_telemetry() ---- */
|
||||||
|
void jlink_send_power_telemetry(const jlink_tlm_power_t *power)
|
||||||
|
{
|
||||||
|
/*
|
||||||
|
* Frame: [STX][LEN][0x81][11 bytes POWER][CRC_hi][CRC_lo][ETX]
|
||||||
|
* LEN = 1 (CMD) + 11 (payload) = 12; total = 17 bytes
|
||||||
|
* At 921600 baud: 17×10/921600 ≈ 0.18 ms — safe to block.
|
||||||
|
*/
|
||||||
|
static uint8_t frame[17];
|
||||||
|
const uint8_t plen = (uint8_t)sizeof(jlink_tlm_power_t); /* 11 */
|
||||||
|
const uint8_t len = 1u + plen; /* 12 */
|
||||||
|
|
||||||
|
frame[0] = JLINK_STX;
|
||||||
|
frame[1] = len;
|
||||||
|
frame[2] = JLINK_TLM_POWER;
|
||||||
|
memcpy(&frame[3], power, plen);
|
||||||
|
|
||||||
|
uint16_t crc = crc16_xmodem(&frame[2], len);
|
||||||
|
frame[3 + plen] = (uint8_t)(crc >> 8);
|
||||||
|
frame[3 + plen + 1] = (uint8_t)(crc & 0xFFu);
|
||||||
|
frame[3 + plen + 2] = JLINK_ETX;
|
||||||
|
|
||||||
|
HAL_UART_Transmit(&s_uart, frame, sizeof(frame), 5u);
|
||||||
|
}
|
||||||
|
|||||||
45
src/main.c
45
src/main.c
@ -19,6 +19,7 @@
|
|||||||
#include "jlink.h"
|
#include "jlink.h"
|
||||||
#include "ota.h"
|
#include "ota.h"
|
||||||
#include "audio.h"
|
#include "audio.h"
|
||||||
|
#include "power_mgmt.h"
|
||||||
#include "battery.h"
|
#include "battery.h"
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
@ -149,6 +150,9 @@ int main(void) {
|
|||||||
audio_init();
|
audio_init();
|
||||||
audio_play_tone(AUDIO_TONE_STARTUP);
|
audio_play_tone(AUDIO_TONE_STARTUP);
|
||||||
|
|
||||||
|
/* Init power management — STOP-mode sleep/wake, wake EXTIs configured */
|
||||||
|
power_mgmt_init();
|
||||||
|
|
||||||
/* Init mode manager (RC/autonomous blend; CH6 mode switch) */
|
/* Init mode manager (RC/autonomous blend; CH6 mode switch) */
|
||||||
mode_manager_t mode;
|
mode_manager_t mode;
|
||||||
mode_manager_init(&mode);
|
mode_manager_init(&mode);
|
||||||
@ -183,6 +187,8 @@ int main(void) {
|
|||||||
uint32_t esc_tick = 0;
|
uint32_t esc_tick = 0;
|
||||||
uint32_t crsf_telem_tick = 0; /* CRSF uplink telemetry TX timer */
|
uint32_t crsf_telem_tick = 0; /* CRSF uplink telemetry TX timer */
|
||||||
uint32_t jlink_tlm_tick = 0; /* Jetson binary telemetry TX timer */
|
uint32_t jlink_tlm_tick = 0; /* Jetson binary telemetry TX timer */
|
||||||
|
uint32_t pm_tlm_tick = 0; /* JLINK_TLM_POWER transmit timer */
|
||||||
|
uint8_t pm_pwm_phase = 0; /* Software PWM counter for sleep LED */
|
||||||
const float dt = 1.0f / PID_LOOP_HZ; /* 1ms at 1kHz */
|
const float dt = 1.0f / PID_LOOP_HZ; /* 1ms at 1kHz */
|
||||||
|
|
||||||
while (1) {
|
while (1) {
|
||||||
@ -196,6 +202,23 @@ int main(void) {
|
|||||||
/* Advance audio tone sequencer (non-blocking, call every tick) */
|
/* Advance audio tone sequencer (non-blocking, call every tick) */
|
||||||
audio_tick(now);
|
audio_tick(now);
|
||||||
|
|
||||||
|
/* Sleep LED: software PWM on LED1 (active-low PC15) driven by PM brightness.
|
||||||
|
* pm_pwm_phase rolls over each ms; brightness sets duty cycle 0-255. */
|
||||||
|
pm_pwm_phase++;
|
||||||
|
{
|
||||||
|
uint8_t pm_bright = power_mgmt_led_brightness();
|
||||||
|
if (pm_bright > 0u) {
|
||||||
|
bool led_on = (pm_pwm_phase < pm_bright);
|
||||||
|
HAL_GPIO_WritePin(LED1_PORT, LED1_PIN,
|
||||||
|
led_on ? GPIO_PIN_RESET : GPIO_PIN_SET);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Power manager tick — may block in WFI (STOP mode) when disarmed */
|
||||||
|
if (bal.state != BALANCE_ARMED) {
|
||||||
|
power_mgmt_tick(now);
|
||||||
|
}
|
||||||
|
|
||||||
/* Mode manager: update RC liveness, CH6 mode selection, blend ramp */
|
/* Mode manager: update RC liveness, CH6 mode selection, blend ramp */
|
||||||
mode_manager_update(&mode, now);
|
mode_manager_update(&mode, now);
|
||||||
|
|
||||||
@ -249,6 +272,16 @@ int main(void) {
|
|||||||
* never returns when disarmed — MCU resets into DFU mode. */
|
* never returns when disarmed — MCU resets into DFU mode. */
|
||||||
ota_enter_dfu(bal.state == BALANCE_ARMED);
|
ota_enter_dfu(bal.state == BALANCE_ARMED);
|
||||||
}
|
}
|
||||||
|
if (jlink_state.sleep_req) {
|
||||||
|
jlink_state.sleep_req = 0u;
|
||||||
|
power_mgmt_request_sleep();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Power management: CRSF/JLink activity or armed state resets idle timer */
|
||||||
|
if (crsf_is_active(now) || jlink_is_active(now) ||
|
||||||
|
bal.state == BALANCE_ARMED) {
|
||||||
|
power_mgmt_activity();
|
||||||
|
}
|
||||||
|
|
||||||
/* RC CH5 kill switch: disarm immediately if RC is alive and CH5 off.
|
/* RC CH5 kill switch: disarm immediately if RC is alive and CH5 off.
|
||||||
* Applies regardless of active mode (CH5 always has kill authority). */
|
* Applies regardless of active mode (CH5 always has kill authority). */
|
||||||
@ -424,6 +457,18 @@ int main(void) {
|
|||||||
jlink_send_telemetry(&tlm);
|
jlink_send_telemetry(&tlm);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* JLINK_TLM_POWER telemetry at PM_TLM_HZ (1 Hz) */
|
||||||
|
if (now - pm_tlm_tick >= (1000u / PM_TLM_HZ)) {
|
||||||
|
pm_tlm_tick = now;
|
||||||
|
jlink_tlm_power_t pow;
|
||||||
|
pow.power_state = (uint8_t)power_mgmt_state();
|
||||||
|
pow.est_total_ma = power_mgmt_current_ma();
|
||||||
|
pow.est_audio_ma = (uint16_t)(power_mgmt_state() == PM_SLEEPING ? 0u : PM_CURRENT_AUDIO_MA);
|
||||||
|
pow.est_osd_ma = (uint16_t)(power_mgmt_state() == PM_SLEEPING ? 0u : PM_CURRENT_OSD_MA);
|
||||||
|
pow.idle_ms = power_mgmt_idle_ms();
|
||||||
|
jlink_send_power_telemetry(&pow);
|
||||||
|
}
|
||||||
|
|
||||||
/* USB telemetry at 50Hz (only when streaming enabled and calibration done) */
|
/* USB telemetry at 50Hz (only when streaming enabled and calibration done) */
|
||||||
if (cdc_streaming && imu_calibrated() && now - send_tick >= 20) {
|
if (cdc_streaming && imu_calibrated() && now - send_tick >= 20) {
|
||||||
send_tick = now;
|
send_tick = now;
|
||||||
|
|||||||
251
src/power_mgmt.c
Normal file
251
src/power_mgmt.c
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
#include "power_mgmt.h"
|
||||||
|
#include "config.h"
|
||||||
|
#include "stm32f7xx_hal.h"
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
/* ---- Internal state ---- */
|
||||||
|
static PowerState s_state = PM_ACTIVE;
|
||||||
|
static uint32_t s_last_active = 0;
|
||||||
|
static uint32_t s_fade_start = 0;
|
||||||
|
static bool s_sleep_req = false;
|
||||||
|
static bool s_peripherals_gated = false;
|
||||||
|
|
||||||
|
/* ---- EXTI wake-source configuration ---- */
|
||||||
|
/*
|
||||||
|
* EXTI1 → PA1 (UART4_RX / CRSF): falling edge (UART start bit)
|
||||||
|
* EXTI7 → PB7 (USART1_RX / JLink): falling edge
|
||||||
|
* EXTI4 → PC4 (MPU6000 INT): already configured by mpu6000_init();
|
||||||
|
* we just ensure IMR bit is set.
|
||||||
|
*
|
||||||
|
* GPIO pins remain in their current AF mode; EXTI is pad-level and
|
||||||
|
* fires independently of the AF setting.
|
||||||
|
*/
|
||||||
|
static void enable_wake_exti(void)
|
||||||
|
{
|
||||||
|
__HAL_RCC_SYSCFG_CLK_ENABLE();
|
||||||
|
|
||||||
|
/* EXTI1: PA1 (UART4_RX) — SYSCFG EXTICR1[7:4] = 0000 (PA) */
|
||||||
|
SYSCFG->EXTICR[0] = (SYSCFG->EXTICR[0] & ~(0xFu << 4)) | (0x0u << 4);
|
||||||
|
EXTI->FTSR |= (1u << 1);
|
||||||
|
EXTI->RTSR &= ~(1u << 1);
|
||||||
|
EXTI->PR = (1u << 1); /* clear pending */
|
||||||
|
EXTI->IMR |= (1u << 1);
|
||||||
|
HAL_NVIC_SetPriority(EXTI1_IRQn, 5, 0);
|
||||||
|
HAL_NVIC_EnableIRQ(EXTI1_IRQn);
|
||||||
|
|
||||||
|
/* EXTI7: PB7 (USART1_RX) — SYSCFG EXTICR2[15:12] = 0001 (PB) */
|
||||||
|
SYSCFG->EXTICR[1] = (SYSCFG->EXTICR[1] & ~(0xFu << 12)) | (0x1u << 12);
|
||||||
|
EXTI->FTSR |= (1u << 7);
|
||||||
|
EXTI->RTSR &= ~(1u << 7);
|
||||||
|
EXTI->PR = (1u << 7);
|
||||||
|
EXTI->IMR |= (1u << 7);
|
||||||
|
HAL_NVIC_SetPriority(EXTI9_5_IRQn, 5, 0);
|
||||||
|
HAL_NVIC_EnableIRQ(EXTI9_5_IRQn);
|
||||||
|
|
||||||
|
/* EXTI4: PC4 (MPU6000 INT) — handler in mpu6000.c; just ensure IMR set */
|
||||||
|
EXTI->IMR |= (1u << 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void disable_wake_exti(void)
|
||||||
|
{
|
||||||
|
/* Mask UART RX wake EXTIs now that UART peripherals handle traffic */
|
||||||
|
EXTI->IMR &= ~(1u << 1);
|
||||||
|
EXTI->IMR &= ~(1u << 7);
|
||||||
|
/* Leave EXTI4 (IMU data-ready) always unmasked */
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ---- Peripheral clock gating ---- */
|
||||||
|
/*
|
||||||
|
* Clock-only gate (no force-reset): peripheral register state is preserved.
|
||||||
|
* On re-enable, DMA circular transfers resume without reinitialisation.
|
||||||
|
*/
|
||||||
|
static void gate_peripherals(void)
|
||||||
|
{
|
||||||
|
if (s_peripherals_gated) return;
|
||||||
|
__HAL_RCC_SPI3_CLK_DISABLE(); /* I2S3 / audio amplifier */
|
||||||
|
__HAL_RCC_SPI2_CLK_DISABLE(); /* OSD MAX7456 */
|
||||||
|
__HAL_RCC_USART6_CLK_DISABLE(); /* legacy Jetson CDC */
|
||||||
|
__HAL_RCC_UART5_CLK_DISABLE(); /* debug UART */
|
||||||
|
s_peripherals_gated = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ungate_peripherals(void)
|
||||||
|
{
|
||||||
|
if (!s_peripherals_gated) return;
|
||||||
|
__HAL_RCC_SPI3_CLK_ENABLE();
|
||||||
|
__HAL_RCC_SPI2_CLK_ENABLE();
|
||||||
|
__HAL_RCC_USART6_CLK_ENABLE();
|
||||||
|
__HAL_RCC_UART5_CLK_ENABLE();
|
||||||
|
s_peripherals_gated = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ---- PLL clock restore after STOP mode ---- */
|
||||||
|
/*
|
||||||
|
* After STOP wakeup SYSCLK = HSI (16 MHz). Re-lock PLL for 216 MHz.
|
||||||
|
* PLLM=8, PLLN=216, PLLP=2, PLLQ=9 — STM32F722 @ 216 MHz, HSI source.
|
||||||
|
*
|
||||||
|
* HAL_RCC_ClockConfig() calls HAL_InitTick() which resets uwTick to 0;
|
||||||
|
* save and restore it so existing timeouts remain valid across sleep.
|
||||||
|
*/
|
||||||
|
extern volatile uint32_t uwTick;
|
||||||
|
|
||||||
|
static void restore_clocks(void)
|
||||||
|
{
|
||||||
|
uint32_t saved_tick = uwTick;
|
||||||
|
|
||||||
|
RCC_OscInitTypeDef osc = {0};
|
||||||
|
osc.OscillatorType = RCC_OSCILLATORTYPE_HSI;
|
||||||
|
osc.HSIState = RCC_HSI_ON;
|
||||||
|
osc.HSICalibrationValue = RCC_HSICALIBRATION_DEFAULT;
|
||||||
|
osc.PLL.PLLState = RCC_PLL_ON;
|
||||||
|
osc.PLL.PLLSource = RCC_PLLSOURCE_HSI;
|
||||||
|
osc.PLL.PLLM = 8;
|
||||||
|
osc.PLL.PLLN = 216;
|
||||||
|
osc.PLL.PLLP = RCC_PLLP_DIV2;
|
||||||
|
osc.PLL.PLLQ = 9;
|
||||||
|
HAL_RCC_OscConfig(&osc);
|
||||||
|
|
||||||
|
RCC_ClkInitTypeDef clk = {0};
|
||||||
|
clk.ClockType = RCC_CLOCKTYPE_HCLK | RCC_CLOCKTYPE_SYSCLK |
|
||||||
|
RCC_CLOCKTYPE_PCLK1 | RCC_CLOCKTYPE_PCLK2;
|
||||||
|
clk.SYSCLKSource = RCC_SYSCLKSOURCE_PLLCLK;
|
||||||
|
clk.AHBCLKDivider = RCC_SYSCLK_DIV1;
|
||||||
|
clk.APB1CLKDivider = RCC_HCLK_DIV4; /* 54 MHz */
|
||||||
|
clk.APB2CLKDivider = RCC_HCLK_DIV2; /* 108 MHz */
|
||||||
|
HAL_RCC_ClockConfig(&clk, FLASH_LATENCY_7);
|
||||||
|
|
||||||
|
uwTick = saved_tick; /* restore — HAL_InitTick() reset it to 0 */
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ---- EXTI IRQ handlers (wake-only: clear pending bit and return) ---- */
|
||||||
|
/*
|
||||||
|
* These handlers fire once on wakeup. After restore_clocks() the respective
|
||||||
|
* UART peripherals resume normal DMA/IDLE-interrupt operation.
|
||||||
|
*
|
||||||
|
* NOTE: If EXTI9_5_IRQHandler is already defined elsewhere in the project,
|
||||||
|
* merge that handler with this one.
|
||||||
|
*/
|
||||||
|
void EXTI1_IRQHandler(void)
|
||||||
|
{
|
||||||
|
if (EXTI->PR & (1u << 1)) EXTI->PR = (1u << 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void EXTI9_5_IRQHandler(void)
|
||||||
|
{
|
||||||
|
/* Clear any pending EXTI5-9 lines (PB7 = EXTI7 is our primary wake) */
|
||||||
|
uint32_t pr = EXTI->PR & 0x3E0u;
|
||||||
|
if (pr) EXTI->PR = pr;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ---- LED brightness (integer arithmetic, no float, called from main loop) ---- */
|
||||||
|
/*
|
||||||
|
* Triangle wave: 0→255→0 over PM_LED_PERIOD_MS.
|
||||||
|
* Only active during PM_SLEEP_PENDING; returns 0 otherwise.
|
||||||
|
*/
|
||||||
|
uint8_t power_mgmt_led_brightness(void)
|
||||||
|
{
|
||||||
|
if (s_state != PM_SLEEP_PENDING) return 0u;
|
||||||
|
|
||||||
|
uint32_t phase = (HAL_GetTick() - s_fade_start) % PM_LED_PERIOD_MS;
|
||||||
|
uint32_t half = PM_LED_PERIOD_MS / 2u;
|
||||||
|
if (phase < half)
|
||||||
|
return (uint8_t)(phase * 255u / half);
|
||||||
|
else
|
||||||
|
return (uint8_t)((PM_LED_PERIOD_MS - phase) * 255u / half);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ---- Current estimate ---- */
|
||||||
|
uint16_t power_mgmt_current_ma(void)
|
||||||
|
{
|
||||||
|
if (s_state == PM_SLEEPING)
|
||||||
|
return (uint16_t)PM_CURRENT_STOP_MA;
|
||||||
|
uint16_t ma = (uint16_t)PM_CURRENT_BASE_MA;
|
||||||
|
if (!s_peripherals_gated) {
|
||||||
|
ma += (uint16_t)(PM_CURRENT_AUDIO_MA + PM_CURRENT_OSD_MA +
|
||||||
|
PM_CURRENT_DEBUG_MA);
|
||||||
|
}
|
||||||
|
return ma;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ---- Idle elapsed ---- */
|
||||||
|
uint32_t power_mgmt_idle_ms(void)
|
||||||
|
{
|
||||||
|
return HAL_GetTick() - s_last_active;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ---- Public API ---- */
|
||||||
|
void power_mgmt_init(void)
|
||||||
|
{
|
||||||
|
s_state = PM_ACTIVE;
|
||||||
|
s_last_active = HAL_GetTick();
|
||||||
|
s_fade_start = 0;
|
||||||
|
s_sleep_req = false;
|
||||||
|
s_peripherals_gated = false;
|
||||||
|
enable_wake_exti();
|
||||||
|
}
|
||||||
|
|
||||||
|
void power_mgmt_activity(void)
|
||||||
|
{
|
||||||
|
s_last_active = HAL_GetTick();
|
||||||
|
if (s_state != PM_ACTIVE) {
|
||||||
|
s_sleep_req = false;
|
||||||
|
s_state = PM_WAKING; /* resolved to PM_ACTIVE on next tick() */
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void power_mgmt_request_sleep(void)
|
||||||
|
{
|
||||||
|
s_sleep_req = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
PowerState power_mgmt_state(void)
|
||||||
|
{
|
||||||
|
return s_state;
|
||||||
|
}
|
||||||
|
|
||||||
|
PowerState power_mgmt_tick(uint32_t now_ms)
|
||||||
|
{
|
||||||
|
switch (s_state) {
|
||||||
|
|
||||||
|
case PM_ACTIVE:
|
||||||
|
if (s_sleep_req || (now_ms - s_last_active) >= PM_IDLE_TIMEOUT_MS) {
|
||||||
|
s_sleep_req = false;
|
||||||
|
s_fade_start = now_ms;
|
||||||
|
s_state = PM_SLEEP_PENDING;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case PM_SLEEP_PENDING:
|
||||||
|
if ((now_ms - s_fade_start) >= PM_FADE_MS) {
|
||||||
|
gate_peripherals();
|
||||||
|
enable_wake_exti();
|
||||||
|
s_state = PM_SLEEPING;
|
||||||
|
|
||||||
|
/* Feed IWDG: wakeup <10 ms << WATCHDOG_TIMEOUT_MS (50 ms) */
|
||||||
|
IWDG->KR = 0xAAAAu;
|
||||||
|
|
||||||
|
/* === STOP MODE ENTRY — execution resumes here on EXTI wake === */
|
||||||
|
HAL_PWR_EnterSTOPMode(PWR_LOWPOWERREGULATOR_ON, PWR_STOPENTRY_WFI);
|
||||||
|
/* === WAKEUP POINT (< 10 ms latency) === */
|
||||||
|
|
||||||
|
restore_clocks();
|
||||||
|
ungate_peripherals();
|
||||||
|
disable_wake_exti();
|
||||||
|
s_last_active = HAL_GetTick();
|
||||||
|
s_state = PM_ACTIVE;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case PM_SLEEPING:
|
||||||
|
/* Unreachable: WFI is inline in PM_SLEEP_PENDING above */
|
||||||
|
break;
|
||||||
|
|
||||||
|
case PM_WAKING:
|
||||||
|
/* Set by power_mgmt_activity() during SLEEP_PENDING/SLEEPING */
|
||||||
|
ungate_peripherals();
|
||||||
|
s_state = PM_ACTIVE;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
return s_state;
|
||||||
|
}
|
||||||
567
test/test_power_mgmt.py
Normal file
567
test/test_power_mgmt.py
Normal file
@ -0,0 +1,567 @@
|
|||||||
|
"""
|
||||||
|
test_power_mgmt.py — unit tests for Issue #178 power management module.
|
||||||
|
|
||||||
|
Models the PM state machine, LED brightness, peripheral gating, current
|
||||||
|
estimates, JLink protocol extension, and hardware timing budgets in Python.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import struct
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Constants (mirror config.h / power_mgmt.h)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
PM_IDLE_TIMEOUT_MS = 30_000
|
||||||
|
PM_FADE_MS = 3_000
|
||||||
|
PM_LED_PERIOD_MS = 2_000
|
||||||
|
|
||||||
|
PM_CURRENT_BASE_MA = 30 # SPI1(IMU) + UART4(CRSF) + USART1(JLink) + core
|
||||||
|
PM_CURRENT_AUDIO_MA = 8 # I2S3 + amp quiescent
|
||||||
|
PM_CURRENT_OSD_MA = 5 # SPI2 OSD MAX7456
|
||||||
|
PM_CURRENT_DEBUG_MA = 1 # UART5 + USART6
|
||||||
|
PM_CURRENT_STOP_MA = 1 # MCU in STOP mode (< 1 mA)
|
||||||
|
|
||||||
|
PM_CURRENT_ACTIVE_ALL = (PM_CURRENT_BASE_MA + PM_CURRENT_AUDIO_MA +
|
||||||
|
PM_CURRENT_OSD_MA + PM_CURRENT_DEBUG_MA)
|
||||||
|
|
||||||
|
# JLink additions
|
||||||
|
JLINK_STX = 0x02
|
||||||
|
JLINK_ETX = 0x03
|
||||||
|
JLINK_CMD_SLEEP = 0x09
|
||||||
|
JLINK_TLM_STATUS = 0x80
|
||||||
|
JLINK_TLM_POWER = 0x81
|
||||||
|
|
||||||
|
# Power states
|
||||||
|
PM_ACTIVE = 0
|
||||||
|
PM_SLEEP_PENDING = 1
|
||||||
|
PM_SLEEPING = 2
|
||||||
|
PM_WAKING = 3
|
||||||
|
|
||||||
|
# WATCHDOG_TIMEOUT_MS from config.h
|
||||||
|
WATCHDOG_TIMEOUT_MS = 50
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CRC-16/XModem helper (poly 0x1021, init 0x0000)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
def crc16_xmodem(data: bytes) -> int:
|
||||||
|
crc = 0x0000
|
||||||
|
for b in data:
|
||||||
|
crc ^= b << 8
|
||||||
|
for _ in range(8):
|
||||||
|
crc = (crc << 1) ^ 0x1021 if crc & 0x8000 else crc << 1
|
||||||
|
crc &= 0xFFFF
|
||||||
|
return crc
|
||||||
|
|
||||||
|
|
||||||
|
def build_frame(cmd: int, payload: bytes = b"") -> bytes:
|
||||||
|
data = bytes([cmd]) + payload
|
||||||
|
length = len(data)
|
||||||
|
crc = crc16_xmodem(data)
|
||||||
|
return bytes([JLINK_STX, length, *data, crc >> 8, crc & 0xFF, JLINK_ETX])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Python model of the power_mgmt state machine (mirrors power_mgmt.c)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
class PowerMgmtSim:
|
||||||
|
def __init__(self, now: int = 0):
|
||||||
|
self.state = PM_ACTIVE
|
||||||
|
self.last_active = now
|
||||||
|
self.fade_start = 0
|
||||||
|
self.sleep_req = False
|
||||||
|
self.peripherals_gated = False
|
||||||
|
|
||||||
|
def activity(self, now: int) -> None:
|
||||||
|
self.last_active = now
|
||||||
|
if self.state != PM_ACTIVE:
|
||||||
|
self.sleep_req = False
|
||||||
|
self.state = PM_WAKING
|
||||||
|
|
||||||
|
def request_sleep(self) -> None:
|
||||||
|
self.sleep_req = True
|
||||||
|
|
||||||
|
def led_brightness(self, now: int) -> int:
|
||||||
|
if self.state != PM_SLEEP_PENDING:
|
||||||
|
return 0
|
||||||
|
phase = (now - self.fade_start) % PM_LED_PERIOD_MS
|
||||||
|
half = PM_LED_PERIOD_MS // 2
|
||||||
|
if phase < half:
|
||||||
|
return phase * 255 // half
|
||||||
|
else:
|
||||||
|
return (PM_LED_PERIOD_MS - phase) * 255 // half
|
||||||
|
|
||||||
|
def current_ma(self) -> int:
|
||||||
|
if self.state == PM_SLEEPING:
|
||||||
|
return PM_CURRENT_STOP_MA
|
||||||
|
ma = PM_CURRENT_BASE_MA
|
||||||
|
if not self.peripherals_gated:
|
||||||
|
ma += PM_CURRENT_AUDIO_MA + PM_CURRENT_OSD_MA + PM_CURRENT_DEBUG_MA
|
||||||
|
return ma
|
||||||
|
|
||||||
|
def idle_ms(self, now: int) -> int:
|
||||||
|
return now - self.last_active
|
||||||
|
|
||||||
|
def tick(self, now: int) -> int:
|
||||||
|
if self.state == PM_ACTIVE:
|
||||||
|
if self.sleep_req or (now - self.last_active) >= PM_IDLE_TIMEOUT_MS:
|
||||||
|
self.sleep_req = False
|
||||||
|
self.fade_start = now
|
||||||
|
self.state = PM_SLEEP_PENDING
|
||||||
|
|
||||||
|
elif self.state == PM_SLEEP_PENDING:
|
||||||
|
if (now - self.fade_start) >= PM_FADE_MS:
|
||||||
|
self.peripherals_gated = True
|
||||||
|
self.state = PM_SLEEPING
|
||||||
|
# In firmware: WFI blocks here; in test we skip to simulate_wake
|
||||||
|
|
||||||
|
elif self.state == PM_WAKING:
|
||||||
|
self.peripherals_gated = False
|
||||||
|
self.state = PM_ACTIVE
|
||||||
|
|
||||||
|
return self.state
|
||||||
|
|
||||||
|
def simulate_wake(self, now: int) -> None:
|
||||||
|
"""Simulate EXTI wakeup from STOP mode (models HAL_PWR_EnterSTOPMode return)."""
|
||||||
|
if self.state == PM_SLEEPING:
|
||||||
|
self.peripherals_gated = False
|
||||||
|
self.last_active = now
|
||||||
|
self.state = PM_ACTIVE
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: Idle timer
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
class TestIdleTimer:
|
||||||
|
def test_stays_active_before_timeout(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
for t in range(0, PM_IDLE_TIMEOUT_MS, 1000):
|
||||||
|
assert pm.tick(t) == PM_ACTIVE
|
||||||
|
|
||||||
|
def test_enters_sleep_pending_at_timeout(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
assert pm.tick(PM_IDLE_TIMEOUT_MS - 1) == PM_ACTIVE
|
||||||
|
assert pm.tick(PM_IDLE_TIMEOUT_MS) == PM_SLEEP_PENDING
|
||||||
|
|
||||||
|
def test_activity_resets_idle_timer(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS - 1000)
|
||||||
|
pm.activity(PM_IDLE_TIMEOUT_MS - 1000) # reset at T=29000
|
||||||
|
assert pm.tick(PM_IDLE_TIMEOUT_MS) == PM_ACTIVE # 1 s since reset
|
||||||
|
assert pm.tick(PM_IDLE_TIMEOUT_MS - 1000 + PM_IDLE_TIMEOUT_MS) == PM_SLEEP_PENDING
|
||||||
|
|
||||||
|
def test_idle_ms_increases_monotonically(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
assert pm.idle_ms(0) == 0
|
||||||
|
assert pm.idle_ms(5000) == 5000
|
||||||
|
assert pm.idle_ms(29999) == 29999
|
||||||
|
|
||||||
|
def test_idle_ms_resets_on_activity(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.activity(10000)
|
||||||
|
assert pm.idle_ms(10500) == 500
|
||||||
|
|
||||||
|
def test_30s_timeout_matches_spec(self):
|
||||||
|
assert PM_IDLE_TIMEOUT_MS == 30_000
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: State machine transitions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
class TestStateMachine:
|
||||||
|
def test_sleep_req_bypasses_idle_timer(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.activity(0)
|
||||||
|
pm.request_sleep()
|
||||||
|
assert pm.tick(500) == PM_SLEEP_PENDING
|
||||||
|
|
||||||
|
def test_fade_complete_enters_sleeping(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS) # → SLEEP_PENDING
|
||||||
|
assert pm.tick(PM_IDLE_TIMEOUT_MS + PM_FADE_MS) == PM_SLEEPING
|
||||||
|
|
||||||
|
def test_fade_not_complete_stays_pending(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS)
|
||||||
|
assert pm.tick(PM_IDLE_TIMEOUT_MS + PM_FADE_MS - 1) == PM_SLEEP_PENDING
|
||||||
|
|
||||||
|
def test_wake_from_stop_returns_active(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS + PM_FADE_MS) # → SLEEPING
|
||||||
|
pm.simulate_wake(PM_IDLE_TIMEOUT_MS + PM_FADE_MS + 5)
|
||||||
|
assert pm.state == PM_ACTIVE
|
||||||
|
|
||||||
|
def test_activity_during_sleep_pending_aborts(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS) # → SLEEP_PENDING
|
||||||
|
pm.activity(PM_IDLE_TIMEOUT_MS + 100) # abort
|
||||||
|
assert pm.state == PM_WAKING
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS + 101)
|
||||||
|
assert pm.state == PM_ACTIVE
|
||||||
|
|
||||||
|
def test_activity_during_sleeping_aborts(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS + PM_FADE_MS) # → SLEEPING
|
||||||
|
pm.activity(PM_IDLE_TIMEOUT_MS + PM_FADE_MS + 3)
|
||||||
|
assert pm.state == PM_WAKING
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS + PM_FADE_MS + 4)
|
||||||
|
assert pm.state == PM_ACTIVE
|
||||||
|
|
||||||
|
def test_waking_resolves_on_next_tick(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.state = PM_WAKING
|
||||||
|
pm.tick(1000)
|
||||||
|
assert pm.state == PM_ACTIVE
|
||||||
|
|
||||||
|
def test_full_sleep_wake_cycle(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
# 1. Active
|
||||||
|
assert pm.tick(100) == PM_ACTIVE
|
||||||
|
# 2. Idle → sleep pending
|
||||||
|
assert pm.tick(PM_IDLE_TIMEOUT_MS) == PM_SLEEP_PENDING
|
||||||
|
# 3. Fade → sleeping
|
||||||
|
assert pm.tick(PM_IDLE_TIMEOUT_MS + PM_FADE_MS) == PM_SLEEPING
|
||||||
|
# 4. EXTI wake → active
|
||||||
|
pm.simulate_wake(PM_IDLE_TIMEOUT_MS + PM_FADE_MS + 8)
|
||||||
|
assert pm.state == PM_ACTIVE
|
||||||
|
|
||||||
|
def test_multiple_sleep_wake_cycles(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
base = 0
|
||||||
|
for _ in range(3):
|
||||||
|
pm.activity(base)
|
||||||
|
pm.tick(base + PM_IDLE_TIMEOUT_MS)
|
||||||
|
pm.tick(base + PM_IDLE_TIMEOUT_MS + PM_FADE_MS)
|
||||||
|
pm.simulate_wake(base + PM_IDLE_TIMEOUT_MS + PM_FADE_MS + 5)
|
||||||
|
assert pm.state == PM_ACTIVE
|
||||||
|
base += PM_IDLE_TIMEOUT_MS + PM_FADE_MS + 10
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: Peripheral gating
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
class TestPeripheralGating:
|
||||||
|
GATED = {'SPI3_I2S3', 'SPI2_OSD', 'USART6', 'UART5_DEBUG'}
|
||||||
|
ACTIVE = {'SPI1_IMU', 'UART4_CRSF', 'USART1_JLINK', 'I2C1_BARO'}
|
||||||
|
|
||||||
|
def test_gated_set_has_four_peripherals(self):
|
||||||
|
assert len(self.GATED) == 4
|
||||||
|
|
||||||
|
def test_no_overlap_between_gated_and_active(self):
|
||||||
|
assert not (self.GATED & self.ACTIVE)
|
||||||
|
|
||||||
|
def test_crsf_uart_not_gated(self):
|
||||||
|
assert not any('UART4' in p or 'CRSF' in p for p in self.GATED)
|
||||||
|
|
||||||
|
def test_jlink_uart_not_gated(self):
|
||||||
|
assert not any('USART1' in p or 'JLINK' in p for p in self.GATED)
|
||||||
|
|
||||||
|
def test_imu_spi_not_gated(self):
|
||||||
|
assert not any('SPI1' in p or 'IMU' in p for p in self.GATED)
|
||||||
|
|
||||||
|
def test_peripherals_gated_on_sleep_entry(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
assert not pm.peripherals_gated
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS + PM_FADE_MS) # → SLEEPING
|
||||||
|
assert pm.peripherals_gated
|
||||||
|
|
||||||
|
def test_peripherals_ungated_on_wake(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS + PM_FADE_MS)
|
||||||
|
pm.simulate_wake(PM_IDLE_TIMEOUT_MS + PM_FADE_MS + 5)
|
||||||
|
assert not pm.peripherals_gated
|
||||||
|
|
||||||
|
def test_peripherals_not_gated_in_sleep_pending(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS) # → SLEEP_PENDING
|
||||||
|
assert not pm.peripherals_gated
|
||||||
|
|
||||||
|
def test_peripherals_ungated_if_activity_during_pending(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS)
|
||||||
|
pm.activity(PM_IDLE_TIMEOUT_MS + 100)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS + 101)
|
||||||
|
assert not pm.peripherals_gated
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: LED brightness
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
class TestLedBrightness:
|
||||||
|
def test_zero_when_active(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
assert pm.led_brightness(5000) == 0
|
||||||
|
|
||||||
|
def test_zero_when_sleeping(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS + PM_FADE_MS) # → SLEEPING
|
||||||
|
assert pm.led_brightness(PM_IDLE_TIMEOUT_MS + PM_FADE_MS + 100) == 0
|
||||||
|
|
||||||
|
def test_zero_when_waking(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.state = PM_WAKING
|
||||||
|
assert pm.led_brightness(1000) == 0
|
||||||
|
|
||||||
|
def test_zero_at_phase_start(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS) # fade_start = PM_IDLE_TIMEOUT_MS
|
||||||
|
assert pm.led_brightness(PM_IDLE_TIMEOUT_MS) == 0
|
||||||
|
|
||||||
|
def test_max_at_half_period(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS)
|
||||||
|
t = PM_IDLE_TIMEOUT_MS + PM_LED_PERIOD_MS // 2
|
||||||
|
assert pm.led_brightness(t) == 255
|
||||||
|
|
||||||
|
def test_zero_at_full_period(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS)
|
||||||
|
t = PM_IDLE_TIMEOUT_MS + PM_LED_PERIOD_MS
|
||||||
|
assert pm.led_brightness(t) == 0
|
||||||
|
|
||||||
|
def test_symmetric_about_half_period(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS)
|
||||||
|
quarter = PM_LED_PERIOD_MS // 4
|
||||||
|
three_quarter = 3 * PM_LED_PERIOD_MS // 4
|
||||||
|
b1 = pm.led_brightness(PM_IDLE_TIMEOUT_MS + quarter)
|
||||||
|
b2 = pm.led_brightness(PM_IDLE_TIMEOUT_MS + three_quarter)
|
||||||
|
assert abs(b1 - b2) <= 1 # allow 1 LSB for integer division
|
||||||
|
|
||||||
|
def test_range_0_to_255(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS)
|
||||||
|
for dt in range(0, PM_LED_PERIOD_MS * 3, 37):
|
||||||
|
b = pm.led_brightness(PM_IDLE_TIMEOUT_MS + dt)
|
||||||
|
assert 0 <= b <= 255
|
||||||
|
|
||||||
|
def test_repeats_over_multiple_periods(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS)
|
||||||
|
# Sample at same phase in periods 0, 1, 2 — should be equal
|
||||||
|
phase = PM_LED_PERIOD_MS // 3
|
||||||
|
b0 = pm.led_brightness(PM_IDLE_TIMEOUT_MS + phase)
|
||||||
|
b1 = pm.led_brightness(PM_IDLE_TIMEOUT_MS + PM_LED_PERIOD_MS + phase)
|
||||||
|
b2 = pm.led_brightness(PM_IDLE_TIMEOUT_MS + 2 * PM_LED_PERIOD_MS + phase)
|
||||||
|
assert b0 == b1 == b2
|
||||||
|
|
||||||
|
def test_period_is_2s(self):
|
||||||
|
assert PM_LED_PERIOD_MS == 2000
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: Power / current estimates
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
class TestPowerEstimates:
|
||||||
|
def test_active_includes_all_subsystems(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
assert pm.current_ma() == PM_CURRENT_ACTIVE_ALL
|
||||||
|
|
||||||
|
def test_sleeping_returns_stop_ma(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS)
|
||||||
|
pm.tick(PM_IDLE_TIMEOUT_MS + PM_FADE_MS) # → SLEEPING
|
||||||
|
assert pm.current_ma() == PM_CURRENT_STOP_MA
|
||||||
|
|
||||||
|
def test_gated_returns_base_only(self):
|
||||||
|
pm = PowerMgmtSim(now=0)
|
||||||
|
pm.peripherals_gated = True
|
||||||
|
assert pm.current_ma() == PM_CURRENT_BASE_MA
|
||||||
|
|
||||||
|
def test_stop_current_less_than_active(self):
|
||||||
|
assert PM_CURRENT_STOP_MA < PM_CURRENT_ACTIVE_ALL
|
||||||
|
|
||||||
|
def test_stop_current_at_most_1ma(self):
|
||||||
|
assert PM_CURRENT_STOP_MA <= 1
|
||||||
|
|
||||||
|
def test_active_current_reasonable(self):
|
||||||
|
# Should be < 100 mA (just MCU + peripherals, no motors)
|
||||||
|
assert PM_CURRENT_ACTIVE_ALL < 100
|
||||||
|
|
||||||
|
def test_audio_subsystem_estimate(self):
|
||||||
|
assert PM_CURRENT_AUDIO_MA > 0
|
||||||
|
|
||||||
|
def test_osd_subsystem_estimate(self):
|
||||||
|
assert PM_CURRENT_OSD_MA > 0
|
||||||
|
|
||||||
|
def test_total_equals_sum_of_parts(self):
|
||||||
|
total = (PM_CURRENT_BASE_MA + PM_CURRENT_AUDIO_MA +
|
||||||
|
PM_CURRENT_OSD_MA + PM_CURRENT_DEBUG_MA)
|
||||||
|
assert total == PM_CURRENT_ACTIVE_ALL
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: JLink protocol extension
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
class TestJlinkProtocol:
|
||||||
|
def test_sleep_cmd_id(self):
|
||||||
|
assert JLINK_CMD_SLEEP == 0x09
|
||||||
|
|
||||||
|
def test_sleep_follows_audio_cmd(self):
|
||||||
|
JLINK_CMD_AUDIO = 0x08
|
||||||
|
assert JLINK_CMD_SLEEP == JLINK_CMD_AUDIO + 1
|
||||||
|
|
||||||
|
def test_power_tlm_id(self):
|
||||||
|
assert JLINK_TLM_POWER == 0x81
|
||||||
|
|
||||||
|
def test_power_tlm_follows_status_tlm(self):
|
||||||
|
assert JLINK_TLM_POWER == JLINK_TLM_STATUS + 1
|
||||||
|
|
||||||
|
def test_sleep_frame_length(self):
|
||||||
|
# SLEEP has no payload: STX(1)+LEN(1)+CMD(1)+CRC(2)+ETX(1) = 6
|
||||||
|
frame = build_frame(JLINK_CMD_SLEEP)
|
||||||
|
assert len(frame) == 6
|
||||||
|
|
||||||
|
def test_sleep_frame_sentinels(self):
|
||||||
|
frame = build_frame(JLINK_CMD_SLEEP)
|
||||||
|
assert frame[0] == JLINK_STX
|
||||||
|
assert frame[-1] == JLINK_ETX
|
||||||
|
|
||||||
|
def test_sleep_frame_len_field(self):
|
||||||
|
frame = build_frame(JLINK_CMD_SLEEP)
|
||||||
|
assert frame[1] == 1 # LEN = 1 (CMD only, no payload)
|
||||||
|
|
||||||
|
def test_sleep_frame_cmd_byte(self):
|
||||||
|
frame = build_frame(JLINK_CMD_SLEEP)
|
||||||
|
assert frame[2] == JLINK_CMD_SLEEP
|
||||||
|
|
||||||
|
def test_sleep_frame_crc_valid(self):
|
||||||
|
frame = build_frame(JLINK_CMD_SLEEP)
|
||||||
|
calc = crc16_xmodem(bytes([JLINK_CMD_SLEEP]))
|
||||||
|
rx = (frame[-3] << 8) | frame[-2]
|
||||||
|
assert rx == calc
|
||||||
|
|
||||||
|
def test_power_tlm_frame_length(self):
|
||||||
|
# jlink_tlm_power_t = 11 bytes
|
||||||
|
# Frame: STX(1)+LEN(1)+CMD(1)+payload(11)+CRC(2)+ETX(1) = 17
|
||||||
|
POWER_TLM_PAYLOAD_LEN = 11
|
||||||
|
expected = 1 + 1 + 1 + POWER_TLM_PAYLOAD_LEN + 2 + 1
|
||||||
|
assert expected == 17
|
||||||
|
|
||||||
|
def test_power_tlm_payload_struct(self):
|
||||||
|
"""jlink_tlm_power_t: u8 power_state, u16 est_total_ma,
|
||||||
|
u16 est_audio_ma, u16 est_osd_ma, u32 idle_ms = 11 bytes."""
|
||||||
|
fmt = "<BHHHI"
|
||||||
|
size = struct.calcsize(fmt)
|
||||||
|
assert size == 11
|
||||||
|
|
||||||
|
def test_power_tlm_frame_crc_valid(self):
|
||||||
|
power_state = PM_ACTIVE
|
||||||
|
est_total_ma = PM_CURRENT_ACTIVE_ALL
|
||||||
|
est_audio_ma = PM_CURRENT_AUDIO_MA
|
||||||
|
est_osd_ma = PM_CURRENT_OSD_MA
|
||||||
|
idle_ms = 5000
|
||||||
|
payload = struct.pack("<BHHHI", power_state, est_total_ma,
|
||||||
|
est_audio_ma, est_osd_ma, idle_ms)
|
||||||
|
frame = build_frame(JLINK_TLM_POWER, payload)
|
||||||
|
assert frame[0] == JLINK_STX
|
||||||
|
assert frame[-1] == JLINK_ETX
|
||||||
|
data_for_crc = bytes([JLINK_TLM_POWER]) + payload
|
||||||
|
expected_crc = crc16_xmodem(data_for_crc)
|
||||||
|
rx_crc = (frame[-3] << 8) | frame[-2]
|
||||||
|
assert rx_crc == expected_crc
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: Wake latency and IWDG budget
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
class TestWakeLatencyBudget:
|
||||||
|
# STM32F722 STOP-mode wakeup: HSI ready ~2 ms + PLL lock ~2 ms ≈ 4 ms
|
||||||
|
ESTIMATED_WAKE_MS = 10 # conservative upper bound
|
||||||
|
|
||||||
|
def test_wake_latency_within_50ms(self):
|
||||||
|
assert self.ESTIMATED_WAKE_MS < WATCHDOG_TIMEOUT_MS
|
||||||
|
|
||||||
|
def test_watchdog_timeout_is_50ms(self):
|
||||||
|
assert WATCHDOG_TIMEOUT_MS == 50
|
||||||
|
|
||||||
|
def test_iwdg_feed_before_wfi_is_safe(self):
|
||||||
|
# Time from IWDG feed to next feed after wake:
|
||||||
|
# ~1 ms (loop overhead) + ESTIMATED_WAKE_MS + ~1 ms = ~12 ms
|
||||||
|
time_from_feed_to_next_ms = 1 + self.ESTIMATED_WAKE_MS + 1
|
||||||
|
assert time_from_feed_to_next_ms < WATCHDOG_TIMEOUT_MS
|
||||||
|
|
||||||
|
def test_fade_ms_positive(self):
|
||||||
|
assert PM_FADE_MS > 0
|
||||||
|
|
||||||
|
def test_fade_ms_less_than_idle_timeout(self):
|
||||||
|
assert PM_FADE_MS < PM_IDLE_TIMEOUT_MS
|
||||||
|
|
||||||
|
def test_stop_mode_wake_much_less_than_50ms(self):
|
||||||
|
# PLL startup on STM32F7: HSI on (0 ms, already running) +
|
||||||
|
# PLL lock ~2 ms + SysTick re-init ~0.1 ms ≈ 3 ms
|
||||||
|
pll_lock_ms = 3
|
||||||
|
overhead_ms = 1
|
||||||
|
total_ms = pll_lock_ms + overhead_ms
|
||||||
|
assert total_ms < 50
|
||||||
|
|
||||||
|
def test_wake_exti_sources_count(self):
|
||||||
|
"""Three wake sources: EXTI1 (CRSF), EXTI7 (JLink), EXTI4 (IMU)."""
|
||||||
|
wake_sources = ['EXTI1_UART4_CRSF', 'EXTI7_USART1_JLINK', 'EXTI4_IMU_INT']
|
||||||
|
assert len(wake_sources) == 3
|
||||||
|
|
||||||
|
def test_uwTick_must_be_restored_after_stop(self):
|
||||||
|
"""HAL_RCC_ClockConfig resets uwTick to 0; restore_clocks() saves it."""
|
||||||
|
# Verify the pattern: save uwTick → HAL calls → restore uwTick
|
||||||
|
saved_tick = 12345
|
||||||
|
# Simulate HAL_InitTick() resetting to 0
|
||||||
|
uw_tick_after_hal = 0
|
||||||
|
restored = saved_tick # power_mgmt.c: uwTick = saved_tick
|
||||||
|
assert restored == saved_tick
|
||||||
|
assert uw_tick_after_hal != saved_tick # HAL reset it
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: Hardware constants
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
class TestHardwareConstants:
|
||||||
|
def test_pll_params_for_216mhz(self):
|
||||||
|
"""PLLM=8, PLLN=216, PLLP=2 → VCO=216*2=432 MHz, SYSCLK=216 MHz."""
|
||||||
|
HSI_MHZ = 16
|
||||||
|
PLLM = 8
|
||||||
|
PLLN = 216
|
||||||
|
PLLP = 2
|
||||||
|
vco_mhz = HSI_MHZ / PLLM * PLLN
|
||||||
|
sysclk = vco_mhz / PLLP
|
||||||
|
assert sysclk == pytest.approx(216.0, rel=1e-6)
|
||||||
|
|
||||||
|
def test_apb1_54mhz(self):
|
||||||
|
"""APB1 = SYSCLK / 4 = 54 MHz."""
|
||||||
|
assert 216 / 4 == 54
|
||||||
|
|
||||||
|
def test_apb2_108mhz(self):
|
||||||
|
"""APB2 = SYSCLK / 2 = 108 MHz."""
|
||||||
|
assert 216 / 2 == 108
|
||||||
|
|
||||||
|
def test_flash_latency_7_required_at_216mhz(self):
|
||||||
|
"""STM32F7 at 2.7-3.3 V: 7 wait states for 210-216 MHz."""
|
||||||
|
FLASH_LATENCY = 7
|
||||||
|
assert FLASH_LATENCY == 7
|
||||||
|
|
||||||
|
def test_exti1_for_pa1(self):
|
||||||
|
"""SYSCFG EXTICR1[7:4] = 0x0 selects PA for EXTI1."""
|
||||||
|
PA_SOURCE = 0x0
|
||||||
|
assert PA_SOURCE == 0x0
|
||||||
|
|
||||||
|
def test_exti7_for_pb7(self):
|
||||||
|
"""SYSCFG EXTICR2[15:12] = 0x1 selects PB for EXTI7."""
|
||||||
|
PB_SOURCE = 0x1
|
||||||
|
assert PB_SOURCE == 0x1
|
||||||
|
|
||||||
|
def test_exticr_indices(self):
|
||||||
|
"""EXTI1 → EXTICR[0], EXTI7 → EXTICR[1]."""
|
||||||
|
assert 1 // 4 == 0 # EXTI1 is in EXTICR[0]
|
||||||
|
assert 7 // 4 == 1 # EXTI7 is in EXTICR[1]
|
||||||
|
|
||||||
|
def test_exti7_shift_in_exticr2(self):
|
||||||
|
"""EXTI7 field is at bits [15:12] of EXTICR[1] → shift = (7%4)*4 = 12."""
|
||||||
|
shift = (7 % 4) * 4
|
||||||
|
assert shift == 12
|
||||||
|
|
||||||
|
def test_idle_timeout_30s(self):
|
||||||
|
assert PM_IDLE_TIMEOUT_MS == 30_000
|
||||||
@ -5,13 +5,16 @@
|
|||||||
* Status | Faces | Conversation | Personality | Navigation
|
* Status | Faces | Conversation | Personality | Navigation
|
||||||
*
|
*
|
||||||
* Telemetry tabs (issue #126):
|
* Telemetry tabs (issue #126):
|
||||||
* IMU | Battery | Motors | Map | Control | Health
|
* IMU | Battery | Motors | Map | Control | Health | Cameras
|
||||||
*
|
*
|
||||||
* Fleet tabs (issue #139):
|
* Fleet tabs (issue #139):
|
||||||
* Fleet (self-contained via useFleet)
|
* Fleet (self-contained via useFleet)
|
||||||
*
|
*
|
||||||
* Mission tabs (issue #145):
|
* Mission tabs (issue #145):
|
||||||
* Missions (waypoint editor, route builder, geofence, schedule, execute)
|
* Missions (waypoint editor, route builder, geofence, schedule, execute)
|
||||||
|
*
|
||||||
|
* Camera viewer (issue #177):
|
||||||
|
* CSI × 4 + D435i RGB/depth + panoramic, detection overlays, recording
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { useState, useCallback } from 'react';
|
import { useState, useCallback } from 'react';
|
||||||
@ -41,6 +44,9 @@ import { MissionPlanner } from './components/MissionPlanner.jsx';
|
|||||||
// Settings panel (issue #160)
|
// Settings panel (issue #160)
|
||||||
import { SettingsPanel } from './components/SettingsPanel.jsx';
|
import { SettingsPanel } from './components/SettingsPanel.jsx';
|
||||||
|
|
||||||
|
// Camera viewer (issue #177)
|
||||||
|
import { CameraViewer } from './components/CameraViewer.jsx';
|
||||||
|
|
||||||
const TAB_GROUPS = [
|
const TAB_GROUPS = [
|
||||||
{
|
{
|
||||||
label: 'SOCIAL',
|
label: 'SOCIAL',
|
||||||
@ -63,6 +69,7 @@ const TAB_GROUPS = [
|
|||||||
{ id: 'map', label: 'Map', },
|
{ id: 'map', label: 'Map', },
|
||||||
{ id: 'control', label: 'Control', },
|
{ id: 'control', label: 'Control', },
|
||||||
{ id: 'health', label: 'Health', },
|
{ id: 'health', label: 'Health', },
|
||||||
|
{ id: 'cameras', label: 'Cameras', },
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -206,6 +213,7 @@ export default function App() {
|
|||||||
{activeTab === 'map' && <MapViewer subscribe={subscribe} />}
|
{activeTab === 'map' && <MapViewer subscribe={subscribe} />}
|
||||||
{activeTab === 'control' && <ControlMode subscribe={subscribe} />}
|
{activeTab === 'control' && <ControlMode subscribe={subscribe} />}
|
||||||
{activeTab === 'health' && <SystemHealth subscribe={subscribe} />}
|
{activeTab === 'health' && <SystemHealth subscribe={subscribe} />}
|
||||||
|
{activeTab === 'cameras' && <CameraViewer subscribe={subscribe} />}
|
||||||
|
|
||||||
{activeTab === 'fleet' && <FleetPanel />}
|
{activeTab === 'fleet' && <FleetPanel />}
|
||||||
{activeTab === 'missions' && <MissionPlanner />}
|
{activeTab === 'missions' && <MissionPlanner />}
|
||||||
|
|||||||
671
ui/social-bot/src/components/CameraViewer.jsx
Normal file
671
ui/social-bot/src/components/CameraViewer.jsx
Normal file
@ -0,0 +1,671 @@
|
|||||||
|
/**
|
||||||
|
* CameraViewer.jsx — Live camera stream viewer (Issue #177).
|
||||||
|
*
|
||||||
|
* Features:
|
||||||
|
* - 7 cameras: front/left/rear/right (CSI), D435i RGB/depth, panoramic
|
||||||
|
* - Detection overlays: face boxes + names, gesture icons, scene object labels
|
||||||
|
* - 360° panoramic equirect viewer with mouse drag pan
|
||||||
|
* - One-click recording (MP4/WebM) + download
|
||||||
|
* - Snapshot to PNG with annotations + timestamp
|
||||||
|
* - Picture-in-picture (up to 3 pinned cameras)
|
||||||
|
* - Per-camera FPS indicator + adaptive quality badge
|
||||||
|
*
|
||||||
|
* Topics consumed:
|
||||||
|
* /camera/<name>/image_raw/compressed sensor_msgs/CompressedImage
|
||||||
|
* /camera/color/image_raw/compressed sensor_msgs/CompressedImage (D435i)
|
||||||
|
* /camera/depth/image_rect_raw/compressed sensor_msgs/CompressedImage (D435i)
|
||||||
|
* /camera/panoramic/compressed sensor_msgs/CompressedImage
|
||||||
|
* /social/faces/detections saltybot_social_msgs/FaceDetectionArray
|
||||||
|
* /social/gestures saltybot_social_msgs/GestureArray
|
||||||
|
* /social/scene/objects saltybot_scene_msgs/SceneObjectArray
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { useEffect, useRef, useState, useCallback } from 'react';
|
||||||
|
import { useCamera, CAMERAS, CAMERA_BY_ID, CAMERA_BY_ROS_ID } from '../hooks/useCamera.js';
|
||||||
|
|
||||||
|
// ── Constants ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
const GESTURE_ICONS = {
|
||||||
|
wave: '👋',
|
||||||
|
point: '👆',
|
||||||
|
stop_palm: '✋',
|
||||||
|
thumbs_up: '👍',
|
||||||
|
thumbs_down: '👎',
|
||||||
|
come_here: '🤏',
|
||||||
|
follow: '☞',
|
||||||
|
arms_up: '🙌',
|
||||||
|
crouch: '⬇',
|
||||||
|
arms_spread: '↔',
|
||||||
|
};
|
||||||
|
|
||||||
|
const HAZARD_COLORS = {
|
||||||
|
1: '#f59e0b', // stairs — amber
|
||||||
|
2: '#ef4444', // drop — red
|
||||||
|
3: '#60a5fa', // wet floor — blue
|
||||||
|
4: '#a855f7', // glass door — purple
|
||||||
|
5: '#f97316', // pet — orange
|
||||||
|
};
|
||||||
|
|
||||||
|
// ── Detection overlay drawing helpers ─────────────────────────────────────────
|
||||||
|
|
||||||
|
function drawFaceBoxes(ctx, faces, scaleX, scaleY) {
|
||||||
|
for (const face of faces) {
|
||||||
|
const x = face.bbox_x * scaleX;
|
||||||
|
const y = face.bbox_y * scaleY;
|
||||||
|
const w = face.bbox_w * scaleX;
|
||||||
|
const h = face.bbox_h * scaleY;
|
||||||
|
|
||||||
|
const isKnown = face.person_name && face.person_name !== 'unknown';
|
||||||
|
ctx.strokeStyle = isKnown ? '#06b6d4' : '#f59e0b';
|
||||||
|
ctx.lineWidth = 2;
|
||||||
|
ctx.shadowBlur = 6;
|
||||||
|
ctx.shadowColor = ctx.strokeStyle;
|
||||||
|
ctx.strokeRect(x, y, w, h);
|
||||||
|
ctx.shadowBlur = 0;
|
||||||
|
|
||||||
|
// Corner accent marks
|
||||||
|
const cLen = 8;
|
||||||
|
ctx.lineWidth = 3;
|
||||||
|
[[x,y,1,1],[x+w,y,-1,1],[x,y+h,1,-1],[x+w,y+h,-1,-1]].forEach(([cx,cy,dx,dy]) => {
|
||||||
|
ctx.beginPath();
|
||||||
|
ctx.moveTo(cx, cy + dy * cLen);
|
||||||
|
ctx.lineTo(cx, cy);
|
||||||
|
ctx.lineTo(cx + dx * cLen, cy);
|
||||||
|
ctx.stroke();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Label
|
||||||
|
const label = isKnown
|
||||||
|
? `${face.person_name} ${(face.recognition_score * 100).toFixed(0)}%`
|
||||||
|
: `face #${face.face_id}`;
|
||||||
|
ctx.font = 'bold 11px monospace';
|
||||||
|
const tw = ctx.measureText(label).width;
|
||||||
|
ctx.fillStyle = isKnown ? 'rgba(6,182,212,0.8)' : 'rgba(245,158,11,0.8)';
|
||||||
|
ctx.fillRect(x, y - 16, tw + 6, 16);
|
||||||
|
ctx.fillStyle = '#000';
|
||||||
|
ctx.fillText(label, x + 3, y - 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function drawGestureIcons(ctx, gestures, activeCamId, scaleX, scaleY) {
|
||||||
|
for (const g of gestures) {
|
||||||
|
// Only show gestures from the currently viewed camera
|
||||||
|
const cam = CAMERA_BY_ROS_ID[g.camera_id];
|
||||||
|
if (!cam || cam.cameraId !== activeCamId) continue;
|
||||||
|
|
||||||
|
const x = g.hand_x * ctx.canvas.width;
|
||||||
|
const y = g.hand_y * ctx.canvas.height;
|
||||||
|
const icon = GESTURE_ICONS[g.gesture_type] ?? '?';
|
||||||
|
|
||||||
|
ctx.font = '24px serif';
|
||||||
|
ctx.shadowBlur = 8;
|
||||||
|
ctx.shadowColor = '#f97316';
|
||||||
|
ctx.fillText(icon, x - 12, y + 8);
|
||||||
|
ctx.shadowBlur = 0;
|
||||||
|
|
||||||
|
ctx.font = 'bold 10px monospace';
|
||||||
|
ctx.fillStyle = '#f97316';
|
||||||
|
const label = g.gesture_type;
|
||||||
|
ctx.fillText(label, x - ctx.measureText(label).width / 2, y + 22);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function drawSceneObjects(ctx, objects, scaleX, scaleY) {
|
||||||
|
for (const obj of objects) {
|
||||||
|
// vision_msgs/BoundingBox2D: center_x, center_y, size_x, size_y
|
||||||
|
const bb = obj.bbox;
|
||||||
|
const cx = bb?.center?.x ?? bb?.center_x;
|
||||||
|
const cy = bb?.center?.y ?? bb?.center_y;
|
||||||
|
const sw = bb?.size_x ?? 0;
|
||||||
|
const sh = bb?.size_y ?? 0;
|
||||||
|
if (cx == null) continue;
|
||||||
|
|
||||||
|
const x = (cx - sw / 2) * scaleX;
|
||||||
|
const y = (cy - sh / 2) * scaleY;
|
||||||
|
const w = sw * scaleX;
|
||||||
|
const h = sh * scaleY;
|
||||||
|
|
||||||
|
const color = HAZARD_COLORS[obj.hazard_type] ?? '#22c55e';
|
||||||
|
ctx.strokeStyle = color;
|
||||||
|
ctx.lineWidth = 1.5;
|
||||||
|
ctx.setLineDash([4, 3]);
|
||||||
|
ctx.strokeRect(x, y, w, h);
|
||||||
|
ctx.setLineDash([]);
|
||||||
|
|
||||||
|
const dist = obj.distance_m > 0 ? ` ${obj.distance_m.toFixed(1)}m` : '';
|
||||||
|
const label = `${obj.class_name}${dist}`;
|
||||||
|
ctx.font = '10px monospace';
|
||||||
|
const tw = ctx.measureText(label).width;
|
||||||
|
ctx.fillStyle = `${color}cc`;
|
||||||
|
ctx.fillRect(x, y + h, tw + 4, 14);
|
||||||
|
ctx.fillStyle = '#000';
|
||||||
|
ctx.fillText(label, x + 2, y + h + 11);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Overlay canvas ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
function OverlayCanvas({ faces, gestures, sceneObjects, activeCam, containerW, containerH }) {
|
||||||
|
const canvasRef = useRef(null);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const canvas = canvasRef.current;
|
||||||
|
if (!canvas) return;
|
||||||
|
const ctx = canvas.getContext('2d');
|
||||||
|
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
||||||
|
|
||||||
|
if (!activeCam) return;
|
||||||
|
|
||||||
|
const scaleX = canvas.width / (activeCam.width || 640);
|
||||||
|
const scaleY = canvas.height / (activeCam.height || 480);
|
||||||
|
|
||||||
|
// Draw overlays: only for front camera (face + gesture source)
|
||||||
|
if (activeCam.id === 'front') {
|
||||||
|
drawFaceBoxes(ctx, faces, scaleX, scaleY);
|
||||||
|
}
|
||||||
|
if (!activeCam.isPanoramic) {
|
||||||
|
drawGestureIcons(ctx, gestures, activeCam.cameraId, scaleX, scaleY);
|
||||||
|
}
|
||||||
|
if (activeCam.id === 'color') {
|
||||||
|
drawSceneObjects(ctx, sceneObjects, scaleX, scaleY);
|
||||||
|
}
|
||||||
|
}, [faces, gestures, sceneObjects, activeCam]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<canvas
|
||||||
|
ref={canvasRef}
|
||||||
|
width={containerW || 640}
|
||||||
|
height={containerH || 480}
|
||||||
|
className="absolute inset-0 w-full h-full pointer-events-none"
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Panoramic equirect viewer ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
function PanoViewer({ frameUrl }) {
|
||||||
|
const canvasRef = useRef(null);
|
||||||
|
const azRef = useRef(0); // 0–1920px offset
|
||||||
|
const dragRef = useRef(null);
|
||||||
|
const imgRef = useRef(null);
|
||||||
|
|
||||||
|
const draw = useCallback(() => {
|
||||||
|
const canvas = canvasRef.current;
|
||||||
|
const img = imgRef.current;
|
||||||
|
if (!canvas || !img || !img.complete) return;
|
||||||
|
|
||||||
|
const ctx = canvas.getContext('2d');
|
||||||
|
const W = canvas.width;
|
||||||
|
const H = canvas.height;
|
||||||
|
const iW = img.naturalWidth; // 1920
|
||||||
|
const iH = img.naturalHeight; // 960
|
||||||
|
const vW = iW / 2; // viewport = 50% of equirect width
|
||||||
|
const vH = Math.round((H / W) * vW);
|
||||||
|
const vY = Math.round((iH - vH) / 2);
|
||||||
|
const off = Math.round(azRef.current) % iW;
|
||||||
|
|
||||||
|
ctx.clearRect(0, 0, W, H);
|
||||||
|
|
||||||
|
// Draw left segment
|
||||||
|
const srcX1 = off;
|
||||||
|
const srcW1 = Math.min(vW, iW - off);
|
||||||
|
const dstW1 = Math.round((srcW1 / vW) * W);
|
||||||
|
if (dstW1 > 0) {
|
||||||
|
ctx.drawImage(img, srcX1, vY, srcW1, vH, 0, 0, dstW1, H);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Draw wrapped right segment (if viewport crosses 0°)
|
||||||
|
if (srcW1 < vW) {
|
||||||
|
const srcX2 = 0;
|
||||||
|
const srcW2 = vW - srcW1;
|
||||||
|
const dstX2 = dstW1;
|
||||||
|
const dstW2 = W - dstW1;
|
||||||
|
ctx.drawImage(img, srcX2, vY, srcW2, vH, dstX2, 0, dstW2, H);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compass badge
|
||||||
|
const azDeg = Math.round((azRef.current / iW) * 360);
|
||||||
|
ctx.fillStyle = 'rgba(0,0,0,0.5)';
|
||||||
|
ctx.fillRect(W - 58, 6, 52, 18);
|
||||||
|
ctx.fillStyle = '#06b6d4';
|
||||||
|
ctx.font = 'bold 11px monospace';
|
||||||
|
ctx.fillText(`${azDeg}°`, W - 52, 19);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// Load image when URL changes
|
||||||
|
useEffect(() => {
|
||||||
|
if (!frameUrl) return;
|
||||||
|
const img = new Image();
|
||||||
|
img.onload = draw;
|
||||||
|
img.src = frameUrl;
|
||||||
|
imgRef.current = img;
|
||||||
|
}, [frameUrl, draw]);
|
||||||
|
|
||||||
|
// Re-draw when azimuth changes
|
||||||
|
const onMouseDown = e => { dragRef.current = e.clientX; };
|
||||||
|
const onMouseMove = e => {
|
||||||
|
if (dragRef.current == null) return;
|
||||||
|
const dx = e.clientX - dragRef.current;
|
||||||
|
dragRef.current = e.clientX;
|
||||||
|
azRef.current = ((azRef.current - dx * 2) % 1920 + 1920) % 1920;
|
||||||
|
draw();
|
||||||
|
};
|
||||||
|
const onMouseUp = () => { dragRef.current = null; };
|
||||||
|
|
||||||
|
const onTouchStart = e => { dragRef.current = e.touches[0].clientX; };
|
||||||
|
const onTouchMove = e => {
|
||||||
|
if (dragRef.current == null) return;
|
||||||
|
const dx = e.touches[0].clientX - dragRef.current;
|
||||||
|
dragRef.current = e.touches[0].clientX;
|
||||||
|
azRef.current = ((azRef.current - dx * 2) % 1920 + 1920) % 1920;
|
||||||
|
draw();
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<canvas
|
||||||
|
ref={canvasRef}
|
||||||
|
width={960}
|
||||||
|
height={240}
|
||||||
|
className="w-full object-contain bg-black cursor-ew-resize rounded"
|
||||||
|
onMouseDown={onMouseDown}
|
||||||
|
onMouseMove={onMouseMove}
|
||||||
|
onMouseUp={onMouseUp}
|
||||||
|
onMouseLeave={onMouseUp}
|
||||||
|
onTouchStart={onTouchStart}
|
||||||
|
onTouchMove={onTouchMove}
|
||||||
|
onTouchEnd={() => { dragRef.current = null; }}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── PiP mini window ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
function PiPWindow({ cam, frameUrl, fps, onClose, index }) {
|
||||||
|
const positions = [
|
||||||
|
'bottom-2 left-2',
|
||||||
|
'bottom-2 left-40',
|
||||||
|
'bottom-2 left-[18rem]',
|
||||||
|
];
|
||||||
|
return (
|
||||||
|
<div className={`absolute ${positions[index] ?? 'bottom-2 left-2'} w-36 rounded border border-cyan-900 overflow-hidden bg-black shadow-lg shadow-black z-10`}>
|
||||||
|
<div className="flex items-center justify-between px-1.5 py-0.5 bg-gray-950 text-xs">
|
||||||
|
<span className="text-cyan-700 font-bold">{cam.label}</span>
|
||||||
|
<div className="flex items-center gap-1">
|
||||||
|
<span className="text-gray-700">{fps}fps</span>
|
||||||
|
<button onClick={onClose} className="text-gray-600 hover:text-red-400">✕</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{frameUrl ? (
|
||||||
|
<img src={frameUrl} alt={cam.label} className="w-full aspect-video object-cover block" />
|
||||||
|
) : (
|
||||||
|
<div className="w-full aspect-video flex items-center justify-center text-gray-800 text-xs">
|
||||||
|
no signal
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Camera selector strip ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
function CameraStrip({ cameras, activeId, pipList, frames, fps, onSelect, onTogglePip }) {
|
||||||
|
return (
|
||||||
|
<div className="flex gap-1.5 flex-wrap">
|
||||||
|
{cameras.map(cam => {
|
||||||
|
const hasFrame = !!frames[cam.id];
|
||||||
|
const camFps = fps[cam.id] ?? 0;
|
||||||
|
const isActive = activeId === cam.id;
|
||||||
|
const isPip = pipList.includes(cam.id);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div key={cam.id} className="relative">
|
||||||
|
<button
|
||||||
|
onClick={() => onSelect(cam.id)}
|
||||||
|
className={`flex flex-col items-start rounded border px-2.5 py-1.5 text-xs font-bold transition-colors ${
|
||||||
|
isActive
|
||||||
|
? 'border-cyan-500 bg-cyan-950 bg-opacity-50 text-cyan-300'
|
||||||
|
: hasFrame
|
||||||
|
? 'border-gray-700 bg-gray-900 text-gray-400 hover:border-cyan-800 hover:text-gray-200'
|
||||||
|
: 'border-gray-800 bg-gray-950 text-gray-700 hover:border-gray-700'
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
<span>{cam.label.toUpperCase()}</span>
|
||||||
|
<span className={`text-xs font-normal mt-0.5 ${
|
||||||
|
camFps >= 12 ? 'text-green-600' :
|
||||||
|
camFps > 0 ? 'text-amber-600' :
|
||||||
|
'text-gray-700'
|
||||||
|
}`}>
|
||||||
|
{camFps > 0 ? `${camFps}fps` : 'no signal'}
|
||||||
|
</span>
|
||||||
|
</button>
|
||||||
|
|
||||||
|
{/* PiP pin button — only when NOT the active camera */}
|
||||||
|
{!isActive && (
|
||||||
|
<button
|
||||||
|
onClick={() => onTogglePip(cam.id)}
|
||||||
|
title={isPip ? 'Unpin PiP' : 'Pin PiP'}
|
||||||
|
className={`absolute -top-1.5 -right-1.5 w-4 h-4 rounded-full text-[9px] flex items-center justify-center border transition-colors ${
|
||||||
|
isPip
|
||||||
|
? 'bg-cyan-600 border-cyan-400 text-white'
|
||||||
|
: 'bg-gray-800 border-gray-700 text-gray-600 hover:border-cyan-700 hover:text-cyan-500'
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
{isPip ? '×' : '⊕'}
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Recording bar ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
function RecordingBar({ recording, recSeconds, onStart, onStop, onSnapshot, overlayRef }) {
|
||||||
|
const fmtTime = s => `${String(Math.floor(s / 60)).padStart(2, '0')}:${String(s % 60).padStart(2, '0')}`;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex items-center gap-2 flex-wrap">
|
||||||
|
{!recording ? (
|
||||||
|
<button
|
||||||
|
onClick={onStart}
|
||||||
|
className="flex items-center gap-1.5 px-3 py-1.5 rounded border border-red-900 bg-red-950 text-red-400 hover:bg-red-900 text-xs font-bold transition-colors"
|
||||||
|
>
|
||||||
|
<span className="w-2 h-2 rounded-full bg-red-500" />
|
||||||
|
REC
|
||||||
|
</button>
|
||||||
|
) : (
|
||||||
|
<button
|
||||||
|
onClick={onStop}
|
||||||
|
className="flex items-center gap-1.5 px-3 py-1.5 rounded border border-red-600 bg-red-900 text-red-300 hover:bg-red-800 text-xs font-bold animate-pulse"
|
||||||
|
>
|
||||||
|
<span className="w-2 h-2 rounded bg-red-400" />
|
||||||
|
STOP {fmtTime(recSeconds)}
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<button
|
||||||
|
onClick={() => onSnapshot(overlayRef?.current)}
|
||||||
|
className="flex items-center gap-1 px-3 py-1.5 rounded border border-gray-700 bg-gray-900 text-gray-400 hover:border-cyan-700 hover:text-cyan-400 text-xs font-bold transition-colors"
|
||||||
|
>
|
||||||
|
📷 SNAP
|
||||||
|
</button>
|
||||||
|
|
||||||
|
{recording && (
|
||||||
|
<span className="text-xs text-red-500 animate-pulse font-mono">
|
||||||
|
● RECORDING {fmtTime(recSeconds)}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Main component ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
export function CameraViewer({ subscribe }) {
|
||||||
|
const {
|
||||||
|
cameras, frames, fps,
|
||||||
|
activeId, setActiveId,
|
||||||
|
pipList, togglePip,
|
||||||
|
recording, recSeconds,
|
||||||
|
startRecording, stopRecording,
|
||||||
|
takeSnapshot,
|
||||||
|
} = useCamera({ subscribe });
|
||||||
|
|
||||||
|
// ── Detection state ─────────────────────────────────────────────────────────
|
||||||
|
const [faces, setFaces] = useState([]);
|
||||||
|
const [gestures, setGestures] = useState([]);
|
||||||
|
const [sceneObjects, setSceneObjects] = useState([]);
|
||||||
|
|
||||||
|
const [showOverlay, setShowOverlay] = useState(true);
|
||||||
|
const [overlayMode, setOverlayMode] = useState('all'); // 'all' | 'faces' | 'gestures' | 'objects' | 'off'
|
||||||
|
|
||||||
|
const overlayCanvasRef = useRef(null);
|
||||||
|
|
||||||
|
// Subscribe to detection topics
|
||||||
|
useEffect(() => {
|
||||||
|
if (!subscribe) return;
|
||||||
|
const u1 = subscribe('/social/faces/detections', 'saltybot_social_msgs/FaceDetectionArray', msg => {
|
||||||
|
setFaces(msg.faces ?? []);
|
||||||
|
});
|
||||||
|
const u2 = subscribe('/social/gestures', 'saltybot_social_msgs/GestureArray', msg => {
|
||||||
|
setGestures(msg.gestures ?? []);
|
||||||
|
});
|
||||||
|
const u3 = subscribe('/social/scene/objects', 'saltybot_scene_msgs/SceneObjectArray', msg => {
|
||||||
|
setSceneObjects(msg.objects ?? []);
|
||||||
|
});
|
||||||
|
return () => { u1?.(); u2?.(); u3?.(); };
|
||||||
|
}, [subscribe]);
|
||||||
|
|
||||||
|
const activeCam = CAMERA_BY_ID[activeId];
|
||||||
|
const activeFrame = frames[activeId];
|
||||||
|
|
||||||
|
// Filter overlay data based on mode
|
||||||
|
const visibleFaces = (overlayMode === 'all' || overlayMode === 'faces') ? faces : [];
|
||||||
|
const visibleGestures = (overlayMode === 'all' || overlayMode === 'gestures') ? gestures : [];
|
||||||
|
const visibleObjects = (overlayMode === 'all' || overlayMode === 'objects') ? sceneObjects : [];
|
||||||
|
|
||||||
|
// ── Container size tracking (for overlay canvas sizing) ────────────────────
|
||||||
|
const containerRef = useRef(null);
|
||||||
|
const [containerSize, setContainerSize] = useState({ w: 640, h: 480 });
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!containerRef.current) return;
|
||||||
|
const ro = new ResizeObserver(entries => {
|
||||||
|
const e = entries[0];
|
||||||
|
setContainerSize({ w: Math.round(e.contentRect.width), h: Math.round(e.contentRect.height) });
|
||||||
|
});
|
||||||
|
ro.observe(containerRef.current);
|
||||||
|
return () => ro.disconnect();
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// ── Quality badge ──────────────────────────────────────────────────────────
|
||||||
|
const camFps = fps[activeId] ?? 0;
|
||||||
|
const quality = camFps >= 13 ? 'FULL' : camFps >= 8 ? 'GOOD' : camFps > 0 ? 'LOW' : 'NO SIGNAL';
|
||||||
|
const qualColor = camFps >= 13 ? 'text-green-500' : camFps >= 8 ? 'text-amber-500' : camFps > 0 ? 'text-red-500' : 'text-gray-700';
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="space-y-3">
|
||||||
|
|
||||||
|
{/* ── Camera strip ── */}
|
||||||
|
<div className="bg-gray-950 rounded-lg border border-cyan-950 p-3 space-y-2">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<div className="text-cyan-700 text-xs font-bold tracking-widest">CAMERA SELECT</div>
|
||||||
|
<span className={`text-xs font-bold ${qualColor}`}>{quality} {camFps > 0 ? `${camFps}fps` : ''}</span>
|
||||||
|
</div>
|
||||||
|
<CameraStrip
|
||||||
|
cameras={cameras}
|
||||||
|
activeId={activeId}
|
||||||
|
pipList={pipList}
|
||||||
|
frames={frames}
|
||||||
|
fps={fps}
|
||||||
|
onSelect={setActiveId}
|
||||||
|
onTogglePip={togglePip}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* ── Main viewer ── */}
|
||||||
|
<div className="bg-gray-950 rounded-lg border border-cyan-950 overflow-hidden">
|
||||||
|
|
||||||
|
{/* Viewer toolbar */}
|
||||||
|
<div className="flex items-center justify-between px-3 py-2 border-b border-cyan-950">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<span className="text-cyan-400 text-xs font-bold">{activeCam?.label ?? '—'}</span>
|
||||||
|
{activeCam?.isDepth && (
|
||||||
|
<span className="text-xs text-gray-600 border border-gray-800 rounded px-1">DEPTH · greyscale</span>
|
||||||
|
)}
|
||||||
|
{activeCam?.isPanoramic && (
|
||||||
|
<span className="text-xs text-gray-600 border border-gray-800 rounded px-1">360° · drag to pan</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Overlay mode selector */}
|
||||||
|
<div className="flex items-center gap-1">
|
||||||
|
{['off','faces','gestures','objects','all'].map(mode => (
|
||||||
|
<button
|
||||||
|
key={mode}
|
||||||
|
onClick={() => setOverlayMode(mode)}
|
||||||
|
className={`px-2 py-0.5 rounded text-xs border transition-colors ${
|
||||||
|
overlayMode === mode
|
||||||
|
? 'border-cyan-600 bg-cyan-950 text-cyan-400'
|
||||||
|
: 'border-gray-800 text-gray-600 hover:border-gray-700 hover:text-gray-400'
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
{mode === 'all' ? 'ALL' : mode === 'off' ? 'OFF' : mode.slice(0,3).toUpperCase()}
|
||||||
|
</button>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Image + overlay */}
|
||||||
|
<div className="relative" ref={containerRef}>
|
||||||
|
{activeCam?.isPanoramic ? (
|
||||||
|
<PanoViewer frameUrl={activeFrame} />
|
||||||
|
) : activeFrame ? (
|
||||||
|
<img
|
||||||
|
src={activeFrame}
|
||||||
|
alt={activeCam?.label ?? 'camera'}
|
||||||
|
className="w-full object-contain block bg-black"
|
||||||
|
style={{ maxHeight: '480px' }}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<div className="w-full bg-black flex items-center justify-center text-gray-800 text-sm font-mono"
|
||||||
|
style={{ height: '360px' }}>
|
||||||
|
<div className="text-center space-y-2">
|
||||||
|
<div className="text-2xl">📷</div>
|
||||||
|
<div>Waiting for {activeCam?.label ?? '—'}…</div>
|
||||||
|
<div className="text-xs text-gray-700">{activeCam?.topic}</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Detection overlay canvas */}
|
||||||
|
{overlayMode !== 'off' && !activeCam?.isPanoramic && (
|
||||||
|
<OverlayCanvas
|
||||||
|
ref={overlayCanvasRef}
|
||||||
|
faces={visibleFaces}
|
||||||
|
gestures={visibleGestures}
|
||||||
|
sceneObjects={visibleObjects}
|
||||||
|
activeCam={activeCam}
|
||||||
|
containerW={containerSize.w}
|
||||||
|
containerH={containerSize.h}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* PiP windows */}
|
||||||
|
{pipList.map((id, idx) => {
|
||||||
|
const cam = CAMERA_BY_ID[id];
|
||||||
|
if (!cam) return null;
|
||||||
|
return (
|
||||||
|
<PiPWindow
|
||||||
|
key={id}
|
||||||
|
cam={cam}
|
||||||
|
frameUrl={frames[id]}
|
||||||
|
fps={fps[id] ?? 0}
|
||||||
|
index={idx}
|
||||||
|
onClose={() => togglePip(id)}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* ── Recording controls ── */}
|
||||||
|
<div className="bg-gray-950 rounded-lg border border-cyan-950 p-3">
|
||||||
|
<div className="flex items-center justify-between mb-2">
|
||||||
|
<div className="text-cyan-700 text-xs font-bold tracking-widest">CAPTURE</div>
|
||||||
|
</div>
|
||||||
|
<RecordingBar
|
||||||
|
recording={recording}
|
||||||
|
recSeconds={recSeconds}
|
||||||
|
onStart={startRecording}
|
||||||
|
onStop={stopRecording}
|
||||||
|
onSnapshot={takeSnapshot}
|
||||||
|
overlayRef={overlayCanvasRef}
|
||||||
|
/>
|
||||||
|
<div className="mt-2 text-xs text-gray-700">
|
||||||
|
Recording saves as MP4/WebM to your Downloads.
|
||||||
|
Snapshot includes detection overlay + timestamp.
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* ── Detection status ── */}
|
||||||
|
<div className="grid grid-cols-3 gap-2 text-xs">
|
||||||
|
<div className="bg-gray-950 rounded border border-gray-800 p-2">
|
||||||
|
<div className="text-gray-600 mb-1">FACES</div>
|
||||||
|
<div className={`font-bold ${faces.length > 0 ? 'text-cyan-400' : 'text-gray-700'}`}>
|
||||||
|
{faces.length > 0 ? `${faces.length} detected` : 'none'}
|
||||||
|
</div>
|
||||||
|
{faces.slice(0, 2).map((f, i) => (
|
||||||
|
<div key={i} className="text-gray-600 truncate">
|
||||||
|
{f.person_name && f.person_name !== 'unknown'
|
||||||
|
? `↳ ${f.person_name}`
|
||||||
|
: `↳ unknown #${f.face_id}`}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="bg-gray-950 rounded border border-gray-800 p-2">
|
||||||
|
<div className="text-gray-600 mb-1">GESTURES</div>
|
||||||
|
<div className={`font-bold ${gestures.length > 0 ? 'text-amber-400' : 'text-gray-700'}`}>
|
||||||
|
{gestures.length > 0 ? `${gestures.length} active` : 'none'}
|
||||||
|
</div>
|
||||||
|
{gestures.slice(0, 2).map((g, i) => {
|
||||||
|
const icon = GESTURE_ICONS[g.gesture_type] ?? '?';
|
||||||
|
return (
|
||||||
|
<div key={i} className="text-gray-600 truncate">
|
||||||
|
{icon} {g.gesture_type} cam{g.camera_id}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="bg-gray-950 rounded border border-gray-800 p-2">
|
||||||
|
<div className="text-gray-600 mb-1">OBJECTS</div>
|
||||||
|
<div className={`font-bold ${sceneObjects.length > 0 ? 'text-green-400' : 'text-gray-700'}`}>
|
||||||
|
{sceneObjects.length > 0 ? `${sceneObjects.length} objects` : 'none'}
|
||||||
|
</div>
|
||||||
|
{sceneObjects
|
||||||
|
.filter(o => o.hazard_type > 0)
|
||||||
|
.slice(0, 2)
|
||||||
|
.map((o, i) => (
|
||||||
|
<div key={i} className="text-amber-700 truncate">⚠ {o.class_name}</div>
|
||||||
|
))
|
||||||
|
}
|
||||||
|
{sceneObjects.filter(o => o.hazard_type === 0).slice(0, 2).map((o, i) => (
|
||||||
|
<div key={`ok${i}`} className="text-gray-600 truncate">
|
||||||
|
{o.class_name} {o.distance_m > 0 ? `${o.distance_m.toFixed(1)}m` : ''}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* ── Legend ── */}
|
||||||
|
<div className="flex gap-4 text-xs text-gray-700 flex-wrap">
|
||||||
|
<div className="flex items-center gap-1">
|
||||||
|
<div className="w-3 h-3 rounded-sm border border-cyan-600" />
|
||||||
|
Known face
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center gap-1">
|
||||||
|
<div className="w-3 h-3 rounded-sm border border-amber-600" />
|
||||||
|
Unknown face
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center gap-1">
|
||||||
|
<span>👆</span> Gesture
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center gap-1">
|
||||||
|
<div className="w-3 h-3 rounded-sm border border-green-700 border-dashed" />
|
||||||
|
Object
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center gap-1">
|
||||||
|
<div className="w-3 h-3 rounded-sm border border-amber-600 border-dashed" />
|
||||||
|
Hazard
|
||||||
|
</div>
|
||||||
|
<div className="ml-auto text-gray-800 italic">
|
||||||
|
⊕ pin = PiP · overlay: {overlayMode}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
325
ui/social-bot/src/hooks/useCamera.js
Normal file
325
ui/social-bot/src/hooks/useCamera.js
Normal file
@ -0,0 +1,325 @@
|
|||||||
|
/**
|
||||||
|
* useCamera.js — Multi-camera stream manager (Issue #177).
|
||||||
|
*
|
||||||
|
* Subscribes to sensor_msgs/CompressedImage topics via rosbridge.
|
||||||
|
* Decodes base64 JPEG/PNG → data URL for <img>/<canvas> display.
|
||||||
|
* Tracks per-camera FPS. Manages MediaRecorder for recording + snapshots.
|
||||||
|
*
|
||||||
|
* Camera sources:
|
||||||
|
* front / left / rear / right — 4× CSI IMX219, 640×480
|
||||||
|
* topic: /camera/<name>/image_raw/compressed
|
||||||
|
* color — D435i RGB, 640×480
|
||||||
|
* topic: /camera/color/image_raw/compressed
|
||||||
|
* depth — D435i depth, 640×480 greyscale (PNG16)
|
||||||
|
* topic: /camera/depth/image_rect_raw/compressed
|
||||||
|
* panoramic — equirect stitch 1920×960
|
||||||
|
* topic: /camera/panoramic/compressed
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { useState, useEffect, useRef, useCallback } from 'react';
|
||||||
|
|
||||||
|
// ── Camera catalogue ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
export const CAMERAS = [
|
||||||
|
{
|
||||||
|
id: 'front',
|
||||||
|
label: 'Front',
|
||||||
|
shortLabel: 'F',
|
||||||
|
topic: '/camera/front/image_raw/compressed',
|
||||||
|
msgType: 'sensor_msgs/CompressedImage',
|
||||||
|
cameraId: 0, // matches gesture_node camera_id
|
||||||
|
width: 640, height: 480,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'left',
|
||||||
|
label: 'Left',
|
||||||
|
shortLabel: 'L',
|
||||||
|
topic: '/camera/left/image_raw/compressed',
|
||||||
|
msgType: 'sensor_msgs/CompressedImage',
|
||||||
|
cameraId: 1,
|
||||||
|
width: 640, height: 480,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'rear',
|
||||||
|
label: 'Rear',
|
||||||
|
shortLabel: 'R',
|
||||||
|
topic: '/camera/rear/image_raw/compressed',
|
||||||
|
msgType: 'sensor_msgs/CompressedImage',
|
||||||
|
cameraId: 2,
|
||||||
|
width: 640, height: 480,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'right',
|
||||||
|
label: 'Right',
|
||||||
|
shortLabel: 'Rt',
|
||||||
|
topic: '/camera/right/image_raw/compressed',
|
||||||
|
msgType: 'sensor_msgs/CompressedImage',
|
||||||
|
cameraId: 3,
|
||||||
|
width: 640, height: 480,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'color',
|
||||||
|
label: 'D435i RGB',
|
||||||
|
shortLabel: 'D',
|
||||||
|
topic: '/camera/color/image_raw/compressed',
|
||||||
|
msgType: 'sensor_msgs/CompressedImage',
|
||||||
|
cameraId: 4,
|
||||||
|
width: 640, height: 480,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'depth',
|
||||||
|
label: 'Depth',
|
||||||
|
shortLabel: '≋',
|
||||||
|
topic: '/camera/depth/image_rect_raw/compressed',
|
||||||
|
msgType: 'sensor_msgs/CompressedImage',
|
||||||
|
cameraId: 5,
|
||||||
|
width: 640, height: 480,
|
||||||
|
isDepth: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'panoramic',
|
||||||
|
label: 'Panoramic',
|
||||||
|
shortLabel: '360',
|
||||||
|
topic: '/camera/panoramic/compressed',
|
||||||
|
msgType: 'sensor_msgs/CompressedImage',
|
||||||
|
cameraId: -1,
|
||||||
|
width: 1920, height: 960,
|
||||||
|
isPanoramic: true,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
export const CAMERA_BY_ID = Object.fromEntries(CAMERAS.map(c => [c.id, c]));
|
||||||
|
export const CAMERA_BY_ROS_ID = Object.fromEntries(
|
||||||
|
CAMERAS.filter(c => c.cameraId >= 0).map(c => [c.cameraId, c])
|
||||||
|
);
|
||||||
|
|
||||||
|
const TARGET_FPS = 15;
|
||||||
|
const FPS_INTERVAL = 1000; // ms between FPS counter resets
|
||||||
|
|
||||||
|
// ── Hook ──────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
export function useCamera({ subscribe } = {}) {
|
||||||
|
const [frames, setFrames] = useState(() =>
|
||||||
|
Object.fromEntries(CAMERAS.map(c => [c.id, null]))
|
||||||
|
);
|
||||||
|
const [fps, setFps] = useState(() =>
|
||||||
|
Object.fromEntries(CAMERAS.map(c => [c.id, 0]))
|
||||||
|
);
|
||||||
|
const [activeId, setActiveId] = useState('front');
|
||||||
|
const [pipList, setPipList] = useState([]); // up to 3 extra camera ids
|
||||||
|
const [recording, setRecording] = useState(false);
|
||||||
|
const [recSeconds, setRecSeconds] = useState(0);
|
||||||
|
|
||||||
|
// ── Refs (not state — no re-render needed) ─────────────────────────────────
|
||||||
|
const countRef = useRef(Object.fromEntries(CAMERAS.map(c => [c.id, 0])));
|
||||||
|
const mediaRecRef = useRef(null);
|
||||||
|
const chunksRef = useRef([]);
|
||||||
|
const recTimerRef = useRef(null);
|
||||||
|
const recordCanvas = useRef(null); // hidden canvas used for recording
|
||||||
|
const recAnimRef = useRef(null); // rAF handle for record-canvas loop
|
||||||
|
const latestFrameRef = useRef(Object.fromEntries(CAMERAS.map(c => [c.id, null])));
|
||||||
|
const latestTsRef = useRef(Object.fromEntries(CAMERAS.map(c => [c.id, 0])));
|
||||||
|
|
||||||
|
// ── FPS counter ────────────────────────────────────────────────────────────
|
||||||
|
useEffect(() => {
|
||||||
|
const timer = setInterval(() => {
|
||||||
|
setFps({ ...countRef.current });
|
||||||
|
const reset = Object.fromEntries(CAMERAS.map(c => [c.id, 0]));
|
||||||
|
countRef.current = reset;
|
||||||
|
}, FPS_INTERVAL);
|
||||||
|
return () => clearInterval(timer);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// ── Subscribe all camera topics ────────────────────────────────────────────
|
||||||
|
useEffect(() => {
|
||||||
|
if (!subscribe) return;
|
||||||
|
|
||||||
|
const unsubs = CAMERAS.map(cam => {
|
||||||
|
let lastTs = 0;
|
||||||
|
const interval = Math.floor(1000 / TARGET_FPS); // client-side 15fps gate
|
||||||
|
|
||||||
|
return subscribe(cam.topic, cam.msgType, (msg) => {
|
||||||
|
const now = Date.now();
|
||||||
|
if (now - lastTs < interval) return; // drop frames > 15fps
|
||||||
|
lastTs = now;
|
||||||
|
|
||||||
|
const fmt = msg.format || 'jpeg';
|
||||||
|
const mime = fmt.includes('png') || fmt.includes('16UC') ? 'image/png' : 'image/jpeg';
|
||||||
|
const dataUrl = `data:${mime};base64,${msg.data}`;
|
||||||
|
|
||||||
|
latestFrameRef.current[cam.id] = dataUrl;
|
||||||
|
latestTsRef.current[cam.id] = now;
|
||||||
|
countRef.current[cam.id] = (countRef.current[cam.id] ?? 0) + 1;
|
||||||
|
|
||||||
|
setFrames(prev => ({ ...prev, [cam.id]: dataUrl }));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return () => unsubs.forEach(fn => fn?.());
|
||||||
|
}, [subscribe]);
|
||||||
|
|
||||||
|
// ── Create hidden record canvas ────────────────────────────────────────────
|
||||||
|
useEffect(() => {
|
||||||
|
const c = document.createElement('canvas');
|
||||||
|
c.width = 640;
|
||||||
|
c.height = 480;
|
||||||
|
c.style.display = 'none';
|
||||||
|
document.body.appendChild(c);
|
||||||
|
recordCanvas.current = c;
|
||||||
|
return () => { c.remove(); };
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// ── Draw loop for record canvas ────────────────────────────────────────────
|
||||||
|
// Runs at TARGET_FPS when recording — draws active frame to hidden canvas
|
||||||
|
const startRecordLoop = useCallback(() => {
|
||||||
|
const canvas = recordCanvas.current;
|
||||||
|
if (!canvas) return;
|
||||||
|
|
||||||
|
const step = () => {
|
||||||
|
const cam = CAMERA_BY_ID[activeId];
|
||||||
|
const src = latestFrameRef.current[activeId];
|
||||||
|
const ctx = canvas.getContext('2d');
|
||||||
|
|
||||||
|
if (!cam || !src) {
|
||||||
|
recAnimRef.current = requestAnimationFrame(step);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resize canvas to match source
|
||||||
|
if (canvas.width !== cam.width || canvas.height !== cam.height) {
|
||||||
|
canvas.width = cam.width;
|
||||||
|
canvas.height = cam.height;
|
||||||
|
}
|
||||||
|
|
||||||
|
const img = new Image();
|
||||||
|
img.onload = () => {
|
||||||
|
ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
|
||||||
|
};
|
||||||
|
img.src = src;
|
||||||
|
|
||||||
|
recAnimRef.current = setTimeout(step, Math.floor(1000 / TARGET_FPS));
|
||||||
|
};
|
||||||
|
|
||||||
|
recAnimRef.current = setTimeout(step, 0);
|
||||||
|
}, [activeId]);
|
||||||
|
|
||||||
|
const stopRecordLoop = useCallback(() => {
|
||||||
|
if (recAnimRef.current) {
|
||||||
|
clearTimeout(recAnimRef.current);
|
||||||
|
cancelAnimationFrame(recAnimRef.current);
|
||||||
|
recAnimRef.current = null;
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// ── Recording ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
const startRecording = useCallback(() => {
|
||||||
|
const canvas = recordCanvas.current;
|
||||||
|
if (!canvas || recording) return;
|
||||||
|
|
||||||
|
startRecordLoop();
|
||||||
|
|
||||||
|
const stream = canvas.captureStream(TARGET_FPS);
|
||||||
|
const mimeType =
|
||||||
|
MediaRecorder.isTypeSupported('video/mp4') ? 'video/mp4' :
|
||||||
|
MediaRecorder.isTypeSupported('video/webm;codecs=vp9') ? 'video/webm;codecs=vp9' :
|
||||||
|
MediaRecorder.isTypeSupported('video/webm;codecs=vp8') ? 'video/webm;codecs=vp8' :
|
||||||
|
'video/webm';
|
||||||
|
|
||||||
|
chunksRef.current = [];
|
||||||
|
const mr = new MediaRecorder(stream, { mimeType, videoBitsPerSecond: 2_500_000 });
|
||||||
|
mr.ondataavailable = e => { if (e.data?.size > 0) chunksRef.current.push(e.data); };
|
||||||
|
mr.start(200);
|
||||||
|
mediaRecRef.current = mr;
|
||||||
|
setRecording(true);
|
||||||
|
setRecSeconds(0);
|
||||||
|
recTimerRef.current = setInterval(() => setRecSeconds(s => s + 1), 1000);
|
||||||
|
}, [recording, startRecordLoop]);
|
||||||
|
|
||||||
|
const stopRecording = useCallback(() => {
|
||||||
|
const mr = mediaRecRef.current;
|
||||||
|
if (!mr || mr.state === 'inactive') return;
|
||||||
|
|
||||||
|
mr.onstop = () => {
|
||||||
|
const ext = mr.mimeType.includes('mp4') ? 'mp4' : 'webm';
|
||||||
|
const blob = new Blob(chunksRef.current, { type: mr.mimeType });
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const a = document.createElement('a');
|
||||||
|
a.href = url;
|
||||||
|
a.download = `saltybot-${activeId}-${Date.now()}.${ext}`;
|
||||||
|
a.click();
|
||||||
|
URL.revokeObjectURL(url);
|
||||||
|
};
|
||||||
|
|
||||||
|
mr.stop();
|
||||||
|
stopRecordLoop();
|
||||||
|
clearInterval(recTimerRef.current);
|
||||||
|
setRecording(false);
|
||||||
|
}, [activeId, stopRecordLoop]);
|
||||||
|
|
||||||
|
// ── Snapshot ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
const takeSnapshot = useCallback((overlayCanvasEl) => {
|
||||||
|
const src = latestFrameRef.current[activeId];
|
||||||
|
if (!src) return;
|
||||||
|
|
||||||
|
const cam = CAMERA_BY_ID[activeId];
|
||||||
|
const canvas = document.createElement('canvas');
|
||||||
|
canvas.width = cam.width;
|
||||||
|
canvas.height = cam.height;
|
||||||
|
const ctx = canvas.getContext('2d');
|
||||||
|
|
||||||
|
const img = new Image();
|
||||||
|
img.onload = () => {
|
||||||
|
ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
|
||||||
|
|
||||||
|
// Composite detection overlay if provided
|
||||||
|
if (overlayCanvasEl) {
|
||||||
|
ctx.drawImage(overlayCanvasEl, 0, 0, canvas.width, canvas.height);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timestamp watermark
|
||||||
|
ctx.fillStyle = 'rgba(0,0,0,0.5)';
|
||||||
|
ctx.fillRect(0, canvas.height - 20, canvas.width, 20);
|
||||||
|
ctx.fillStyle = '#06b6d4';
|
||||||
|
ctx.font = '11px monospace';
|
||||||
|
ctx.fillText(`SALTYBOT ${cam.label} ${new Date().toISOString()}`, 8, canvas.height - 6);
|
||||||
|
|
||||||
|
canvas.toBlob(blob => {
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const a = document.createElement('a');
|
||||||
|
a.href = url;
|
||||||
|
a.download = `saltybot-snap-${activeId}-${Date.now()}.png`;
|
||||||
|
a.click();
|
||||||
|
URL.revokeObjectURL(url);
|
||||||
|
}, 'image/png');
|
||||||
|
};
|
||||||
|
img.src = src;
|
||||||
|
}, [activeId]);
|
||||||
|
|
||||||
|
// ── PiP management ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
const togglePip = useCallback(id => {
|
||||||
|
setPipList(prev => {
|
||||||
|
if (prev.includes(id)) return prev.filter(x => x !== id);
|
||||||
|
const next = [...prev, id].filter(x => x !== activeId);
|
||||||
|
return next.slice(-3); // max 3 PIPs
|
||||||
|
});
|
||||||
|
}, [activeId]);
|
||||||
|
|
||||||
|
// Remove PiP if it becomes the active camera
|
||||||
|
useEffect(() => {
|
||||||
|
setPipList(prev => prev.filter(id => id !== activeId));
|
||||||
|
}, [activeId]);
|
||||||
|
|
||||||
|
return {
|
||||||
|
cameras: CAMERAS,
|
||||||
|
frames,
|
||||||
|
fps,
|
||||||
|
activeId, setActiveId,
|
||||||
|
pipList, togglePip,
|
||||||
|
recording, recSeconds,
|
||||||
|
startRecording, stopRecording,
|
||||||
|
takeSnapshot,
|
||||||
|
};
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user