Compare commits

...

11 Commits

Author SHA1 Message Date
90c8b427fc feat(social): multi-language support — Whisper LID + per-lang Piper TTS (Issue #167)
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (push) Failing after 2s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (push) Has been skipped
social-bot integration tests / Lint (flake8 + pep257) (pull_request) Failing after 10s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (pull_request) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (push) Has been cancelled
social-bot integration tests / Latency profiling (GPU, Orin) (pull_request) Has been cancelled
- Add SpeechTranscript.language (BCP-47), ConversationResponse.language fields
- speech_pipeline_node: whisper_language param (""=auto-detect via Whisper LID);
  detected language published in every transcript
- conversation_node: track per-speaker language; inject "[Please respond in X.]"
  hint for non-English speakers; propagate language to ConversationResponse.
  _LANG_NAMES: 24 BCP-47 codes -> English names. Also adds Issue #161 emotion
  context plumbing (co-located in same branch for clean merge)
- tts_node: voice_map_json param (JSON BCP-47->ONNX path); lazy voice loading
  per language; playback queue now carries (text, lang) tuples for voice routing
- speech_params.yaml, tts_params.yaml: new language params with docs
- 47/47 tests pass (test_multilang.py)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 10:57:34 -05:00
077f26d9d6 Merge pull request 'feat(power): STOP-mode sleep/wake power manager — Issue #178' (#186) from sl-firmware/issue-178-power-mgmt into main 2026-03-02 10:56:52 -05:00
f446e5766e feat(power): STOP-mode sleep/wake power manager — Issue #178
Adds STM32F7 STOP-mode power management with <10ms wake latency:

- power_mgmt.c: state machine (ACTIVE→SLEEP_PENDING→SLEEPING→WAKING),
  30s idle timeout (PM_IDLE_TIMEOUT_MS), 3s LED fade before STOP,
  gate SPI3/I2S3+SPI2+USART6+UART5 on sleep (clock-only, state preserved),
  EXTI1(PA1/CRSF)+EXTI7(PB7/JLink)+EXTI4(PC4/IMU) wake sources,
  PLL restore after STOP (PLLM=8/N=216/P=2 → 216MHz), uwTick save/restore
- Peripheral gating: I2S3, SPI2(OSD), USART6, UART5 disabled during STOP;
  SPI1(IMU), UART4(CRSF), USART1(JLink), I2C1 remain active as wake sources
- Sleep LED: triangle-wave pulse (2s period) on LED1 during SLEEP_PENDING,
  software PWM in main loop (1-bit, pm_pwm_phase vs brightness)
- IWDG: fed just before WFI; <10ms wake << 50ms WATCHDOG_TIMEOUT_MS
- JLink: JLINK_CMD_SLEEP=0x09, JLINK_TLM_POWER=0x81 (11-byte power frame
  at 1Hz: power_state, est_total_ma, est_audio_ma, est_osd_ma, idle_ms)
- main.c: power_mgmt_init(), activity() on CRSF/JLink/armed, tick() when
  disarmed, sleep_req handler, LED PWM, JLINK_TLM_POWER telemetry
- config.h: PM_* constants, PM_CURRENT_*_MA estimates, PM_TLM_HZ
- test_power_mgmt.py: 72 tests passing (state machine, LED, gating,
  current estimates, JLink protocol, wake latency, hardware constants)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 10:53:02 -05:00
728d1b0c0e Merge pull request 'feat(webui): live camera viewer — multi-stream + detection overlays (Issue #177)' (#182) from sl-webui/issue-177-camera-viewer into main 2026-03-02 10:48:50 -05:00
57420807ca feat(webui): live camera viewer — multi-stream + detection overlays (Issue #177)
UI (src/hooks/useCamera.js, src/components/CameraViewer.jsx):
  - 7 camera sources: front/left/rear/right CSI, D435i RGB/depth, panoramic
  - Compressed image subscription via rosbridge (sensor_msgs/CompressedImage)
  - Client-side 15fps gate (drops excess frames, reduces JS pressure)
  - Per-camera FPS indicator with quality badge (FULL/GOOD/LOW/NO SIGNAL)
  - Detection overlays: face boxes + names (/social/faces/detections),
    gesture icons (/social/gestures), scene object labels + hazard colours
    (/social/scene/objects); overlay mode selector (off/faces/gestures/objects/all)
  - 360° panoramic equirect viewer with mouse/touch drag azimuth pan
  - Picture-in-picture: up to 3 pinned cameras via ⊕ button
  - One-click recording (MediaRecorder → MP4/WebM download)
  - Snapshot to PNG with detection overlay composite + timestamp watermark
  - Cameras tab added to TELEMETRY group in App.jsx

Jetson (rosbridge bringup):
  - rosbridge_params.yaml: whitelist + /camera/depth/image_rect_raw/compressed,
    /camera/panoramic/compressed, /social/faces/detections,
    /social/gestures, /social/scene/objects
  - rosbridge.launch.py: D435i colour republisher (JPEG 75%) +
    depth republisher (compressedDepth/PNG16 preserving uint16 values)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 10:47:01 -05:00
9ca0e0844c Merge pull request 'feat(social): facial expression recognition — TRT FP16 emotion CNN (Issue #161)' (#180) from sl-jetson/issue-161-emotion into main
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (push) Failing after 10s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (push) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (push) Has been cancelled
2026-03-02 10:46:21 -05:00
54668536c1 Merge pull request 'feat(jetson): dynamic obstacle tracking — LIDAR motion detection, Kalman tracking, trajectory prediction, Nav2 costmap (#176)' (#181) from sl-perception/issue-176-dynamic-obstacles into main 2026-03-02 10:45:22 -05:00
c4bf8c371f Merge pull request 'feat(#169): emergency behavior system — obstacle stop, fall prevention, stuck detection, recovery FSM' (#179) from sl-controls/issue-169-emergency into main 2026-03-02 10:44:49 -05:00
2f4540f1d3 feat(jetson): add dynamic obstacle tracking package (issue #176)
Implements real-time moving obstacle detection, Kalman tracking, trajectory
prediction, and Nav2 costmap integration at 10 Hz / <50ms latency:

saltybot_dynamic_obs_msgs (ament_cmake):
• TrackedObject.msg      — id, PoseWithCovariance, velocity, predicted_path,
                           predicted_times, speed, confidence, age, hits
• MovingObjectArray.msg  — TrackedObject[], active_count, tentative_count,
                           detector_latency_ms

saltybot_dynamic_obstacles (ament_python):
• object_detector.py    — LIDAR background subtraction (EMA occupancy grid),
                           foreground dilation + scipy connected-component
                           clustering → Detection list
• kalman_tracker.py     — CV Kalman filter, state [px,py,vx,vy], Joseph-form
                           covariance update, predict_horizon() (non-mutating)
• tracker_manager.py    — up to 20 tracks, Hungarian assignment
                           (scipy.optimize.linear_sum_assignment), TENTATIVE→
                           CONFIRMED lifecycle, miss-prune
• dynamic_obs_node.py   — 10 Hz timer: detect→track→publish
                           /saltybot/moving_objects + MarkerArray viz
• costmap_layer_node.py — predicted paths → PointCloud2 inflation smear
                           → /saltybot/dynamic_obs_cloud for Nav2 ObstacleLayer
• launch/dynamic_obstacles.launch.py + config/dynamic_obstacles_params.yaml
• test/test_dynamic_obstacles.py — 27 unit tests (27/27 pass, no ROS2 needed)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 10:44:32 -05:00
50971c0946 feat(social): facial expression recognition — TRT FP16 emotion CNN (Issue #161)
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (push) Failing after 2s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (push) Has been skipped
social-bot integration tests / Lint (flake8 + pep257) (pull_request) Failing after 2s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (pull_request) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (push) Has been cancelled
social-bot integration tests / Latency profiling (GPU, Orin) (pull_request) Has been cancelled
- Add Expression.msg / ExpressionArray.msg ROS2 message definitions
- Add emotion_classifier.py: 7-class CNN (happy/sad/angry/surprised/fearful/disgusted/neutral)
  via TensorRT FP16 with landmark-geometry fallback; EMA per-person smoothing; opt-out registry
- Add emotion_node.py: subscribes /social/faces/detections, runs TRT crop inference (<5ms),
  publishes /social/faces/expressions and /social/emotion/context JSON for LLM
- Wire emotion context into conversation_node.py: emotion hint injected into LLM prompt
  when speaker shows non-neutral affect; subscribes /social/emotion/context
- Add emotion_params.yaml config and emotion.launch.py launch file
- Add 67-test suite (test_emotion_classifier.py): classifier, tracker, opt-out, heuristic

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 10:40:54 -05:00
3b2f219d66 feat(#169): emergency behavior system — obstacle stop, fall prevention, stuck detection, recovery FSM
Two new packages:
- saltybot_emergency_msgs: EmergencyEvent.msg, RecoveryAction.msg
- saltybot_emergency: threat_detector, alert_manager, recovery_sequencer, emergency_fsm, emergency_node

Implements:
- ObstacleDetector: MAJOR <30cm, CRITICAL <10cm; suppressed when not moving
- FallDetector: MINOR/MAJOR/CRITICAL tilt thresholds; floor-drop edge detection
- StuckDetector: MAJOR after 3s wheel stall (cmd>threshold, actual~0)
- BumpDetector: jerk = |Δ|a||/dt with gravity removal; MAJOR/CRITICAL thresholds
- AlertManager: per-(type,level) suppression; MAJOR×N within window → CRITICAL escalation
- RecoverySequencer: REVERSING→TURNING→RETRYING FSM; max_retries before GAVE_UP
- EmergencyFSM: NOMINAL→STOPPING→RECOVERING→ESCALATED; acknowledge to clear
- EmergencyNode: 20Hz ROS2 node; /saltybot/emergency, /saltybot/e_stop, cmd_vel mux

59/59 tests passing.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 10:39:37 -05:00
63 changed files with 7425 additions and 343 deletions

View File

@ -189,6 +189,18 @@
/* Full blend transition time: MANUAL→AUTO takes this many ms */ /* Full blend transition time: MANUAL→AUTO takes this many ms */
#define MODE_BLEND_MS 500 #define MODE_BLEND_MS 500
// --- Power Management (STOP mode, Issue #178) ---
#define PM_IDLE_TIMEOUT_MS 30000u // 30s no activity → PM_SLEEP_PENDING
#define PM_FADE_MS 3000u // LED fade-out duration before STOP entry
#define PM_LED_PERIOD_MS 2000u // sleep-pending triangle-wave period (ms)
// Estimated per-subsystem currents (mA) — used for JLINK_TLM_POWER telemetry
#define PM_CURRENT_BASE_MA 30 // SPI1(IMU)+UART4(CRSF)+USART1(JLink)+core
#define PM_CURRENT_AUDIO_MA 8 // I2S3 + amplifier quiescent
#define PM_CURRENT_OSD_MA 5 // SPI2 OSD (MAX7456)
#define PM_CURRENT_DEBUG_MA 1 // UART5 + USART6
#define PM_CURRENT_STOP_MA 1 // MCU in STOP mode (< 1 mA)
#define PM_TLM_HZ 1 // JLINK_TLM_POWER transmit rate (Hz)
// --- Audio Amplifier (I2S3, Issue #143) --- // --- Audio Amplifier (I2S3, Issue #143) ---
// SPI3 repurposed as I2S3; blackbox flash unused on balance bot // SPI3 repurposed as I2S3; blackbox flash unused on balance bot
#define AUDIO_BCLK_PORT GPIOC #define AUDIO_BCLK_PORT GPIOC

View File

@ -54,9 +54,11 @@
#define JLINK_CMD_DFU_ENTER 0x06u #define JLINK_CMD_DFU_ENTER 0x06u
#define JLINK_CMD_ESTOP 0x07u #define JLINK_CMD_ESTOP 0x07u
#define JLINK_CMD_AUDIO 0x08u /* PCM audio chunk: int16 samples, up to 126 */ #define JLINK_CMD_AUDIO 0x08u /* PCM audio chunk: int16 samples, up to 126 */
#define JLINK_CMD_SLEEP 0x09u /* no payload; request STOP-mode sleep */
/* ---- Telemetry IDs (STM32 → Jetson) ---- */ /* ---- Telemetry IDs (STM32 → Jetson) ---- */
#define JLINK_TLM_STATUS 0x80u #define JLINK_TLM_STATUS 0x80u
#define JLINK_TLM_POWER 0x81u /* jlink_tlm_power_t (11 bytes) */
/* ---- Telemetry STATUS payload (20 bytes, packed) ---- */ /* ---- Telemetry STATUS payload (20 bytes, packed) ---- */
typedef struct __attribute__((packed)) { typedef struct __attribute__((packed)) {
@ -77,6 +79,15 @@ typedef struct __attribute__((packed)) {
uint8_t fw_patch; uint8_t fw_patch;
} jlink_tlm_status_t; /* 20 bytes */ } jlink_tlm_status_t; /* 20 bytes */
/* ---- Telemetry POWER payload (11 bytes, packed) ---- */
typedef struct __attribute__((packed)) {
uint8_t power_state; /* PowerState: 0=ACTIVE,1=SLEEP_PENDING,2=SLEEPING,3=WAKING */
uint16_t est_total_ma; /* estimated total current draw (mA) */
uint16_t est_audio_ma; /* estimated I2S3+amp current (mA); 0 if gated */
uint16_t est_osd_ma; /* estimated OSD SPI2 current (mA); 0 if gated */
uint32_t idle_ms; /* ms since last cmd_vel activity */
} jlink_tlm_power_t; /* 11 bytes */
/* ---- Volatile state (read from main loop) ---- */ /* ---- Volatile state (read from main loop) ---- */
typedef struct { typedef struct {
/* Drive command — updated on JLINK_CMD_DRIVE */ /* Drive command — updated on JLINK_CMD_DRIVE */
@ -99,6 +110,8 @@ typedef struct {
/* DFU reboot request — set by parser, cleared by main loop */ /* DFU reboot request — set by parser, cleared by main loop */
volatile uint8_t dfu_req; volatile uint8_t dfu_req;
/* Sleep request — set by JLINK_CMD_SLEEP, cleared by main loop */
volatile uint8_t sleep_req;
} JLinkState; } JLinkState;
extern volatile JLinkState jlink_state; extern volatile JLinkState jlink_state;
@ -130,4 +143,10 @@ void jlink_send_telemetry(const jlink_tlm_status_t *status);
*/ */
void jlink_process(void); void jlink_process(void);
/*
* jlink_send_power_telemetry(power) build and transmit a JLINK_TLM_POWER
* frame (17 bytes) at PM_TLM_HZ. Call from main loop when not in STOP mode.
*/
void jlink_send_power_telemetry(const jlink_tlm_power_t *power);
#endif /* JLINK_H */ #endif /* JLINK_H */

96
include/power_mgmt.h Normal file
View File

@ -0,0 +1,96 @@
#ifndef POWER_MGMT_H
#define POWER_MGMT_H
#include <stdint.h>
#include <stdbool.h>
/*
* power_mgmt STM32F7 STOP-mode sleep/wake manager (Issue #178).
*
* State machine:
* PM_ACTIVE (idle PM_IDLE_TIMEOUT_MS or sleep cmd) PM_SLEEP_PENDING
* PM_SLEEP_PENDING (fade complete, PM_FADE_MS) PM_SLEEPING (WFI)
* PM_SLEEPING (EXTI wake) PM_WAKING (clocks restored) PM_ACTIVE
*
* Any call to power_mgmt_activity() during SLEEP_PENDING or SLEEPING
* immediately transitions back toward PM_ACTIVE.
*
* Wake sources (EXTI, falling edge on UART idle-high RX pin or IMU INT):
* EXTI1 PA1 UART4_RX CRSF/ELRS start bit
* EXTI7 PB7 USART1_RX JLink start bit
* EXTI4 PC4 MPU6000 INT IMU motion (handler owned by mpu6000.c)
*
* Peripheral gating on sleep entry (clock disable, state preserved):
* Disabled: SPI3/I2S3 (audio amp), SPI2 (OSD), USART6, UART5 (debug)
* Active: SPI1 (IMU), UART4 (CRSF), USART1 (JLink), I2C1 (baro/mag)
*
* Sleep LED (LED1, active-low PC15):
* PM_SLEEP_PENDING: triangle-wave pulse, period PM_LED_PERIOD_MS
* All other states: 0 (caller uses normal LED logic)
*
* IWDG:
* Fed immediately before WFI. STOP wakeup <10 ms typical well within
* WATCHDOG_TIMEOUT_MS (50 ms).
*
* Safety interlock:
* Caller MUST NOT call power_mgmt_tick() while armed; call
* power_mgmt_activity() instead to keep the idle timer reset.
*
* JLink integration:
* JLINK_CMD_SLEEP (0x09) power_mgmt_request_sleep()
* Any valid JLink frame power_mgmt_activity() (handled in main loop)
*/
typedef enum {
PM_ACTIVE = 0, /* Normal, all peripherals running */
PM_SLEEP_PENDING = 1, /* Idle timeout reached; LED fade-out in progress */
PM_SLEEPING = 2, /* In STOP mode (WFI); execution blocked in tick() */
PM_WAKING = 3, /* Transitional; clocks/peripherals being restored */
} PowerState;
/* ---- API ---- */
/*
* power_mgmt_init() configure wake EXTI lines (EXTI1, EXTI7).
* Call after crsf_init() and jlink_init().
*/
void power_mgmt_init(void);
/*
* power_mgmt_activity() record cmd_vel event (CRSF frame, JLink frame).
* Resets idle timer; aborts any pending/active sleep.
*/
void power_mgmt_activity(void);
/*
* power_mgmt_request_sleep() force sleep regardless of idle timer
* (called on JLINK_CMD_SLEEP). Next tick() enters PM_SLEEP_PENDING.
*/
void power_mgmt_request_sleep(void);
/*
* power_mgmt_tick(now_ms) drive state machine. May block in WFI during
* STOP mode. Returns state after this tick.
* MUST NOT be called while balance_state == BALANCE_ARMED.
*/
PowerState power_mgmt_tick(uint32_t now_ms);
/* power_mgmt_state() — non-blocking read of current state. */
PowerState power_mgmt_state(void);
/*
* power_mgmt_led_brightness() 0-255 brightness for sleep-pending pulse.
* Returns 0 when not in PM_SLEEP_PENDING; caller uses normal LED logic.
*/
uint8_t power_mgmt_led_brightness(void);
/*
* power_mgmt_current_ma() estimated total current draw (mA) based on
* gating state; populated in JLINK_TLM_POWER telemetry.
*/
uint16_t power_mgmt_current_ma(void);
/* power_mgmt_idle_ms() — ms elapsed since last power_mgmt_activity() call. */
uint32_t power_mgmt_idle_ms(void);
#endif /* POWER_MGMT_H */

View File

@ -40,6 +40,11 @@ rosbridge_websocket:
"/person/target", "/person/target",
"/person/detections", "/person/detections",
"/camera/*/image_raw/compressed", "/camera/*/image_raw/compressed",
"/camera/depth/image_rect_raw/compressed",
"/camera/panoramic/compressed",
"/social/faces/detections",
"/social/gestures",
"/social/scene/objects",
"/scan", "/scan",
"/cmd_vel", "/cmd_vel",
"/saltybot/imu", "/saltybot/imu",

View File

@ -94,4 +94,33 @@ def generate_launch_description():
for name in _CAMERAS for name in _CAMERAS
] ]
return LaunchDescription([rosbridge] + republishers) # ── D435i colour republisher (Issue #177) ────────────────────────────────
d435i_color = Node(
package='image_transport',
executable='republish',
name='compress_d435i_color',
arguments=['raw', 'compressed'],
remappings=[
('in', '/camera/color/image_raw'),
('out/compressed', '/camera/color/image_raw/compressed'),
],
parameters=[{'compressed.jpeg_quality': _JPEG_QUALITY}],
output='screen',
)
# ── D435i depth republisher (Issue #177) ─────────────────────────────────
# Depth stream as compressedDepth (PNG16) — preserves uint16 depth values.
# Browser displays as greyscale PNG (darker = closer).
d435i_depth = Node(
package='image_transport',
executable='republish',
name='compress_d435i_depth',
arguments=['raw', 'compressedDepth'],
remappings=[
('in', '/camera/depth/image_rect_raw'),
('out/compressedDepth', '/camera/depth/image_rect_raw/compressed'),
],
output='screen',
)
return LaunchDescription([rosbridge] + republishers + [d435i_color, d435i_depth])

View File

@ -0,0 +1,16 @@
cmake_minimum_required(VERSION 3.8)
project(saltybot_dynamic_obs_msgs)
find_package(ament_cmake REQUIRED)
find_package(rosidl_default_generators REQUIRED)
find_package(std_msgs REQUIRED)
find_package(geometry_msgs REQUIRED)
rosidl_generate_interfaces(${PROJECT_NAME}
"msg/TrackedObject.msg"
"msg/MovingObjectArray.msg"
DEPENDENCIES std_msgs geometry_msgs
)
ament_export_dependencies(rosidl_default_runtime)
ament_package()

View File

@ -0,0 +1,12 @@
# MovingObjectArray — all currently tracked moving obstacles.
#
# Published at ~10 Hz on /saltybot/moving_objects.
# Only confirmed tracks (hits >= confirm_frames) appear here.
std_msgs/Header header
saltybot_dynamic_obs_msgs/TrackedObject[] objects
uint32 active_count # number of confirmed tracks
uint32 tentative_count # tracks not yet confirmed
float32 detector_latency_ms # pipeline latency hint

View File

@ -0,0 +1,21 @@
# TrackedObject — a single tracked moving obstacle.
#
# predicted_path[i] is the estimated pose at predicted_times[i] seconds from now.
# All poses are in the same frame as header.frame_id (typically 'odom').
std_msgs/Header header
uint32 object_id # stable ID across frames (monotonically increasing)
geometry_msgs/PoseWithCovariance pose # current best-estimate pose (x, y, yaw)
geometry_msgs/Vector3 velocity # vx, vy in m/s (vz = 0 for ground objects)
geometry_msgs/Pose[] predicted_path # future positions at predicted_times
float32[] predicted_times # seconds from header.stamp for each pose
float32 speed_mps # scalar |v|
float32 confidence # 0.01.0 (higher after more confirmed frames)
uint32 age_frames # frames since first detection
uint32 hits # number of successful associations
bool is_valid # false if in tentative / just-created state

View File

@ -0,0 +1,23 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_dynamic_obs_msgs</name>
<version>0.1.0</version>
<description>Custom message types for dynamic obstacle tracking.</description>
<maintainer email="robot@saltylab.local">SaltyLab</maintainer>
<license>MIT</license>
<buildtool_depend>ament_cmake</buildtool_depend>
<buildtool_depend>rosidl_default_generators</buildtool_depend>
<depend>std_msgs</depend>
<depend>geometry_msgs</depend>
<exec_depend>rosidl_default_runtime</exec_depend>
<member_of_group>rosidl_interface_packages</member_of_group>
<export>
<build_type>ament_cmake</build_type>
</export>
</package>

View File

@ -0,0 +1,52 @@
# saltybot_dynamic_obstacles — runtime parameters
#
# Requires:
# /scan (sensor_msgs/LaserScan) — RPLIDAR A1M8 at ~5.5 Hz
#
# LIDAR scan is published by rplidar_ros node.
# Make sure RPLIDAR is running before starting this stack.
dynamic_obs_tracker:
ros__parameters:
max_tracks: 20 # max simultaneous tracked objects
confirm_frames: 3 # hits before a track is published
max_missed_frames: 6 # missed frames before track deletion
assoc_dist_m: 1.5 # max assignment distance (Hungarian)
prediction_hz: 10.0 # tracker update + publish rate
horizon_s: 2.5 # prediction look-ahead
pred_step_s: 0.5 # time between predicted waypoints
odom_frame: 'odom'
min_speed_mps: 0.05 # suppress near-stationary tracks
max_range_m: 8.0 # ignore detections beyond this
dynamic_obs_costmap:
ros__parameters:
inflation_radius_m: 0.35 # safety bubble around each predicted point
ring_points: 8 # polygon points for inflation circle
clear_on_empty: true # push empty cloud to clear stale Nav2 markings
# ── Nav2 costmap integration ───────────────────────────────────────────────────
# In your nav2_params.yaml, under local_costmap or global_costmap > plugins, add
# an ObstacleLayer with:
#
# obstacle_layer:
# plugin: "nav2_costmap_2d::ObstacleLayer"
# enabled: true
# observation_sources: static_scan dynamic_obs
# static_scan:
# topic: /scan
# data_type: LaserScan
# ...
# dynamic_obs:
# topic: /saltybot/dynamic_obs_cloud
# data_type: PointCloud2
# sensor_frame: odom
# obstacle_max_range: 10.0
# raytrace_max_range: 10.0
# marking: true
# clearing: false
#
# This feeds the predicted trajectory smear directly into Nav2's obstacle
# inflation, forcing the planner to route around the predicted future path
# of every tracked moving object.

View File

@ -0,0 +1,75 @@
"""
dynamic_obstacles.launch.py Dynamic obstacle tracking + Nav2 costmap layer.
Starts:
dynamic_obs_tracker LIDAR motion detection + Kalman tracking @10 Hz
dynamic_obs_costmap Predicted-trajectory PointCloud2 for Nav2
Launch args:
max_tracks int '20'
assoc_dist_m float '1.5'
horizon_s float '2.5'
inflation_radius_m float '0.35'
Verify:
ros2 topic hz /saltybot/moving_objects # ~10 Hz
ros2 topic echo /saltybot/moving_objects # TrackedObject list
ros2 topic hz /saltybot/dynamic_obs_cloud # ~10 Hz (when objects present)
rviz2 add MarkerArray /saltybot/moving_objects_viz
Nav2 costmap integration:
In your costmap_params.yaml ObstacleLayer observation_sources, add:
dynamic_obs:
topic: /saltybot/dynamic_obs_cloud
data_type: PointCloud2
marking: true
clearing: false
"""
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
args = [
DeclareLaunchArgument('max_tracks', default_value='20'),
DeclareLaunchArgument('assoc_dist_m', default_value='1.5'),
DeclareLaunchArgument('horizon_s', default_value='2.5'),
DeclareLaunchArgument('inflation_radius_m', default_value='0.35'),
DeclareLaunchArgument('min_speed_mps', default_value='0.05'),
]
tracker = Node(
package='saltybot_dynamic_obstacles',
executable='dynamic_obs_tracker',
name='dynamic_obs_tracker',
output='screen',
parameters=[{
'max_tracks': LaunchConfiguration('max_tracks'),
'assoc_dist_m': LaunchConfiguration('assoc_dist_m'),
'horizon_s': LaunchConfiguration('horizon_s'),
'min_speed_mps': LaunchConfiguration('min_speed_mps'),
'prediction_hz': 10.0,
'confirm_frames': 3,
'max_missed_frames': 6,
'pred_step_s': 0.5,
'odom_frame': 'odom',
'max_range_m': 8.0,
}],
)
costmap = Node(
package='saltybot_dynamic_obstacles',
executable='dynamic_obs_costmap',
name='dynamic_obs_costmap',
output='screen',
parameters=[{
'inflation_radius_m': LaunchConfiguration('inflation_radius_m'),
'ring_points': 8,
'clear_on_empty': True,
}],
)
return LaunchDescription(args + [tracker, costmap])

View File

@ -0,0 +1,29 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_dynamic_obstacles</name>
<version>0.1.0</version>
<description>
Dynamic obstacle detection, multi-object Kalman tracking, trajectory
prediction, and Nav2 costmap layer integration for SaltyBot.
</description>
<maintainer email="robot@saltylab.local">SaltyLab</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>std_msgs</depend>
<depend>sensor_msgs</depend>
<depend>geometry_msgs</depend>
<depend>nav_msgs</depend>
<depend>visualization_msgs</depend>
<depend>saltybot_dynamic_obs_msgs</depend>
<exec_depend>python3-numpy</exec_depend>
<exec_depend>python3-scipy</exec_depend>
<test_depend>pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,178 @@
"""
costmap_layer_node.py Nav2 costmap integration for dynamic obstacles.
Converts predicted trajectories from /saltybot/moving_objects into a
PointCloud2 fed into Nav2's ObstacleLayer. Each predicted future position
is added as a point, creating a "smeared" dynamic obstacle zone that
covers the full 2-3 s prediction horizon.
Nav2 ObstacleLayer config (in costmap_params.yaml):
obstacle_layer:
enabled: true
observation_sources: dynamic_obs
dynamic_obs:
topic: /saltybot/dynamic_obs_cloud
sensor_frame: odom
data_type: PointCloud2
obstacle_max_range: 12.0
obstacle_min_range: 0.0
raytrace_max_range: 12.0
marking: true
clearing: false # let the tracker handle clearing
The node also clears old obstacle points when tracks are dropped, by
publishing a clearing cloud to /saltybot/dynamic_obs_clear.
Subscribes:
/saltybot/moving_objects saltybot_dynamic_obs_msgs/MovingObjectArray
Publishes:
/saltybot/dynamic_obs_cloud sensor_msgs/PointCloud2 marking cloud
/saltybot/dynamic_obs_clear sensor_msgs/PointCloud2 clearing cloud
Parameters:
inflation_radius_m float 0.35 (each predicted point inflated by this radius)
ring_points int 8 (polygon approximation of inflation circle)
clear_on_empty bool true (publish clear cloud when no objects tracked)
"""
from __future__ import annotations
import math
import struct
from typing import List
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from sensor_msgs.msg import PointCloud2, PointField
from std_msgs.msg import Header
try:
from saltybot_dynamic_obs_msgs.msg import MovingObjectArray
_MSGS_AVAILABLE = True
except ImportError:
_MSGS_AVAILABLE = False
_RELIABLE_QOS = QoSProfile(
reliability=ReliabilityPolicy.RELIABLE,
history=HistoryPolicy.KEEP_LAST,
depth=10,
)
def _make_pc2(header: Header, points_xyz: List[tuple]) -> PointCloud2:
"""Pack a list of (x, y, z) into a PointCloud2 message."""
fields = [
PointField(name='x', offset=0, datatype=PointField.FLOAT32, count=1),
PointField(name='y', offset=4, datatype=PointField.FLOAT32, count=1),
PointField(name='z', offset=8, datatype=PointField.FLOAT32, count=1),
]
point_step = 12 # 3 × float32
data = bytearray(len(points_xyz) * point_step)
for i, (x, y, z) in enumerate(points_xyz):
struct.pack_into('<fff', data, i * point_step, x, y, z)
pc = PointCloud2()
pc.header = header
pc.height = 1
pc.width = len(points_xyz)
pc.fields = fields
pc.is_bigendian = False
pc.point_step = point_step
pc.row_step = point_step * len(points_xyz)
pc.data = bytes(data)
pc.is_dense = True
return pc
class CostmapLayerNode(Node):
def __init__(self):
super().__init__('dynamic_obs_costmap')
self.declare_parameter('inflation_radius_m', 0.35)
self.declare_parameter('ring_points', 8)
self.declare_parameter('clear_on_empty', True)
self._infl_r = self.get_parameter('inflation_radius_m').value
self._ring_n = self.get_parameter('ring_points').value
self._clear_empty = self.get_parameter('clear_on_empty').value
# Pre-compute ring offsets for inflation
self._ring_offsets = [
(self._infl_r * math.cos(2 * math.pi * i / self._ring_n),
self._infl_r * math.sin(2 * math.pi * i / self._ring_n))
for i in range(self._ring_n)
]
if _MSGS_AVAILABLE:
self.create_subscription(
MovingObjectArray,
'/saltybot/moving_objects',
self._on_objects,
_RELIABLE_QOS,
)
else:
self.get_logger().warning(
'[costmap_layer] saltybot_dynamic_obs_msgs not built — '
'will not subscribe to MovingObjectArray'
)
self._mark_pub = self.create_publisher(
PointCloud2, '/saltybot/dynamic_obs_cloud', 10
)
self._clear_pub = self.create_publisher(
PointCloud2, '/saltybot/dynamic_obs_clear', 10
)
self.get_logger().info(
f'dynamic_obs_costmap ready — '
f'inflation={self._infl_r}m ring_pts={self._ring_n}'
)
# ── Callback ──────────────────────────────────────────────────────────────
def _on_objects(self, msg: 'MovingObjectArray') -> None:
hdr = msg.header
mark_pts: List[tuple] = []
for obj in msg.objects:
if not obj.is_valid:
continue
# Current position
self._add_inflated(obj.pose.pose.position.x,
obj.pose.pose.position.y, mark_pts)
# Predicted future positions
for pose in obj.predicted_path:
self._add_inflated(pose.position.x, pose.position.y, mark_pts)
if mark_pts:
self._mark_pub.publish(_make_pc2(hdr, mark_pts))
elif self._clear_empty:
# Publish tiny clear cloud so Nav2 clears stale markings
self._clear_pub.publish(_make_pc2(hdr, []))
def _add_inflated(self, cx: float, cy: float, pts: List[tuple]) -> None:
"""Add the centre + ring of inflation points at height 0.5 m."""
pts.append((cx, cy, 0.5))
for ox, oy in self._ring_offsets:
pts.append((cx + ox, cy + oy, 0.5))
def main(args=None):
rclpy.init(args=args)
node = CostmapLayerNode()
try:
rclpy.spin(node)
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,319 @@
"""
dynamic_obs_node.py ROS2 node: LIDAR moving-object detection + Kalman tracking.
Pipeline:
1. Subscribe /scan (RPLIDAR LaserScan, ~5.5 Hz).
2. ObjectDetector performs background subtraction moving blobs.
3. TrackerManager runs Hungarian assignment + Kalman predict/update at 10 Hz.
4. Publish /saltybot/moving_objects (MovingObjectArray).
5. Publish /saltybot/moving_objects_viz (MarkerArray) for RViz.
The 10 Hz timer drives the tracker regardless of scan rate, so prediction
continues between scans (pure-predict steps).
Subscribes:
/scan sensor_msgs/LaserScan RPLIDAR A1M8
Publishes:
/saltybot/moving_objects saltybot_dynamic_obs_msgs/MovingObjectArray
/saltybot/moving_objects_viz visualization_msgs/MarkerArray
Parameters:
max_tracks int 20
confirm_frames int 3
max_missed_frames int 6
assoc_dist_m float 1.5
prediction_hz float 10.0 (tracker + publish rate)
horizon_s float 2.5
pred_step_s float 0.5
odom_frame str 'odom'
min_speed_mps float 0.05 (suppress near-zero velocity tracks)
max_range_m float 8.0
"""
import time
import math
import numpy as np
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from sensor_msgs.msg import LaserScan
from geometry_msgs.msg import Pose, Point, Quaternion, Vector3
from std_msgs.msg import Header, ColorRGBA
from visualization_msgs.msg import Marker, MarkerArray
try:
from saltybot_dynamic_obs_msgs.msg import TrackedObject, MovingObjectArray
_MSGS_AVAILABLE = True
except ImportError:
_MSGS_AVAILABLE = False
from .object_detector import ObjectDetector
from .tracker_manager import TrackerManager, Track
_SENSOR_QOS = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=5,
)
def _yaw_quat(yaw: float) -> Quaternion:
q = Quaternion()
q.w = math.cos(yaw * 0.5)
q.z = math.sin(yaw * 0.5)
return q
class DynamicObsNode(Node):
def __init__(self):
super().__init__('dynamic_obs_tracker')
# ── Parameters ──────────────────────────────────────────────────────
self.declare_parameter('max_tracks', 20)
self.declare_parameter('confirm_frames', 3)
self.declare_parameter('max_missed_frames', 6)
self.declare_parameter('assoc_dist_m', 1.5)
self.declare_parameter('prediction_hz', 10.0)
self.declare_parameter('horizon_s', 2.5)
self.declare_parameter('pred_step_s', 0.5)
self.declare_parameter('odom_frame', 'odom')
self.declare_parameter('min_speed_mps', 0.05)
self.declare_parameter('max_range_m', 8.0)
max_tracks = self.get_parameter('max_tracks').value
confirm_f = self.get_parameter('confirm_frames').value
max_missed = self.get_parameter('max_missed_frames').value
assoc_dist = self.get_parameter('assoc_dist_m').value
pred_hz = self.get_parameter('prediction_hz').value
horizon_s = self.get_parameter('horizon_s').value
pred_step = self.get_parameter('pred_step_s').value
self._frame = self.get_parameter('odom_frame').value
self._min_spd = self.get_parameter('min_speed_mps').value
self._max_rng = self.get_parameter('max_range_m').value
# ── Core modules ────────────────────────────────────────────────────
self._detector = ObjectDetector(
grid_radius_m=min(self._max_rng + 2.0, 12.0),
max_cluster=int((self._max_rng / 0.1) ** 2 * 0.5),
)
self._tracker = TrackerManager(
max_tracks=max_tracks,
confirm_frames=confirm_f,
max_missed=max_missed,
assoc_dist_m=assoc_dist,
horizon_s=horizon_s,
pred_step_s=pred_step,
)
self._horizon_s = horizon_s
self._pred_step = pred_step
# ── State ────────────────────────────────────────────────────────────
self._latest_scan: LaserScan | None = None
self._last_track_t: float = time.monotonic()
self._scan_processed_stamp: float | None = None
# ── Subscriptions ────────────────────────────────────────────────────
self.create_subscription(LaserScan, '/scan', self._on_scan, _SENSOR_QOS)
# ── Publishers ───────────────────────────────────────────────────────
if _MSGS_AVAILABLE:
self._obj_pub = self.create_publisher(
MovingObjectArray, '/saltybot/moving_objects', 10
)
else:
self._obj_pub = None
self.get_logger().warning(
'[dyn_obs] saltybot_dynamic_obs_msgs not built — '
'MovingObjectArray will not be published'
)
self._viz_pub = self.create_publisher(
MarkerArray, '/saltybot/moving_objects_viz', 10
)
# ── Timer ────────────────────────────────────────────────────────────
self.create_timer(1.0 / pred_hz, self._track_tick)
self.get_logger().info(
f'dynamic_obs_tracker ready — '
f'max_tracks={max_tracks} horizon={horizon_s}s assoc={assoc_dist}m'
)
# ── Scan callback ─────────────────────────────────────────────────────────
def _on_scan(self, msg: LaserScan) -> None:
self._latest_scan = msg
# ── 10 Hz tracker tick ────────────────────────────────────────────────────
def _track_tick(self) -> None:
t0 = time.monotonic()
now_mono = t0
dt = now_mono - self._last_track_t
dt = max(1e-3, min(dt, 0.5))
self._last_track_t = now_mono
scan = self._latest_scan
detections = []
if scan is not None:
stamp_sec = scan.header.stamp.sec + scan.header.stamp.nanosec * 1e-9
if stamp_sec != self._scan_processed_stamp:
self._scan_processed_stamp = stamp_sec
ranges = np.asarray(scan.ranges, dtype=np.float32)
detections = self._detector.update(
ranges,
scan.angle_min,
scan.angle_increment,
min(scan.range_max, self._max_rng),
)
confirmed = self._tracker.update(detections, dt)
latency_ms = (time.monotonic() - t0) * 1000.0
stamp = self.get_clock().now().to_msg()
if _MSGS_AVAILABLE and self._obj_pub is not None:
self._publish_objects(confirmed, stamp, latency_ms)
self._publish_viz(confirmed, stamp)
# ── Publish helpers ───────────────────────────────────────────────────────
def _publish_objects(self, confirmed: list, stamp, latency_ms: float) -> None:
arr = MovingObjectArray()
arr.header.stamp = stamp
arr.header.frame_id = self._frame
arr.active_count = len(confirmed)
arr.tentative_count = self._tracker.tentative_count
arr.detector_latency_ms = float(latency_ms)
for tr in confirmed:
px, py = tr.kalman.position
vx, vy = tr.kalman.velocity
speed = tr.kalman.speed
if speed < self._min_spd:
continue
obj = TrackedObject()
obj.header = arr.header
obj.object_id = tr.track_id
obj.pose.pose.position.x = px
obj.pose.pose.position.y = py
obj.pose.pose.orientation = _yaw_quat(math.atan2(vy, vx))
cov = tr.kalman.covariance_2x2
obj.pose.covariance[0] = float(cov[0, 0])
obj.pose.covariance[1] = float(cov[0, 1])
obj.pose.covariance[6] = float(cov[1, 0])
obj.pose.covariance[7] = float(cov[1, 1])
obj.velocity.x = vx
obj.velocity.y = vy
obj.speed_mps = speed
obj.confidence = min(1.0, tr.hits / (self._tracker._confirm_frames * 3))
obj.age_frames = tr.age
obj.hits = tr.hits
obj.is_valid = True
# Predicted path
for px_f, py_f, t_f in tr.kalman.predict_horizon(
self._horizon_s, self._pred_step
):
p = Pose()
p.position.x = px_f
p.position.y = py_f
p.orientation.w = 1.0
obj.predicted_path.append(p)
obj.predicted_times.append(float(t_f))
arr.objects.append(obj)
self._obj_pub.publish(arr)
def _publish_viz(self, confirmed: list, stamp) -> None:
markers = MarkerArray()
# Delete old markers
del_marker = Marker()
del_marker.header.stamp = stamp
del_marker.header.frame_id = self._frame
del_marker.action = Marker.DELETEALL
markers.markers.append(del_marker)
for tr in confirmed:
px, py = tr.kalman.position
vx, vy = tr.kalman.velocity
speed = tr.kalman.speed
if speed < self._min_spd:
continue
# Cylinder at current position
m = Marker()
m.header.stamp = stamp
m.header.frame_id = self._frame
m.ns = 'dyn_obs'
m.id = tr.track_id
m.type = Marker.CYLINDER
m.action = Marker.ADD
m.pose.position.x = px
m.pose.position.y = py
m.pose.position.z = 0.5
m.pose.orientation.w = 1.0
m.scale.x = 0.4
m.scale.y = 0.4
m.scale.z = 1.0
m.color = ColorRGBA(r=1.0, g=0.2, b=0.0, a=0.7)
m.lifetime.sec = 1
markers.markers.append(m)
# Arrow for velocity
vel_m = Marker()
vel_m.header = m.header
vel_m.ns = 'dyn_obs_vel'
vel_m.id = tr.track_id
vel_m.type = Marker.ARROW
vel_m.action = Marker.ADD
from geometry_msgs.msg import Point as GPoint
p_start = GPoint(x=px, y=py, z=1.0)
p_end = GPoint(x=px + vx, y=py + vy, z=1.0)
vel_m.points = [p_start, p_end]
vel_m.scale.x = 0.05
vel_m.scale.y = 0.10
vel_m.color = ColorRGBA(r=1.0, g=1.0, b=0.0, a=0.9)
vel_m.lifetime.sec = 1
markers.markers.append(vel_m)
# Line strip for predicted path
path_m = Marker()
path_m.header = m.header
path_m.ns = 'dyn_obs_path'
path_m.id = tr.track_id
path_m.type = Marker.LINE_STRIP
path_m.action = Marker.ADD
path_m.scale.x = 0.04
path_m.color = ColorRGBA(r=1.0, g=0.5, b=0.0, a=0.5)
path_m.lifetime.sec = 1
path_m.points.append(GPoint(x=px, y=py, z=0.5))
for fx, fy, _ in tr.kalman.predict_horizon(self._horizon_s, self._pred_step):
path_m.points.append(GPoint(x=fx, y=fy, z=0.5))
markers.markers.append(path_m)
self._viz_pub.publish(markers)
def main(args=None):
rclpy.init(args=args)
node = DynamicObsNode()
try:
rclpy.spin(node)
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,132 @@
"""
kalman_tracker.py Single-object Kalman filter for 2-D ground-plane tracking.
State vector: x = [px, py, vx, vy] (position + velocity)
Motion model: constant-velocity (CV) with process noise on acceleration
Predict step:
F = | 1 0 dt 0 | x_{k|k-1} = F @ x_{k-1|k-1}
| 0 1 0 dt | P_{k|k-1} = F @ P @ F^T + Q
| 0 0 1 0 |
| 0 0 0 1 |
Update step (position observation only):
H = | 1 0 0 0 | y = z - H @ x
| 0 1 0 0 | S = H @ P @ H^T + R
K = P @ H^T @ inv(S)
x = x + K @ y
P = (I - K @ H) @ P (Joseph form for stability)
Trajectory prediction: unrolls the CV model forward at fixed time steps.
"""
from __future__ import annotations
from typing import List, Tuple
import numpy as np
# ── Default noise matrices ────────────────────────────────────────────────────
# Process noise: models uncertainty in acceleration between frames
_Q_BASE = np.diag([0.02, 0.02, 0.8, 0.8]).astype(np.float64)
# Measurement noise: LIDAR centroid uncertainty (~0.15 m std)
_R = np.diag([0.025, 0.025]).astype(np.float64) # 0.16 m sigma each axis
# Observation matrix
_H = np.array([[1, 0, 0, 0],
[0, 1, 0, 0]], dtype=np.float64)
_I4 = np.eye(4, dtype=np.float64)
class KalmanTracker:
"""
Kalman filter tracking one object.
Parameters
----------
x0, y0 : initial position (metres, odom frame)
process_noise : scalar multiplier on _Q_BASE
"""
def __init__(self, x0: float, y0: float, process_noise: float = 1.0):
self._x = np.array([x0, y0, 0.0, 0.0], dtype=np.float64)
self._P = np.eye(4, dtype=np.float64) * 0.5
self._Q = _Q_BASE * process_noise
# ── Core filter ───────────────────────────────────────────────────────────
def predict(self, dt: float) -> None:
"""Propagate state by dt seconds."""
F = np.array([
[1, 0, dt, 0],
[0, 1, 0, dt],
[0, 0, 1, 0],
[0, 0, 0, 1],
], dtype=np.float64)
self._x = F @ self._x
self._P = F @ self._P @ F.T + self._Q
def update(self, z: np.ndarray) -> None:
"""
Incorporate a position measurement z = [x, y].
Uses Joseph-form covariance update for numerical stability.
"""
y = z.astype(np.float64) - _H @ self._x
S = _H @ self._P @ _H.T + _R
K = self._P @ _H.T @ np.linalg.inv(S)
self._x = self._x + K @ y
IKH = _I4 - K @ _H
# Joseph form: (I-KH) P (I-KH)^T + K R K^T
self._P = IKH @ self._P @ IKH.T + K @ _R @ K.T
# ── Prediction horizon ────────────────────────────────────────────────────
def predict_horizon(
self,
horizon_s: float = 2.5,
step_s: float = 0.5,
) -> List[Tuple[float, float, float]]:
"""
Return [(x, y, t), ...] at equally-spaced future times.
Does NOT modify internal filter state.
"""
predictions: List[Tuple[float, float, float]] = []
state = self._x.copy()
t = 0.0
F_step = np.array([
[1, 0, step_s, 0],
[0, 1, 0, step_s],
[0, 0, 1, 0],
[0, 0, 0, 1],
], dtype=np.float64)
while t < horizon_s - 1e-6:
state = F_step @ state
t += step_s
predictions.append((float(state[0]), float(state[1]), t))
return predictions
# ── Properties ────────────────────────────────────────────────────────────
@property
def position(self) -> Tuple[float, float]:
return float(self._x[0]), float(self._x[1])
@property
def velocity(self) -> Tuple[float, float]:
return float(self._x[2]), float(self._x[3])
@property
def speed(self) -> float:
return float(np.hypot(self._x[2], self._x[3]))
@property
def covariance_2x2(self) -> np.ndarray:
"""Position covariance (top-left 2×2 of P)."""
return self._P[:2, :2].copy()
@property
def state(self) -> np.ndarray:
return self._x.copy()

View File

@ -0,0 +1,168 @@
"""
object_detector.py LIDAR-based moving object detector.
Algorithm:
1. Convert each LaserScan to a 2-D occupancy grid (robot-centred, fixed size).
2. Maintain a background model via exponential moving average (EMA):
bg_t = α * current + (1-α) * bg_{t-1} (only for non-moving cells)
3. Foreground = cells whose occupancy significantly exceeds the background.
4. Cluster foreground cells with scipy connected-components Detection list.
The grid is robot-relative (origin at robot centre) so it naturally tracks
the robot's motion without needing TF at this stage. The caller is responsible
for transforming detections into a stable frame (odom) before passing to the
tracker.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List, Optional
import numpy as np
from scipy import ndimage
@dataclass
class Detection:
"""A clustered moving foreground blob from one scan."""
x: float # centroid x in sensor frame (m)
y: float # centroid y in sensor frame (m)
size_m2: float # approximate area of the cluster (m²)
range_m: float # distance from robot (m)
class ObjectDetector:
"""
Detects moving objects in consecutive 2-D LIDAR scans.
Parameters
----------
grid_radius_m : half-size of the occupancy grid (grid covers ±radius)
resolution : metres per cell
bg_alpha : EMA update rate for background (small = slow forgetting)
motion_thr : occupancy delta above background to count as moving
min_cluster : minimum cells to keep a cluster
max_cluster : maximum cells before a cluster is considered static wall
"""
def __init__(
self,
grid_radius_m: float = 10.0,
resolution: float = 0.10,
bg_alpha: float = 0.04,
motion_thr: float = 0.45,
min_cluster: int = 3,
max_cluster: int = 200,
):
cells = int(2 * grid_radius_m / resolution)
self._cells = cells
self._res = resolution
self._origin = -grid_radius_m # world x/y at grid index 0
self._bg_alpha = bg_alpha
self._motion_thr = motion_thr
self._min_clust = min_cluster
self._max_clust = max_cluster
self._bg = np.zeros((cells, cells), dtype=np.float32)
self._initialized = False
# ── Public API ────────────────────────────────────────────────────────────
def update(
self,
ranges: np.ndarray,
angle_min: float,
angle_increment: float,
range_max: float,
) -> List[Detection]:
"""
Process one LaserScan and return detected moving blobs.
Parameters
----------
ranges : 1-D array of range readings (metres)
angle_min : angle of first beam (radians)
angle_increment : angular step between beams (radians)
range_max : maximum valid range (metres)
"""
curr = self._scan_to_grid(ranges, angle_min, angle_increment, range_max)
if not self._initialized:
self._bg = curr.copy()
self._initialized = True
return []
# Foreground mask
motion_mask = (curr - self._bg) > self._motion_thr
# Update background only on non-moving cells
static = ~motion_mask
self._bg[static] = (
self._bg[static] * (1.0 - self._bg_alpha)
+ curr[static] * self._bg_alpha
)
return self._cluster(motion_mask)
def reset(self) -> None:
self._bg[:] = 0.0
self._initialized = False
# ── Internals ─────────────────────────────────────────────────────────────
def _scan_to_grid(
self,
ranges: np.ndarray,
angle_min: float,
angle_increment: float,
range_max: float,
) -> np.ndarray:
grid = np.zeros((self._cells, self._cells), dtype=np.float32)
n = len(ranges)
angles = angle_min + np.arange(n) * angle_increment
r = np.asarray(ranges, dtype=np.float32)
valid = (r > 0.05) & (r < range_max) & np.isfinite(r)
r, a = r[valid], angles[valid]
x = r * np.cos(a)
y = r * np.sin(a)
ix = np.clip(
((x - self._origin) / self._res).astype(np.int32), 0, self._cells - 1
)
iy = np.clip(
((y - self._origin) / self._res).astype(np.int32), 0, self._cells - 1
)
grid[iy, ix] = 1.0
return grid
def _cluster(self, mask: np.ndarray) -> List[Detection]:
# Dilate slightly to connect nearby hit cells into one blob
struct = ndimage.generate_binary_structure(2, 2)
dilated = ndimage.binary_dilation(mask, structure=struct, iterations=1)
labeled, n_labels = ndimage.label(dilated)
detections: List[Detection] = []
for label_id in range(1, n_labels + 1):
coords = np.argwhere(labeled == label_id)
n_cells = len(coords)
if n_cells < self._min_clust or n_cells > self._max_clust:
continue
ys, xs = coords[:, 0], coords[:, 1]
cx_grid = float(np.mean(xs))
cy_grid = float(np.mean(ys))
cx = cx_grid * self._res + self._origin
cy = cy_grid * self._res + self._origin
detections.append(Detection(
x=cx,
y=cy,
size_m2=n_cells * self._res ** 2,
range_m=float(np.hypot(cx, cy)),
))
return detections

View File

@ -0,0 +1,206 @@
"""
tracker_manager.py Multi-object tracker with Hungarian data association.
Track lifecycle:
TENTATIVE confirmed after `confirm_frames` consecutive hits
CONFIRMED normal tracked state
LOST missed for 1..max_missed frames (still predicts, not published)
DEAD missed > max_missed removed
Association:
Uses scipy.optimize.linear_sum_assignment (Hungarian algorithm) on a cost
matrix of Euclidean distances between predicted track positions and new
detections. Assignments with cost > assoc_dist_m are rejected.
Up to `max_tracks` simultaneous live tracks (tentative + confirmed).
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Dict, List, Optional, Tuple
import numpy as np
from scipy.optimize import linear_sum_assignment
from .kalman_tracker import KalmanTracker
from .object_detector import Detection
class TrackState(IntEnum):
TENTATIVE = 0
CONFIRMED = 1
LOST = 2
@dataclass
class Track:
track_id: int
kalman: KalmanTracker
state: TrackState = TrackState.TENTATIVE
hits: int = 1
age: int = 1 # frames since creation
missed: int = 0 # consecutive missed frames
class TrackerManager:
"""
Manages a pool of Kalman tracks.
Parameters
----------
max_tracks : hard cap on simultaneously alive tracks
confirm_frames : hits needed before a track is CONFIRMED
max_missed : consecutive missed frames before a track is DEAD
assoc_dist_m : max allowed distance (m) for a valid assignment
horizon_s : prediction horizon for trajectory output (seconds)
pred_step_s : time step between predicted waypoints
process_noise : KalmanTracker process-noise multiplier
"""
def __init__(
self,
max_tracks: int = 20,
confirm_frames: int = 3,
max_missed: int = 6,
assoc_dist_m: float = 1.5,
horizon_s: float = 2.5,
pred_step_s: float = 0.5,
process_noise: float = 1.0,
):
self._max_tracks = max_tracks
self._confirm_frames = confirm_frames
self._max_missed = max_missed
self._assoc_dist = assoc_dist_m
self._horizon_s = horizon_s
self._pred_step = pred_step_s
self._proc_noise = process_noise
self._tracks: Dict[int, Track] = {}
self._next_id: int = 1
# ── Public API ────────────────────────────────────────────────────────────
def update(self, detections: List[Detection], dt: float) -> List[Track]:
"""
Process one frame of detections.
1. Predict all tracks by dt.
2. Hungarian assignment of predictions detections.
3. Update matched tracks; mark unmatched tracks as LOST.
4. Promote tracks crossing `confirm_frames`.
5. Create new tracks for unmatched detections (if room).
6. Remove DEAD tracks.
Returns confirmed tracks only.
"""
# 1. Predict
for tr in self._tracks.values():
tr.kalman.predict(dt)
tr.age += 1
# 2. Assign
matched, unmatched_tracks, unmatched_dets = self._assign(detections)
# 3a. Update matched
for tid, did in matched:
tr = self._tracks[tid]
det = detections[did]
tr.kalman.update(np.array([det.x, det.y]))
tr.hits += 1
tr.missed = 0
if tr.state == TrackState.LOST:
tr.state = TrackState.CONFIRMED
elif tr.state == TrackState.TENTATIVE and tr.hits >= self._confirm_frames:
tr.state = TrackState.CONFIRMED
# 3b. Unmatched tracks
for tid in unmatched_tracks:
tr = self._tracks[tid]
tr.missed += 1
if tr.missed > 1:
tr.state = TrackState.LOST
# 4. New tracks for unmatched detections
live = sum(1 for t in self._tracks.values() if t.state != TrackState.LOST
or t.missed <= self._max_missed)
for did in unmatched_dets:
if live >= self._max_tracks:
break
det = detections[did]
init_state = (TrackState.CONFIRMED
if self._confirm_frames <= 1
else TrackState.TENTATIVE)
new_tr = Track(
track_id=self._next_id,
kalman=KalmanTracker(det.x, det.y, self._proc_noise),
state=init_state,
)
self._tracks[self._next_id] = new_tr
self._next_id += 1
live += 1
# 5. Prune dead
dead = [tid for tid, t in self._tracks.items() if t.missed > self._max_missed]
for tid in dead:
del self._tracks[tid]
return [t for t in self._tracks.values() if t.state == TrackState.CONFIRMED]
@property
def all_tracks(self) -> List[Track]:
return list(self._tracks.values())
@property
def tentative_count(self) -> int:
return sum(1 for t in self._tracks.values()
if t.state == TrackState.TENTATIVE)
def reset(self) -> None:
self._tracks.clear()
self._next_id = 1
# ── Hungarian assignment ──────────────────────────────────────────────────
def _assign(
self,
detections: List[Detection],
) -> Tuple[List[Tuple[int, int]], List[int], List[int]]:
"""
Returns:
matched list of (track_id, det_index)
unmatched_tids track IDs with no detection assigned
unmatched_dids detection indices with no track assigned
"""
track_ids = list(self._tracks.keys())
if not track_ids or not detections:
return [], track_ids, list(range(len(detections)))
# Build cost matrix: rows=tracks, cols=detections
cost = np.full((len(track_ids), len(detections)), fill_value=np.inf)
for r, tid in enumerate(track_ids):
tx, ty = self._tracks[tid].kalman.position
for c, det in enumerate(detections):
cost[r, c] = np.hypot(tx - det.x, ty - det.y)
row_ind, col_ind = linear_sum_assignment(cost)
matched: List[Tuple[int, int]] = []
matched_track_idx: set = set()
matched_det_idx: set = set()
for r, c in zip(row_ind, col_ind):
if cost[r, c] > self._assoc_dist:
continue
matched.append((track_ids[r], c))
matched_track_idx.add(r)
matched_det_idx.add(c)
unmatched_tids = [track_ids[r] for r in range(len(track_ids))
if r not in matched_track_idx]
unmatched_dids = [c for c in range(len(detections))
if c not in matched_det_idx]
return matched, unmatched_tids, unmatched_dids

View File

@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/saltybot_dynamic_obstacles
[install]
install_scripts=$base/lib/saltybot_dynamic_obstacles

View File

@ -0,0 +1,32 @@
from setuptools import setup, find_packages
from glob import glob
package_name = 'saltybot_dynamic_obstacles'
setup(
name=package_name,
version='0.1.0',
packages=find_packages(exclude=['test']),
data_files=[
('share/ament_index/resource_index/packages',
['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
('share/' + package_name + '/launch',
glob('launch/*.launch.py')),
('share/' + package_name + '/config',
glob('config/*.yaml')),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='SaltyLab',
maintainer_email='robot@saltylab.local',
description='Dynamic obstacle tracking: LIDAR motion detection, Kalman tracking, Nav2 costmap',
license='MIT',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'dynamic_obs_tracker = saltybot_dynamic_obstacles.dynamic_obs_node:main',
'dynamic_obs_costmap = saltybot_dynamic_obstacles.costmap_layer_node:main',
],
},
)

View File

@ -0,0 +1,262 @@
"""
test_dynamic_obstacles.py Unit tests for KalmanTracker, TrackerManager,
and ObjectDetector.
Runs without ROS2 / hardware (no rclpy imports).
"""
from __future__ import annotations
import math
import sys
import os
import numpy as np
import pytest
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from saltybot_dynamic_obstacles.kalman_tracker import KalmanTracker
from saltybot_dynamic_obstacles.tracker_manager import TrackerManager, TrackState
from saltybot_dynamic_obstacles.object_detector import ObjectDetector, Detection
# ── KalmanTracker ─────────────────────────────────────────────────────────────
class TestKalmanTracker:
def test_initial_position(self):
kt = KalmanTracker(3.0, 4.0)
px, py = kt.position
assert px == pytest.approx(3.0)
assert py == pytest.approx(4.0)
def test_initial_velocity_zero(self):
kt = KalmanTracker(0.0, 0.0)
vx, vy = kt.velocity
assert vx == pytest.approx(0.0)
assert vy == pytest.approx(0.0)
def test_predict_moves_position(self):
kt = KalmanTracker(0.0, 0.0)
# Give it some velocity via update sequence
kt.update(np.array([0.1, 0.0]))
kt.update(np.array([0.2, 0.0]))
kt.predict(0.1)
px, _ = kt.position
assert px > 0.0 # should have moved forward
def test_pure_predict_constant_velocity(self):
"""After velocity is established, predict() should move linearly."""
kt = KalmanTracker(0.0, 0.0)
# Force velocity by repeated updates
for i in range(10):
kt.update(np.array([i * 0.1, 0.0]))
kt.predict(0.1)
vx, _ = kt.velocity
px0, _ = kt.position
kt.predict(1.0)
px1, _ = kt.position
# Should advance roughly vx * 1.0 metres
assert px1 == pytest.approx(px0 + vx * 1.0, abs=0.3)
def test_update_corrects_position(self):
kt = KalmanTracker(0.0, 0.0)
# Predict way off
kt.predict(10.0)
# Then update to ground truth
kt.update(np.array([1.0, 2.0]))
px, py = kt.position
# Should move toward (1, 2)
assert px == pytest.approx(1.0, abs=0.5)
assert py == pytest.approx(2.0, abs=0.5)
def test_predict_horizon_length(self):
kt = KalmanTracker(0.0, 0.0)
preds = kt.predict_horizon(horizon_s=2.5, step_s=0.5)
assert len(preds) == 5 # 0.5, 1.0, 1.5, 2.0, 2.5
def test_predict_horizon_times(self):
kt = KalmanTracker(0.0, 0.0)
preds = kt.predict_horizon(horizon_s=2.0, step_s=0.5)
times = [t for _, _, t in preds]
assert times == pytest.approx([0.5, 1.0, 1.5, 2.0], abs=0.01)
def test_predict_horizon_does_not_mutate_state(self):
kt = KalmanTracker(1.0, 2.0)
kt.predict_horizon(horizon_s=3.0, step_s=0.5)
px, py = kt.position
assert px == pytest.approx(1.0)
assert py == pytest.approx(2.0)
def test_speed_zero_at_init(self):
kt = KalmanTracker(5.0, 5.0)
assert kt.speed == pytest.approx(0.0)
def test_covariance_shape(self):
kt = KalmanTracker(0.0, 0.0)
cov = kt.covariance_2x2
assert cov.shape == (2, 2)
def test_covariance_positive_definite(self):
kt = KalmanTracker(0.0, 0.0)
for _ in range(5):
kt.predict(0.1)
kt.update(np.array([0.1, 0.0]))
eigvals = np.linalg.eigvalsh(kt.covariance_2x2)
assert np.all(eigvals > 0)
def test_joseph_form_stays_symmetric(self):
"""Covariance should remain symmetric after many updates."""
kt = KalmanTracker(0.0, 0.0)
for i in range(50):
kt.predict(0.1)
kt.update(np.array([i * 0.01, 0.0]))
P = kt._P
assert np.allclose(P, P.T, atol=1e-10)
# ── TrackerManager ────────────────────────────────────────────────────────────
class TestTrackerManager:
def _det(self, x, y):
return Detection(x=x, y=y, size_m2=0.1, range_m=math.hypot(x, y))
def test_empty_detections_no_tracks(self):
tm = TrackerManager()
confirmed = tm.update([], 0.1)
assert confirmed == []
def test_track_created_on_detection(self):
tm = TrackerManager(confirm_frames=1)
confirmed = tm.update([self._det(1.0, 0.0)], 0.1)
assert len(confirmed) == 1
def test_track_tentative_before_confirm(self):
tm = TrackerManager(confirm_frames=3)
tm.update([self._det(1.0, 0.0)], 0.1)
# Only 1 hit — should still be tentative
assert tm.tentative_count == 1
def test_track_confirmed_after_N_hits(self):
tm = TrackerManager(confirm_frames=3, assoc_dist_m=2.0)
for _ in range(4):
confirmed = tm.update([self._det(1.0, 0.0)], 0.1)
assert len(confirmed) == 1
def test_track_deleted_after_max_missed(self):
tm = TrackerManager(confirm_frames=1, max_missed=3, assoc_dist_m=2.0)
tm.update([self._det(1.0, 0.0)], 0.1) # create
for _ in range(5):
tm.update([], 0.1) # no detections → missed++
assert len(tm.all_tracks) == 0
def test_max_tracks_cap(self):
tm = TrackerManager(max_tracks=5, confirm_frames=1)
dets = [self._det(float(i), 0.0) for i in range(10)]
tm.update(dets, 0.1)
assert len(tm.all_tracks) <= 5
def test_consistent_track_id(self):
tm = TrackerManager(confirm_frames=3, assoc_dist_m=1.5)
for i in range(5):
confirmed = tm.update([self._det(1.0 + i * 0.01, 0.0)], 0.1)
assert len(confirmed) == 1
track_id = confirmed[0].track_id
# One more tick — ID should be stable
confirmed2 = tm.update([self._det(1.06, 0.0)], 0.1)
assert confirmed2[0].track_id == track_id
def test_two_independent_tracks(self):
tm = TrackerManager(confirm_frames=3, assoc_dist_m=0.8)
for _ in range(5):
confirmed = tm.update([self._det(1.0, 0.0), self._det(5.0, 0.0)], 0.1)
assert len(confirmed) == 2
def test_reset_clears_all(self):
tm = TrackerManager(confirm_frames=1)
tm.update([self._det(1.0, 0.0)], 0.1)
tm.reset()
assert len(tm.all_tracks) == 0
def test_far_detection_not_assigned(self):
tm = TrackerManager(confirm_frames=1, assoc_dist_m=0.5)
tm.update([self._det(1.0, 0.0)], 0.1) # create track at (1,0)
# Detection 3 m away → new track, not update
tm.update([self._det(4.0, 0.0)], 0.1)
assert len(tm.all_tracks) == 2
# ── ObjectDetector ────────────────────────────────────────────────────────────
class TestObjectDetector:
def _empty_scan(self, n=360, rmax=8.0) -> tuple:
"""All readings at max range (static background)."""
ranges = np.full(n, rmax - 0.1, dtype=np.float32)
return ranges, -math.pi, 2 * math.pi / n, rmax
def _scan_with_blob(self, blob_r=2.0, blob_theta=0.0, n=360, rmax=8.0) -> tuple:
"""Background scan + a short-range cluster at blob_theta."""
ranges = np.full(n, rmax - 0.1, dtype=np.float32)
angle_inc = 2 * math.pi / n
angle_min = -math.pi
# Put a cluster of ~10 beams at blob_r
center_idx = int((blob_theta - angle_min) / angle_inc) % n
for di in range(-5, 6):
idx = (center_idx + di) % n
ranges[idx] = blob_r
return ranges, angle_min, angle_inc, rmax
def test_empty_scan_no_detections_after_warmup(self):
od = ObjectDetector()
r, a_min, a_inc, rmax = self._empty_scan()
od.update(r, a_min, a_inc, rmax) # init background
for _ in range(3):
dets = od.update(r, a_min, a_inc, rmax)
assert len(dets) == 0
def test_moving_blob_detected(self):
od = ObjectDetector()
bg_r, a_min, a_inc, rmax = self._empty_scan()
od.update(bg_r, a_min, a_inc, rmax) # seed background
for _ in range(5):
od.update(bg_r, a_min, a_inc, rmax)
# Now inject a foreground blob
fg_r, _, _, _ = self._scan_with_blob(blob_r=2.0, blob_theta=0.0)
dets = od.update(fg_r, a_min, a_inc, rmax)
assert len(dets) >= 1
def test_detection_centroid_approximate(self):
od = ObjectDetector()
bg_r, a_min, a_inc, rmax = self._empty_scan()
for _ in range(8):
od.update(bg_r, a_min, a_inc, rmax)
fg_r, _, _, _ = self._scan_with_blob(blob_r=3.0, blob_theta=0.0)
dets = od.update(fg_r, a_min, a_inc, rmax)
assert len(dets) >= 1
# Blob is at ~3 m along x-axis (theta=0)
cx = dets[0].x
cy = dets[0].y
assert abs(cx - 3.0) < 0.5
assert abs(cy) < 0.5
def test_reset_clears_background(self):
od = ObjectDetector()
bg_r, a_min, a_inc, rmax = self._empty_scan()
for _ in range(5):
od.update(bg_r, a_min, a_inc, rmax)
od.reset()
assert not od._initialized
def test_no_inf_nan_ranges(self):
od = ObjectDetector()
r = np.array([np.inf, np.nan, 5.0, -1.0, 0.0] * 72, dtype=np.float32)
a_min = -math.pi
a_inc = 2 * math.pi / len(r)
od.update(r, a_min, a_inc, 8.0) # should not raise
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@ -0,0 +1,43 @@
/**:
ros__parameters:
# Control loop rate (Hz)
control_rate: 20.0
# Odometry topic for stuck detection
odom_topic: "/saltybot/rover_odom"
# ── LaserScan forward sector ───────────────────────────────────────────────
forward_scan_angle_rad: 0.785 # ±45° forward sector
# ── Obstacle proximity ────────────────────────────────────────────────────
stop_distance_m: 0.30 # MAJOR threshold (spec: <30 cm)
critical_distance_m: 0.10 # CRITICAL threshold
min_cmd_speed_ms: 0.05 # ignore obstacle when nearly stopped
# ── Fall detection (IMU tilt) ─────────────────────────────────────────────
minor_tilt_rad: 0.20 # advisory
major_tilt_rad: 0.35 # stop + recover
critical_tilt_rad: 0.52 # ~30° — full shutdown
floor_drop_m: 0.15 # depth discontinuity triggering MAJOR
# ── Stuck detection ───────────────────────────────────────────────────────
stuck_timeout_s: 3.0 # (spec: 3 s wheel stall)
# ── Bump / jerk detection ─────────────────────────────────────────────────
jerk_threshold_ms3: 8.0
critical_jerk_threshold_ms3: 25.0
# ── FSM / recovery ────────────────────────────────────────────────────────
stopped_ms: 0.03 # speed below which robot is "stopped" (m/s)
major_count_threshold: 3 # MAJOR alerts before escalation to CRITICAL
escalation_window_s: 10.0 # sliding window for escalation counter (s)
suppression_s: 1.0 # de-bounce period for duplicate alerts (s)
# Recovery sequence
reverse_speed_ms: -0.15 # back-up speed (m/s; must be negative)
reverse_distance_m: 0.30 # distance to reverse each cycle (m)
angular_speed_rads: 0.60 # turn speed (rad/s)
turn_angle_rad: 1.5708 # ~90° turn (rad)
retry_timeout_s: 3.0 # time in RETRYING per attempt (s)
clear_hold_s: 0.50 # consecutive clear time to declare success (s)
max_retries: 3 # maximum reverse+turn attempts before GAVE_UP

View File

@ -0,0 +1,53 @@
"""
emergency.launch.py Launch the emergency behavior system (Issue #169).
Usage
-----
ros2 launch saltybot_emergency emergency.launch.py
ros2 launch saltybot_emergency emergency.launch.py \
stop_distance_m:=0.30 max_retries:=3
"""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
pkg_share = get_package_share_directory("saltybot_emergency")
default_params = os.path.join(pkg_share, "config", "emergency_params.yaml")
return LaunchDescription([
DeclareLaunchArgument(
"params_file",
default_value=default_params,
description="Path to emergency_params.yaml",
),
DeclareLaunchArgument(
"stop_distance_m",
default_value="0.30",
description="Obstacle distance triggering MAJOR stop (m)",
),
DeclareLaunchArgument(
"max_retries",
default_value="3",
description="Maximum recovery cycles before ESCALATED",
),
Node(
package="saltybot_emergency",
executable="emergency_node",
name="emergency",
output="screen",
parameters=[
LaunchConfiguration("params_file"),
{
"stop_distance_m": LaunchConfiguration("stop_distance_m"),
"max_retries": LaunchConfiguration("max_retries"),
},
],
),
])

View File

@ -0,0 +1,24 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_emergency</name>
<version>0.1.0</version>
<description>Emergency behavior system — collision avoidance, fall prevention, stuck detection, recovery (Issue #169)</description>
<maintainer email="sl-controls@saltylab.local">sl-controls</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>sensor_msgs</depend>
<depend>nav_msgs</depend>
<depend>geometry_msgs</depend>
<depend>std_msgs</depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,139 @@
"""
alert_manager.py Alert severity escalation for emergency behavior (Issue #169).
Alert levels
NONE : no action
MINOR : advisory beep publish to /saltybot/alert_beep
MAJOR : stop + LED flash publish to /saltybot/alert_flash; cmd_vel override
CRITICAL : full shutdown + MQTT publish to /saltybot/e_stop + /saltybot/critical_alert
Escalation
If major_count_threshold MAJOR alerts occur within escalation_window_s, the
next MAJOR is promoted to CRITICAL. This catches persistent stuck / repeated
collision scenarios.
Suppression
Identical (type, level) alerts are suppressed within suppression_s to avoid
flooding downstream topics.
Pure module no ROS2 dependency.
"""
from __future__ import annotations
from collections import deque
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from saltybot_emergency.threat_detector import ThreatEvent, ThreatLevel, ThreatType
# ── Alert level ───────────────────────────────────────────────────────────────
class AlertLevel(Enum):
NONE = 0
MINOR = 1 # beep
MAJOR = 2 # stop + flash
CRITICAL = 3 # shutdown + MQTT
# ── Alert ─────────────────────────────────────────────────────────────────────
@dataclass
class Alert:
level: AlertLevel
source: str # ThreatType value string
message: str
timestamp_s: float
# ── AlertManager ─────────────────────────────────────────────────────────────
class AlertManager:
"""
Converts ThreatEvents to Alerts with escalation and suppression logic.
Parameters
----------
major_count_threshold : number of MAJOR alerts within window to escalate
escalation_window_s : sliding window for escalation counting (s)
suppression_s : suppress duplicate (type, level) alerts within this period
"""
def __init__(
self,
major_count_threshold: int = 3,
escalation_window_s: float = 10.0,
suppression_s: float = 1.0,
):
self._major_threshold = max(1, int(major_count_threshold))
self._esc_window = float(escalation_window_s)
self._suppress = float(suppression_s)
self._major_times: deque = deque() # timestamps of MAJOR alerts
self._last_seen: dict = {} # (type, level) → timestamp
# ── Update ────────────────────────────────────────────────────────────────
def update(self, threat: ThreatEvent) -> Optional[Alert]:
"""
Convert one ThreatEvent to an Alert, applying escalation and suppression.
Returns None if threat is CLEAR or the alert is suppressed.
"""
if threat.level == ThreatLevel.CLEAR:
return None
now = threat.timestamp_s
alert_level = _threat_to_alert(threat.level)
# ── Suppression ───────────────────────────────────────────────────────
key = (threat.threat_type, threat.level)
last = self._last_seen.get(key)
if last is not None and (now - last) < self._suppress:
return None
self._last_seen[key] = now
# ── Escalation ────────────────────────────────────────────────────────
if alert_level == AlertLevel.MAJOR:
# Prune old timestamps
while self._major_times and (now - self._major_times[0]) > self._esc_window:
self._major_times.popleft()
self._major_times.append(now)
if len(self._major_times) >= self._major_threshold:
alert_level = AlertLevel.CRITICAL
msg = _build_message(alert_level, threat)
return Alert(
level=alert_level,
source=threat.threat_type.value,
message=msg,
timestamp_s=now,
)
def reset(self) -> None:
"""Clear escalation history and suppression state."""
self._major_times.clear()
self._last_seen.clear()
# ── Helpers ───────────────────────────────────────────────────────────────────
def _threat_to_alert(level: ThreatLevel) -> AlertLevel:
return {
ThreatLevel.MINOR: AlertLevel.MINOR,
ThreatLevel.MAJOR: AlertLevel.MAJOR,
ThreatLevel.CRITICAL: AlertLevel.CRITICAL,
}.get(level, AlertLevel.NONE)
def _build_message(level: AlertLevel, threat: ThreatEvent) -> str:
prefix = {
AlertLevel.MINOR: "[MINOR]",
AlertLevel.MAJOR: "[MAJOR]",
AlertLevel.CRITICAL: "[CRITICAL]",
}.get(level, "[?]")
return f"{prefix} {threat.threat_type.value}: {threat.detail}"

View File

@ -0,0 +1,232 @@
"""
emergency_fsm.py Master emergency FSM integrating all detectors (Issue #169).
States
NOMINAL : normal operation; minor alerts pass through; major/critical STOPPING.
STOPPING : commanding zero velocity until robot speed drops below stopped_ms.
Critical threats skip RECOVERING ESCALATED immediately.
RECOVERING : RecoverySequencer executing reverse+turn sequence.
Success NOMINAL; gave-up ESCALATED.
ESCALATED : full stop; critical alert emitted once; stays until acknowledge.
Alert actions produced by state
NOMINAL : emit MINOR alert (beep only); no velocity override.
STOPPING : suppress nav, publish zero; emit MAJOR alert once.
RECOVERING : suppress nav, publish recovery cmds; no new alerts.
ESCALATED : suppress nav, publish zero; emit CRITICAL alert once per entry.
Pure module no ROS2 dependency.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional
from saltybot_emergency.alert_manager import Alert, AlertLevel, AlertManager
from saltybot_emergency.recovery_sequencer import RecoveryInputs, RecoverySequencer
from saltybot_emergency.threat_detector import ThreatEvent, ThreatLevel
# ── States ────────────────────────────────────────────────────────────────────
class EmergencyState(Enum):
NOMINAL = "NOMINAL"
STOPPING = "STOPPING"
RECOVERING = "RECOVERING"
ESCALATED = "ESCALATED"
# ── I/O ───────────────────────────────────────────────────────────────────────
@dataclass
class EmergencyInputs:
threat: ThreatEvent # highest-severity threat this tick
robot_speed_ms: float = 0.0 # actual speed from odometry (m/s)
acknowledge: bool = False # operator cleared the escalation
@dataclass
class EmergencyOutputs:
state: EmergencyState = EmergencyState.NOMINAL
cmd_override: bool = False # True = emergency owns cmd_vel
cmd_linear: float = 0.0
cmd_angular: float = 0.0
alert: Optional[Alert] = None
e_stop: bool = False # assert /saltybot/e_stop
state_changed: bool = False
recovery_progress: float = 0.0
recovery_retry_count: int = 0
# ── EmergencyFSM ──────────────────────────────────────────────────────────────
class EmergencyFSM:
"""
Master emergency FSM.
Owns an AlertManager and a RecoverySequencer; coordinates them each tick.
Parameters
----------
stopped_ms : speed below which robot is considered stopped (m/s)
major_count_threshold : MAJOR events within window before escalation
escalation_window_s : sliding window for escalation (s)
suppression_s : alert de-bounce period (s)
reverse_speed_ms : reverse speed during recovery (m/s)
reverse_distance_m : reverse travel per cycle (m)
angular_speed_rads : turn speed during recovery (rad/s)
turn_angle_rad : turn per cycle (rad)
retry_timeout_s : time in RETRYING before next cycle (s)
clear_hold_s : clear duration required to declare success (s)
max_retries : recovery cycles before GAVE_UP
"""
def __init__(
self,
stopped_ms: float = 0.03,
major_count_threshold: int = 3,
escalation_window_s: float = 10.0,
suppression_s: float = 1.0,
reverse_speed_ms: float = -0.15,
reverse_distance_m: float = 0.30,
angular_speed_rads: float = 0.60,
turn_angle_rad: float = 1.5708,
retry_timeout_s: float = 3.0,
clear_hold_s: float = 0.5,
max_retries: int = 3,
):
self._stopped_ms = max(0.0, stopped_ms)
self._alert_mgr = AlertManager(
major_count_threshold=major_count_threshold,
escalation_window_s=escalation_window_s,
suppression_s=suppression_s,
)
self._recovery = RecoverySequencer(
reverse_speed_ms=reverse_speed_ms,
reverse_distance_m=reverse_distance_m,
angular_speed_rads=angular_speed_rads,
turn_angle_rad=turn_angle_rad,
retry_timeout_s=retry_timeout_s,
clear_hold_s=clear_hold_s,
max_retries=max_retries,
)
self._state = EmergencyState.NOMINAL
self._critical_pending = False # STOPPING → ESCALATED (not RECOVERING)
self._escalation_alerted = False # CRITICAL alert emitted once per ESCALATED entry
# ── Public API ────────────────────────────────────────────────────────────
@property
def state(self) -> EmergencyState:
return self._state
def reset(self) -> None:
self._state = EmergencyState.NOMINAL
self._critical_pending = False
self._escalation_alerted = False
self._alert_mgr.reset()
self._recovery.reset()
def tick(self, inputs: EmergencyInputs) -> EmergencyOutputs:
prev = self._state
out = self._step(inputs)
out.state = self._state
out.state_changed = (self._state != prev)
return out
# ── Step ─────────────────────────────────────────────────────────────────
def _step(self, inp: EmergencyInputs) -> EmergencyOutputs:
out = EmergencyOutputs(state=self._state)
# Run alert manager for this threat
alert = self._alert_mgr.update(inp.threat)
# ── NOMINAL ───────────────────────────────────────────────────────────
if self._state == EmergencyState.NOMINAL:
if inp.threat.level == ThreatLevel.CRITICAL:
self._state = EmergencyState.STOPPING
self._critical_pending = True
out.alert = alert
out.cmd_override = True # start overriding on entry tick
elif inp.threat.level == ThreatLevel.MAJOR:
self._state = EmergencyState.STOPPING
self._critical_pending = False
out.alert = alert
out.cmd_override = True # start overriding on entry tick
elif inp.threat.level == ThreatLevel.MINOR:
# Advisory only — no override
out.alert = alert
# ── STOPPING ──────────────────────────────────────────────────────────
elif self._state == EmergencyState.STOPPING:
out.cmd_override = True
out.cmd_linear = 0.0
out.cmd_angular = 0.0
# Upgrade to critical if new critical arrives
if inp.threat.level == ThreatLevel.CRITICAL:
self._critical_pending = True
if abs(inp.robot_speed_ms) <= self._stopped_ms:
if self._critical_pending:
self._state = EmergencyState.ESCALATED
self._escalation_alerted = False
else:
self._state = EmergencyState.RECOVERING
self._recovery.reset()
self._recovery.tick(RecoveryInputs(trigger=True, dt=0.0))
# ── RECOVERING ────────────────────────────────────────────────────────
elif self._state == EmergencyState.RECOVERING:
threat_cleared = inp.threat.level == ThreatLevel.CLEAR
rec_inp = RecoveryInputs(
trigger=False,
threat_cleared=threat_cleared,
dt=0.02, # nominal dt; node should pass actual dt
)
rec_out = self._recovery.tick(rec_inp)
out.cmd_override = True
out.cmd_linear = rec_out.cmd_linear
out.cmd_angular = rec_out.cmd_angular
out.recovery_progress = rec_out.progress
out.recovery_retry_count = rec_out.retry_count
if rec_out.gave_up:
self._state = EmergencyState.ESCALATED
self._escalation_alerted = False
elif rec_out.state.value == "IDLE" and not inp.trigger if hasattr(inp, "trigger") else True:
# RecoverySequencer returned to IDLE = success
from saltybot_emergency.recovery_sequencer import RecoveryState
if self._recovery.state == RecoveryState.IDLE and not rec_out.gave_up:
self._state = EmergencyState.NOMINAL
self._recovery.reset()
# ── ESCALATED ─────────────────────────────────────────────────────────
elif self._state == EmergencyState.ESCALATED:
out.cmd_override = True
out.cmd_linear = 0.0
out.cmd_angular = 0.0
out.e_stop = True
if not self._escalation_alerted:
# Force a CRITICAL alert regardless of suppression
from saltybot_emergency.alert_manager import Alert, AlertLevel
out.alert = Alert(
level=AlertLevel.CRITICAL,
source=inp.threat.threat_type.value,
message=f"[CRITICAL] ESCALATED: {inp.threat.detail or 'Recovery gave up'}",
timestamp_s=inp.threat.timestamp_s,
)
self._escalation_alerted = True
if inp.acknowledge:
self._state = EmergencyState.NOMINAL
self._critical_pending = False
self._escalation_alerted = False
out.e_stop = False
self._alert_mgr.reset()
self._recovery.reset()
return out

View File

@ -0,0 +1,383 @@
"""
emergency_node.py Emergency behavior system orchestration (Issue #169).
Overview
Aggregates threats from four independent detectors and drives the
EmergencyFSM. Overrides /cmd_vel when an emergency is active. Escalates
via /saltybot/e_stop and /saltybot/critical_alert for CRITICAL events.
Pipeline (20 Hz)
1. LaserScan callback ObstacleDetector ThreatEvent
2. IMU callback FallDetector + BumpDetector ThreatEvent (×2)
3. Odom callback StuckDetector (fed in timer) ThreatEvent
4. 20 Hz timer highest_threat() EmergencyFSM.tick()
publish overriding cmd_vel or pass-through
publish /saltybot/emergency + /saltybot/recovery_action
Subscribes
/scan sensor_msgs/LaserScan obstacle detection
/saltybot/imu sensor_msgs/Imu fall + bump detection
<odom_topic> nav_msgs/Odometry stuck + speed tracking
/cmd_vel geometry_msgs/Twist nav commands (pass-through)
Publishes
/saltybot/cmd_vel_out geometry_msgs/Twist muxed cmd_vel (to drive nodes)
/saltybot/e_stop std_msgs/Bool emergency stop flag
/saltybot/alert_beep std_msgs/Empty beep trigger (MINOR)
/saltybot/alert_flash std_msgs/Empty LED flash trigger (MAJOR)
/saltybot/critical_alert std_msgs/String (JSON) CRITICAL event for MQTT bridge
/saltybot/emergency saltybot_emergency_msgs/EmergencyEvent
/saltybot/recovery_action saltybot_emergency_msgs/RecoveryAction
Parameters
control_rate 20.0 Hz
odom_topic /saltybot/rover_odom
forward_scan_angle_rad 0.785 rad (±45° forward sector for obstacle check)
stop_distance_m 0.30 m
critical_distance_m 0.10 m
min_cmd_speed_ms 0.05 m/s
minor_tilt_rad 0.20 rad
major_tilt_rad 0.35 rad
critical_tilt_rad 0.52 rad
floor_drop_m 0.15 m
stuck_timeout_s 3.0 s
jerk_threshold_ms3 8.0 m/
critical_jerk_threshold_ms3 25.0 m/
stopped_ms 0.03 m/s
major_count_threshold 3
escalation_window_s 10.0 s
suppression_s 1.0 s
reverse_speed_ms -0.15 m/s
reverse_distance_m 0.30 m
angular_speed_rads 0.60 rad/s
turn_angle_rad 1.5708 rad
retry_timeout_s 3.0 s
clear_hold_s 0.50 s
max_retries 3
"""
import json
import math
import time
import rclpy
from rclpy.node import Node
from rclpy.qos import HistoryPolicy, QoSProfile, ReliabilityPolicy
from geometry_msgs.msg import Twist
from nav_msgs.msg import Odometry
from sensor_msgs.msg import Imu, LaserScan
from std_msgs.msg import Bool, Empty, String
from saltybot_emergency.alert_manager import AlertLevel
from saltybot_emergency.emergency_fsm import EmergencyFSM, EmergencyInputs, EmergencyState
from saltybot_emergency.threat_detector import (
BumpDetector,
FallDetector,
ObstacleDetector,
StuckDetector,
ThreatEvent,
ThreatType,
highest_threat,
)
try:
from saltybot_emergency_msgs.msg import EmergencyEvent, RecoveryAction
_MSGS_OK = True
except ImportError:
_MSGS_OK = False
def _quaternion_to_pitch_roll(qx, qy, qz, qw):
pitch = math.asin(max(-1.0, min(1.0, 2.0 * (qw * qy - qz * qx))))
roll = math.atan2(2.0 * (qw * qx + qy * qz), 1.0 - 2.0 * (qx * qx + qy * qy))
return pitch, roll
class EmergencyNode(Node):
def __init__(self):
super().__init__("emergency")
self._declare_params()
p = self._load_params()
# ── Detectors ────────────────────────────────────────────────────────
self._obstacle = ObstacleDetector(
stop_distance_m=p["stop_distance_m"],
critical_distance_m=p["critical_distance_m"],
min_speed_ms=p["min_cmd_speed_ms"],
)
self._fall = FallDetector(
minor_tilt_rad=p["minor_tilt_rad"],
major_tilt_rad=p["major_tilt_rad"],
critical_tilt_rad=p["critical_tilt_rad"],
floor_drop_m=p["floor_drop_m"],
)
self._stuck = StuckDetector(
stuck_timeout_s=p["stuck_timeout_s"],
min_cmd_ms=p["min_cmd_speed_ms"],
)
self._bump = BumpDetector(
jerk_threshold_ms3=p["jerk_threshold_ms3"],
critical_jerk_threshold_ms3=p["critical_jerk_threshold_ms3"],
)
self._fsm = EmergencyFSM(
stopped_ms=p["stopped_ms"],
major_count_threshold=p["major_count_threshold"],
escalation_window_s=p["escalation_window_s"],
suppression_s=p["suppression_s"],
reverse_speed_ms=p["reverse_speed_ms"],
reverse_distance_m=p["reverse_distance_m"],
angular_speed_rads=p["angular_speed_rads"],
turn_angle_rad=p["turn_angle_rad"],
retry_timeout_s=p["retry_timeout_s"],
clear_hold_s=p["clear_hold_s"],
max_retries=p["max_retries"],
)
# ── State ────────────────────────────────────────────────────────────
self._latest_obstacle_threat = ThreatEvent()
self._latest_fall_threat = ThreatEvent()
self._latest_bump_threat = ThreatEvent()
self._cmd_speed_ms = 0.0
self._actual_speed_ms = 0.0
self._last_ctrl_t = time.monotonic()
self._scan_forward_angle = p["forward_scan_angle_rad"]
self._acknowledge_flag = False
# ── QoS ──────────────────────────────────────────────────────────────
reliable = QoSProfile(
reliability=ReliabilityPolicy.RELIABLE,
history=HistoryPolicy.KEEP_LAST,
depth=10,
)
best_effort = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=1,
)
# ── Subscriptions ────────────────────────────────────────────────────
self.create_subscription(LaserScan, "/scan", self._scan_cb, best_effort)
self.create_subscription(Imu, "/saltybot/imu", self._imu_cb, best_effort)
self.create_subscription(Odometry, p["odom_topic"], self._odom_cb, reliable)
self.create_subscription(Twist, "/cmd_vel", self._cmd_vel_cb, reliable)
self.create_subscription(Bool, "/saltybot/emergency_ack", self._ack_cb, reliable)
# ── Publishers ───────────────────────────────────────────────────────
self._cmd_out_pub = self.create_publisher(Twist, "/saltybot/cmd_vel_out", reliable)
self._estop_pub = self.create_publisher(Bool, "/saltybot/e_stop", reliable)
self._beep_pub = self.create_publisher(Empty, "/saltybot/alert_beep", reliable)
self._flash_pub = self.create_publisher(Empty, "/saltybot/alert_flash", reliable)
self._critical_pub = self.create_publisher(String, "/saltybot/critical_alert", reliable)
self._event_pub = None
self._recovery_pub = None
if _MSGS_OK:
self._event_pub = self.create_publisher(EmergencyEvent, "/saltybot/emergency", reliable)
self._recovery_pub = self.create_publisher(RecoveryAction, "/saltybot/recovery_action", reliable)
# ── Timer ────────────────────────────────────────────────────────────
rate = p["control_rate"]
self._timer = self.create_timer(1.0 / rate, self._control_cb)
self.get_logger().info(f"EmergencyNode ready rate={rate}Hz")
# ── Parameters ────────────────────────────────────────────────────────────
def _declare_params(self) -> None:
self.declare_parameter("control_rate", 20.0)
self.declare_parameter("odom_topic", "/saltybot/rover_odom")
self.declare_parameter("forward_scan_angle_rad", 0.785)
self.declare_parameter("stop_distance_m", 0.30)
self.declare_parameter("critical_distance_m", 0.10)
self.declare_parameter("min_cmd_speed_ms", 0.05)
self.declare_parameter("minor_tilt_rad", 0.20)
self.declare_parameter("major_tilt_rad", 0.35)
self.declare_parameter("critical_tilt_rad", 0.52)
self.declare_parameter("floor_drop_m", 0.15)
self.declare_parameter("stuck_timeout_s", 3.0)
self.declare_parameter("jerk_threshold_ms3", 8.0)
self.declare_parameter("critical_jerk_threshold_ms3", 25.0)
self.declare_parameter("stopped_ms", 0.03)
self.declare_parameter("major_count_threshold", 3)
self.declare_parameter("escalation_window_s", 10.0)
self.declare_parameter("suppression_s", 1.0)
self.declare_parameter("reverse_speed_ms", -0.15)
self.declare_parameter("reverse_distance_m", 0.30)
self.declare_parameter("angular_speed_rads", 0.60)
self.declare_parameter("turn_angle_rad", 1.5708)
self.declare_parameter("retry_timeout_s", 3.0)
self.declare_parameter("clear_hold_s", 0.50)
self.declare_parameter("max_retries", 3)
def _load_params(self) -> dict:
g = self.get_parameter
return {k: g(k).value for k in [
"control_rate", "odom_topic",
"forward_scan_angle_rad",
"stop_distance_m", "critical_distance_m", "min_cmd_speed_ms",
"minor_tilt_rad", "major_tilt_rad", "critical_tilt_rad", "floor_drop_m",
"stuck_timeout_s", "jerk_threshold_ms3", "critical_jerk_threshold_ms3",
"stopped_ms",
"major_count_threshold", "escalation_window_s", "suppression_s",
"reverse_speed_ms", "reverse_distance_m",
"angular_speed_rads", "turn_angle_rad",
"retry_timeout_s", "clear_hold_s", "max_retries",
]}
# ── Callbacks ─────────────────────────────────────────────────────────────
def _scan_cb(self, msg: LaserScan) -> None:
# Extract minimum range within forward sector (±forward_scan_angle_rad)
half = self._scan_forward_angle
ranges = []
for i, r in enumerate(msg.ranges):
angle = msg.angle_min + i * msg.angle_increment
if abs(angle) <= half and msg.range_min < r < msg.range_max:
ranges.append(r)
min_r = min(ranges) if ranges else float("inf")
self._latest_obstacle_threat = self._obstacle.update(
min_r, self._cmd_speed_ms, time.monotonic()
)
def _imu_cb(self, msg: Imu) -> None:
now = time.monotonic()
ax = msg.linear_acceleration.x
ay = msg.linear_acceleration.y
az = msg.linear_acceleration.z
pitch, roll = _quaternion_to_pitch_roll(
msg.orientation.x, msg.orientation.y,
msg.orientation.z, msg.orientation.w,
)
# dt for jerk is approximated; bump detector handles None on first call
dt = 0.02 # nominal 20 Hz
self._latest_fall_threat = self._fall.update(pitch, roll, 0.0, now)
self._latest_bump_threat = self._bump.update(ax, ay, az, dt, now)
def _odom_cb(self, msg: Odometry) -> None:
self._actual_speed_ms = msg.twist.twist.linear.x
def _cmd_vel_cb(self, msg: Twist) -> None:
self._cmd_speed_ms = msg.linear.x
def _ack_cb(self, msg: Bool) -> None:
if msg.data:
self._acknowledge_flag = True
# ── 20 Hz control loop ────────────────────────────────────────────────────
def _control_cb(self) -> None:
now = time.monotonic()
dt = now - self._last_ctrl_t
self._last_ctrl_t = now
stuck_threat = self._stuck.update(
self._cmd_speed_ms, self._actual_speed_ms, dt, now
)
threat = highest_threat([
self._latest_obstacle_threat,
self._latest_fall_threat,
self._latest_bump_threat,
stuck_threat,
])
inp = EmergencyInputs(
threat=threat,
robot_speed_ms=self._actual_speed_ms,
acknowledge=self._acknowledge_flag,
)
self._acknowledge_flag = False
out = self._fsm.tick(inp)
if out.state_changed:
self.get_logger().info(f"Emergency FSM → {out.state.value}")
# ── Alert dispatch ────────────────────────────────────────────────────
if out.alert is not None:
lvl = out.alert.level
self.get_logger().warn(out.alert.message)
if lvl == AlertLevel.MINOR:
self._beep_pub.publish(Empty())
elif lvl == AlertLevel.MAJOR:
self._flash_pub.publish(Empty())
elif lvl == AlertLevel.CRITICAL:
self._flash_pub.publish(Empty())
self._publish_critical_alert(out.alert.message, threat)
# ── E-stop ───────────────────────────────────────────────────────────
estop_msg = Bool()
estop_msg.data = out.e_stop
self._estop_pub.publish(estop_msg)
# ── cmd_vel mux ───────────────────────────────────────────────────────
twist = Twist()
if out.cmd_override:
twist.linear.x = out.cmd_linear
twist.angular.z = out.cmd_angular
else:
twist.linear.x = self._cmd_speed_ms
self._cmd_out_pub.publish(twist)
# ── Status topics ─────────────────────────────────────────────────────
if self._event_pub is not None:
self._publish_event(out, threat)
if self._recovery_pub is not None:
self._publish_recovery(out)
# ── Publishers ────────────────────────────────────────────────────────────
def _publish_critical_alert(self, message: str, threat: ThreatEvent) -> None:
msg = String()
msg.data = json.dumps({
"severity": "CRITICAL",
"threat": threat.threat_type.value,
"value": round(threat.value, 3),
"detail": threat.detail,
"message": message,
})
self._critical_pub.publish(msg)
def _publish_event(self, out, threat: ThreatEvent) -> None:
msg = EmergencyEvent()
msg.stamp = self.get_clock().now().to_msg()
msg.state = out.state.value
msg.threat_type = threat.threat_type.value
msg.severity = threat.level.name
msg.threat_value = float(threat.value)
msg.detail = threat.detail
msg.cmd_override = out.cmd_override
self._event_pub.publish(msg)
def _publish_recovery(self, out) -> None:
msg = RecoveryAction()
msg.stamp = self.get_clock().now().to_msg()
msg.action = self._fsm._recovery.state.value
msg.retry_count = out.recovery_retry_count
msg.progress = float(out.recovery_progress)
self._recovery_pub.publish(msg)
# ── Entry point ───────────────────────────────────────────────────────────────
def main(args=None):
rclpy.init(args=args)
node = EmergencyNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.try_shutdown()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,193 @@
"""
recovery_sequencer.py Reverse + turn recovery FSM for emergency behavior (Issue #169).
Sequence
IDLE REVERSING TURNING RETRYING (IDLE on success)
(REVERSING on re-threat, retry loop)
(GAVE_UP after max_retries)
REVERSING : command reverse at reverse_speed_ms until reverse_distance_m covered.
TURNING : command angular_speed_rads until turn_angle_rad covered (90°).
RETRYING : zero velocity; wait up to retry_timeout_s for threat to clear.
If threat clears within clear_hold_s back to IDLE (success).
If timeout without clearance start another REVERSING cycle.
If retry_count >= max_retries GAVE_UP.
Pure module no ROS2 dependency.
"""
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
# ── States ────────────────────────────────────────────────────────────────────
class RecoveryState(Enum):
IDLE = "IDLE"
REVERSING = "REVERSING"
TURNING = "TURNING"
RETRYING = "RETRYING"
GAVE_UP = "GAVE_UP"
# ── I/O ───────────────────────────────────────────────────────────────────────
@dataclass
class RecoveryInputs:
trigger: bool = False # True to start (or restart) recovery
threat_cleared: bool = False # True when all threats are CLEAR
dt: float = 0.02 # time step (s)
@dataclass
class RecoveryOutputs:
state: RecoveryState = RecoveryState.IDLE
cmd_linear: float = 0.0 # m/s
cmd_angular: float = 0.0 # rad/s
progress: float = 0.0 # [0, 1] completion of current phase
retry_count: int = 0
gave_up: bool = False
state_changed: bool = False
# ── RecoverySequencer ────────────────────────────────────────────────────────
class RecoverySequencer:
"""
Tick-based FSM for executing reverse + turn recovery sequences.
Parameters
----------
reverse_speed_ms : backward speed during REVERSING (m/s; stored as negative)
reverse_distance_m: total reverse travel before turning (m)
angular_speed_rads: yaw rate during TURNING (rad/s; positive = left)
turn_angle_rad : total turn angle before RETRYING (rad; default π/2)
retry_timeout_s : max RETRYING time per attempt before next reverse cycle
clear_hold_s : consecutive clear time needed to declare success
max_retries : maximum reverse+turn attempts before GAVE_UP
"""
def __init__(
self,
reverse_speed_ms: float = -0.15,
reverse_distance_m: float = 0.30,
angular_speed_rads: float = 0.60,
turn_angle_rad: float = 1.5708, # π/2
retry_timeout_s: float = 3.0,
clear_hold_s: float = 0.5,
max_retries: int = 3,
):
self._rev_speed = min(0.0, float(reverse_speed_ms)) # ensure negative
self._rev_dist = max(0.01, float(reverse_distance_m))
self._ang_speed = abs(float(angular_speed_rads))
self._turn_angle = max(0.01, float(turn_angle_rad))
self._retry_tout = max(0.1, float(retry_timeout_s))
self._clear_hold = max(0.0, float(clear_hold_s))
self._max_retry = max(1, int(max_retries))
self._state = RecoveryState.IDLE
self._rev_done = 0.0 # distance reversed so far
self._turn_done = 0.0 # angle turned so far
self._retry_count = 0
self._retry_timer = 0.0 # time spent in RETRYING
self._clear_timer = 0.0 # consecutive clear time in RETRYING
# ── Public API ────────────────────────────────────────────────────────────
@property
def state(self) -> RecoveryState:
return self._state
@property
def retry_count(self) -> int:
return self._retry_count
def reset(self) -> None:
"""Return to IDLE and clear all counters."""
self._state = RecoveryState.IDLE
self._rev_done = 0.0
self._turn_done = 0.0
self._retry_count = 0
self._retry_timer = 0.0
self._clear_timer = 0.0
def tick(self, inputs: RecoveryInputs) -> RecoveryOutputs:
prev = self._state
out = self._step(inputs)
out.state = self._state
out.retry_count = self._retry_count
out.state_changed = (self._state != prev)
if out.state_changed:
self._on_enter(self._state)
return out
# ── Internal step ─────────────────────────────────────────────────────────
def _step(self, inp: RecoveryInputs) -> RecoveryOutputs:
out = RecoveryOutputs(state=self._state)
dt = max(0.0, inp.dt)
# ── IDLE ──────────────────────────────────────────────────────────────
if self._state == RecoveryState.IDLE:
if inp.trigger:
self._state = RecoveryState.REVERSING
# ── REVERSING ─────────────────────────────────────────────────────────
elif self._state == RecoveryState.REVERSING:
step = abs(self._rev_speed) * dt
self._rev_done += step
out.cmd_linear = self._rev_speed
out.progress = min(1.0, self._rev_done / self._rev_dist)
if self._rev_done >= self._rev_dist:
self._state = RecoveryState.TURNING
# ── TURNING ───────────────────────────────────────────────────────────
elif self._state == RecoveryState.TURNING:
step = self._ang_speed * dt
self._turn_done += step
out.cmd_angular = self._ang_speed
out.progress = min(1.0, self._turn_done / self._turn_angle)
if self._turn_done >= self._turn_angle:
self._retry_count += 1
self._state = RecoveryState.RETRYING
# ── RETRYING ──────────────────────────────────────────────────────────
elif self._state == RecoveryState.RETRYING:
self._retry_timer += dt
if inp.threat_cleared:
self._clear_timer += dt
if self._clear_timer >= self._clear_hold:
# Success — threat has cleared
self._state = RecoveryState.IDLE
return out
else:
self._clear_timer = 0.0
if self._retry_timer >= self._retry_tout:
if self._retry_count >= self._max_retry:
self._state = RecoveryState.GAVE_UP
out.gave_up = True
else:
self._state = RecoveryState.REVERSING
# ── GAVE_UP ───────────────────────────────────────────────────────────
elif self._state == RecoveryState.GAVE_UP:
# Stay in GAVE_UP until external reset()
out.gave_up = True
return out
# ── Entry actions ─────────────────────────────────────────────────────────
def _on_enter(self, state: RecoveryState) -> None:
if state == RecoveryState.REVERSING:
self._rev_done = 0.0
elif state == RecoveryState.TURNING:
self._turn_done = 0.0
elif state == RecoveryState.RETRYING:
self._retry_timer = 0.0
self._clear_timer = 0.0

View File

@ -0,0 +1,354 @@
"""
threat_detector.py Multi-source threat detection for emergency behavior (Issue #169).
Detectors
ObstacleDetector : Forward-sector minimum range < stop thresholds at speed.
Inputs: min_range_m (pre-filtered from LaserScan forward
sector), cmd_speed_ms.
FallDetector : Extreme pitch/roll from IMU, or depth floor-drop ahead.
Inputs: pitch_rad, roll_rad, floor_drop_m (depth-derived;
0.0 if depth unavailable).
StuckDetector : Commanded speed vs actual speed mismatch sustained for
stuck_timeout_s. Tracks elapsed time with dt argument.
BumpDetector : IMU acceleration jerk (|Δ|a||/dt) above threshold.
MAJOR at jerk_threshold_ms3, CRITICAL at
critical_jerk_threshold_ms3.
ThreatLevel
CLEAR : no threat; normal operation
MINOR : advisory; log/beep only
MAJOR : stop and execute recovery
CRITICAL : full shutdown + MQTT escalation
Pure module no ROS2 dependency.
"""
from __future__ import annotations
import math
import time
from dataclasses import dataclass
from enum import Enum
from typing import Optional
# ── Enumerations ──────────────────────────────────────────────────────────────
class ThreatLevel(Enum):
CLEAR = 0
MINOR = 1
MAJOR = 2
CRITICAL = 3
class ThreatType(Enum):
NONE = "NONE"
OBSTACLE_PROXIMITY = "OBSTACLE_PROXIMITY"
FALL_RISK = "FALL_RISK"
WHEEL_STUCK = "WHEEL_STUCK"
BUMP = "BUMP"
# ── ThreatEvent ───────────────────────────────────────────────────────────────
@dataclass
class ThreatEvent:
"""Snapshot of a single detected threat."""
threat_type: ThreatType = ThreatType.NONE
level: ThreatLevel = ThreatLevel.CLEAR
value: float = 0.0 # triggering metric
detail: str = ""
timestamp_s: float = 0.0
@staticmethod
def clear(timestamp_s: float = 0.0) -> "ThreatEvent":
return ThreatEvent(timestamp_s=timestamp_s)
_CLEAR = ThreatEvent()
# ── ObstacleDetector ─────────────────────────────────────────────────────────
class ObstacleDetector:
"""
Obstacle proximity threat from forward-sector minimum range.
Parameters
----------
stop_distance_m : range below which MAJOR is raised (default 0.30 m)
critical_distance_m : range below which CRITICAL is raised (default 0.10 m)
min_speed_ms : only active above this commanded speed (default 0.05 m/s)
"""
def __init__(
self,
stop_distance_m: float = 0.30,
critical_distance_m: float = 0.10,
min_speed_ms: float = 0.05,
):
self._stop = max(1e-3, stop_distance_m)
self._critical = max(1e-3, min(self._stop, critical_distance_m))
self._min_spd = abs(min_speed_ms)
def update(
self,
min_range_m: float,
cmd_speed_ms: float,
timestamp_s: float = 0.0,
) -> ThreatEvent:
"""
Parameters
----------
min_range_m : minimum obstacle range in forward sector (m)
cmd_speed_ms : signed commanded forward speed (m/s)
"""
if abs(cmd_speed_ms) < self._min_spd:
return ThreatEvent.clear(timestamp_s)
if min_range_m <= self._critical:
return ThreatEvent(
threat_type=ThreatType.OBSTACLE_PROXIMITY,
level=ThreatLevel.CRITICAL,
value=min_range_m,
detail=f"Obstacle {min_range_m:.2f} m (critical zone)",
timestamp_s=timestamp_s,
)
if min_range_m <= self._stop:
return ThreatEvent(
threat_type=ThreatType.OBSTACLE_PROXIMITY,
level=ThreatLevel.MAJOR,
value=min_range_m,
detail=f"Obstacle {min_range_m:.2f} m ahead",
timestamp_s=timestamp_s,
)
return ThreatEvent.clear(timestamp_s)
# ── FallDetector ──────────────────────────────────────────────────────────────
class FallDetector:
"""
Fall / tipping risk from IMU pitch/roll and optional depth floor-drop.
Parameters
----------
minor_tilt_rad : |pitch| or |roll| above which MINOR fires (default 0.20 rad)
major_tilt_rad : above which MAJOR fires (default 0.35 rad)
critical_tilt_rad : above which CRITICAL fires (default 0.52 rad 30°)
floor_drop_m : depth discontinuity (m) triggering MAJOR (default 0.15 m)
"""
def __init__(
self,
minor_tilt_rad: float = 0.20,
major_tilt_rad: float = 0.35,
critical_tilt_rad: float = 0.52,
floor_drop_m: float = 0.15,
):
self._minor = float(minor_tilt_rad)
self._major = float(major_tilt_rad)
self._critical = float(critical_tilt_rad)
self._drop = float(floor_drop_m)
def update(
self,
pitch_rad: float,
roll_rad: float,
floor_drop_m: float = 0.0,
timestamp_s: float = 0.0,
) -> ThreatEvent:
"""
Parameters
----------
pitch_rad : forward tilt (rad); +ve = nose up
roll_rad : lateral tilt (rad); +ve = left side up
floor_drop_m : depth discontinuity ahead of robot (m); 0 = not measured
"""
tilt = max(abs(pitch_rad), abs(roll_rad))
if tilt >= self._critical:
return ThreatEvent(
threat_type=ThreatType.FALL_RISK,
level=ThreatLevel.CRITICAL,
value=tilt,
detail=f"Critical tilt {math.degrees(tilt):.1f}°",
timestamp_s=timestamp_s,
)
if tilt >= self._major or floor_drop_m >= self._drop:
value = max(tilt, floor_drop_m)
detail = (
f"Floor drop {floor_drop_m:.2f} m" if floor_drop_m >= self._drop
else f"Major tilt {math.degrees(tilt):.1f}°"
)
return ThreatEvent(
threat_type=ThreatType.FALL_RISK,
level=ThreatLevel.MAJOR,
value=value,
detail=detail,
timestamp_s=timestamp_s,
)
if tilt >= self._minor:
return ThreatEvent(
threat_type=ThreatType.FALL_RISK,
level=ThreatLevel.MINOR,
value=tilt,
detail=f"Tilt advisory {math.degrees(tilt):.1f}°",
timestamp_s=timestamp_s,
)
return ThreatEvent.clear(timestamp_s)
# ── StuckDetector ─────────────────────────────────────────────────────────────
class StuckDetector:
"""
Wheel stall / stuck detection from cmd_vel vs odometry mismatch.
Accumulates stuck time while |cmd| > min_cmd_ms AND |actual| < moving_ms.
Resets when motion resumes or commanded speed drops below min_cmd_ms.
Parameters
----------
stuck_timeout_s : accumulated stuck time before MAJOR fires (default 3.0 s)
min_cmd_ms : minimum commanded speed to consider stalling (0.05 m/s)
moving_threshold_ms : actual speed above which robot is considered moving
"""
def __init__(
self,
stuck_timeout_s: float = 3.0,
min_cmd_ms: float = 0.05,
moving_threshold_ms: float = 0.05,
):
self._timeout = max(0.1, stuck_timeout_s)
self._min_cmd = abs(min_cmd_ms)
self._moving = abs(moving_threshold_ms)
self._stuck_time: float = 0.0
@property
def stuck_time(self) -> float:
"""Accumulated stuck duration (s)."""
return self._stuck_time
def update(
self,
cmd_speed_ms: float,
actual_speed_ms: float,
dt: float,
timestamp_s: float = 0.0,
) -> ThreatEvent:
"""
Parameters
----------
cmd_speed_ms : commanded forward speed (m/s)
actual_speed_ms : measured forward speed from odometry (m/s)
dt : elapsed time since last call (s)
"""
commanding = abs(cmd_speed_ms) >= self._min_cmd
moving = abs(actual_speed_ms) >= self._moving
if not commanding or moving:
self._stuck_time = 0.0
return ThreatEvent.clear(timestamp_s)
self._stuck_time += max(0.0, dt)
if self._stuck_time >= self._timeout:
return ThreatEvent(
threat_type=ThreatType.WHEEL_STUCK,
level=ThreatLevel.MAJOR,
value=self._stuck_time,
detail=f"Wheels stuck for {self._stuck_time:.1f} s",
timestamp_s=timestamp_s,
)
return ThreatEvent.clear(timestamp_s)
def reset(self) -> None:
self._stuck_time = 0.0
# ── BumpDetector ─────────────────────────────────────────────────────────────
class BumpDetector:
"""
Collision / bump detection via IMU acceleration jerk.
Jerk = |Δ|a|| / dt where |a| = sqrt(ax²+ay²+az²) g (gravity removed)
Parameters
----------
jerk_threshold_ms3 : MAJOR at jerk above this (m/, default 8.0)
critical_jerk_threshold_ms3 : CRITICAL at jerk above this (m/, default 25.0)
gravity_ms2 : gravity magnitude to subtract (default 9.81)
"""
def __init__(
self,
jerk_threshold_ms3: float = 8.0,
critical_jerk_threshold_ms3: float = 25.0,
gravity_ms2: float = 9.81,
):
self._jerk_major = float(jerk_threshold_ms3)
self._jerk_critical = float(critical_jerk_threshold_ms3)
self._gravity = float(gravity_ms2)
self._prev_dyn_mag: Optional[float] = None # previous |a_dynamic|
def update(
self,
ax: float,
ay: float,
az: float,
dt: float,
timestamp_s: float = 0.0,
) -> ThreatEvent:
"""
Parameters
----------
ax, ay, az : IMU linear acceleration (m/)
dt : elapsed time since last call (s)
"""
raw_mag = math.sqrt(ax * ax + ay * ay + az * az)
dyn_mag = abs(raw_mag - self._gravity) # remove gravity component
if self._prev_dyn_mag is None or dt <= 0.0:
self._prev_dyn_mag = dyn_mag
return ThreatEvent.clear(timestamp_s)
jerk = abs(dyn_mag - self._prev_dyn_mag) / dt
self._prev_dyn_mag = dyn_mag
if jerk >= self._jerk_critical:
return ThreatEvent(
threat_type=ThreatType.BUMP,
level=ThreatLevel.CRITICAL,
value=jerk,
detail=f"Critical impact: jerk {jerk:.1f} m/s³",
timestamp_s=timestamp_s,
)
if jerk >= self._jerk_major:
return ThreatEvent(
threat_type=ThreatType.BUMP,
level=ThreatLevel.MAJOR,
value=jerk,
detail=f"Bump detected: jerk {jerk:.1f} m/s³",
timestamp_s=timestamp_s,
)
return ThreatEvent.clear(timestamp_s)
def reset(self) -> None:
self._prev_dyn_mag = None
# ── Utility: pick highest-severity threat ─────────────────────────────────────
def highest_threat(events: list[ThreatEvent]) -> ThreatEvent:
"""Return the ThreatEvent with the highest ThreatLevel from a list."""
if not events:
return _CLEAR
return max(events, key=lambda e: e.level.value)

View File

@ -0,0 +1,5 @@
[develop]
script_dir=$base/lib/saltybot_emergency
[install]
install_scripts=$base/lib/saltybot_emergency

View File

@ -0,0 +1,32 @@
from setuptools import setup, find_packages
import os
from glob import glob
package_name = "saltybot_emergency"
setup(
name=package_name,
version="0.1.0",
packages=find_packages(exclude=["test"]),
data_files=[
("share/ament_index/resource_index/packages",
[f"resource/{package_name}"]),
(f"share/{package_name}", ["package.xml"]),
(os.path.join("share", package_name, "config"),
glob("config/*.yaml")),
(os.path.join("share", package_name, "launch"),
glob("launch/*.py")),
],
install_requires=["setuptools"],
zip_safe=True,
maintainer="sl-controls",
maintainer_email="sl-controls@saltylab.local",
description="Emergency behavior system — collision avoidance, fall prevention, stuck detection, recovery",
license="MIT",
tests_require=["pytest"],
entry_points={
"console_scripts": [
f"emergency_node = {package_name}.emergency_node:main",
],
},
)

View File

@ -0,0 +1,560 @@
"""
test_emergency.py Unit tests for Issue #169 emergency behavior modules.
Covers:
ObstacleDetector proximity thresholds, speed gate
FallDetector tilt levels, floor drop
StuckDetector timeout accumulation, reset on motion
BumpDetector jerk thresholds, first-call safety
AlertManager severity mapping, escalation, suppression
RecoverySequencer full sequence, retry, give-up
EmergencyFSM all state transitions and guard conditions
"""
import math
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import pytest
from saltybot_emergency.threat_detector import (
BumpDetector,
FallDetector,
ObstacleDetector,
StuckDetector,
ThreatEvent,
ThreatLevel,
ThreatType,
highest_threat,
)
from saltybot_emergency.alert_manager import Alert, AlertLevel, AlertManager
from saltybot_emergency.recovery_sequencer import (
RecoveryInputs,
RecoverySequencer,
RecoveryState,
)
from saltybot_emergency.emergency_fsm import (
EmergencyFSM,
EmergencyInputs,
EmergencyState,
)
# ── Helpers ───────────────────────────────────────────────────────────────────
def _obs(**kw):
d = dict(stop_distance_m=0.30, critical_distance_m=0.10, min_speed_ms=0.05)
d.update(kw)
return ObstacleDetector(**d)
def _fall(**kw):
d = dict(minor_tilt_rad=0.20, major_tilt_rad=0.35,
critical_tilt_rad=0.52, floor_drop_m=0.15)
d.update(kw)
return FallDetector(**d)
def _stuck(**kw):
d = dict(stuck_timeout_s=3.0, min_cmd_ms=0.05, moving_threshold_ms=0.05)
d.update(kw)
return StuckDetector(**d)
def _bump(**kw):
d = dict(jerk_threshold_ms3=8.0, critical_jerk_threshold_ms3=25.0)
d.update(kw)
return BumpDetector(**d)
def _alert_mgr(**kw):
d = dict(major_count_threshold=3, escalation_window_s=10.0, suppression_s=1.0)
d.update(kw)
return AlertManager(**d)
def _seq(**kw):
d = dict(
reverse_speed_ms=-0.15, reverse_distance_m=0.30,
angular_speed_rads=0.60, turn_angle_rad=1.5708,
retry_timeout_s=3.0, clear_hold_s=0.5, max_retries=3,
)
d.update(kw)
return RecoverySequencer(**d)
def _fsm(**kw):
d = dict(
stopped_ms=0.03, major_count_threshold=3, escalation_window_s=10.0,
suppression_s=0.0, # disable suppression for cleaner tests
reverse_speed_ms=-0.15, reverse_distance_m=0.30,
angular_speed_rads=0.60, turn_angle_rad=1.5708,
retry_timeout_s=3.0, clear_hold_s=0.5, max_retries=3,
)
d.update(kw)
return EmergencyFSM(**d)
def _major_threat(threat_type=ThreatType.OBSTACLE_PROXIMITY, ts=0.0):
return ThreatEvent(threat_type=threat_type, level=ThreatLevel.MAJOR,
value=1.0, detail="test", timestamp_s=ts)
def _critical_threat(ts=0.0):
return ThreatEvent(threat_type=ThreatType.OBSTACLE_PROXIMITY,
level=ThreatLevel.CRITICAL, value=0.05,
detail="critical test", timestamp_s=ts)
def _minor_threat(ts=0.0):
return ThreatEvent(threat_type=ThreatType.FALL_RISK, level=ThreatLevel.MINOR,
value=0.21, detail="tilt", timestamp_s=ts)
def _clear_threat():
return ThreatEvent.clear()
def _inp(threat=None, speed=0.0, ack=False):
return EmergencyInputs(
threat=threat or _clear_threat(),
robot_speed_ms=speed,
acknowledge=ack,
)
# ══════════════════════════════════════════════════════════════════════════════
# ObstacleDetector
# ══════════════════════════════════════════════════════════════════════════════
class TestObstacleDetector:
def test_clear_when_far(self):
ev = _obs().update(0.5, 0.3)
assert ev.level == ThreatLevel.CLEAR
def test_major_within_stop_distance(self):
ev = _obs(stop_distance_m=0.30).update(0.25, 0.3)
assert ev.level == ThreatLevel.MAJOR
assert ev.threat_type == ThreatType.OBSTACLE_PROXIMITY
def test_critical_within_critical_distance(self):
ev = _obs(critical_distance_m=0.10).update(0.05, 0.3)
assert ev.level == ThreatLevel.CRITICAL
def test_clear_when_stopped(self):
"""Obstacle detection suppressed when not moving."""
ev = _obs(min_speed_ms=0.05).update(0.05, 0.01)
assert ev.level == ThreatLevel.CLEAR
def test_active_at_min_speed(self):
ev = _obs(min_speed_ms=0.05).update(0.20, 0.06)
assert ev.level == ThreatLevel.MAJOR
def test_value_is_distance(self):
ev = _obs().update(0.20, 0.3)
assert ev.value == pytest.approx(0.20, abs=1e-9)
# ══════════════════════════════════════════════════════════════════════════════
# FallDetector
# ══════════════════════════════════════════════════════════════════════════════
class TestFallDetector:
def test_clear_on_flat(self):
ev = _fall().update(0.0, 0.0)
assert ev.level == ThreatLevel.CLEAR
def test_minor_moderate_tilt(self):
ev = _fall(minor_tilt_rad=0.20, major_tilt_rad=0.35).update(0.25, 0.0)
assert ev.level == ThreatLevel.MINOR
def test_major_high_tilt(self):
ev = _fall(major_tilt_rad=0.35, critical_tilt_rad=0.52).update(0.40, 0.0)
assert ev.level == ThreatLevel.MAJOR
def test_critical_extreme_tilt(self):
ev = _fall(critical_tilt_rad=0.52).update(0.60, 0.0)
assert ev.level == ThreatLevel.CRITICAL
def test_major_on_floor_drop(self):
ev = _fall(floor_drop_m=0.15).update(0.0, 0.0, floor_drop_m=0.20)
assert ev.level == ThreatLevel.MAJOR
assert "drop" in ev.detail.lower()
def test_roll_triggers_same_as_pitch(self):
"""Roll beyond minor threshold also fires."""
ev = _fall(minor_tilt_rad=0.20).update(0.0, 0.25)
assert ev.level == ThreatLevel.MINOR
# ══════════════════════════════════════════════════════════════════════════════
# StuckDetector
# ══════════════════════════════════════════════════════════════════════════════
class TestStuckDetector:
def test_clear_when_not_commanded(self):
s = _stuck(stuck_timeout_s=1.0, min_cmd_ms=0.05)
ev = s.update(0.01, 0.0, dt=1.0) # cmd below threshold
assert ev.level == ThreatLevel.CLEAR
def test_clear_when_moving(self):
s = _stuck(stuck_timeout_s=1.0)
ev = s.update(0.2, 0.2, dt=1.0) # actually moving
assert ev.level == ThreatLevel.CLEAR
def test_major_after_timeout(self):
s = _stuck(stuck_timeout_s=3.0, min_cmd_ms=0.05, moving_threshold_ms=0.05)
for _ in range(6):
ev = s.update(0.2, 0.0, dt=0.5) # cmd=0.2, actual=0 → stuck
assert ev.level == ThreatLevel.MAJOR
def test_no_major_before_timeout(self):
s = _stuck(stuck_timeout_s=3.0)
ev = s.update(0.2, 0.0, dt=1.0) # only 1s — not yet
assert ev.level == ThreatLevel.CLEAR
def test_reset_on_motion_resume(self):
s = _stuck(stuck_timeout_s=1.0)
s.update(0.2, 0.0, dt=0.8) # accumulate stuck time
s.update(0.2, 0.3, dt=0.1) # motion resumes → reset
ev = s.update(0.2, 0.0, dt=0.3) # only 0.3s since reset → still clear
assert ev.level == ThreatLevel.CLEAR
def test_stuck_time_property(self):
s = _stuck(stuck_timeout_s=5.0)
s.update(0.2, 0.0, dt=1.0)
s.update(0.2, 0.0, dt=1.0)
assert s.stuck_time == pytest.approx(2.0, abs=1e-6)
# ══════════════════════════════════════════════════════════════════════════════
# BumpDetector
# ══════════════════════════════════════════════════════════════════════════════
class TestBumpDetector:
def test_clear_on_first_call(self):
"""No jerk on first sample (no previous value)."""
ev = _bump().update(0.0, 0.0, 9.81, dt=0.05)
assert ev.level == ThreatLevel.CLEAR
def test_major_on_jerk(self):
b = _bump(jerk_threshold_ms3=5.0, critical_jerk_threshold_ms3=20.0)
b.update(0.0, 0.0, 9.81, dt=0.05) # seed → dyn_mag = 0
# ax=4.5: raw≈10.79, dyn≈0.98, jerk≈9.8 m/s³ → MAJOR (5.0 < 9.8 < 20.0)
ev = b.update(4.5, 0.0, 9.81, dt=0.1)
assert ev.level == ThreatLevel.MAJOR
def test_critical_on_severe_jerk(self):
b = _bump(jerk_threshold_ms3=5.0, critical_jerk_threshold_ms3=20.0)
b.update(0.0, 0.0, 9.81, dt=0.05)
# Very large spike
ev = b.update(50.0, 0.0, 9.81, dt=0.1)
assert ev.level == ThreatLevel.CRITICAL
def test_clear_on_gentle_acceleration(self):
b = _bump(jerk_threshold_ms3=8.0)
b.update(0.0, 0.0, 9.81, dt=0.05)
ev = b.update(0.1, 0.0, 9.81, dt=0.05) # tiny change
assert ev.level == ThreatLevel.CLEAR
# ══════════════════════════════════════════════════════════════════════════════
# highest_threat
# ══════════════════════════════════════════════════════════════════════════════
class TestHighestThreat:
def test_empty_returns_clear(self):
assert highest_threat([]).level == ThreatLevel.CLEAR
def test_picks_highest(self):
a = ThreatEvent(level=ThreatLevel.MINOR)
b = ThreatEvent(level=ThreatLevel.CRITICAL)
c = ThreatEvent(level=ThreatLevel.MAJOR)
assert highest_threat([a, b, c]).level == ThreatLevel.CRITICAL
def test_single_item(self):
ev = ThreatEvent(level=ThreatLevel.MAJOR)
assert highest_threat([ev]) is ev
# ══════════════════════════════════════════════════════════════════════════════
# AlertManager
# ══════════════════════════════════════════════════════════════════════════════
class TestAlertManager:
def test_clear_returns_none(self):
am = _alert_mgr()
assert am.update(_clear_threat()) is None
def test_minor_threat_gives_minor_alert(self):
am = _alert_mgr(suppression_s=0.0)
alert = am.update(_minor_threat(ts=0.0))
assert alert is not None
assert alert.level == AlertLevel.MINOR
def test_major_threat_gives_major_alert(self):
am = _alert_mgr(suppression_s=0.0)
alert = am.update(_major_threat(ts=0.0))
assert alert is not None
assert alert.level == AlertLevel.MAJOR
def test_critical_threat_gives_critical_alert(self):
am = _alert_mgr(suppression_s=0.0)
alert = am.update(_critical_threat(ts=0.0))
assert alert is not None
assert alert.level == AlertLevel.CRITICAL
def test_suppression_blocks_duplicate(self):
am = _alert_mgr(suppression_s=5.0)
am.update(_major_threat(ts=0.0))
alert = am.update(_major_threat(ts=1.0)) # within 5s window
assert alert is None
def test_suppression_expires(self):
am = _alert_mgr(suppression_s=2.0)
am.update(_major_threat(ts=0.0))
alert = am.update(_major_threat(ts=3.0)) # after 2s window
assert alert is not None
def test_escalation_major_to_critical(self):
"""After major_count_threshold major alerts, next one becomes CRITICAL."""
am = _alert_mgr(major_count_threshold=3, escalation_window_s=60.0,
suppression_s=0.0)
for i in range(3):
am.update(_major_threat(ts=float(i)))
# 4th should be escalated
alert = am.update(_major_threat(ts=4.0))
assert alert is not None
assert alert.level == AlertLevel.CRITICAL
def test_escalation_resets_after_window(self):
"""Major alerts outside the window don't contribute to escalation."""
am = _alert_mgr(major_count_threshold=3, escalation_window_s=5.0,
suppression_s=0.0)
am.update(_major_threat(ts=0.0))
am.update(_major_threat(ts=1.0))
am.update(_major_threat(ts=2.0))
# All 3 are old; new one at t=10 is outside window
alert = am.update(_major_threat(ts=10.0))
assert alert is not None
assert alert.level == AlertLevel.MAJOR # not escalated
def test_reset_clears_escalation_state(self):
am = _alert_mgr(major_count_threshold=2, suppression_s=0.0)
am.update(_major_threat(ts=0.0))
am.update(_major_threat(ts=1.0)) # now at threshold
am.reset()
alert = am.update(_major_threat(ts=2.0))
assert alert.level == AlertLevel.MAJOR # back to major after reset
# ══════════════════════════════════════════════════════════════════════════════
# RecoverySequencer
# ══════════════════════════════════════════════════════════════════════════════
class TestRecoverySequencer:
def _trigger(self, seq):
return seq.tick(RecoveryInputs(trigger=True, dt=0.02))
def test_idle_on_init(self):
seq = _seq()
assert seq.state == RecoveryState.IDLE
def test_trigger_starts_reversing(self):
seq = _seq()
out = self._trigger(seq)
assert seq.state == RecoveryState.REVERSING
def test_reversing_backward_velocity(self):
seq = _seq(reverse_speed_ms=-0.15)
self._trigger(seq)
out = seq.tick(RecoveryInputs(dt=0.02))
assert out.cmd_linear < 0.0
def test_reversing_completes_to_turning(self):
seq = _seq(reverse_speed_ms=-1.0, reverse_distance_m=0.5)
self._trigger(seq)
for _ in range(30):
out = seq.tick(RecoveryInputs(dt=0.02))
assert seq.state == RecoveryState.TURNING
def test_turning_positive_angular(self):
seq = _seq(reverse_speed_ms=-1.0, reverse_distance_m=0.1,
angular_speed_rads=1.0)
self._trigger(seq)
# Skip through reversing quickly
for _ in range(20):
seq.tick(RecoveryInputs(dt=0.02))
if seq.state == RecoveryState.TURNING:
out = seq.tick(RecoveryInputs(dt=0.02))
assert out.cmd_angular > 0.0
def test_retrying_increments_count(self):
seq = _seq(reverse_speed_ms=-1.0, reverse_distance_m=0.05,
angular_speed_rads=10.0, turn_angle_rad=0.1)
self._trigger(seq)
for _ in range(100):
seq.tick(RecoveryInputs(dt=0.02))
assert seq.state == RecoveryState.RETRYING
assert seq.retry_count == 1
def test_threat_cleared_returns_idle(self):
seq = _seq(reverse_speed_ms=-1.0, reverse_distance_m=0.05,
angular_speed_rads=10.0, turn_angle_rad=0.1,
clear_hold_s=0.1)
self._trigger(seq)
# Fast-forward to RETRYING
for _ in range(100):
seq.tick(RecoveryInputs(dt=0.02))
assert seq.state == RecoveryState.RETRYING
# Feed cleared ticks until clear_hold met
for _ in range(20):
seq.tick(RecoveryInputs(threat_cleared=True, dt=0.02))
assert seq.state == RecoveryState.IDLE
def test_max_retries_gives_up(self):
seq = _seq(reverse_speed_ms=-1.0, reverse_distance_m=0.05,
angular_speed_rads=10.0, turn_angle_rad=0.1,
retry_timeout_s=0.1, max_retries=2)
self._trigger(seq)
for _ in range(500):
out = seq.tick(RecoveryInputs(threat_cleared=False, dt=0.05))
if seq.state == RecoveryState.GAVE_UP:
break
assert seq.state == RecoveryState.GAVE_UP
def test_reset_returns_to_idle(self):
seq = _seq()
self._trigger(seq)
seq.reset()
assert seq.state == RecoveryState.IDLE
assert seq.retry_count == 0
# ══════════════════════════════════════════════════════════════════════════════
# EmergencyFSM
# ══════════════════════════════════════════════════════════════════════════════
class TestEmergencyFSMBasic:
def test_initial_state_nominal(self):
fsm = _fsm()
assert fsm.state == EmergencyState.NOMINAL
def test_nominal_stays_on_clear(self):
fsm = _fsm()
out = fsm.tick(_inp())
assert fsm.state == EmergencyState.NOMINAL
assert out.cmd_override is False
def test_minor_alert_no_override(self):
fsm = _fsm()
out = fsm.tick(_inp(_minor_threat(ts=0.0)))
assert fsm.state == EmergencyState.NOMINAL
assert out.cmd_override is False
assert out.alert is not None
assert out.alert.level == AlertLevel.MINOR
def test_major_threat_enters_stopping(self):
fsm = _fsm()
out = fsm.tick(_inp(_major_threat()))
assert fsm.state == EmergencyState.STOPPING
assert out.cmd_override is True
def test_critical_threat_enters_stopping_critical_pending(self):
fsm = _fsm()
fsm.tick(_inp(_critical_threat()))
assert fsm.state == EmergencyState.STOPPING
assert fsm._critical_pending is True
class TestEmergencyFSMStopping:
def test_stopping_commands_zero(self):
fsm = _fsm()
fsm.tick(_inp(_major_threat()))
out = fsm.tick(_inp(_major_threat(), speed=0.5))
assert out.cmd_linear == pytest.approx(0.0, abs=1e-9)
assert out.cmd_angular == pytest.approx(0.0, abs=1e-9)
def test_stopped_enters_recovering(self):
fsm = _fsm(stopped_ms=0.03)
fsm.tick(_inp(_major_threat()))
out = fsm.tick(_inp(_major_threat(), speed=0.01)) # below stopped_ms
assert fsm.state == EmergencyState.RECOVERING
def test_critical_pending_enters_escalated(self):
fsm = _fsm(stopped_ms=0.03)
fsm.tick(_inp(_critical_threat()))
fsm.tick(_inp(_critical_threat(), speed=0.01)) # stopped → ESCALATED
assert fsm.state == EmergencyState.ESCALATED
class TestEmergencyFSMRecovering:
def _reach_recovering(self, fsm):
fsm.tick(_inp(_major_threat()))
fsm.tick(_inp(_major_threat(), speed=0.0)) # stopped → RECOVERING
assert fsm.state == EmergencyState.RECOVERING
def test_recovering_has_cmd_override(self):
fsm = _fsm()
self._reach_recovering(fsm)
out = fsm.tick(_inp(_clear_threat()))
assert out.cmd_override is True
def test_recovering_gave_up_escalates(self):
fsm = _fsm(max_retries=1, retry_timeout_s=0.05)
self._reach_recovering(fsm)
# Drive recovery to GAVE_UP by feeding many non-clearing ticks
for _ in range(500):
out = fsm.tick(_inp(_major_threat()))
if fsm.state == EmergencyState.ESCALATED:
break
assert fsm.state == EmergencyState.ESCALATED
class TestEmergencyFSMEscalated:
def _reach_escalated(self, fsm):
fsm.tick(_inp(_critical_threat()))
fsm.tick(_inp(_critical_threat(), speed=0.0))
assert fsm.state == EmergencyState.ESCALATED
def test_escalated_emits_critical_alert_once(self):
fsm = _fsm()
self._reach_escalated(fsm)
out1 = fsm.tick(_inp())
out2 = fsm.tick(_inp())
assert out1.alert is not None
assert out1.alert.level == AlertLevel.CRITICAL
assert out2.alert is None # suppressed after first emission
def test_escalated_e_stop_asserted(self):
fsm = _fsm()
self._reach_escalated(fsm)
out = fsm.tick(_inp())
assert out.e_stop is True
def test_escalated_stays_without_ack(self):
fsm = _fsm()
self._reach_escalated(fsm)
for _ in range(5):
fsm.tick(_inp())
assert fsm.state == EmergencyState.ESCALATED
def test_acknowledge_returns_to_nominal(self):
fsm = _fsm()
self._reach_escalated(fsm)
fsm.tick(_inp(ack=True))
assert fsm.state == EmergencyState.NOMINAL
def test_reset_returns_to_nominal(self):
fsm = _fsm()
self._reach_escalated(fsm)
fsm.reset()
assert fsm.state == EmergencyState.NOMINAL
def test_e_stop_cleared_on_ack(self):
fsm = _fsm()
self._reach_escalated(fsm)
out = fsm.tick(_inp(ack=True))
assert out.e_stop is False

View File

@ -0,0 +1,15 @@
cmake_minimum_required(VERSION 3.8)
project(saltybot_emergency_msgs)
find_package(ament_cmake REQUIRED)
find_package(rosidl_default_generators REQUIRED)
find_package(builtin_interfaces REQUIRED)
rosidl_generate_interfaces(${PROJECT_NAME}
"msg/EmergencyEvent.msg"
"msg/RecoveryAction.msg"
DEPENDENCIES builtin_interfaces
)
ament_export_dependencies(rosidl_default_runtime)
ament_package()

View File

@ -0,0 +1,25 @@
# EmergencyEvent.msg — Real-time emergency system state snapshot (Issue #169)
# Published by: /saltybot/emergency_node
# Topic: /saltybot/emergency
builtin_interfaces/Time stamp
# Overall FSM state
# Values: "NOMINAL" | "STOPPING" | "RECOVERING" | "ESCALATED"
string state
# Active threat (highest severity across all detectors)
# threat_type values: "NONE" | "OBSTACLE_PROXIMITY" | "FALL_RISK" | "WHEEL_STUCK" | "BUMP"
string threat_type
# Severity: "CLEAR" | "MINOR" | "MAJOR" | "CRITICAL"
string severity
# Triggering metric value (e.g. distance in m, jerk in m/s³, stuck seconds)
float32 threat_value
# Human-readable description of the active threat
string detail
# True when emergency system is overriding normal cmd_vel with its own commands
bool cmd_override

View File

@ -0,0 +1,15 @@
# RecoveryAction.msg — Recovery sequencer state (Issue #169)
# Published by: /saltybot/emergency_node
# Topic: /saltybot/recovery_action
builtin_interfaces/Time stamp
# Current recovery action
# Values: "IDLE" | "REVERSING" | "TURNING" | "RETRYING" | "GAVE_UP"
string action
# Number of reverse+turn attempts completed so far
int32 retry_count
# Progress through current phase [0.0 1.0]
float32 progress

View File

@ -0,0 +1,22 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_emergency_msgs</name>
<version>0.1.0</version>
<description>Emergency behavior message definitions for SaltyBot (Issue #169)</description>
<maintainer email="sl-controls@saltylab.local">sl-controls</maintainer>
<license>MIT</license>
<buildtool_depend>ament_cmake</buildtool_depend>
<buildtool_depend>rosidl_default_generators</buildtool_depend>
<depend>builtin_interfaces</depend>
<exec_depend>rosidl_default_runtime</exec_depend>
<member_of_group>rosidl_interface_packages</member_of_group>
<export>
<build_type>ament_cmake</build_type>
</export>
</package>

View File

@ -0,0 +1,38 @@
emotion_node:
ros__parameters:
# Path to the TRT FP16 emotion engine (built from emotion_cnn.onnx).
# Build with:
# trtexec --onnx=emotion_cnn.onnx --fp16 --saveEngine=emotion_fp16.trt
# Leave empty to use landmark heuristic only (no GPU required).
engine_path: "/models/emotion_fp16.trt"
# Minimum smoothed confidence to publish an expression.
# Lower = more detections but more noise; 0.40 is a good production default.
min_confidence: 0.40
# EMA smoothing weight applied to each new observation.
# 0.0 = frozen (never updates) 1.0 = no smoothing (raw output)
# 0.30 gives stable ~3-frame moving average at 10 Hz face detection rate.
smoothing_alpha: 0.30
# Comma-separated person_ids that have opted out of emotion tracking.
# Empty string = everyone is tracked by default.
# Example: "1,3,7"
opt_out_persons: ""
# Skip face crops smaller than this side length (pixels).
# Very small crops produce unreliable emotion predictions.
face_min_size: 24
# If true and TRT engine is unavailable, fall back to the 5-point
# landmark heuristic. Set false to suppress output entirely without TRT.
landmark_fallback: true
# Camera names matching saltybot_cameras topic naming convention.
camera_names: "front,left,rear,right"
# Number of active CSI camera streams to subscribe to.
n_cameras: 4
# Publish /social/emotion/context (JSON string) for conversation_node.
publish_context: true

View File

@ -8,6 +8,7 @@ speech_pipeline_node:
use_silero_vad: true use_silero_vad: true
whisper_model: "small" # small (~500ms), medium (better quality, ~900ms) whisper_model: "small" # small (~500ms), medium (better quality, ~900ms)
whisper_compute_type: "float16" whisper_compute_type: "float16"
whisper_language: "" # "" = auto-detect; set e.g. "fr" to force
speaker_threshold: 0.65 speaker_threshold: 0.65
speaker_db_path: "/social_db/speaker_embeddings.json" speaker_db_path: "/social_db/speaker_embeddings.json"
publish_partial: true publish_partial: true

View File

@ -1,6 +1,8 @@
tts_node: tts_node:
ros__parameters: ros__parameters:
voice_path: "/models/piper/en_US-lessac-medium.onnx" voice_path: "/models/piper/en_US-lessac-medium.onnx"
voice_map_json: "{}"
default_language: "en"
sample_rate: 22050 sample_rate: 22050
volume: 1.0 volume: 1.0
audio_device: "" # "" = system default; set to device name if needed audio_device: "" # "" = system default; set to device name if needed

View File

@ -0,0 +1,67 @@
"""emotion.launch.py — Launch emotion_node for facial expression recognition (Issue #161)."""
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration, PathJoinSubstitution
from launch_ros.actions import Node
from launch_ros.substitutions import FindPackageShare
def generate_launch_description() -> LaunchDescription:
pkg = FindPackageShare("saltybot_social")
params_file = PathJoinSubstitution([pkg, "config", "emotion_params.yaml"])
return LaunchDescription([
DeclareLaunchArgument(
"params_file",
default_value=params_file,
description="Path to emotion_node parameter YAML",
),
DeclareLaunchArgument(
"engine_path",
default_value="/models/emotion_fp16.trt",
description="Path to TensorRT FP16 emotion engine (empty = landmark heuristic)",
),
DeclareLaunchArgument(
"min_confidence",
default_value="0.40",
description="Minimum detection confidence to publish an expression",
),
DeclareLaunchArgument(
"smoothing_alpha",
default_value="0.30",
description="EMA smoothing weight (0=frozen, 1=no smoothing)",
),
DeclareLaunchArgument(
"opt_out_persons",
default_value="",
description="Comma-separated person_ids that opted out",
),
DeclareLaunchArgument(
"landmark_fallback",
default_value="true",
description="Use landmark heuristic when TRT engine unavailable",
),
DeclareLaunchArgument(
"publish_context",
default_value="true",
description="Publish /social/emotion/context JSON for LLM context",
),
Node(
package="saltybot_social",
executable="emotion_node",
name="emotion_node",
output="screen",
parameters=[
LaunchConfiguration("params_file"),
{
"engine_path": LaunchConfiguration("engine_path"),
"min_confidence": LaunchConfiguration("min_confidence"),
"smoothing_alpha": LaunchConfiguration("smoothing_alpha"),
"opt_out_persons": LaunchConfiguration("opt_out_persons"),
"landmark_fallback": LaunchConfiguration("landmark_fallback"),
"publish_context": LaunchConfiguration("publish_context"),
},
],
),
])

View File

@ -1,54 +1,30 @@
"""conversation_node.py — Local LLM conversation engine with per-person context. """conversation_node.py — Local LLM conversation engine with per-person context.
Issue #83/#161/#167
Issue #83: Conversation engine for social-bot.
Stack: Phi-3-mini or Llama-3.2-3B GGUF Q4_K_M via llama-cpp-python (CUDA).
Subscribes /social/speech/transcript builds per-person prompt streams
token output publishes /social/conversation/response.
Streaming: publishes partial=true tokens as they arrive, then final=false
at end of generation. TTS node can begin synthesis on first sentence boundary.
ROS2 topics:
Subscribe: /social/speech/transcript (saltybot_social_msgs/SpeechTranscript)
Publish: /social/conversation/response (saltybot_social_msgs/ConversationResponse)
Parameters:
model_path (str, "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf")
n_ctx (int, 4096)
n_gpu_layers (int, 20) GPU offload layers (increase for more VRAM usage)
max_tokens (int, 200)
temperature (float, 0.7)
top_p (float, 0.9)
soul_path (str, "/soul/SOUL.md")
context_db_path (str, "/social_db/conversation_context.json")
save_interval_s (float, 30.0) how often to persist context to disk
stream (bool, true)
""" """
from __future__ import annotations from __future__ import annotations
import json, threading, time
import threading from typing import Dict, Optional
import time
from typing import Optional
import rclpy import rclpy
from rclpy.node import Node from rclpy.node import Node
from rclpy.qos import QoSProfile from rclpy.qos import QoSProfile
from std_msgs.msg import String
from saltybot_social_msgs.msg import SpeechTranscript, ConversationResponse from saltybot_social_msgs.msg import SpeechTranscript, ConversationResponse
from .llm_context import ContextStore, build_llama_prompt, load_system_prompt, needs_summary_prompt from .llm_context import ContextStore, build_llama_prompt, load_system_prompt, needs_summary_prompt
_LANG_NAMES: Dict[str, str] = {
"en": "English", "fr": "French", "es": "Spanish", "de": "German",
"it": "Italian", "pt": "Portuguese", "ja": "Japanese", "zh": "Chinese",
"ko": "Korean", "ar": "Arabic", "ru": "Russian", "nl": "Dutch",
"pl": "Polish", "sv": "Swedish", "da": "Danish", "fi": "Finnish",
"no": "Norwegian", "tr": "Turkish", "hi": "Hindi", "uk": "Ukrainian",
"cs": "Czech", "ro": "Romanian", "hu": "Hungarian", "el": "Greek",
}
class ConversationNode(Node): class ConversationNode(Node):
"""Local LLM inference node with per-person conversation memory.""" """Local LLM inference node with per-person conversation memory."""
def __init__(self) -> None: def __init__(self) -> None:
super().__init__("conversation_node") super().__init__("conversation_node")
self.declare_parameter("model_path", "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf")
# ── Parameters ──────────────────────────────────────────────────────
self.declare_parameter("model_path",
"/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf")
self.declare_parameter("n_ctx", 4096) self.declare_parameter("n_ctx", 4096)
self.declare_parameter("n_gpu_layers", 20) self.declare_parameter("n_gpu_layers", 20)
self.declare_parameter("max_tokens", 200) self.declare_parameter("max_tokens", 200)
@ -58,7 +34,6 @@ class ConversationNode(Node):
self.declare_parameter("context_db_path", "/social_db/conversation_context.json") self.declare_parameter("context_db_path", "/social_db/conversation_context.json")
self.declare_parameter("save_interval_s", 30.0) self.declare_parameter("save_interval_s", 30.0)
self.declare_parameter("stream", True) self.declare_parameter("stream", True)
self._model_path = self.get_parameter("model_path").value self._model_path = self.get_parameter("model_path").value
self._n_ctx = self.get_parameter("n_ctx").value self._n_ctx = self.get_parameter("n_ctx").value
self._n_gpu = self.get_parameter("n_gpu_layers").value self._n_gpu = self.get_parameter("n_gpu_layers").value
@ -69,18 +44,9 @@ class ConversationNode(Node):
self._db_path = self.get_parameter("context_db_path").value self._db_path = self.get_parameter("context_db_path").value
self._save_interval = self.get_parameter("save_interval_s").value self._save_interval = self.get_parameter("save_interval_s").value
self._stream = self.get_parameter("stream").value self._stream = self.get_parameter("stream").value
# ── Publishers / Subscribers ─────────────────────────────────────────
qos = QoSProfile(depth=10) qos = QoSProfile(depth=10)
self._resp_pub = self.create_publisher( self._resp_pub = self.create_publisher(ConversationResponse, "/social/conversation/response", qos)
ConversationResponse, "/social/conversation/response", qos self._transcript_sub = self.create_subscription(SpeechTranscript, "/social/speech/transcript", self._on_transcript, qos)
)
self._transcript_sub = self.create_subscription(
SpeechTranscript, "/social/speech/transcript",
self._on_transcript, qos
)
# ── State ────────────────────────────────────────────────────────────
self._llm = None self._llm = None
self._system_prompt = load_system_prompt(self._soul_path) self._system_prompt = load_system_prompt(self._soul_path)
self._ctx_store = ContextStore(self._db_path) self._ctx_store = ContextStore(self._db_path)
@ -88,180 +54,114 @@ class ConversationNode(Node):
self._turn_counter = 0 self._turn_counter = 0
self._generating = False self._generating = False
self._last_save = time.time() self._last_save = time.time()
self._speaker_lang: Dict[str, str] = {}
# ── Load LLM in background ──────────────────────────────────────────── self._emotions: Dict[str, str] = {}
self.create_subscription(String, "/social/emotion/context", self._on_emotion_context, 10)
threading.Thread(target=self._load_llm, daemon=True).start() threading.Thread(target=self._load_llm, daemon=True).start()
# ── Periodic context save ────────────────────────────────────────────
self._save_timer = self.create_timer(self._save_interval, self._save_context) self._save_timer = self.create_timer(self._save_interval, self._save_context)
self.get_logger().info( self.get_logger().info(
f"ConversationNode init (model={self._model_path}, " f"ConversationNode init (model={self._model_path}, "
f"gpu_layers={self._n_gpu}, ctx={self._n_ctx})" f"gpu_layers={self._n_gpu}, ctx={self._n_ctx})"
) )
# ── Model loading ─────────────────────────────────────────────────────────
def _load_llm(self) -> None: def _load_llm(self) -> None:
t0 = time.time() t0 = time.time()
self.get_logger().info(f"Loading LLM: {self._model_path}")
try: try:
from llama_cpp import Llama from llama_cpp import Llama
self._llm = Llama( self._llm = Llama(model_path=self._model_path, n_ctx=self._n_ctx, n_gpu_layers=self._n_gpu, n_threads=4, verbose=False)
model_path=self._model_path, self.get_logger().info(f"LLM ready ({time.time()-t0:.1f}s)")
n_ctx=self._n_ctx,
n_gpu_layers=self._n_gpu,
n_threads=4,
verbose=False,
)
self.get_logger().info(
f"LLM ready ({time.time()-t0:.1f}s). "
f"Context: {self._n_ctx} tokens, GPU layers: {self._n_gpu}"
)
except Exception as e: except Exception as e:
self.get_logger().error(f"LLM load failed: {e}") self.get_logger().error(f"LLM load failed: {e}")
# ── Transcript callback ───────────────────────────────────────────────────
def _on_transcript(self, msg: SpeechTranscript) -> None: def _on_transcript(self, msg: SpeechTranscript) -> None:
"""Handle final transcripts only (skip streaming partials).""" if msg.is_partial or not msg.text.strip():
if msg.is_partial:
return return
if not msg.text.strip(): if msg.language:
return self._speaker_lang[msg.speaker_id] = msg.language
self.get_logger().info(f"Transcript [{msg.speaker_id}/{msg.language or '?'}]: '{msg.text}'")
self.get_logger().info( threading.Thread(target=self._generate_response, args=(msg.text.strip(), msg.speaker_id), daemon=True).start()
f"Transcript [{msg.speaker_id}]: '{msg.text}'"
)
threading.Thread(
target=self._generate_response,
args=(msg.text.strip(), msg.speaker_id),
daemon=True,
).start()
# ── LLM inference ─────────────────────────────────────────────────────────
def _generate_response(self, user_text: str, speaker_id: str) -> None: def _generate_response(self, user_text: str, speaker_id: str) -> None:
"""Generate LLM response with streaming. Runs in thread."""
if self._llm is None: if self._llm is None:
self.get_logger().warn("LLM not loaded yet, dropping utterance") self.get_logger().warn("LLM not loaded yet, dropping utterance"); return
return
with self._lock: with self._lock:
if self._generating: if self._generating:
self.get_logger().warn("LLM busy, dropping utterance") self.get_logger().warn("LLM busy, dropping utterance"); return
return
self._generating = True self._generating = True
self._turn_counter += 1 self._turn_counter += 1
turn_id = self._turn_counter turn_id = self._turn_counter
lang = self._speaker_lang.get(speaker_id, "en")
try: try:
ctx = self._ctx_store.get(speaker_id) ctx = self._ctx_store.get(speaker_id)
# Summary compression if context is long
if ctx.needs_compression(): if ctx.needs_compression():
self._compress_context(ctx) self._compress_context(ctx)
emotion_hint = self._emotion_hint(speaker_id)
ctx.add_user(user_text) lang_hint = self._language_hint(speaker_id)
hints = " ".join(h for h in (emotion_hint, lang_hint) if h)
prompt = build_llama_prompt( annotated = f"{user_text} {hints}".rstrip() if hints else user_text
ctx, user_text, self._system_prompt ctx.add_user(annotated)
) prompt = build_llama_prompt(ctx, annotated, self._system_prompt)
t0 = time.perf_counter() t0 = time.perf_counter()
full_response = "" full_response = ""
if self._stream: if self._stream:
output = self._llm( output = self._llm(prompt, max_tokens=self._max_tokens, temperature=self._temperature, top_p=self._top_p, stream=True, stop=["<|user|>", "<|system|>", "\n\n\n"])
prompt,
max_tokens=self._max_tokens,
temperature=self._temperature,
top_p=self._top_p,
stream=True,
stop=["<|user|>", "<|system|>", "\n\n\n"],
)
for chunk in output: for chunk in output:
token = chunk["choices"][0]["text"] token = chunk["choices"][0]["text"]
full_response += token full_response += token
# Publish partial after each sentence boundary for low TTS latency
if token.endswith((".", "!", "?", "\n")): if token.endswith((".", "!", "?", "\n")):
self._publish_response( self._publish_response(full_response.strip(), speaker_id, turn_id, language=lang, is_partial=True)
full_response.strip(), speaker_id, turn_id, is_partial=True
)
else: else:
output = self._llm( output = self._llm(prompt, max_tokens=self._max_tokens, temperature=self._temperature, top_p=self._top_p, stream=False, stop=["<|user|>", "<|system|>"])
prompt,
max_tokens=self._max_tokens,
temperature=self._temperature,
top_p=self._top_p,
stream=False,
stop=["<|user|>", "<|system|>"],
)
full_response = output["choices"][0]["text"] full_response = output["choices"][0]["text"]
full_response = full_response.strip() full_response = full_response.strip()
latency_ms = (time.perf_counter() - t0) * 1000 self.get_logger().info(f"LLM [{speaker_id}/{lang}] ({(time.perf_counter()-t0)*1000:.0f}ms): '{full_response[:80]}'")
self.get_logger().info(
f"LLM [{speaker_id}] ({latency_ms:.0f}ms): '{full_response[:80]}'"
)
ctx.add_assistant(full_response) ctx.add_assistant(full_response)
self._publish_response(full_response, speaker_id, turn_id, is_partial=False) self._publish_response(full_response, speaker_id, turn_id, language=lang, is_partial=False)
except Exception as e: except Exception as e:
self.get_logger().error(f"LLM inference error: {e}") self.get_logger().error(f"LLM inference error: {e}")
finally: finally:
with self._lock: with self._lock: self._generating = False
self._generating = False
def _compress_context(self, ctx) -> None: def _compress_context(self, ctx) -> None:
"""Ask LLM to summarize old turns for context compression.""" if self._llm is None: ctx.compress("(history omitted)"); return
if self._llm is None:
ctx.compress("(history omitted)")
return
try: try:
summary_prompt = needs_summary_prompt(ctx) result = self._llm(needs_summary_prompt(ctx), max_tokens=80, temperature=0.3, stream=False)
result = self._llm(summary_prompt, max_tokens=80, temperature=0.3, stream=False) ctx.compress(result["choices"][0]["text"].strip())
summary = result["choices"][0]["text"].strip() except Exception: ctx.compress("(history omitted)")
ctx.compress(summary)
self.get_logger().debug(
f"Context compressed for {ctx.person_id}: '{summary[:60]}'"
)
except Exception:
ctx.compress("(history omitted)")
# ── Publish ─────────────────────────────────────────────────────────────── def _language_hint(self, speaker_id: str) -> str:
lang = self._speaker_lang.get(speaker_id, "en")
if lang and lang != "en":
return f"[Please respond in {_LANG_NAMES.get(lang, lang)}.]"
return ""
def _publish_response( def _on_emotion_context(self, msg: String) -> None:
self, text: str, speaker_id: str, turn_id: int, is_partial: bool try:
) -> None: for k, v in json.loads(msg.data).get("emotions", {}).items():
self._emotions[k] = v
except Exception: pass
def _emotion_hint(self, speaker_id: str) -> str:
emo = self._emotions.get(speaker_id, "")
return f"[The person seems {emo} right now.]" if emo and emo != "neutral" else ""
def _publish_response(self, text: str, speaker_id: str, turn_id: int, language: str = "en", is_partial: bool = False) -> None:
msg = ConversationResponse() msg = ConversationResponse()
msg.header.stamp = self.get_clock().now().to_msg() msg.header.stamp = self.get_clock().now().to_msg()
msg.text = text msg.text = text; msg.speaker_id = speaker_id; msg.is_partial = is_partial
msg.speaker_id = speaker_id msg.turn_id = turn_id; msg.language = language
msg.is_partial = is_partial
msg.turn_id = turn_id
self._resp_pub.publish(msg) self._resp_pub.publish(msg)
def _save_context(self) -> None: def _save_context(self) -> None:
try: try: self._ctx_store.save()
self._ctx_store.save() except Exception as e: self.get_logger().error(f"Context save error: {e}")
except Exception as e:
self.get_logger().error(f"Context save error: {e}")
def destroy_node(self) -> None: def destroy_node(self) -> None:
self._save_context() self._save_context(); super().destroy_node()
super().destroy_node()
def main(args=None) -> None: def main(args=None) -> None:
rclpy.init(args=args) rclpy.init(args=args)
node = ConversationNode() node = ConversationNode()
try: try: rclpy.spin(node)
rclpy.spin(node) except KeyboardInterrupt: pass
except KeyboardInterrupt: finally: node.destroy_node(); rclpy.shutdown()
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@ -0,0 +1,414 @@
"""emotion_classifier.py — 7-class facial expression classifier (Issue #161).
Pure Python, no ROS2 / TensorRT / OpenCV dependencies.
Wraps a TensorRT FP16 emotion CNN and provides:
- On-device TRT inference on 48×48 grayscale face crops
- Heuristic fallback from 5-point SCRFD facial landmarks
- Per-person EMA temporal smoothing for stable outputs
- Per-person opt-out registry
Emotion classes (index order matches CNN output layer)
------------------------------------------------------
0 = happy
1 = sad
2 = angry
3 = surprised
4 = fearful
5 = disgusted
6 = neutral
Coordinate convention
---------------------
Face crop: BGR uint8 ndarray, any size (resized to INPUT_SIZE internally).
Landmarks (lm10): 10 floats from FaceDetection.landmarks
[left_eye_x, left_eye_y, right_eye_x, right_eye_y,
nose_x, nose_y, left_mouth_x, left_mouth_y,
right_mouth_x, right_mouth_y]
All coordinates are normalised image-space (0.01.0).
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
# ── Constants ─────────────────────────────────────────────────────────────────
EMOTIONS: List[str] = [
"happy", "sad", "angry", "surprised", "fearful", "disgusted", "neutral"
]
N_CLASSES: int = len(EMOTIONS)
# Landmark slot indices within the 10-value array
_LEX, _LEY = 0, 1 # left eye
_REX, _REY = 2, 3 # right eye
_NX, _NY = 4, 5 # nose
_LMX, _LMY = 6, 7 # left mouth corner
_RMX, _RMY = 8, 9 # right mouth corner
# CNN input size
INPUT_SIZE: int = 48
# ── Result type ───────────────────────────────────────────────────────────────
@dataclass
class ClassifiedEmotion:
emotion: str # dominant emotion label
confidence: float # smoothed softmax probability (0.01.0)
scores: List[float] # per-class scores, len == N_CLASSES
source: str = "cnn_trt" # "cnn_trt" | "landmark_heuristic" | "opt_out"
# ── Math helpers ──────────────────────────────────────────────────────────────
def _softmax(logits: List[float]) -> List[float]:
"""Numerically stable softmax."""
m = max(logits)
exps = [math.exp(x - m) for x in logits]
total = sum(exps)
return [e / total for e in exps]
def _argmax(values: List[float]) -> int:
best = 0
for i in range(1, len(values)):
if values[i] > values[best]:
best = i
return best
def _uniform_scores() -> List[float]:
return [1.0 / N_CLASSES] * N_CLASSES
# ── Per-person temporal smoother + opt-out registry ───────────────────────────
class PersonEmotionTracker:
"""Exponential moving-average smoother + opt-out registry per person.
Args:
alpha: EMA weight for the newest observation (0.0 = frozen, 1.0 = no smooth).
max_age: Frames without update before the EMA is expired and reset.
"""
def __init__(self, alpha: float = 0.3, max_age: int = 16) -> None:
self._alpha = alpha
self._max_age = max_age
self._ema: Dict[int, List[float]] = {} # person_id → EMA scores
self._age: Dict[int, int] = {} # person_id → frames since last update
self._opt_out: set = set() # person_ids that opted out
# ── Opt-out management ────────────────────────────────────────────────────
def set_opt_out(self, person_id: int, value: bool) -> None:
if value:
self._opt_out.add(person_id)
self._ema.pop(person_id, None)
self._age.pop(person_id, None)
else:
self._opt_out.discard(person_id)
def is_opt_out(self, person_id: int) -> bool:
return person_id in self._opt_out
# ── Smoothing ─────────────────────────────────────────────────────────────
def smooth(self, person_id: int, scores: List[float]) -> List[float]:
"""Apply EMA to raw scores. Returns smoothed scores (length N_CLASSES)."""
if person_id < 0:
return scores[:]
# Reset stale entries
if person_id in self._age and self._age[person_id] > self._max_age:
self.reset(person_id)
if person_id not in self._ema:
self._ema[person_id] = scores[:]
self._age[person_id] = 0
return scores[:]
ema = self._ema[person_id]
a = self._alpha
smoothed = [a * s + (1.0 - a) * e for s, e in zip(scores, ema)]
self._ema[person_id] = smoothed
self._age[person_id] = 0
return smoothed
def tick(self) -> None:
"""Advance age counter for all tracked persons (call once per frame)."""
for pid in list(self._age.keys()):
self._age[pid] += 1
def reset(self, person_id: int) -> None:
self._ema.pop(person_id, None)
self._age.pop(person_id, None)
def reset_all(self) -> None:
self._ema.clear()
self._age.clear()
@property
def tracked_ids(self) -> List[int]:
return list(self._ema.keys())
# ── TensorRT engine loader ────────────────────────────────────────────────────
class _TrtEngine:
"""Thin wrapper around a TensorRT FP16 emotion CNN engine.
Expected engine:
- Input binding: name "input" shape (1, 1, 48, 48) float32
- Output binding: name "output" shape (1, 7) float32 (softmax logits)
The engine is built offline from a MobileNetV2-based 48×48 grayscale CNN
(FER+ or AffectNet trained) via:
trtexec --onnx=emotion_cnn.onnx --fp16 --saveEngine=emotion_fp16.trt
"""
def __init__(self) -> None:
self._context = None
self._h_input = None
self._h_output = None
self._d_input = None
self._d_output = None
self._stream = None
self._bindings: List = []
def load(self, engine_path: str) -> bool:
"""Load TRT engine from file. Returns True on success."""
try:
import tensorrt as trt # type: ignore
import pycuda.autoinit # type: ignore # noqa: F401
import pycuda.driver as cuda # type: ignore
import numpy as np
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
with open(engine_path, "rb") as f:
engine_data = f.read()
runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(engine_data)
self._context = engine.create_execution_context()
# Allocate host + device buffers
self._h_input = cuda.pagelocked_empty((1, 1, INPUT_SIZE, INPUT_SIZE),
dtype=np.float32)
self._h_output = cuda.pagelocked_empty((1, N_CLASSES), dtype=np.float32)
self._d_input = cuda.mem_alloc(self._h_input.nbytes)
self._d_output = cuda.mem_alloc(self._h_output.nbytes)
self._stream = cuda.Stream()
self._bindings = [int(self._d_input), int(self._d_output)]
return True
except Exception:
self._context = None
return False
@property
def ready(self) -> bool:
return self._context is not None
def infer(self, face_bgr) -> List[float]:
"""Run TRT inference on a BGR face crop. Returns 7 softmax scores."""
import pycuda.driver as cuda # type: ignore
import numpy as np
img = _preprocess(face_bgr) # (1, 1, 48, 48) float32
np.copyto(self._h_input, img)
cuda.memcpy_htod_async(self._d_input, self._h_input, self._stream)
self._context.execute_async_v2(self._bindings, self._stream.handle, None)
cuda.memcpy_dtoh_async(self._h_output, self._d_output, self._stream)
self._stream.synchronize()
logits = list(self._h_output[0])
return _softmax(logits)
# ── Image pre-processing helper (importable without cv2 in test mode) ─────────
def _preprocess(face_bgr) -> "np.ndarray": # type: ignore
"""Resize BGR crop to 48×48 grayscale, normalise to [-1, 1].
Returns ndarray shape (1, 1, 48, 48) float32.
"""
import numpy as np # type: ignore
import cv2 # type: ignore
gray = cv2.cvtColor(face_bgr, cv2.COLOR_BGR2GRAY)
resized = cv2.resize(gray, (INPUT_SIZE, INPUT_SIZE),
interpolation=cv2.INTER_LINEAR)
norm = resized.astype(np.float32) / 127.5 - 1.0
return norm.reshape(1, 1, INPUT_SIZE, INPUT_SIZE)
# ── Landmark heuristic ────────────────────────────────────────────────────────
def classify_from_landmarks(lm10: List[float]) -> ClassifiedEmotion:
"""Estimate emotion from 5-point SCRFD landmarks (10 floats).
Uses geometric ratios between eyes, nose, and mouth corners.
Accuracy is limited treats this as a soft prior, not a definitive label.
Returns ClassifiedEmotion with source="landmark_heuristic".
"""
if len(lm10) < 10:
scores = _uniform_scores()
scores[6] = 0.5 # bias neutral
scores = _renorm(scores)
return ClassifiedEmotion("neutral", scores[6], scores, "landmark_heuristic")
eye_y = (lm10[_LEY] + lm10[_REY]) / 2.0
nose_y = lm10[_NY]
mouth_y = (lm10[_LMY] + lm10[_RMY]) / 2.0
eye_span = max(abs(lm10[_REX] - lm10[_LEX]), 1e-4)
mouth_width = abs(lm10[_RMX] - lm10[_LMX])
mouth_asymm = abs(lm10[_LMY] - lm10[_RMY])
face_h = max(mouth_y - eye_y, 1e-4)
# Ratio of mouth span to interocular distance
width_ratio = mouth_width / eye_span # >1.0 = wide open / happy smile
# How far mouth is below nose, relative to face height
mouth_below_nose = (mouth_y - nose_y) / face_h # ~0.30.6 typical
# Relative asymmetry of mouth corners
asym_ratio = mouth_asymm / face_h # >0.05 = notable asymmetry
# Build soft scores
scores = _uniform_scores() # start uniform
if width_ratio > 0.85 and mouth_below_nose > 0.35:
# Wide mouth, normal vertical position → happy
scores[0] = 0.55 + 0.25 * min(1.0, (width_ratio - 0.85) / 0.5)
elif mouth_below_nose < 0.20 and width_ratio < 0.7:
# Tight, compressed mouth high up → surprised OR angry
scores[3] = 0.35 # surprised
scores[2] = 0.30 # angry
elif asym_ratio > 0.06:
# Asymmetric mouth → disgust or sadness
scores[5] = 0.30 # disgusted
scores[1] = 0.25 # sad
elif width_ratio < 0.65 and mouth_below_nose < 0.30:
# Tight and compressed → sad/angry
scores[1] = 0.35 # sad
scores[2] = 0.25 # angry
else:
# Default to neutral
scores[6] = 0.45
scores = _renorm(scores)
top_idx = _argmax(scores)
return ClassifiedEmotion(
emotion=EMOTIONS[top_idx],
confidence=round(scores[top_idx], 3),
scores=[round(s, 4) for s in scores],
source="landmark_heuristic",
)
def _renorm(scores: List[float]) -> List[float]:
"""Re-normalise scores so they sum to 1.0."""
total = sum(scores)
if total <= 0:
return _uniform_scores()
return [s / total for s in scores]
# ── Public classifier ─────────────────────────────────────────────────────────
class EmotionClassifier:
"""Facade combining TRT inference + landmark fallback + per-person smoothing.
Usage
-----
>>> clf = EmotionClassifier(engine_path="/models/emotion_fp16.trt")
>>> clf.load()
True
>>> result = clf.classify_crop(face_bgr, person_id=42, tracker=tracker)
>>> result.emotion, result.confidence
('happy', 0.87)
"""
def __init__(
self,
engine_path: str = "",
alpha: float = 0.3,
) -> None:
self._engine_path = engine_path
self._alpha = alpha
self._engine = _TrtEngine()
def load(self) -> bool:
"""Load TRT engine. Returns True if engine is ready."""
if not self._engine_path:
return False
return self._engine.load(self._engine_path)
@property
def ready(self) -> bool:
return self._engine.ready
def classify_crop(
self,
face_bgr,
person_id: int = -1,
tracker: Optional[PersonEmotionTracker] = None,
) -> ClassifiedEmotion:
"""Classify a BGR face crop.
Args:
face_bgr: BGR ndarray from cv_bridge / direct crop.
person_id: For temporal smoothing. -1 = no smoothing.
tracker: Optional PersonEmotionTracker for EMA smoothing.
Returns:
ClassifiedEmotion with smoothed scores.
"""
if self._engine.ready:
raw = self._engine.infer(face_bgr)
source = "cnn_trt"
else:
# Uniform fallback — no inference without engine
raw = _uniform_scores()
raw[6] = 0.40 # mild neutral bias
raw = _renorm(raw)
source = "landmark_heuristic"
smoothed = raw
if tracker is not None and person_id >= 0:
smoothed = tracker.smooth(person_id, raw)
top_idx = _argmax(smoothed)
return ClassifiedEmotion(
emotion=EMOTIONS[top_idx],
confidence=round(smoothed[top_idx], 3),
scores=[round(s, 4) for s in smoothed],
source=source,
)
def classify_from_landmarks(
self,
lm10: List[float],
person_id: int = -1,
tracker: Optional[PersonEmotionTracker] = None,
) -> ClassifiedEmotion:
"""Classify using landmark geometry only (no crop required)."""
result = classify_from_landmarks(lm10)
if tracker is not None and person_id >= 0:
smoothed = tracker.smooth(person_id, result.scores)
top_idx = _argmax(smoothed)
result = ClassifiedEmotion(
emotion=EMOTIONS[top_idx],
confidence=round(smoothed[top_idx], 3),
scores=[round(s, 4) for s in smoothed],
source=result.source,
)
return result

View File

@ -0,0 +1,380 @@
"""emotion_node.py — Facial expression recognition node (Issue #161).
Piggybacks on the face detection pipeline: subscribes to
/social/faces/detections (FaceDetectionArray), extracts face crops from the
latest camera frames, runs a TensorRT FP16 emotion CNN, applies per-person
EMA temporal smoothing, and publishes results on /social/faces/expressions.
Architecture
------------
FaceDetectionArray crop extraction TRT FP16 inference (< 5 ms)
EMA smoothing ExpressionArray publish
If TRT engine is not available the node falls back to the 5-point landmark
heuristic (classify_from_landmarks) which requires no GPU and adds < 0.1 ms.
ROS2 topics
-----------
Subscribe:
/social/faces/detections (saltybot_social_msgs/FaceDetectionArray)
/camera/{name}/image_raw (sensor_msgs/Image) × 4
/social/persons (saltybot_social_msgs/PersonStateArray)
Publish:
/social/faces/expressions (saltybot_social_msgs/ExpressionArray)
/social/emotion/context (std_msgs/String) JSON for LLM context
Parameters
----------
engine_path (str) "" path to emotion_fp16.trt; empty = landmark only
min_confidence (float) 0.40 suppress results below this threshold
smoothing_alpha (float) 0.30 EMA weight (higher = faster, less stable)
opt_out_persons (str) "" comma-separated person_ids that opted out
face_min_size (int) 24 skip faces whose bbox is smaller (px side)
landmark_fallback (bool) true use landmark heuristic when TRT unavailable
camera_names (str) "front,left,rear,right"
n_cameras (int) 4
publish_context (bool) true publish /social/emotion/context JSON
"""
from __future__ import annotations
import json
import threading
import time
from typing import Dict, List, Optional, Tuple
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from std_msgs.msg import String
from sensor_msgs.msg import Image
from saltybot_social_msgs.msg import (
FaceDetectionArray,
PersonStateArray,
ExpressionArray,
Expression,
)
from saltybot_social.emotion_classifier import (
EmotionClassifier,
PersonEmotionTracker,
EMOTIONS,
)
try:
from cv_bridge import CvBridge
_CV_BRIDGE_OK = True
except ImportError:
_CV_BRIDGE_OK = False
try:
import cv2
_CV2_OK = True
except ImportError:
_CV2_OK = False
# ── Per-camera latest-frame buffer ────────────────────────────────────────────
class _FrameBuffer:
"""Thread-safe store of the most recent image per camera."""
def __init__(self) -> None:
self._lock = threading.Lock()
self._frames: Dict[int, object] = {} # camera_id → cv2 BGR image
def put(self, camera_id: int, img) -> None:
with self._lock:
self._frames[camera_id] = img
def get(self, camera_id: int):
with self._lock:
return self._frames.get(camera_id)
# ── ROS2 Node ─────────────────────────────────────────────────────────────────
class EmotionNode(Node):
"""Facial expression recognition — TRT FP16 + landmark fallback."""
def __init__(self) -> None:
super().__init__("emotion_node")
# ── Parameters ────────────────────────────────────────────────────────
self.declare_parameter("engine_path", "")
self.declare_parameter("min_confidence", 0.40)
self.declare_parameter("smoothing_alpha", 0.30)
self.declare_parameter("opt_out_persons", "")
self.declare_parameter("face_min_size", 24)
self.declare_parameter("landmark_fallback", True)
self.declare_parameter("camera_names", "front,left,rear,right")
self.declare_parameter("n_cameras", 4)
self.declare_parameter("publish_context", True)
engine_path = self.get_parameter("engine_path").value
self._min_conf = self.get_parameter("min_confidence").value
alpha = self.get_parameter("smoothing_alpha").value
opt_out_str = self.get_parameter("opt_out_persons").value
self._face_min = self.get_parameter("face_min_size").value
self._lm_fallback = self.get_parameter("landmark_fallback").value
cam_names_str = self.get_parameter("camera_names").value
n_cameras = self.get_parameter("n_cameras").value
self._pub_ctx = self.get_parameter("publish_context").value
# ── Classifier + tracker ───────────────────────────────────────────
self._classifier = EmotionClassifier(engine_path=engine_path, alpha=alpha)
self._tracker = PersonEmotionTracker(alpha=alpha)
# Parse opt-out list
for pid_str in opt_out_str.split(","):
pid_str = pid_str.strip()
if pid_str.isdigit():
self._tracker.set_opt_out(int(pid_str), True)
# ── Camera frame buffer + cv_bridge ───────────────────────────────
self._frame_buf = _FrameBuffer()
self._bridge = CvBridge() if _CV_BRIDGE_OK else None
self._cam_names = [n.strip() for n in cam_names_str.split(",")]
self._cam_id_map: Dict[str, int] = {
name: idx for idx, name in enumerate(self._cam_names)
}
# ── Latest persons for person_id ↔ face_id correlation ────────────
self._persons_lock = threading.Lock()
self._face_to_person: Dict[int, int] = {} # face_id → person_id
# ── Context for LLM (latest emotion per person) ───────────────────
self._emotion_context: Dict[int, str] = {} # person_id → emotion label
# ── QoS profiles ──────────────────────────────────────────────────
be_qos = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=4,
)
# ── Subscriptions ──────────────────────────────────────────────────
self.create_subscription(
FaceDetectionArray,
"/social/faces/detections",
self._on_faces,
be_qos,
)
self.create_subscription(
PersonStateArray,
"/social/persons",
self._on_persons,
be_qos,
)
for name in self._cam_names[:n_cameras]:
cam_id = self._cam_id_map.get(name, 0)
self.create_subscription(
Image,
f"/camera/{name}/image_raw",
lambda msg, cid=cam_id: self._on_image(msg, cid),
be_qos,
)
# ── Publishers ─────────────────────────────────────────────────────
self._expr_pub = self.create_publisher(
ExpressionArray, "/social/faces/expressions", be_qos
)
if self._pub_ctx:
self._ctx_pub = self.create_publisher(
String, "/social/emotion/context", 10
)
else:
self._ctx_pub = None
# ── Load TRT engine in background ──────────────────────────────────
threading.Thread(target=self._load_engine, daemon=True).start()
self.get_logger().info(
f"EmotionNode ready — engine='{engine_path}', "
f"alpha={alpha:.2f}, min_conf={self._min_conf:.2f}, "
f"cameras={self._cam_names[:n_cameras]}"
)
# ── Engine loading ────────────────────────────────────────────────────────
def _load_engine(self) -> None:
if not self._classifier._engine_path:
if self._lm_fallback:
self.get_logger().info(
"No TRT engine path — using landmark heuristic fallback"
)
else:
self.get_logger().warn(
"No TRT engine and landmark_fallback=false — "
"emotion_node will not classify"
)
return
ok = self._classifier.load()
if ok:
self.get_logger().info("TRT emotion engine loaded ✓")
else:
self.get_logger().warn(
"TRT engine load failed — falling back to landmark heuristic"
)
# ── Camera frame ingestion ────────────────────────────────────────────────
def _on_image(self, msg: Image, camera_id: int) -> None:
if not _CV_BRIDGE_OK or self._bridge is None:
return
try:
bgr = self._bridge.imgmsg_to_cv2(msg, desired_encoding="bgr8")
self._frame_buf.put(camera_id, bgr)
except Exception as exc:
self.get_logger().debug(f"cv_bridge error cam{camera_id}: {exc}")
# ── Person-state for face_id → person_id mapping ─────────────────────────
def _on_persons(self, msg: PersonStateArray) -> None:
mapping: Dict[int, int] = {}
for ps in msg.persons:
if ps.face_id >= 0:
mapping[ps.face_id] = ps.person_id
with self._persons_lock:
self._face_to_person = mapping
def _get_person_id(self, face_id: int) -> int:
with self._persons_lock:
return self._face_to_person.get(face_id, -1)
# ── Main face-detection callback ──────────────────────────────────────────
def _on_faces(self, msg: FaceDetectionArray) -> None:
if not msg.faces:
return
expressions: List[Expression] = []
for face in msg.faces:
person_id = self._get_person_id(face.face_id)
# Opt-out check
if person_id >= 0 and self._tracker.is_opt_out(person_id):
expr = Expression()
expr.header = msg.header
expr.person_id = person_id
expr.face_id = face.face_id
expr.emotion = ""
expr.confidence = 0.0
expr.scores = [0.0] * 7
expr.is_opt_out = True
expr.source = "opt_out"
expressions.append(expr)
continue
# Try TRT crop classification first
result = None
if self._classifier.ready and _CV2_OK:
result = self._classify_crop(face, person_id)
# Fallback to landmark heuristic
if result is None and self._lm_fallback:
result = self._classifier.classify_from_landmarks(
list(face.landmarks),
person_id=person_id,
tracker=self._tracker,
)
if result is None:
continue
if result.confidence < self._min_conf:
continue
# Build ROS message
expr = Expression()
expr.header = msg.header
expr.person_id = person_id
expr.face_id = face.face_id
expr.emotion = result.emotion
expr.confidence = result.confidence
# Pad/trim scores to exactly 7
sc = result.scores
expr.scores = (sc + [0.0] * 7)[:7]
expr.is_opt_out = False
expr.source = result.source
expressions.append(expr)
# Update LLM context cache
if person_id >= 0:
self._emotion_context[person_id] = result.emotion
if not expressions:
return
# Publish ExpressionArray
arr = ExpressionArray()
arr.header = msg.header
arr.expressions = expressions
self._expr_pub.publish(arr)
# Publish JSON context for conversation node
if self._ctx_pub is not None:
self._publish_context()
self._tracker.tick()
def _classify_crop(self, face, person_id: int):
"""Extract crop from frame buffer and run TRT inference."""
# Resolve camera_id from face header frame_id
frame_id = face.header.frame_id if face.header.frame_id else "front"
cam_name = frame_id.split("/")[-1].split("_")[0] # "front", "left", etc.
camera_id = self._cam_id_map.get(cam_name, 0)
frame = self._frame_buf.get(camera_id)
if frame is None:
return None
h, w = frame.shape[:2]
x1 = max(0, int(face.bbox_x * w))
y1 = max(0, int(face.bbox_y * h))
x2 = min(w, int((face.bbox_x + face.bbox_w) * w))
y2 = min(h, int((face.bbox_y + face.bbox_h) * h))
if (x2 - x1) < self._face_min or (y2 - y1) < self._face_min:
return None
crop = frame[y1:y2, x1:x2]
if crop.size == 0:
return None
return self._classifier.classify_crop(
crop, person_id=person_id, tracker=self._tracker
)
# ── LLM context publisher ─────────────────────────────────────────────────
def _publish_context(self) -> None:
"""Publish latest-emotions dict as JSON for conversation_node."""
ctx = {str(pid): emo for pid, emo in self._emotion_context.items()}
msg = String()
msg.data = json.dumps({"emotions": ctx, "ts": time.time()})
self._ctx_pub.publish(msg)
# ── Entry point ───────────────────────────────────────────────────────────────
def main(args=None) -> None:
rclpy.init(args=args)
node = EmotionNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.try_shutdown()
if __name__ == "__main__":
main()

View File

@ -66,6 +66,7 @@ class SpeechPipelineNode(Node):
self.declare_parameter("use_silero_vad", True) self.declare_parameter("use_silero_vad", True)
self.declare_parameter("whisper_model", "small") self.declare_parameter("whisper_model", "small")
self.declare_parameter("whisper_compute_type", "float16") self.declare_parameter("whisper_compute_type", "float16")
self.declare_parameter("whisper_language", "")
self.declare_parameter("speaker_threshold", 0.65) self.declare_parameter("speaker_threshold", 0.65)
self.declare_parameter("speaker_db_path", "/social_db/speaker_embeddings.json") self.declare_parameter("speaker_db_path", "/social_db/speaker_embeddings.json")
self.declare_parameter("publish_partial", True) self.declare_parameter("publish_partial", True)
@ -78,6 +79,7 @@ class SpeechPipelineNode(Node):
self._use_silero = self.get_parameter("use_silero_vad").value self._use_silero = self.get_parameter("use_silero_vad").value
self._whisper_model_name = self.get_parameter("whisper_model").value self._whisper_model_name = self.get_parameter("whisper_model").value
self._compute_type = self.get_parameter("whisper_compute_type").value self._compute_type = self.get_parameter("whisper_compute_type").value
self._whisper_language = self.get_parameter("whisper_language").value or None
self._speaker_thresh = self.get_parameter("speaker_threshold").value self._speaker_thresh = self.get_parameter("speaker_threshold").value
self._speaker_db = self.get_parameter("speaker_db_path").value self._speaker_db = self.get_parameter("speaker_db_path").value
self._publish_partial = self.get_parameter("publish_partial").value self._publish_partial = self.get_parameter("publish_partial").value
@ -315,20 +317,24 @@ class SpeechPipelineNode(Node):
except Exception as e: except Exception as e:
self.get_logger().debug(f"Speaker ID error: {e}") self.get_logger().debug(f"Speaker ID error: {e}")
# Streaming Whisper transcription # Streaming Whisper transcription with language detection
partial_text = "" partial_text = ""
detected_lang = self._whisper_language or "en"
try: try:
segments_gen, _info = self._whisper.transcribe( segments_gen, info = self._whisper.transcribe(
audio_np, audio_np,
language="en", language=self._whisper_language, # None = auto-detect
beam_size=3, beam_size=3,
vad_filter=False, vad_filter=False,
) )
if hasattr(info, "language") and info.language:
detected_lang = info.language
for seg in segments_gen: for seg in segments_gen:
partial_text += seg.text.strip() + " " partial_text += seg.text.strip() + " "
if self._publish_partial: if self._publish_partial:
self._publish_transcript( self._publish_transcript(
partial_text.strip(), speaker_id, 0.0, duration, is_partial=True partial_text.strip(), speaker_id, 0.0, duration,
language=detected_lang, is_partial=True,
) )
except Exception as e: except Exception as e:
self.get_logger().error(f"Whisper error: {e}") self.get_logger().error(f"Whisper error: {e}")
@ -340,15 +346,19 @@ class SpeechPipelineNode(Node):
latency_ms = (time.perf_counter() - t0) * 1000 latency_ms = (time.perf_counter() - t0) * 1000
self.get_logger().info( self.get_logger().info(
f"STT [{speaker_id}] ({duration:.1f}s, {latency_ms:.0f}ms): '{final_text}'" f"STT [{speaker_id}/{detected_lang}] ({duration:.1f}s, {latency_ms:.0f}ms): "
f"'{final_text}'"
)
self._publish_transcript(
final_text, speaker_id, 0.9, duration,
language=detected_lang, is_partial=False,
) )
self._publish_transcript(final_text, speaker_id, 0.9, duration, is_partial=False)
# ── Publishers ──────────────────────────────────────────────────────────── # ── Publishers ────────────────────────────────────────────────────────────
def _publish_transcript( def _publish_transcript(
self, text: str, speaker_id: str, confidence: float, self, text: str, speaker_id: str, confidence: float,
duration: float, is_partial: bool duration: float, language: str = "en", is_partial: bool = False,
) -> None: ) -> None:
msg = SpeechTranscript() msg = SpeechTranscript()
msg.header.stamp = self.get_clock().now().to_msg() msg.header.stamp = self.get_clock().now().to_msg()
@ -356,6 +366,7 @@ class SpeechPipelineNode(Node):
msg.speaker_id = speaker_id msg.speaker_id = speaker_id
msg.confidence = confidence msg.confidence = confidence
msg.audio_duration = duration msg.audio_duration = duration
msg.language = language
msg.is_partial = is_partial msg.is_partial = is_partial
self._transcript_pub.publish(msg) self._transcript_pub.publish(msg)

View File

@ -1,228 +1,136 @@
"""tts_node.py — Streaming TTS with Piper / first-chunk streaming. """tts_node.py -- Streaming TTS with Piper / first-chunk streaming.
Issue #85/#167
Issue #85: Streaming TTS — Piper/XTTS integration with first-chunk streaming.
Pipeline:
/social/conversation/response (ConversationResponse)
sentence split Piper ONNX synthesis (sentence by sentence)
PCM16 chunks USB speaker (sounddevice) + publish /social/tts/audio
First-chunk strategy:
- On partial=true ConversationResponse, extract first sentence and synthesize
immediately audio starts before LLM finishes generating
- On final=false, synthesize remaining sentences
Latency target: <200ms to first audio chunk.
ROS2 topics:
Subscribe: /social/conversation/response (saltybot_social_msgs/ConversationResponse)
Publish: /social/tts/audio (audio_msgs/Audio or std_msgs/UInt8MultiArray fallback)
Parameters:
voice_path (str, "/models/piper/en_US-lessac-medium.onnx")
sample_rate (int, 22050)
volume (float, 1.0)
audio_device (str, "") sounddevice device name; "" = system default
playback_enabled (bool, true)
publish_audio (bool, false) publish PCM to ROS2 topic
sentence_streaming (bool, true) synthesize sentence-by-sentence
""" """
from __future__ import annotations from __future__ import annotations
import json, queue, threading, time
import queue from typing import Any, Dict, Optional
import threading
import time
from typing import Optional
import rclpy import rclpy
from rclpy.node import Node from rclpy.node import Node
from rclpy.qos import QoSProfile from rclpy.qos import QoSProfile
from std_msgs.msg import UInt8MultiArray from std_msgs.msg import UInt8MultiArray
from saltybot_social_msgs.msg import ConversationResponse from saltybot_social_msgs.msg import ConversationResponse
from .tts_utils import split_sentences, strip_ssml, apply_volume, chunk_pcm, estimate_duration_ms from .tts_utils import split_sentences, strip_ssml, apply_volume, chunk_pcm, estimate_duration_ms
class TtsNode(Node): class TtsNode(Node):
"""Streaming TTS node using Piper ONNX.""" """Streaming TTS node using Piper ONNX with per-language voice switching."""
def __init__(self) -> None: def __init__(self) -> None:
super().__init__("tts_node") super().__init__("tts_node")
# ── Parameters ──────────────────────────────────────────────────────
self.declare_parameter("voice_path", "/models/piper/en_US-lessac-medium.onnx") self.declare_parameter("voice_path", "/models/piper/en_US-lessac-medium.onnx")
self.declare_parameter("voice_map_json", "{}")
self.declare_parameter("default_language", "en")
self.declare_parameter("sample_rate", 22050) self.declare_parameter("sample_rate", 22050)
self.declare_parameter("volume", 1.0) self.declare_parameter("volume", 1.0)
self.declare_parameter("audio_device", "") self.declare_parameter("audio_device", "")
self.declare_parameter("playback_enabled", True) self.declare_parameter("playback_enabled", True)
self.declare_parameter("publish_audio", False) self.declare_parameter("publish_audio", False)
self.declare_parameter("sentence_streaming", True) self.declare_parameter("sentence_streaming", True)
self._voice_path = self.get_parameter("voice_path").value self._voice_path = self.get_parameter("voice_path").value
self._voice_map_json = self.get_parameter("voice_map_json").value
self._default_language = self.get_parameter("default_language").value or "en"
self._sample_rate = self.get_parameter("sample_rate").value self._sample_rate = self.get_parameter("sample_rate").value
self._volume = self.get_parameter("volume").value self._volume = self.get_parameter("volume").value
self._audio_device = self.get_parameter("audio_device").value or None self._audio_device = self.get_parameter("audio_device").value or None
self._playback = self.get_parameter("playback_enabled").value self._playback = self.get_parameter("playback_enabled").value
self._publish_audio = self.get_parameter("publish_audio").value self._publish_audio = self.get_parameter("publish_audio").value
self._sentence_streaming = self.get_parameter("sentence_streaming").value self._sentence_streaming = self.get_parameter("sentence_streaming").value
try:
# ── Publishers / Subscribers ───────────────────────────────────────── extra: Dict[str, str] = json.loads(self._voice_map_json) if self._voice_map_json.strip() not in ("{}","") else {}
except Exception as e:
self.get_logger().warn(f"voice_map_json parse error: {e}"); extra = {}
self._voice_paths: Dict[str, str] = {self._default_language: self._voice_path}
self._voice_paths.update(extra)
qos = QoSProfile(depth=10) qos = QoSProfile(depth=10)
self._resp_sub = self.create_subscription( self._resp_sub = self.create_subscription(ConversationResponse, "/social/conversation/response", self._on_response, qos)
ConversationResponse, "/social/conversation/response",
self._on_response, qos
)
if self._publish_audio: if self._publish_audio:
self._audio_pub = self.create_publisher( self._audio_pub = self.create_publisher(UInt8MultiArray, "/social/tts/audio", qos)
UInt8MultiArray, "/social/tts/audio", qos self._voices: Dict[str, Any] = {}
) self._voices_lock = threading.Lock()
# ── TTS engine ────────────────────────────────────────────────────────
self._voice = None
self._playback_queue: queue.Queue = queue.Queue(maxsize=16) self._playback_queue: queue.Queue = queue.Queue(maxsize=16)
self._current_turn = -1 self._current_turn = -1
self._synthesized_turns: set = set() # turn_ids already synthesized self._synthesized_turns: set = set()
self._lock = threading.Lock() self._lock = threading.Lock()
threading.Thread(target=self._load_voice_for_lang, args=(self._default_language,), daemon=True).start()
threading.Thread(target=self._load_voice, daemon=True).start()
threading.Thread(target=self._playback_worker, daemon=True).start() threading.Thread(target=self._playback_worker, daemon=True).start()
self.get_logger().info(f"TtsNode init (langs={list(self._voice_paths.keys())})")
self.get_logger().info( def _load_voice_for_lang(self, lang: str) -> None:
f"TtsNode init (voice={self._voice_path}, " path = self._voice_paths.get(lang)
f"streaming={self._sentence_streaming})" if not path:
) self.get_logger().warn(f"No voice for '{lang}', fallback to '{self._default_language}'"); return
with self._voices_lock:
# ── Voice loading ───────────────────────────────────────────────────────── if lang in self._voices: return
def _load_voice(self) -> None:
t0 = time.time()
self.get_logger().info(f"Loading Piper voice: {self._voice_path}")
try: try:
from piper import PiperVoice from piper import PiperVoice
self._voice = PiperVoice.load(self._voice_path) voice = PiperVoice.load(path)
# Warmup synthesis to pre-JIT ONNX graph list(voice.synthesize_stream_raw("Hello."))
warmup_text = "Hello." with self._voices_lock: self._voices[lang] = voice
list(self._voice.synthesize_stream_raw(warmup_text)) self.get_logger().info(f"Piper [{lang}] ready")
self.get_logger().info(f"Piper voice ready ({time.time()-t0:.1f}s)")
except Exception as e: except Exception as e:
self.get_logger().error(f"Piper voice load failed: {e}") self.get_logger().error(f"Piper voice load failed [{lang}]: {e}")
# ── Response handler ────────────────────────────────────────────────────── def _get_voice(self, lang: str):
with self._voices_lock:
v = self._voices.get(lang)
if v is not None: return v
if lang in self._voice_paths:
threading.Thread(target=self._load_voice_for_lang, args=(lang,), daemon=True).start()
return self._voices.get(self._default_language)
def _on_response(self, msg: ConversationResponse) -> None: def _on_response(self, msg: ConversationResponse) -> None:
"""Handle streaming LLM response — synthesize sentence by sentence.""" if not msg.text.strip(): return
if not msg.text.strip(): lang = msg.language if msg.language else self._default_language
return
with self._lock: with self._lock:
is_new_turn = msg.turn_id != self._current_turn if msg.turn_id != self._current_turn:
if is_new_turn: self._current_turn = msg.turn_id; self._synthesized_turns = set()
self._current_turn = msg.turn_id
# Clear old synthesized sentence cache for this new turn
self._synthesized_turns = set()
text = strip_ssml(msg.text) text = strip_ssml(msg.text)
if self._sentence_streaming: if self._sentence_streaming:
sentences = split_sentences(text) for sentence in split_sentences(text):
for sentence in sentences:
# Track which sentences we've already queued by content hash
key = (msg.turn_id, hash(sentence)) key = (msg.turn_id, hash(sentence))
with self._lock: with self._lock:
if key in self._synthesized_turns: if key in self._synthesized_turns: continue
continue
self._synthesized_turns.add(key) self._synthesized_turns.add(key)
self._queue_synthesis(sentence) self._queue_synthesis(sentence, lang)
elif not msg.is_partial: elif not msg.is_partial:
# Non-streaming: synthesize full response at end self._queue_synthesis(text, lang)
self._queue_synthesis(text)
def _queue_synthesis(self, text: str) -> None: def _queue_synthesis(self, text: str, lang: str) -> None:
"""Queue a text segment for synthesis in the playback worker.""" if not text.strip(): return
if not text.strip(): try: self._playback_queue.put_nowait((text.strip(), lang))
return except queue.Full: self.get_logger().warn("TTS queue full")
try:
self._playback_queue.put_nowait(text.strip())
except queue.Full:
self.get_logger().warn("TTS playback queue full, dropping segment")
# ── Playback worker ───────────────────────────────────────────────────────
def _playback_worker(self) -> None: def _playback_worker(self) -> None:
"""Consume synthesis queue: synthesize → play → publish."""
while rclpy.ok(): while rclpy.ok():
try: try: item = self._playback_queue.get(timeout=0.5)
text = self._playback_queue.get(timeout=0.5) except queue.Empty: continue
except queue.Empty: text, lang = item
continue voice = self._get_voice(lang)
if voice is None:
if self._voice is None: self.get_logger().warn(f"No voice for '{lang}'"); self._playback_queue.task_done(); continue
self.get_logger().warn("TTS voice not loaded yet")
self._playback_queue.task_done()
continue
t0 = time.perf_counter() t0 = time.perf_counter()
pcm_data = self._synthesize(text) pcm_data = self._synthesize(text, voice)
if pcm_data is None: if pcm_data is None: self._playback_queue.task_done(); continue
self._playback_queue.task_done() if self._volume != 1.0: pcm_data = apply_volume(pcm_data, self._volume)
continue if self._playback: self._play_audio(pcm_data)
if self._publish_audio: self._publish_pcm(pcm_data)
synth_ms = (time.perf_counter() - t0) * 1000
dur_ms = estimate_duration_ms(pcm_data, self._sample_rate)
self.get_logger().debug(
f"TTS '{text[:40]}' synth={synth_ms:.0f}ms, dur={dur_ms:.0f}ms"
)
if self._volume != 1.0:
pcm_data = apply_volume(pcm_data, self._volume)
if self._playback:
self._play_audio(pcm_data)
if self._publish_audio:
self._publish_pcm(pcm_data)
self._playback_queue.task_done() self._playback_queue.task_done()
def _synthesize(self, text: str) -> Optional[bytes]: def _synthesize(self, text: str, voice) -> Optional[bytes]:
"""Synthesize text to PCM16 bytes using Piper streaming.""" try: return b"".join(voice.synthesize_stream_raw(text))
if self._voice is None: except Exception as e: self.get_logger().error(f"TTS error: {e}"); return None
return None
try:
chunks = list(self._voice.synthesize_stream_raw(text))
return b"".join(chunks)
except Exception as e:
self.get_logger().error(f"TTS synthesis error: {e}")
return None
def _play_audio(self, pcm_data: bytes) -> None: def _play_audio(self, pcm_data: bytes) -> None:
"""Play PCM16 data on USB speaker via sounddevice."""
try: try:
import sounddevice as sd import sounddevice as sd, numpy as np
import numpy as np sd.play(np.frombuffer(pcm_data,dtype=np.int16).astype(np.float32)/32768.0, samplerate=self._sample_rate, device=self._audio_device, blocking=True)
samples = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0 except Exception as e: self.get_logger().error(f"Playback error: {e}")
sd.play(samples, samplerate=self._sample_rate, device=self._audio_device,
blocking=True)
except Exception as e:
self.get_logger().error(f"Audio playback error: {e}")
def _publish_pcm(self, pcm_data: bytes) -> None: def _publish_pcm(self, pcm_data: bytes) -> None:
"""Publish PCM data as UInt8MultiArray.""" if not hasattr(self,"_audio_pub"): return
if not hasattr(self, "_audio_pub"): msg = UInt8MultiArray(); msg.data = list(pcm_data); self._audio_pub.publish(msg)
return
msg = UInt8MultiArray()
msg.data = list(pcm_data)
self._audio_pub.publish(msg)
def main(args=None) -> None: def main(args=None) -> None:
rclpy.init(args=args) rclpy.init(args=args)
node = TtsNode() node = TtsNode()
try: try: rclpy.spin(node)
rclpy.spin(node) except KeyboardInterrupt: pass
except KeyboardInterrupt: finally: node.destroy_node(); rclpy.shutdown()
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@ -37,6 +37,8 @@ setup(
'voice_command_node = saltybot_social.voice_command_node:main', 'voice_command_node = saltybot_social.voice_command_node:main',
# Multi-camera gesture recognition (Issue #140) # Multi-camera gesture recognition (Issue #140)
'gesture_node = saltybot_social.gesture_node:main', 'gesture_node = saltybot_social.gesture_node:main',
# Facial expression recognition (Issue #161)
'emotion_node = saltybot_social.emotion_node:main',
], ],
}, },
) )

View File

@ -0,0 +1,528 @@
"""test_emotion_classifier.py — Unit tests for emotion_classifier (Issue #161).
Tests cover:
- Emotion constant definitions
- Softmax normalisation
- PersonEmotionTracker: EMA smoothing, age expiry, opt-out, reset
- Landmark heuristic classifier: geometric edge cases
- EmotionClassifier: classify_crop fallback, classify_from_landmarks
- Score renormalisation, argmax, utility helpers
- Integration: full classify pipeline with mock engine
No ROS2, TensorRT, or OpenCV runtime required.
"""
import math
import pytest
from saltybot_social.emotion_classifier import (
EMOTIONS,
N_CLASSES,
ClassifiedEmotion,
PersonEmotionTracker,
EmotionClassifier,
classify_from_landmarks,
_softmax,
_argmax,
_uniform_scores,
_renorm,
)
# ── Helpers ───────────────────────────────────────────────────────────────────
def _lm_neutral() -> list:
"""5-point SCRFD landmarks for a frontal neutral face (normalised)."""
return [
0.35, 0.38, # left_eye
0.65, 0.38, # right_eye
0.50, 0.52, # nose
0.38, 0.72, # left_mouth
0.62, 0.72, # right_mouth
]
def _lm_happy() -> list:
"""Wide mouth, symmetric corners, mouth well below nose."""
return [
0.35, 0.38,
0.65, 0.38,
0.50, 0.52,
0.30, 0.74, # wide mouth
0.70, 0.74,
]
def _lm_sad() -> list:
"""Compressed mouth, tight width."""
return [
0.35, 0.38,
0.65, 0.38,
0.50, 0.52,
0.44, 0.62, # tight, close to nose
0.56, 0.62,
]
def _lm_asymmetric() -> list:
"""Asymmetric mouth corners → disgust/sad signal."""
return [
0.35, 0.38,
0.65, 0.38,
0.50, 0.52,
0.38, 0.65,
0.62, 0.74, # right corner lower → asymmetric
]
# ── TestEmotionConstants ──────────────────────────────────────────────────────
class TestEmotionConstants:
def test_emotions_length(self):
assert len(EMOTIONS) == 7
def test_n_classes(self):
assert N_CLASSES == 7
def test_emotions_labels(self):
expected = ["happy", "sad", "angry", "surprised", "fearful", "disgusted", "neutral"]
assert EMOTIONS == expected
def test_happy_index(self):
assert EMOTIONS[0] == "happy"
def test_neutral_index(self):
assert EMOTIONS[6] == "neutral"
def test_all_lowercase(self):
for e in EMOTIONS:
assert e == e.lower()
# ── TestSoftmax ───────────────────────────────────────────────────────────────
class TestSoftmax:
def test_sums_to_one(self):
logits = [1.0, 2.0, 0.5, -1.0, 3.0, 0.0, 1.5]
result = _softmax(logits)
assert abs(sum(result) - 1.0) < 1e-6
def test_all_positive(self):
logits = [1.0, 2.0, 3.0, 0.0, -1.0, -2.0, 0.5]
result = _softmax(logits)
assert all(s > 0 for s in result)
def test_max_logit_gives_max_prob(self):
logits = [0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0]
result = _softmax(logits)
assert _argmax(result) == 2
def test_uniform_logits_uniform_probs(self):
logits = [1.0] * 7
result = _softmax(logits)
for p in result:
assert abs(p - 1.0 / 7.0) < 1e-6
def test_length_preserved(self):
logits = [0.0] * 7
assert len(_softmax(logits)) == 7
def test_numerically_stable_large_values(self):
logits = [1000.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
result = _softmax(logits)
assert abs(result[0] - 1.0) < 1e-6
# ── TestArgmax ────────────────────────────────────────────────────────────────
class TestArgmax:
def test_finds_max(self):
assert _argmax([0.1, 0.2, 0.9, 0.3, 0.1, 0.1, 0.1]) == 2
def test_single_element(self):
assert _argmax([0.5]) == 0
def test_first_when_all_equal(self):
result = _argmax([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
assert result == 0
# ── TestUniformScores ─────────────────────────────────────────────────────────
class TestUniformScores:
def test_length(self):
assert len(_uniform_scores()) == N_CLASSES
def test_sums_to_one(self):
assert abs(sum(_uniform_scores()) - 1.0) < 1e-9
def test_all_equal(self):
s = _uniform_scores()
assert all(abs(x - s[0]) < 1e-9 for x in s)
# ── TestRenorm ────────────────────────────────────────────────────────────────
class TestRenorm:
def test_sums_to_one(self):
s = [0.1, 0.2, 0.3, 0.1, 0.1, 0.1, 0.1]
r = _renorm(s)
assert abs(sum(r) - 1.0) < 1e-9
def test_all_zero_returns_uniform(self):
r = _renorm([0.0] * 7)
assert len(r) == 7
assert abs(sum(r) - 1.0) < 1e-9
def test_preserves_order(self):
s = [0.5, 0.3, 0.0, 0.0, 0.0, 0.0, 0.2]
r = _renorm(s)
assert r[0] > r[1] > r[6]
# ── TestPersonEmotionTracker ──────────────────────────────────────────────────
class TestPersonEmotionTracker:
def _scores(self, idx: int, val: float = 0.8) -> list:
s = [0.0] * N_CLASSES
s[idx] = val
total = val + (N_CLASSES - 1) * 0.02
s = [v if i == idx else 0.02 for i, v in enumerate(s)]
s[idx] = val
t = sum(s)
return [x / t for x in s]
def test_first_call_returns_input(self):
tracker = PersonEmotionTracker(alpha=0.5)
scores = self._scores(0)
result = tracker.smooth(1, scores)
assert abs(result[0] - scores[0]) < 1e-9
def test_ema_converges_toward_new_dominant(self):
tracker = PersonEmotionTracker(alpha=0.5)
happy = self._scores(0)
sad = self._scores(1)
tracker.smooth(1, happy)
# Push sad repeatedly
prev_sad = tracker.smooth(1, sad)[1]
for _ in range(20):
result = tracker.smooth(1, sad)
assert result[1] > prev_sad # sad score increased
def test_alpha_1_no_smoothing(self):
tracker = PersonEmotionTracker(alpha=1.0)
s1 = self._scores(0)
s2 = self._scores(1)
tracker.smooth(1, s1)
result = tracker.smooth(1, s2)
for a, b in zip(result, s2):
assert abs(a - b) < 1e-9
def test_alpha_0_frozen(self):
tracker = PersonEmotionTracker(alpha=0.0)
s1 = self._scores(0)
s2 = self._scores(1)
tracker.smooth(1, s1)
result = tracker.smooth(1, s2)
for a, b in zip(result, s1):
assert abs(a - b) < 1e-9
def test_different_persons_independent(self):
tracker = PersonEmotionTracker(alpha=0.5)
happy = self._scores(0)
sad = self._scores(1)
tracker.smooth(1, happy)
tracker.smooth(2, sad)
r1 = tracker.smooth(1, happy)
r2 = tracker.smooth(2, sad)
assert r1[0] > r2[0] # person 1 more happy
assert r2[1] > r1[1] # person 2 more sad
def test_negative_person_id_no_tracking(self):
tracker = PersonEmotionTracker()
scores = self._scores(2)
result = tracker.smooth(-1, scores)
assert result == scores # unchanged, not stored
def test_reset_clears_ema(self):
tracker = PersonEmotionTracker(alpha=0.5)
s1 = self._scores(0)
tracker.smooth(1, s1)
tracker.reset(1)
assert 1 not in tracker.tracked_ids
def test_reset_all_clears_all(self):
tracker = PersonEmotionTracker(alpha=0.5)
for pid in range(5):
tracker.smooth(pid, self._scores(pid % N_CLASSES))
tracker.reset_all()
assert tracker.tracked_ids == []
def test_tracked_ids_populated(self):
tracker = PersonEmotionTracker()
tracker.smooth(10, self._scores(0))
tracker.smooth(20, self._scores(1))
assert set(tracker.tracked_ids) == {10, 20}
def test_age_expiry_resets_ema(self):
tracker = PersonEmotionTracker(alpha=0.5, max_age=3)
tracker.smooth(1, self._scores(0))
# Advance age beyond max_age
for _ in range(4):
tracker.tick()
# Next smooth after expiry should reset (first call returns input unchanged)
fresh_scores = self._scores(1)
result = tracker.smooth(1, fresh_scores)
# After reset, result should equal fresh_scores exactly
for a, b in zip(result, fresh_scores):
assert abs(a - b) < 1e-9
# ── TestOptOut ────────────────────────────────────────────────────────────────
class TestOptOut:
def test_set_opt_out_true(self):
tracker = PersonEmotionTracker()
tracker.set_opt_out(42, True)
assert tracker.is_opt_out(42)
def test_set_opt_out_false(self):
tracker = PersonEmotionTracker()
tracker.set_opt_out(42, True)
tracker.set_opt_out(42, False)
assert not tracker.is_opt_out(42)
def test_opt_out_clears_ema(self):
tracker = PersonEmotionTracker()
tracker.smooth(42, [0.5, 0.1, 0.1, 0.1, 0.1, 0.0, 0.1])
assert 42 in tracker.tracked_ids
tracker.set_opt_out(42, True)
assert 42 not in tracker.tracked_ids
def test_unknown_person_not_opted_out(self):
tracker = PersonEmotionTracker()
assert not tracker.is_opt_out(99)
def test_multiple_opt_outs(self):
tracker = PersonEmotionTracker()
for pid in [1, 2, 3]:
tracker.set_opt_out(pid, True)
for pid in [1, 2, 3]:
assert tracker.is_opt_out(pid)
# ── TestClassifyFromLandmarks ─────────────────────────────────────────────────
class TestClassifyFromLandmarks:
def test_returns_classified_emotion(self):
result = classify_from_landmarks(_lm_neutral())
assert isinstance(result, ClassifiedEmotion)
def test_emotion_is_valid_label(self):
result = classify_from_landmarks(_lm_neutral())
assert result.emotion in EMOTIONS
def test_scores_length(self):
result = classify_from_landmarks(_lm_neutral())
assert len(result.scores) == N_CLASSES
def test_scores_sum_to_one(self):
result = classify_from_landmarks(_lm_neutral())
# Scores are rounded to 4 dp; allow 1e-3 accumulation across 7 terms
assert abs(sum(result.scores) - 1.0) < 1e-3
def test_confidence_matches_top_score(self):
result = classify_from_landmarks(_lm_neutral())
# confidence is round(score, 3) and max score is round(s, 4) → ≤0.5e-3 diff
assert abs(result.confidence - max(result.scores)) < 5e-3
def test_source_is_landmark_heuristic(self):
result = classify_from_landmarks(_lm_neutral())
assert result.source == "landmark_heuristic"
def test_happy_landmarks_boost_happy(self):
happy = classify_from_landmarks(_lm_happy())
neutral = classify_from_landmarks(_lm_neutral())
# Happy landmarks should give relatively higher happy score
assert happy.scores[0] >= neutral.scores[0]
def test_sad_landmarks_suppress_happy(self):
sad_result = classify_from_landmarks(_lm_sad())
# Happy score should be relatively low for sad landmarks
assert sad_result.scores[0] < 0.5
def test_asymmetric_mouth_non_neutral(self):
asym = classify_from_landmarks(_lm_asymmetric())
# Asymmetric → disgust or sad should be elevated
assert asym.scores[5] > 0.10 or asym.scores[1] > 0.10
def test_empty_landmarks_returns_neutral(self):
result = classify_from_landmarks([])
assert result.emotion == "neutral"
def test_short_landmarks_returns_neutral(self):
result = classify_from_landmarks([0.5, 0.5])
assert result.emotion == "neutral"
def test_all_positive_scores(self):
result = classify_from_landmarks(_lm_happy())
assert all(s >= 0.0 for s in result.scores)
def test_confidence_in_range(self):
result = classify_from_landmarks(_lm_neutral())
assert 0.0 <= result.confidence <= 1.0
# ── TestEmotionClassifier ─────────────────────────────────────────────────────
class TestEmotionClassifier:
def test_init_no_engine_not_ready(self):
clf = EmotionClassifier(engine_path="")
assert not clf.ready
def test_load_empty_path_returns_false(self):
clf = EmotionClassifier(engine_path="")
assert clf.load() is False
def test_load_nonexistent_path_returns_false(self):
clf = EmotionClassifier(engine_path="/nonexistent/engine.trt")
result = clf.load()
assert result is False
def test_classify_crop_fallback_without_engine(self):
"""Without TRT engine, classify_crop should return a valid result
with landmark heuristic source."""
clf = EmotionClassifier(engine_path="")
# Build a minimal synthetic 48x48 BGR image
try:
import numpy as np
fake_crop = np.zeros((48, 48, 3), dtype="uint8")
result = clf.classify_crop(fake_crop, person_id=-1, tracker=None)
assert isinstance(result, ClassifiedEmotion)
assert result.emotion in EMOTIONS
assert 0.0 <= result.confidence <= 1.0
assert len(result.scores) == N_CLASSES
except ImportError:
pytest.skip("numpy not available")
def test_classify_from_landmarks_delegates(self):
clf = EmotionClassifier(engine_path="")
result = clf.classify_from_landmarks(_lm_happy())
assert isinstance(result, ClassifiedEmotion)
assert result.source == "landmark_heuristic"
def test_classify_from_landmarks_with_tracker_smooths(self):
clf = EmotionClassifier(engine_path="", alpha=0.5)
tracker = PersonEmotionTracker(alpha=0.5)
r1 = clf.classify_from_landmarks(_lm_happy(), person_id=1, tracker=tracker)
r2 = clf.classify_from_landmarks(_lm_sad(), person_id=1, tracker=tracker)
# After smoothing, r2's top score should not be identical to raw sad scores
# (EMA blends r1 history into r2)
raw_sad = classify_from_landmarks(_lm_sad())
assert r2.scores != raw_sad.scores
def test_classify_from_landmarks_no_tracker_no_smooth(self):
clf = EmotionClassifier(engine_path="")
r1 = clf.classify_from_landmarks(_lm_happy(), person_id=1, tracker=None)
r2 = clf.classify_from_landmarks(_lm_happy(), person_id=1, tracker=None)
# Without tracker, same input → same output
assert r1.scores == r2.scores
def test_source_fallback_when_no_engine(self):
try:
import numpy as np
except ImportError:
pytest.skip("numpy not available")
clf = EmotionClassifier(engine_path="")
crop = __import__("numpy").zeros((48, 48, 3), dtype="uint8")
result = clf.classify_crop(crop)
assert result.source == "landmark_heuristic"
def test_classifier_alpha_propagated_to_tracker(self):
clf = EmotionClassifier(engine_path="", alpha=0.1)
assert clf._alpha == 0.1
# ── TestClassifiedEmotionDataclass ───────────────────────────────────────────
class TestClassifiedEmotionDataclass:
def test_fields(self):
ce = ClassifiedEmotion(
emotion="happy",
confidence=0.85,
scores=[0.85, 0.02, 0.02, 0.03, 0.02, 0.02, 0.04],
source="cnn_trt",
)
assert ce.emotion == "happy"
assert ce.confidence == 0.85
assert len(ce.scores) == 7
assert ce.source == "cnn_trt"
def test_default_source(self):
ce = ClassifiedEmotion("neutral", 0.6, [0.0] * 7)
assert ce.source == "cnn_trt"
def test_emotion_is_mutable(self):
ce = ClassifiedEmotion("neutral", 0.6, [0.0] * 7)
ce.emotion = "happy"
assert ce.emotion == "happy"
# ── TestEdgeCases ─────────────────────────────────────────────────────────────
class TestEdgeCases:
def test_softmax_single_element(self):
result = _softmax([5.0])
assert abs(result[0] - 1.0) < 1e-9
def test_tracker_negative_id_no_stored_state(self):
tracker = PersonEmotionTracker()
scores = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.4]
tracker.smooth(-1, scores)
assert tracker.tracked_ids == []
def test_tick_increments_age(self):
tracker = PersonEmotionTracker(max_age=2)
tracker.smooth(1, [0.1] * 7)
tracker.tick()
tracker.tick()
tracker.tick()
# Age should exceed max_age → next smooth resets
fresh = [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]
result = tracker.smooth(1, fresh)
assert abs(result[2] - 1.0) < 1e-9
def test_renorm_negative_scores_safe(self):
# Negative scores shouldn't crash (though unusual in practice)
scores = [0.1, 0.2, 0.0, 0.0, 0.0, 0.0, 0.3]
r = _renorm(scores)
assert abs(sum(r) - 1.0) < 1e-9
def test_landmark_heuristic_very_close_eye_mouth(self):
# Degenerate face where everything is at same y → should not crash
lm = [0.3, 0.5, 0.7, 0.5, 0.5, 0.5, 0.4, 0.5, 0.6, 0.5]
result = classify_from_landmarks(lm)
assert result.emotion in EMOTIONS
def test_opt_out_then_re_enable(self):
tracker = PersonEmotionTracker()
tracker.smooth(5, [0.1] * 7)
tracker.set_opt_out(5, True)
assert tracker.is_opt_out(5)
tracker.set_opt_out(5, False)
assert not tracker.is_opt_out(5)
# Should be able to smooth again after re-enable
result = tracker.smooth(5, [0.2] * 7)
assert len(result) == N_CLASSES

View File

@ -0,0 +1,122 @@
"""test_multilang.py -- Unit tests for Issue #167 multi-language support."""
from __future__ import annotations
import json, os
from typing import Any, Dict, Optional
import pytest
def _pkg_root():
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def _read_src(rel_path):
with open(os.path.join(_pkg_root(), rel_path)) as f:
return f.read()
def _extract_lang_names():
import ast
src = _read_src("saltybot_social/conversation_node.py")
start = src.index("_LANG_NAMES: Dict[str, str] = {")
end = src.index("\n}", start) + 2
return ast.literal_eval(src[start:end].split("=",1)[1].strip())
class TestLangNames:
@pytest.fixture(scope="class")
def ln(self): return _extract_lang_names()
def test_english(self, ln): assert ln["en"] == "English"
def test_french(self, ln): assert ln["fr"] == "French"
def test_spanish(self, ln): assert ln["es"] == "Spanish"
def test_german(self, ln): assert ln["de"] == "German"
def test_japanese(self, ln):assert ln["ja"] == "Japanese"
def test_chinese(self, ln): assert ln["zh"] == "Chinese"
def test_arabic(self, ln): assert ln["ar"] == "Arabic"
def test_at_least_15(self, ln): assert len(ln) >= 15
def test_lowercase_keys(self, ln):
for k in ln: assert k == k.lower() and 2 <= len(k) <= 3
def test_nonempty_values(self, ln):
for k, v in ln.items(): assert v
class TestLanguageHint:
def _h(self, sl, sid, ln):
lang = sl.get(sid, "en")
return f"[Please respond in {ln.get(lang, lang)}.]" if lang and lang != "en" else ""
@pytest.fixture(scope="class")
def ln(self): return _extract_lang_names()
def test_english_no_hint(self, ln): assert self._h({"p": "en"}, "p", ln) == ""
def test_unknown_no_hint(self, ln): assert self._h({}, "p", ln) == ""
def test_french(self, ln): assert self._h({"p":"fr"},"p",ln) == "[Please respond in French.]"
def test_spanish(self, ln): assert self._h({"p":"es"},"p",ln) == "[Please respond in Spanish.]"
def test_unknown_code(self, ln): assert "xx" in self._h({"p":"xx"},"p",ln)
def test_brackets(self, ln):
h = self._h({"p":"de"},"p",ln)
assert h.startswith("[") and h.endswith("]")
class TestVoiceMap:
def _parse(self, jstr, dl, dp):
try: extra = json.loads(jstr) if jstr.strip() not in ("{}","") else {}
except: extra = {}
r = {dl: dp}; r.update(extra); return r
def test_empty(self): assert self._parse("{}","en","/e") == {"en":"/e"}
def test_extra(self):
vm = self._parse('{"fr":"/f"}', "en", "/e")
assert vm["fr"] == "/f"
def test_invalid(self): assert self._parse("BAD","en","/e") == {"en":"/e"}
def test_multi(self):
assert len(self._parse(json.dumps({"fr":"/f","es":"/s"}),"en","/e")) == 3
class TestVoiceSelect:
def _s(self, voices, lang, default):
return voices.get(lang) or voices.get(default)
def test_exact(self): assert self._s({"en":"E","fr":"F"},"fr","en") == "F"
def test_fallback(self): assert self._s({"en":"E"},"fr","en") == "E"
def test_none(self): assert self._s({},"fr","en") is None
class TestSttFields:
@pytest.fixture(scope="class")
def src(self): return _read_src("saltybot_social/speech_pipeline_node.py")
def test_param(self, src): assert "whisper_language" in src
def test_detected_lang(self, src): assert "detected_lang" in src
def test_msg_language(self, src): assert "msg.language = language" in src
def test_auto_detect(self, src): assert "language=self._whisper_language" in src
class TestConvFields:
@pytest.fixture(scope="class")
def src(self): return _read_src("saltybot_social/conversation_node.py")
def test_speaker_lang(self, src): assert "_speaker_lang" in src
def test_lang_hint_method(self, src): assert "_language_hint" in src
def test_msg_language(self, src): assert "msg.language = language" in src
def test_lang_names(self, src): assert "_LANG_NAMES" in src
def test_please_respond(self, src): assert "Please respond in" in src
def test_emotion_coexists(self, src): assert "_emotion_hint" in src
class TestTtsFields:
@pytest.fixture(scope="class")
def src(self): return _read_src("saltybot_social/tts_node.py")
def test_voice_map_json(self, src): assert "voice_map_json" in src
def test_default_lang(self, src): assert "default_language" in src
def test_voices_dict(self, src): assert "_voices" in src
def test_get_voice(self, src): assert "_get_voice" in src
def test_load_voice_for_lang(self, src): assert "_load_voice_for_lang" in src
def test_queue_tuple(self, src): assert "(text.strip(), lang)" in src
def test_synthesize_voice_arg(self, src): assert "_synthesize(text, voice)" in src
class TestMsgDefs:
@pytest.fixture(scope="class")
def tr(self): return _read_src("../saltybot_social_msgs/msg/SpeechTranscript.msg")
@pytest.fixture(scope="class")
def re(self): return _read_src("../saltybot_social_msgs/msg/ConversationResponse.msg")
def test_transcript_lang(self, tr): assert "string language" in tr
def test_transcript_bcp47(self, tr): assert "BCP-47" in tr
def test_response_lang(self, re): assert "string language" in re
class TestEdgeCases:
def test_empty_lang_no_hint(self):
lang = "" or "en"; assert lang == "en"
def test_lang_flows(self):
d: Dict[str,str] = {}; d["p1"] = "fr"
assert d.get("p1","en") == "fr"
def test_multi_speakers(self):
d = {"p1":"fr","p2":"es"}
assert d["p1"] == "fr" and d["p2"] == "es"
def test_voice_map_code_in_tts(self):
src = _read_src("saltybot_social/tts_node.py")
assert "voice_map_json" in src and "json.loads" in src

View File

@ -38,6 +38,9 @@ rosidl_generate_interfaces(${PROJECT_NAME}
# Issue #140 gesture recognition # Issue #140 gesture recognition
"msg/Gesture.msg" "msg/Gesture.msg"
"msg/GestureArray.msg" "msg/GestureArray.msg"
# Issue #161 emotion detection
"msg/Expression.msg"
"msg/ExpressionArray.msg"
DEPENDENCIES std_msgs geometry_msgs builtin_interfaces DEPENDENCIES std_msgs geometry_msgs builtin_interfaces
) )

View File

@ -7,3 +7,4 @@ string text # Full or partial response text
string speaker_id # Who the response is addressed to string speaker_id # Who the response is addressed to
bool is_partial # true = streaming token chunk, false = final response bool is_partial # true = streaming token chunk, false = final response
int32 turn_id # Conversation turn counter (for deduplication) int32 turn_id # Conversation turn counter (for deduplication)
string language # BCP-47 language code for TTS voice selection e.g. "en" "fr" "es"

View File

@ -0,0 +1,17 @@
# Expression.msg — Detected facial expression for one person (Issue #161).
# Published by emotion_node on /social/faces/expressions
std_msgs/Header header
int32 person_id # -1 = unidentified; matches PersonState.person_id
int32 face_id # matches FaceDetection.face_id
string emotion # one of: happy sad angry surprised fearful disgusted neutral
float32 confidence # smoothed confidence of the top emotion (0.01.0)
float32[7] scores # per-class softmax scores, order:
# [0]=happy [1]=sad [2]=angry [3]=surprised
# [4]=fearful [5]=disgusted [6]=neutral
bool is_opt_out # true = this person opted out; no emotion data published
string source # "cnn_trt" | "landmark_heuristic" | "opt_out"

View File

@ -0,0 +1,5 @@
# ExpressionArray.msg — Batch of detected facial expressions (Issue #161).
# Published by emotion_node on /social/faces/expressions
std_msgs/Header header
Expression[] expressions

View File

@ -8,3 +8,4 @@ string speaker_id # e.g. "person_42" or "unknown"
float32 confidence # ASR confidence 0..1 float32 confidence # ASR confidence 0..1
float32 audio_duration # Duration of the utterance in seconds float32 audio_duration # Duration of the utterance in seconds
bool is_partial # true = intermediate streaming result, false = final bool is_partial # true = intermediate streaming result, false = final
string language # BCP-47 detected language code e.g. "en" "fr" "es" (empty = unknown)

View File

@ -176,6 +176,11 @@ static void dispatch(const uint8_t *payload, uint8_t cmd, uint8_t plen)
} }
break; break;
case JLINK_CMD_SLEEP:
/* Payload-less; main loop calls power_mgmt_request_sleep() */
jlink_state.sleep_req = 1u;
break;
default: default:
break; break;
} }
@ -290,3 +295,28 @@ void jlink_send_telemetry(const jlink_tlm_status_t *status)
HAL_UART_Transmit(&s_uart, frame, sizeof(frame), 5u); HAL_UART_Transmit(&s_uart, frame, sizeof(frame), 5u);
} }
/* ---- jlink_send_power_telemetry() ---- */
void jlink_send_power_telemetry(const jlink_tlm_power_t *power)
{
/*
* Frame: [STX][LEN][0x81][11 bytes POWER][CRC_hi][CRC_lo][ETX]
* LEN = 1 (CMD) + 11 (payload) = 12; total = 17 bytes
* At 921600 baud: 17×10/921600 0.18 ms safe to block.
*/
static uint8_t frame[17];
const uint8_t plen = (uint8_t)sizeof(jlink_tlm_power_t); /* 11 */
const uint8_t len = 1u + plen; /* 12 */
frame[0] = JLINK_STX;
frame[1] = len;
frame[2] = JLINK_TLM_POWER;
memcpy(&frame[3], power, plen);
uint16_t crc = crc16_xmodem(&frame[2], len);
frame[3 + plen] = (uint8_t)(crc >> 8);
frame[3 + plen + 1] = (uint8_t)(crc & 0xFFu);
frame[3 + plen + 2] = JLINK_ETX;
HAL_UART_Transmit(&s_uart, frame, sizeof(frame), 5u);
}

View File

@ -19,6 +19,7 @@
#include "jlink.h" #include "jlink.h"
#include "ota.h" #include "ota.h"
#include "audio.h" #include "audio.h"
#include "power_mgmt.h"
#include "battery.h" #include "battery.h"
#include <math.h> #include <math.h>
#include <string.h> #include <string.h>
@ -149,6 +150,9 @@ int main(void) {
audio_init(); audio_init();
audio_play_tone(AUDIO_TONE_STARTUP); audio_play_tone(AUDIO_TONE_STARTUP);
/* Init power management — STOP-mode sleep/wake, wake EXTIs configured */
power_mgmt_init();
/* Init mode manager (RC/autonomous blend; CH6 mode switch) */ /* Init mode manager (RC/autonomous blend; CH6 mode switch) */
mode_manager_t mode; mode_manager_t mode;
mode_manager_init(&mode); mode_manager_init(&mode);
@ -183,6 +187,8 @@ int main(void) {
uint32_t esc_tick = 0; uint32_t esc_tick = 0;
uint32_t crsf_telem_tick = 0; /* CRSF uplink telemetry TX timer */ uint32_t crsf_telem_tick = 0; /* CRSF uplink telemetry TX timer */
uint32_t jlink_tlm_tick = 0; /* Jetson binary telemetry TX timer */ uint32_t jlink_tlm_tick = 0; /* Jetson binary telemetry TX timer */
uint32_t pm_tlm_tick = 0; /* JLINK_TLM_POWER transmit timer */
uint8_t pm_pwm_phase = 0; /* Software PWM counter for sleep LED */
const float dt = 1.0f / PID_LOOP_HZ; /* 1ms at 1kHz */ const float dt = 1.0f / PID_LOOP_HZ; /* 1ms at 1kHz */
while (1) { while (1) {
@ -196,6 +202,23 @@ int main(void) {
/* Advance audio tone sequencer (non-blocking, call every tick) */ /* Advance audio tone sequencer (non-blocking, call every tick) */
audio_tick(now); audio_tick(now);
/* Sleep LED: software PWM on LED1 (active-low PC15) driven by PM brightness.
* pm_pwm_phase rolls over each ms; brightness sets duty cycle 0-255. */
pm_pwm_phase++;
{
uint8_t pm_bright = power_mgmt_led_brightness();
if (pm_bright > 0u) {
bool led_on = (pm_pwm_phase < pm_bright);
HAL_GPIO_WritePin(LED1_PORT, LED1_PIN,
led_on ? GPIO_PIN_RESET : GPIO_PIN_SET);
}
}
/* Power manager tick — may block in WFI (STOP mode) when disarmed */
if (bal.state != BALANCE_ARMED) {
power_mgmt_tick(now);
}
/* Mode manager: update RC liveness, CH6 mode selection, blend ramp */ /* Mode manager: update RC liveness, CH6 mode selection, blend ramp */
mode_manager_update(&mode, now); mode_manager_update(&mode, now);
@ -249,6 +272,16 @@ int main(void) {
* never returns when disarmed MCU resets into DFU mode. */ * never returns when disarmed MCU resets into DFU mode. */
ota_enter_dfu(bal.state == BALANCE_ARMED); ota_enter_dfu(bal.state == BALANCE_ARMED);
} }
if (jlink_state.sleep_req) {
jlink_state.sleep_req = 0u;
power_mgmt_request_sleep();
}
/* Power management: CRSF/JLink activity or armed state resets idle timer */
if (crsf_is_active(now) || jlink_is_active(now) ||
bal.state == BALANCE_ARMED) {
power_mgmt_activity();
}
/* RC CH5 kill switch: disarm immediately if RC is alive and CH5 off. /* RC CH5 kill switch: disarm immediately if RC is alive and CH5 off.
* Applies regardless of active mode (CH5 always has kill authority). */ * Applies regardless of active mode (CH5 always has kill authority). */
@ -424,6 +457,18 @@ int main(void) {
jlink_send_telemetry(&tlm); jlink_send_telemetry(&tlm);
} }
/* JLINK_TLM_POWER telemetry at PM_TLM_HZ (1 Hz) */
if (now - pm_tlm_tick >= (1000u / PM_TLM_HZ)) {
pm_tlm_tick = now;
jlink_tlm_power_t pow;
pow.power_state = (uint8_t)power_mgmt_state();
pow.est_total_ma = power_mgmt_current_ma();
pow.est_audio_ma = (uint16_t)(power_mgmt_state() == PM_SLEEPING ? 0u : PM_CURRENT_AUDIO_MA);
pow.est_osd_ma = (uint16_t)(power_mgmt_state() == PM_SLEEPING ? 0u : PM_CURRENT_OSD_MA);
pow.idle_ms = power_mgmt_idle_ms();
jlink_send_power_telemetry(&pow);
}
/* USB telemetry at 50Hz (only when streaming enabled and calibration done) */ /* USB telemetry at 50Hz (only when streaming enabled and calibration done) */
if (cdc_streaming && imu_calibrated() && now - send_tick >= 20) { if (cdc_streaming && imu_calibrated() && now - send_tick >= 20) {
send_tick = now; send_tick = now;

251
src/power_mgmt.c Normal file
View File

@ -0,0 +1,251 @@
#include "power_mgmt.h"
#include "config.h"
#include "stm32f7xx_hal.h"
#include <string.h>
/* ---- Internal state ---- */
static PowerState s_state = PM_ACTIVE;
static uint32_t s_last_active = 0;
static uint32_t s_fade_start = 0;
static bool s_sleep_req = false;
static bool s_peripherals_gated = false;
/* ---- EXTI wake-source configuration ---- */
/*
* EXTI1 PA1 (UART4_RX / CRSF): falling edge (UART start bit)
* EXTI7 PB7 (USART1_RX / JLink): falling edge
* EXTI4 PC4 (MPU6000 INT): already configured by mpu6000_init();
* we just ensure IMR bit is set.
*
* GPIO pins remain in their current AF mode; EXTI is pad-level and
* fires independently of the AF setting.
*/
static void enable_wake_exti(void)
{
__HAL_RCC_SYSCFG_CLK_ENABLE();
/* EXTI1: PA1 (UART4_RX) — SYSCFG EXTICR1[7:4] = 0000 (PA) */
SYSCFG->EXTICR[0] = (SYSCFG->EXTICR[0] & ~(0xFu << 4)) | (0x0u << 4);
EXTI->FTSR |= (1u << 1);
EXTI->RTSR &= ~(1u << 1);
EXTI->PR = (1u << 1); /* clear pending */
EXTI->IMR |= (1u << 1);
HAL_NVIC_SetPriority(EXTI1_IRQn, 5, 0);
HAL_NVIC_EnableIRQ(EXTI1_IRQn);
/* EXTI7: PB7 (USART1_RX) — SYSCFG EXTICR2[15:12] = 0001 (PB) */
SYSCFG->EXTICR[1] = (SYSCFG->EXTICR[1] & ~(0xFu << 12)) | (0x1u << 12);
EXTI->FTSR |= (1u << 7);
EXTI->RTSR &= ~(1u << 7);
EXTI->PR = (1u << 7);
EXTI->IMR |= (1u << 7);
HAL_NVIC_SetPriority(EXTI9_5_IRQn, 5, 0);
HAL_NVIC_EnableIRQ(EXTI9_5_IRQn);
/* EXTI4: PC4 (MPU6000 INT) — handler in mpu6000.c; just ensure IMR set */
EXTI->IMR |= (1u << 4);
}
static void disable_wake_exti(void)
{
/* Mask UART RX wake EXTIs now that UART peripherals handle traffic */
EXTI->IMR &= ~(1u << 1);
EXTI->IMR &= ~(1u << 7);
/* Leave EXTI4 (IMU data-ready) always unmasked */
}
/* ---- Peripheral clock gating ---- */
/*
* Clock-only gate (no force-reset): peripheral register state is preserved.
* On re-enable, DMA circular transfers resume without reinitialisation.
*/
static void gate_peripherals(void)
{
if (s_peripherals_gated) return;
__HAL_RCC_SPI3_CLK_DISABLE(); /* I2S3 / audio amplifier */
__HAL_RCC_SPI2_CLK_DISABLE(); /* OSD MAX7456 */
__HAL_RCC_USART6_CLK_DISABLE(); /* legacy Jetson CDC */
__HAL_RCC_UART5_CLK_DISABLE(); /* debug UART */
s_peripherals_gated = true;
}
static void ungate_peripherals(void)
{
if (!s_peripherals_gated) return;
__HAL_RCC_SPI3_CLK_ENABLE();
__HAL_RCC_SPI2_CLK_ENABLE();
__HAL_RCC_USART6_CLK_ENABLE();
__HAL_RCC_UART5_CLK_ENABLE();
s_peripherals_gated = false;
}
/* ---- PLL clock restore after STOP mode ---- */
/*
* After STOP wakeup SYSCLK = HSI (16 MHz). Re-lock PLL for 216 MHz.
* PLLM=8, PLLN=216, PLLP=2, PLLQ=9 STM32F722 @ 216 MHz, HSI source.
*
* HAL_RCC_ClockConfig() calls HAL_InitTick() which resets uwTick to 0;
* save and restore it so existing timeouts remain valid across sleep.
*/
extern volatile uint32_t uwTick;
static void restore_clocks(void)
{
uint32_t saved_tick = uwTick;
RCC_OscInitTypeDef osc = {0};
osc.OscillatorType = RCC_OSCILLATORTYPE_HSI;
osc.HSIState = RCC_HSI_ON;
osc.HSICalibrationValue = RCC_HSICALIBRATION_DEFAULT;
osc.PLL.PLLState = RCC_PLL_ON;
osc.PLL.PLLSource = RCC_PLLSOURCE_HSI;
osc.PLL.PLLM = 8;
osc.PLL.PLLN = 216;
osc.PLL.PLLP = RCC_PLLP_DIV2;
osc.PLL.PLLQ = 9;
HAL_RCC_OscConfig(&osc);
RCC_ClkInitTypeDef clk = {0};
clk.ClockType = RCC_CLOCKTYPE_HCLK | RCC_CLOCKTYPE_SYSCLK |
RCC_CLOCKTYPE_PCLK1 | RCC_CLOCKTYPE_PCLK2;
clk.SYSCLKSource = RCC_SYSCLKSOURCE_PLLCLK;
clk.AHBCLKDivider = RCC_SYSCLK_DIV1;
clk.APB1CLKDivider = RCC_HCLK_DIV4; /* 54 MHz */
clk.APB2CLKDivider = RCC_HCLK_DIV2; /* 108 MHz */
HAL_RCC_ClockConfig(&clk, FLASH_LATENCY_7);
uwTick = saved_tick; /* restore — HAL_InitTick() reset it to 0 */
}
/* ---- EXTI IRQ handlers (wake-only: clear pending bit and return) ---- */
/*
* These handlers fire once on wakeup. After restore_clocks() the respective
* UART peripherals resume normal DMA/IDLE-interrupt operation.
*
* NOTE: If EXTI9_5_IRQHandler is already defined elsewhere in the project,
* merge that handler with this one.
*/
void EXTI1_IRQHandler(void)
{
if (EXTI->PR & (1u << 1)) EXTI->PR = (1u << 1);
}
void EXTI9_5_IRQHandler(void)
{
/* Clear any pending EXTI5-9 lines (PB7 = EXTI7 is our primary wake) */
uint32_t pr = EXTI->PR & 0x3E0u;
if (pr) EXTI->PR = pr;
}
/* ---- LED brightness (integer arithmetic, no float, called from main loop) ---- */
/*
* Triangle wave: 02550 over PM_LED_PERIOD_MS.
* Only active during PM_SLEEP_PENDING; returns 0 otherwise.
*/
uint8_t power_mgmt_led_brightness(void)
{
if (s_state != PM_SLEEP_PENDING) return 0u;
uint32_t phase = (HAL_GetTick() - s_fade_start) % PM_LED_PERIOD_MS;
uint32_t half = PM_LED_PERIOD_MS / 2u;
if (phase < half)
return (uint8_t)(phase * 255u / half);
else
return (uint8_t)((PM_LED_PERIOD_MS - phase) * 255u / half);
}
/* ---- Current estimate ---- */
uint16_t power_mgmt_current_ma(void)
{
if (s_state == PM_SLEEPING)
return (uint16_t)PM_CURRENT_STOP_MA;
uint16_t ma = (uint16_t)PM_CURRENT_BASE_MA;
if (!s_peripherals_gated) {
ma += (uint16_t)(PM_CURRENT_AUDIO_MA + PM_CURRENT_OSD_MA +
PM_CURRENT_DEBUG_MA);
}
return ma;
}
/* ---- Idle elapsed ---- */
uint32_t power_mgmt_idle_ms(void)
{
return HAL_GetTick() - s_last_active;
}
/* ---- Public API ---- */
void power_mgmt_init(void)
{
s_state = PM_ACTIVE;
s_last_active = HAL_GetTick();
s_fade_start = 0;
s_sleep_req = false;
s_peripherals_gated = false;
enable_wake_exti();
}
void power_mgmt_activity(void)
{
s_last_active = HAL_GetTick();
if (s_state != PM_ACTIVE) {
s_sleep_req = false;
s_state = PM_WAKING; /* resolved to PM_ACTIVE on next tick() */
}
}
void power_mgmt_request_sleep(void)
{
s_sleep_req = true;
}
PowerState power_mgmt_state(void)
{
return s_state;
}
PowerState power_mgmt_tick(uint32_t now_ms)
{
switch (s_state) {
case PM_ACTIVE:
if (s_sleep_req || (now_ms - s_last_active) >= PM_IDLE_TIMEOUT_MS) {
s_sleep_req = false;
s_fade_start = now_ms;
s_state = PM_SLEEP_PENDING;
}
break;
case PM_SLEEP_PENDING:
if ((now_ms - s_fade_start) >= PM_FADE_MS) {
gate_peripherals();
enable_wake_exti();
s_state = PM_SLEEPING;
/* Feed IWDG: wakeup <10 ms << WATCHDOG_TIMEOUT_MS (50 ms) */
IWDG->KR = 0xAAAAu;
/* === STOP MODE ENTRY — execution resumes here on EXTI wake === */
HAL_PWR_EnterSTOPMode(PWR_LOWPOWERREGULATOR_ON, PWR_STOPENTRY_WFI);
/* === WAKEUP POINT (< 10 ms latency) === */
restore_clocks();
ungate_peripherals();
disable_wake_exti();
s_last_active = HAL_GetTick();
s_state = PM_ACTIVE;
}
break;
case PM_SLEEPING:
/* Unreachable: WFI is inline in PM_SLEEP_PENDING above */
break;
case PM_WAKING:
/* Set by power_mgmt_activity() during SLEEP_PENDING/SLEEPING */
ungate_peripherals();
s_state = PM_ACTIVE;
break;
}
return s_state;
}

567
test/test_power_mgmt.py Normal file
View File

@ -0,0 +1,567 @@
"""
test_power_mgmt.py unit tests for Issue #178 power management module.
Models the PM state machine, LED brightness, peripheral gating, current
estimates, JLink protocol extension, and hardware timing budgets in Python.
"""
import struct
import pytest
# ---------------------------------------------------------------------------
# Constants (mirror config.h / power_mgmt.h)
# ---------------------------------------------------------------------------
PM_IDLE_TIMEOUT_MS = 30_000
PM_FADE_MS = 3_000
PM_LED_PERIOD_MS = 2_000
PM_CURRENT_BASE_MA = 30 # SPI1(IMU) + UART4(CRSF) + USART1(JLink) + core
PM_CURRENT_AUDIO_MA = 8 # I2S3 + amp quiescent
PM_CURRENT_OSD_MA = 5 # SPI2 OSD MAX7456
PM_CURRENT_DEBUG_MA = 1 # UART5 + USART6
PM_CURRENT_STOP_MA = 1 # MCU in STOP mode (< 1 mA)
PM_CURRENT_ACTIVE_ALL = (PM_CURRENT_BASE_MA + PM_CURRENT_AUDIO_MA +
PM_CURRENT_OSD_MA + PM_CURRENT_DEBUG_MA)
# JLink additions
JLINK_STX = 0x02
JLINK_ETX = 0x03
JLINK_CMD_SLEEP = 0x09
JLINK_TLM_STATUS = 0x80
JLINK_TLM_POWER = 0x81
# Power states
PM_ACTIVE = 0
PM_SLEEP_PENDING = 1
PM_SLEEPING = 2
PM_WAKING = 3
# WATCHDOG_TIMEOUT_MS from config.h
WATCHDOG_TIMEOUT_MS = 50
# ---------------------------------------------------------------------------
# CRC-16/XModem helper (poly 0x1021, init 0x0000)
# ---------------------------------------------------------------------------
def crc16_xmodem(data: bytes) -> int:
crc = 0x0000
for b in data:
crc ^= b << 8
for _ in range(8):
crc = (crc << 1) ^ 0x1021 if crc & 0x8000 else crc << 1
crc &= 0xFFFF
return crc
def build_frame(cmd: int, payload: bytes = b"") -> bytes:
data = bytes([cmd]) + payload
length = len(data)
crc = crc16_xmodem(data)
return bytes([JLINK_STX, length, *data, crc >> 8, crc & 0xFF, JLINK_ETX])
# ---------------------------------------------------------------------------
# Python model of the power_mgmt state machine (mirrors power_mgmt.c)
# ---------------------------------------------------------------------------
class PowerMgmtSim:
def __init__(self, now: int = 0):
self.state = PM_ACTIVE
self.last_active = now
self.fade_start = 0
self.sleep_req = False
self.peripherals_gated = False
def activity(self, now: int) -> None:
self.last_active = now
if self.state != PM_ACTIVE:
self.sleep_req = False
self.state = PM_WAKING
def request_sleep(self) -> None:
self.sleep_req = True
def led_brightness(self, now: int) -> int:
if self.state != PM_SLEEP_PENDING:
return 0
phase = (now - self.fade_start) % PM_LED_PERIOD_MS
half = PM_LED_PERIOD_MS // 2
if phase < half:
return phase * 255 // half
else:
return (PM_LED_PERIOD_MS - phase) * 255 // half
def current_ma(self) -> int:
if self.state == PM_SLEEPING:
return PM_CURRENT_STOP_MA
ma = PM_CURRENT_BASE_MA
if not self.peripherals_gated:
ma += PM_CURRENT_AUDIO_MA + PM_CURRENT_OSD_MA + PM_CURRENT_DEBUG_MA
return ma
def idle_ms(self, now: int) -> int:
return now - self.last_active
def tick(self, now: int) -> int:
if self.state == PM_ACTIVE:
if self.sleep_req or (now - self.last_active) >= PM_IDLE_TIMEOUT_MS:
self.sleep_req = False
self.fade_start = now
self.state = PM_SLEEP_PENDING
elif self.state == PM_SLEEP_PENDING:
if (now - self.fade_start) >= PM_FADE_MS:
self.peripherals_gated = True
self.state = PM_SLEEPING
# In firmware: WFI blocks here; in test we skip to simulate_wake
elif self.state == PM_WAKING:
self.peripherals_gated = False
self.state = PM_ACTIVE
return self.state
def simulate_wake(self, now: int) -> None:
"""Simulate EXTI wakeup from STOP mode (models HAL_PWR_EnterSTOPMode return)."""
if self.state == PM_SLEEPING:
self.peripherals_gated = False
self.last_active = now
self.state = PM_ACTIVE
# ---------------------------------------------------------------------------
# Tests: Idle timer
# ---------------------------------------------------------------------------
class TestIdleTimer:
def test_stays_active_before_timeout(self):
pm = PowerMgmtSim(now=0)
for t in range(0, PM_IDLE_TIMEOUT_MS, 1000):
assert pm.tick(t) == PM_ACTIVE
def test_enters_sleep_pending_at_timeout(self):
pm = PowerMgmtSim(now=0)
assert pm.tick(PM_IDLE_TIMEOUT_MS - 1) == PM_ACTIVE
assert pm.tick(PM_IDLE_TIMEOUT_MS) == PM_SLEEP_PENDING
def test_activity_resets_idle_timer(self):
pm = PowerMgmtSim(now=0)
pm.tick(PM_IDLE_TIMEOUT_MS - 1000)
pm.activity(PM_IDLE_TIMEOUT_MS - 1000) # reset at T=29000
assert pm.tick(PM_IDLE_TIMEOUT_MS) == PM_ACTIVE # 1 s since reset
assert pm.tick(PM_IDLE_TIMEOUT_MS - 1000 + PM_IDLE_TIMEOUT_MS) == PM_SLEEP_PENDING
def test_idle_ms_increases_monotonically(self):
pm = PowerMgmtSim(now=0)
assert pm.idle_ms(0) == 0
assert pm.idle_ms(5000) == 5000
assert pm.idle_ms(29999) == 29999
def test_idle_ms_resets_on_activity(self):
pm = PowerMgmtSim(now=0)
pm.activity(10000)
assert pm.idle_ms(10500) == 500
def test_30s_timeout_matches_spec(self):
assert PM_IDLE_TIMEOUT_MS == 30_000
# ---------------------------------------------------------------------------
# Tests: State machine transitions
# ---------------------------------------------------------------------------
class TestStateMachine:
def test_sleep_req_bypasses_idle_timer(self):
pm = PowerMgmtSim(now=0)
pm.activity(0)
pm.request_sleep()
assert pm.tick(500) == PM_SLEEP_PENDING
def test_fade_complete_enters_sleeping(self):
pm = PowerMgmtSim(now=0)
pm.tick(PM_IDLE_TIMEOUT_MS) # → SLEEP_PENDING
assert pm.tick(PM_IDLE_TIMEOUT_MS + PM_FADE_MS) == PM_SLEEPING
def test_fade_not_complete_stays_pending(self):
pm = PowerMgmtSim(now=0)
pm.tick(PM_IDLE_TIMEOUT_MS)
assert pm.tick(PM_IDLE_TIMEOUT_MS + PM_FADE_MS - 1) == PM_SLEEP_PENDING
def test_wake_from_stop_returns_active(self):
pm = PowerMgmtSim(now=0)
pm.tick(PM_IDLE_TIMEOUT_MS)
pm.tick(PM_IDLE_TIMEOUT_MS + PM_FADE_MS) # → SLEEPING
pm.simulate_wake(PM_IDLE_TIMEOUT_MS + PM_FADE_MS + 5)
assert pm.state == PM_ACTIVE
def test_activity_during_sleep_pending_aborts(self):
pm = PowerMgmtSim(now=0)
pm.tick(PM_IDLE_TIMEOUT_MS) # → SLEEP_PENDING
pm.activity(PM_IDLE_TIMEOUT_MS + 100) # abort
assert pm.state == PM_WAKING
pm.tick(PM_IDLE_TIMEOUT_MS + 101)
assert pm.state == PM_ACTIVE
def test_activity_during_sleeping_aborts(self):
pm = PowerMgmtSim(now=0)
pm.tick(PM_IDLE_TIMEOUT_MS)
pm.tick(PM_IDLE_TIMEOUT_MS + PM_FADE_MS) # → SLEEPING
pm.activity(PM_IDLE_TIMEOUT_MS + PM_FADE_MS + 3)
assert pm.state == PM_WAKING
pm.tick(PM_IDLE_TIMEOUT_MS + PM_FADE_MS + 4)
assert pm.state == PM_ACTIVE
def test_waking_resolves_on_next_tick(self):
pm = PowerMgmtSim(now=0)
pm.state = PM_WAKING
pm.tick(1000)
assert pm.state == PM_ACTIVE
def test_full_sleep_wake_cycle(self):
pm = PowerMgmtSim(now=0)
# 1. Active
assert pm.tick(100) == PM_ACTIVE
# 2. Idle → sleep pending
assert pm.tick(PM_IDLE_TIMEOUT_MS) == PM_SLEEP_PENDING
# 3. Fade → sleeping
assert pm.tick(PM_IDLE_TIMEOUT_MS + PM_FADE_MS) == PM_SLEEPING
# 4. EXTI wake → active
pm.simulate_wake(PM_IDLE_TIMEOUT_MS + PM_FADE_MS + 8)
assert pm.state == PM_ACTIVE
def test_multiple_sleep_wake_cycles(self):
pm = PowerMgmtSim(now=0)
base = 0
for _ in range(3):
pm.activity(base)
pm.tick(base + PM_IDLE_TIMEOUT_MS)
pm.tick(base + PM_IDLE_TIMEOUT_MS + PM_FADE_MS)
pm.simulate_wake(base + PM_IDLE_TIMEOUT_MS + PM_FADE_MS + 5)
assert pm.state == PM_ACTIVE
base += PM_IDLE_TIMEOUT_MS + PM_FADE_MS + 10
# ---------------------------------------------------------------------------
# Tests: Peripheral gating
# ---------------------------------------------------------------------------
class TestPeripheralGating:
GATED = {'SPI3_I2S3', 'SPI2_OSD', 'USART6', 'UART5_DEBUG'}
ACTIVE = {'SPI1_IMU', 'UART4_CRSF', 'USART1_JLINK', 'I2C1_BARO'}
def test_gated_set_has_four_peripherals(self):
assert len(self.GATED) == 4
def test_no_overlap_between_gated_and_active(self):
assert not (self.GATED & self.ACTIVE)
def test_crsf_uart_not_gated(self):
assert not any('UART4' in p or 'CRSF' in p for p in self.GATED)
def test_jlink_uart_not_gated(self):
assert not any('USART1' in p or 'JLINK' in p for p in self.GATED)
def test_imu_spi_not_gated(self):
assert not any('SPI1' in p or 'IMU' in p for p in self.GATED)
def test_peripherals_gated_on_sleep_entry(self):
pm = PowerMgmtSim(now=0)
assert not pm.peripherals_gated
pm.tick(PM_IDLE_TIMEOUT_MS)
pm.tick(PM_IDLE_TIMEOUT_MS + PM_FADE_MS) # → SLEEPING
assert pm.peripherals_gated
def test_peripherals_ungated_on_wake(self):
pm = PowerMgmtSim(now=0)
pm.tick(PM_IDLE_TIMEOUT_MS)
pm.tick(PM_IDLE_TIMEOUT_MS + PM_FADE_MS)
pm.simulate_wake(PM_IDLE_TIMEOUT_MS + PM_FADE_MS + 5)
assert not pm.peripherals_gated
def test_peripherals_not_gated_in_sleep_pending(self):
pm = PowerMgmtSim(now=0)
pm.tick(PM_IDLE_TIMEOUT_MS) # → SLEEP_PENDING
assert not pm.peripherals_gated
def test_peripherals_ungated_if_activity_during_pending(self):
pm = PowerMgmtSim(now=0)
pm.tick(PM_IDLE_TIMEOUT_MS)
pm.activity(PM_IDLE_TIMEOUT_MS + 100)
pm.tick(PM_IDLE_TIMEOUT_MS + 101)
assert not pm.peripherals_gated
# ---------------------------------------------------------------------------
# Tests: LED brightness
# ---------------------------------------------------------------------------
class TestLedBrightness:
def test_zero_when_active(self):
pm = PowerMgmtSim(now=0)
assert pm.led_brightness(5000) == 0
def test_zero_when_sleeping(self):
pm = PowerMgmtSim(now=0)
pm.tick(PM_IDLE_TIMEOUT_MS)
pm.tick(PM_IDLE_TIMEOUT_MS + PM_FADE_MS) # → SLEEPING
assert pm.led_brightness(PM_IDLE_TIMEOUT_MS + PM_FADE_MS + 100) == 0
def test_zero_when_waking(self):
pm = PowerMgmtSim(now=0)
pm.state = PM_WAKING
assert pm.led_brightness(1000) == 0
def test_zero_at_phase_start(self):
pm = PowerMgmtSim(now=0)
pm.tick(PM_IDLE_TIMEOUT_MS) # fade_start = PM_IDLE_TIMEOUT_MS
assert pm.led_brightness(PM_IDLE_TIMEOUT_MS) == 0
def test_max_at_half_period(self):
pm = PowerMgmtSim(now=0)
pm.tick(PM_IDLE_TIMEOUT_MS)
t = PM_IDLE_TIMEOUT_MS + PM_LED_PERIOD_MS // 2
assert pm.led_brightness(t) == 255
def test_zero_at_full_period(self):
pm = PowerMgmtSim(now=0)
pm.tick(PM_IDLE_TIMEOUT_MS)
t = PM_IDLE_TIMEOUT_MS + PM_LED_PERIOD_MS
assert pm.led_brightness(t) == 0
def test_symmetric_about_half_period(self):
pm = PowerMgmtSim(now=0)
pm.tick(PM_IDLE_TIMEOUT_MS)
quarter = PM_LED_PERIOD_MS // 4
three_quarter = 3 * PM_LED_PERIOD_MS // 4
b1 = pm.led_brightness(PM_IDLE_TIMEOUT_MS + quarter)
b2 = pm.led_brightness(PM_IDLE_TIMEOUT_MS + three_quarter)
assert abs(b1 - b2) <= 1 # allow 1 LSB for integer division
def test_range_0_to_255(self):
pm = PowerMgmtSim(now=0)
pm.tick(PM_IDLE_TIMEOUT_MS)
for dt in range(0, PM_LED_PERIOD_MS * 3, 37):
b = pm.led_brightness(PM_IDLE_TIMEOUT_MS + dt)
assert 0 <= b <= 255
def test_repeats_over_multiple_periods(self):
pm = PowerMgmtSim(now=0)
pm.tick(PM_IDLE_TIMEOUT_MS)
# Sample at same phase in periods 0, 1, 2 — should be equal
phase = PM_LED_PERIOD_MS // 3
b0 = pm.led_brightness(PM_IDLE_TIMEOUT_MS + phase)
b1 = pm.led_brightness(PM_IDLE_TIMEOUT_MS + PM_LED_PERIOD_MS + phase)
b2 = pm.led_brightness(PM_IDLE_TIMEOUT_MS + 2 * PM_LED_PERIOD_MS + phase)
assert b0 == b1 == b2
def test_period_is_2s(self):
assert PM_LED_PERIOD_MS == 2000
# ---------------------------------------------------------------------------
# Tests: Power / current estimates
# ---------------------------------------------------------------------------
class TestPowerEstimates:
def test_active_includes_all_subsystems(self):
pm = PowerMgmtSim(now=0)
assert pm.current_ma() == PM_CURRENT_ACTIVE_ALL
def test_sleeping_returns_stop_ma(self):
pm = PowerMgmtSim(now=0)
pm.tick(PM_IDLE_TIMEOUT_MS)
pm.tick(PM_IDLE_TIMEOUT_MS + PM_FADE_MS) # → SLEEPING
assert pm.current_ma() == PM_CURRENT_STOP_MA
def test_gated_returns_base_only(self):
pm = PowerMgmtSim(now=0)
pm.peripherals_gated = True
assert pm.current_ma() == PM_CURRENT_BASE_MA
def test_stop_current_less_than_active(self):
assert PM_CURRENT_STOP_MA < PM_CURRENT_ACTIVE_ALL
def test_stop_current_at_most_1ma(self):
assert PM_CURRENT_STOP_MA <= 1
def test_active_current_reasonable(self):
# Should be < 100 mA (just MCU + peripherals, no motors)
assert PM_CURRENT_ACTIVE_ALL < 100
def test_audio_subsystem_estimate(self):
assert PM_CURRENT_AUDIO_MA > 0
def test_osd_subsystem_estimate(self):
assert PM_CURRENT_OSD_MA > 0
def test_total_equals_sum_of_parts(self):
total = (PM_CURRENT_BASE_MA + PM_CURRENT_AUDIO_MA +
PM_CURRENT_OSD_MA + PM_CURRENT_DEBUG_MA)
assert total == PM_CURRENT_ACTIVE_ALL
# ---------------------------------------------------------------------------
# Tests: JLink protocol extension
# ---------------------------------------------------------------------------
class TestJlinkProtocol:
def test_sleep_cmd_id(self):
assert JLINK_CMD_SLEEP == 0x09
def test_sleep_follows_audio_cmd(self):
JLINK_CMD_AUDIO = 0x08
assert JLINK_CMD_SLEEP == JLINK_CMD_AUDIO + 1
def test_power_tlm_id(self):
assert JLINK_TLM_POWER == 0x81
def test_power_tlm_follows_status_tlm(self):
assert JLINK_TLM_POWER == JLINK_TLM_STATUS + 1
def test_sleep_frame_length(self):
# SLEEP has no payload: STX(1)+LEN(1)+CMD(1)+CRC(2)+ETX(1) = 6
frame = build_frame(JLINK_CMD_SLEEP)
assert len(frame) == 6
def test_sleep_frame_sentinels(self):
frame = build_frame(JLINK_CMD_SLEEP)
assert frame[0] == JLINK_STX
assert frame[-1] == JLINK_ETX
def test_sleep_frame_len_field(self):
frame = build_frame(JLINK_CMD_SLEEP)
assert frame[1] == 1 # LEN = 1 (CMD only, no payload)
def test_sleep_frame_cmd_byte(self):
frame = build_frame(JLINK_CMD_SLEEP)
assert frame[2] == JLINK_CMD_SLEEP
def test_sleep_frame_crc_valid(self):
frame = build_frame(JLINK_CMD_SLEEP)
calc = crc16_xmodem(bytes([JLINK_CMD_SLEEP]))
rx = (frame[-3] << 8) | frame[-2]
assert rx == calc
def test_power_tlm_frame_length(self):
# jlink_tlm_power_t = 11 bytes
# Frame: STX(1)+LEN(1)+CMD(1)+payload(11)+CRC(2)+ETX(1) = 17
POWER_TLM_PAYLOAD_LEN = 11
expected = 1 + 1 + 1 + POWER_TLM_PAYLOAD_LEN + 2 + 1
assert expected == 17
def test_power_tlm_payload_struct(self):
"""jlink_tlm_power_t: u8 power_state, u16 est_total_ma,
u16 est_audio_ma, u16 est_osd_ma, u32 idle_ms = 11 bytes."""
fmt = "<BHHHI"
size = struct.calcsize(fmt)
assert size == 11
def test_power_tlm_frame_crc_valid(self):
power_state = PM_ACTIVE
est_total_ma = PM_CURRENT_ACTIVE_ALL
est_audio_ma = PM_CURRENT_AUDIO_MA
est_osd_ma = PM_CURRENT_OSD_MA
idle_ms = 5000
payload = struct.pack("<BHHHI", power_state, est_total_ma,
est_audio_ma, est_osd_ma, idle_ms)
frame = build_frame(JLINK_TLM_POWER, payload)
assert frame[0] == JLINK_STX
assert frame[-1] == JLINK_ETX
data_for_crc = bytes([JLINK_TLM_POWER]) + payload
expected_crc = crc16_xmodem(data_for_crc)
rx_crc = (frame[-3] << 8) | frame[-2]
assert rx_crc == expected_crc
# ---------------------------------------------------------------------------
# Tests: Wake latency and IWDG budget
# ---------------------------------------------------------------------------
class TestWakeLatencyBudget:
# STM32F722 STOP-mode wakeup: HSI ready ~2 ms + PLL lock ~2 ms ≈ 4 ms
ESTIMATED_WAKE_MS = 10 # conservative upper bound
def test_wake_latency_within_50ms(self):
assert self.ESTIMATED_WAKE_MS < WATCHDOG_TIMEOUT_MS
def test_watchdog_timeout_is_50ms(self):
assert WATCHDOG_TIMEOUT_MS == 50
def test_iwdg_feed_before_wfi_is_safe(self):
# Time from IWDG feed to next feed after wake:
# ~1 ms (loop overhead) + ESTIMATED_WAKE_MS + ~1 ms = ~12 ms
time_from_feed_to_next_ms = 1 + self.ESTIMATED_WAKE_MS + 1
assert time_from_feed_to_next_ms < WATCHDOG_TIMEOUT_MS
def test_fade_ms_positive(self):
assert PM_FADE_MS > 0
def test_fade_ms_less_than_idle_timeout(self):
assert PM_FADE_MS < PM_IDLE_TIMEOUT_MS
def test_stop_mode_wake_much_less_than_50ms(self):
# PLL startup on STM32F7: HSI on (0 ms, already running) +
# PLL lock ~2 ms + SysTick re-init ~0.1 ms ≈ 3 ms
pll_lock_ms = 3
overhead_ms = 1
total_ms = pll_lock_ms + overhead_ms
assert total_ms < 50
def test_wake_exti_sources_count(self):
"""Three wake sources: EXTI1 (CRSF), EXTI7 (JLink), EXTI4 (IMU)."""
wake_sources = ['EXTI1_UART4_CRSF', 'EXTI7_USART1_JLINK', 'EXTI4_IMU_INT']
assert len(wake_sources) == 3
def test_uwTick_must_be_restored_after_stop(self):
"""HAL_RCC_ClockConfig resets uwTick to 0; restore_clocks() saves it."""
# Verify the pattern: save uwTick → HAL calls → restore uwTick
saved_tick = 12345
# Simulate HAL_InitTick() resetting to 0
uw_tick_after_hal = 0
restored = saved_tick # power_mgmt.c: uwTick = saved_tick
assert restored == saved_tick
assert uw_tick_after_hal != saved_tick # HAL reset it
# ---------------------------------------------------------------------------
# Tests: Hardware constants
# ---------------------------------------------------------------------------
class TestHardwareConstants:
def test_pll_params_for_216mhz(self):
"""PLLM=8, PLLN=216, PLLP=2 → VCO=216*2=432 MHz, SYSCLK=216 MHz."""
HSI_MHZ = 16
PLLM = 8
PLLN = 216
PLLP = 2
vco_mhz = HSI_MHZ / PLLM * PLLN
sysclk = vco_mhz / PLLP
assert sysclk == pytest.approx(216.0, rel=1e-6)
def test_apb1_54mhz(self):
"""APB1 = SYSCLK / 4 = 54 MHz."""
assert 216 / 4 == 54
def test_apb2_108mhz(self):
"""APB2 = SYSCLK / 2 = 108 MHz."""
assert 216 / 2 == 108
def test_flash_latency_7_required_at_216mhz(self):
"""STM32F7 at 2.7-3.3 V: 7 wait states for 210-216 MHz."""
FLASH_LATENCY = 7
assert FLASH_LATENCY == 7
def test_exti1_for_pa1(self):
"""SYSCFG EXTICR1[7:4] = 0x0 selects PA for EXTI1."""
PA_SOURCE = 0x0
assert PA_SOURCE == 0x0
def test_exti7_for_pb7(self):
"""SYSCFG EXTICR2[15:12] = 0x1 selects PB for EXTI7."""
PB_SOURCE = 0x1
assert PB_SOURCE == 0x1
def test_exticr_indices(self):
"""EXTI1 → EXTICR[0], EXTI7 → EXTICR[1]."""
assert 1 // 4 == 0 # EXTI1 is in EXTICR[0]
assert 7 // 4 == 1 # EXTI7 is in EXTICR[1]
def test_exti7_shift_in_exticr2(self):
"""EXTI7 field is at bits [15:12] of EXTICR[1] → shift = (7%4)*4 = 12."""
shift = (7 % 4) * 4
assert shift == 12
def test_idle_timeout_30s(self):
assert PM_IDLE_TIMEOUT_MS == 30_000

View File

@ -5,13 +5,16 @@
* Status | Faces | Conversation | Personality | Navigation * Status | Faces | Conversation | Personality | Navigation
* *
* Telemetry tabs (issue #126): * Telemetry tabs (issue #126):
* IMU | Battery | Motors | Map | Control | Health * IMU | Battery | Motors | Map | Control | Health | Cameras
* *
* Fleet tabs (issue #139): * Fleet tabs (issue #139):
* Fleet (self-contained via useFleet) * Fleet (self-contained via useFleet)
* *
* Mission tabs (issue #145): * Mission tabs (issue #145):
* Missions (waypoint editor, route builder, geofence, schedule, execute) * Missions (waypoint editor, route builder, geofence, schedule, execute)
*
* Camera viewer (issue #177):
* CSI × 4 + D435i RGB/depth + panoramic, detection overlays, recording
*/ */
import { useState, useCallback } from 'react'; import { useState, useCallback } from 'react';
@ -41,6 +44,9 @@ import { MissionPlanner } from './components/MissionPlanner.jsx';
// Settings panel (issue #160) // Settings panel (issue #160)
import { SettingsPanel } from './components/SettingsPanel.jsx'; import { SettingsPanel } from './components/SettingsPanel.jsx';
// Camera viewer (issue #177)
import { CameraViewer } from './components/CameraViewer.jsx';
const TAB_GROUPS = [ const TAB_GROUPS = [
{ {
label: 'SOCIAL', label: 'SOCIAL',
@ -63,6 +69,7 @@ const TAB_GROUPS = [
{ id: 'map', label: 'Map', }, { id: 'map', label: 'Map', },
{ id: 'control', label: 'Control', }, { id: 'control', label: 'Control', },
{ id: 'health', label: 'Health', }, { id: 'health', label: 'Health', },
{ id: 'cameras', label: 'Cameras', },
], ],
}, },
{ {
@ -206,6 +213,7 @@ export default function App() {
{activeTab === 'map' && <MapViewer subscribe={subscribe} />} {activeTab === 'map' && <MapViewer subscribe={subscribe} />}
{activeTab === 'control' && <ControlMode subscribe={subscribe} />} {activeTab === 'control' && <ControlMode subscribe={subscribe} />}
{activeTab === 'health' && <SystemHealth subscribe={subscribe} />} {activeTab === 'health' && <SystemHealth subscribe={subscribe} />}
{activeTab === 'cameras' && <CameraViewer subscribe={subscribe} />}
{activeTab === 'fleet' && <FleetPanel />} {activeTab === 'fleet' && <FleetPanel />}
{activeTab === 'missions' && <MissionPlanner />} {activeTab === 'missions' && <MissionPlanner />}

View File

@ -0,0 +1,671 @@
/**
* CameraViewer.jsx Live camera stream viewer (Issue #177).
*
* Features:
* - 7 cameras: front/left/rear/right (CSI), D435i RGB/depth, panoramic
* - Detection overlays: face boxes + names, gesture icons, scene object labels
* - 360° panoramic equirect viewer with mouse drag pan
* - One-click recording (MP4/WebM) + download
* - Snapshot to PNG with annotations + timestamp
* - Picture-in-picture (up to 3 pinned cameras)
* - Per-camera FPS indicator + adaptive quality badge
*
* Topics consumed:
* /camera/<name>/image_raw/compressed sensor_msgs/CompressedImage
* /camera/color/image_raw/compressed sensor_msgs/CompressedImage (D435i)
* /camera/depth/image_rect_raw/compressed sensor_msgs/CompressedImage (D435i)
* /camera/panoramic/compressed sensor_msgs/CompressedImage
* /social/faces/detections saltybot_social_msgs/FaceDetectionArray
* /social/gestures saltybot_social_msgs/GestureArray
* /social/scene/objects saltybot_scene_msgs/SceneObjectArray
*/
import { useEffect, useRef, useState, useCallback } from 'react';
import { useCamera, CAMERAS, CAMERA_BY_ID, CAMERA_BY_ROS_ID } from '../hooks/useCamera.js';
// Constants
const GESTURE_ICONS = {
wave: '👋',
point: '👆',
stop_palm: '✋',
thumbs_up: '👍',
thumbs_down: '👎',
come_here: '🤏',
follow: '☞',
arms_up: '🙌',
crouch: '⬇',
arms_spread: '↔',
};
const HAZARD_COLORS = {
1: '#f59e0b', // stairs amber
2: '#ef4444', // drop red
3: '#60a5fa', // wet floor blue
4: '#a855f7', // glass door purple
5: '#f97316', // pet orange
};
// Detection overlay drawing helpers
function drawFaceBoxes(ctx, faces, scaleX, scaleY) {
for (const face of faces) {
const x = face.bbox_x * scaleX;
const y = face.bbox_y * scaleY;
const w = face.bbox_w * scaleX;
const h = face.bbox_h * scaleY;
const isKnown = face.person_name && face.person_name !== 'unknown';
ctx.strokeStyle = isKnown ? '#06b6d4' : '#f59e0b';
ctx.lineWidth = 2;
ctx.shadowBlur = 6;
ctx.shadowColor = ctx.strokeStyle;
ctx.strokeRect(x, y, w, h);
ctx.shadowBlur = 0;
// Corner accent marks
const cLen = 8;
ctx.lineWidth = 3;
[[x,y,1,1],[x+w,y,-1,1],[x,y+h,1,-1],[x+w,y+h,-1,-1]].forEach(([cx,cy,dx,dy]) => {
ctx.beginPath();
ctx.moveTo(cx, cy + dy * cLen);
ctx.lineTo(cx, cy);
ctx.lineTo(cx + dx * cLen, cy);
ctx.stroke();
});
// Label
const label = isKnown
? `${face.person_name} ${(face.recognition_score * 100).toFixed(0)}%`
: `face #${face.face_id}`;
ctx.font = 'bold 11px monospace';
const tw = ctx.measureText(label).width;
ctx.fillStyle = isKnown ? 'rgba(6,182,212,0.8)' : 'rgba(245,158,11,0.8)';
ctx.fillRect(x, y - 16, tw + 6, 16);
ctx.fillStyle = '#000';
ctx.fillText(label, x + 3, y - 4);
}
}
function drawGestureIcons(ctx, gestures, activeCamId, scaleX, scaleY) {
for (const g of gestures) {
// Only show gestures from the currently viewed camera
const cam = CAMERA_BY_ROS_ID[g.camera_id];
if (!cam || cam.cameraId !== activeCamId) continue;
const x = g.hand_x * ctx.canvas.width;
const y = g.hand_y * ctx.canvas.height;
const icon = GESTURE_ICONS[g.gesture_type] ?? '?';
ctx.font = '24px serif';
ctx.shadowBlur = 8;
ctx.shadowColor = '#f97316';
ctx.fillText(icon, x - 12, y + 8);
ctx.shadowBlur = 0;
ctx.font = 'bold 10px monospace';
ctx.fillStyle = '#f97316';
const label = g.gesture_type;
ctx.fillText(label, x - ctx.measureText(label).width / 2, y + 22);
}
}
function drawSceneObjects(ctx, objects, scaleX, scaleY) {
for (const obj of objects) {
// vision_msgs/BoundingBox2D: center_x, center_y, size_x, size_y
const bb = obj.bbox;
const cx = bb?.center?.x ?? bb?.center_x;
const cy = bb?.center?.y ?? bb?.center_y;
const sw = bb?.size_x ?? 0;
const sh = bb?.size_y ?? 0;
if (cx == null) continue;
const x = (cx - sw / 2) * scaleX;
const y = (cy - sh / 2) * scaleY;
const w = sw * scaleX;
const h = sh * scaleY;
const color = HAZARD_COLORS[obj.hazard_type] ?? '#22c55e';
ctx.strokeStyle = color;
ctx.lineWidth = 1.5;
ctx.setLineDash([4, 3]);
ctx.strokeRect(x, y, w, h);
ctx.setLineDash([]);
const dist = obj.distance_m > 0 ? ` ${obj.distance_m.toFixed(1)}m` : '';
const label = `${obj.class_name}${dist}`;
ctx.font = '10px monospace';
const tw = ctx.measureText(label).width;
ctx.fillStyle = `${color}cc`;
ctx.fillRect(x, y + h, tw + 4, 14);
ctx.fillStyle = '#000';
ctx.fillText(label, x + 2, y + h + 11);
}
}
// Overlay canvas
function OverlayCanvas({ faces, gestures, sceneObjects, activeCam, containerW, containerH }) {
const canvasRef = useRef(null);
useEffect(() => {
const canvas = canvasRef.current;
if (!canvas) return;
const ctx = canvas.getContext('2d');
ctx.clearRect(0, 0, canvas.width, canvas.height);
if (!activeCam) return;
const scaleX = canvas.width / (activeCam.width || 640);
const scaleY = canvas.height / (activeCam.height || 480);
// Draw overlays: only for front camera (face + gesture source)
if (activeCam.id === 'front') {
drawFaceBoxes(ctx, faces, scaleX, scaleY);
}
if (!activeCam.isPanoramic) {
drawGestureIcons(ctx, gestures, activeCam.cameraId, scaleX, scaleY);
}
if (activeCam.id === 'color') {
drawSceneObjects(ctx, sceneObjects, scaleX, scaleY);
}
}, [faces, gestures, sceneObjects, activeCam]);
return (
<canvas
ref={canvasRef}
width={containerW || 640}
height={containerH || 480}
className="absolute inset-0 w-full h-full pointer-events-none"
/>
);
}
// Panoramic equirect viewer
function PanoViewer({ frameUrl }) {
const canvasRef = useRef(null);
const azRef = useRef(0); // 01920px offset
const dragRef = useRef(null);
const imgRef = useRef(null);
const draw = useCallback(() => {
const canvas = canvasRef.current;
const img = imgRef.current;
if (!canvas || !img || !img.complete) return;
const ctx = canvas.getContext('2d');
const W = canvas.width;
const H = canvas.height;
const iW = img.naturalWidth; // 1920
const iH = img.naturalHeight; // 960
const vW = iW / 2; // viewport = 50% of equirect width
const vH = Math.round((H / W) * vW);
const vY = Math.round((iH - vH) / 2);
const off = Math.round(azRef.current) % iW;
ctx.clearRect(0, 0, W, H);
// Draw left segment
const srcX1 = off;
const srcW1 = Math.min(vW, iW - off);
const dstW1 = Math.round((srcW1 / vW) * W);
if (dstW1 > 0) {
ctx.drawImage(img, srcX1, vY, srcW1, vH, 0, 0, dstW1, H);
}
// Draw wrapped right segment (if viewport crosses 0°)
if (srcW1 < vW) {
const srcX2 = 0;
const srcW2 = vW - srcW1;
const dstX2 = dstW1;
const dstW2 = W - dstW1;
ctx.drawImage(img, srcX2, vY, srcW2, vH, dstX2, 0, dstW2, H);
}
// Compass badge
const azDeg = Math.round((azRef.current / iW) * 360);
ctx.fillStyle = 'rgba(0,0,0,0.5)';
ctx.fillRect(W - 58, 6, 52, 18);
ctx.fillStyle = '#06b6d4';
ctx.font = 'bold 11px monospace';
ctx.fillText(`${azDeg}°`, W - 52, 19);
}, []);
// Load image when URL changes
useEffect(() => {
if (!frameUrl) return;
const img = new Image();
img.onload = draw;
img.src = frameUrl;
imgRef.current = img;
}, [frameUrl, draw]);
// Re-draw when azimuth changes
const onMouseDown = e => { dragRef.current = e.clientX; };
const onMouseMove = e => {
if (dragRef.current == null) return;
const dx = e.clientX - dragRef.current;
dragRef.current = e.clientX;
azRef.current = ((azRef.current - dx * 2) % 1920 + 1920) % 1920;
draw();
};
const onMouseUp = () => { dragRef.current = null; };
const onTouchStart = e => { dragRef.current = e.touches[0].clientX; };
const onTouchMove = e => {
if (dragRef.current == null) return;
const dx = e.touches[0].clientX - dragRef.current;
dragRef.current = e.touches[0].clientX;
azRef.current = ((azRef.current - dx * 2) % 1920 + 1920) % 1920;
draw();
};
return (
<canvas
ref={canvasRef}
width={960}
height={240}
className="w-full object-contain bg-black cursor-ew-resize rounded"
onMouseDown={onMouseDown}
onMouseMove={onMouseMove}
onMouseUp={onMouseUp}
onMouseLeave={onMouseUp}
onTouchStart={onTouchStart}
onTouchMove={onTouchMove}
onTouchEnd={() => { dragRef.current = null; }}
/>
);
}
// PiP mini window
function PiPWindow({ cam, frameUrl, fps, onClose, index }) {
const positions = [
'bottom-2 left-2',
'bottom-2 left-40',
'bottom-2 left-[18rem]',
];
return (
<div className={`absolute ${positions[index] ?? 'bottom-2 left-2'} w-36 rounded border border-cyan-900 overflow-hidden bg-black shadow-lg shadow-black z-10`}>
<div className="flex items-center justify-between px-1.5 py-0.5 bg-gray-950 text-xs">
<span className="text-cyan-700 font-bold">{cam.label}</span>
<div className="flex items-center gap-1">
<span className="text-gray-700">{fps}fps</span>
<button onClick={onClose} className="text-gray-600 hover:text-red-400"></button>
</div>
</div>
{frameUrl ? (
<img src={frameUrl} alt={cam.label} className="w-full aspect-video object-cover block" />
) : (
<div className="w-full aspect-video flex items-center justify-center text-gray-800 text-xs">
no signal
</div>
)}
</div>
);
}
// Camera selector strip
function CameraStrip({ cameras, activeId, pipList, frames, fps, onSelect, onTogglePip }) {
return (
<div className="flex gap-1.5 flex-wrap">
{cameras.map(cam => {
const hasFrame = !!frames[cam.id];
const camFps = fps[cam.id] ?? 0;
const isActive = activeId === cam.id;
const isPip = pipList.includes(cam.id);
return (
<div key={cam.id} className="relative">
<button
onClick={() => onSelect(cam.id)}
className={`flex flex-col items-start rounded border px-2.5 py-1.5 text-xs font-bold transition-colors ${
isActive
? 'border-cyan-500 bg-cyan-950 bg-opacity-50 text-cyan-300'
: hasFrame
? 'border-gray-700 bg-gray-900 text-gray-400 hover:border-cyan-800 hover:text-gray-200'
: 'border-gray-800 bg-gray-950 text-gray-700 hover:border-gray-700'
}`}
>
<span>{cam.label.toUpperCase()}</span>
<span className={`text-xs font-normal mt-0.5 ${
camFps >= 12 ? 'text-green-600' :
camFps > 0 ? 'text-amber-600' :
'text-gray-700'
}`}>
{camFps > 0 ? `${camFps}fps` : 'no signal'}
</span>
</button>
{/* PiP pin button — only when NOT the active camera */}
{!isActive && (
<button
onClick={() => onTogglePip(cam.id)}
title={isPip ? 'Unpin PiP' : 'Pin PiP'}
className={`absolute -top-1.5 -right-1.5 w-4 h-4 rounded-full text-[9px] flex items-center justify-center border transition-colors ${
isPip
? 'bg-cyan-600 border-cyan-400 text-white'
: 'bg-gray-800 border-gray-700 text-gray-600 hover:border-cyan-700 hover:text-cyan-500'
}`}
>
{isPip ? '×' : '⊕'}
</button>
)}
</div>
);
})}
</div>
);
}
// Recording bar
function RecordingBar({ recording, recSeconds, onStart, onStop, onSnapshot, overlayRef }) {
const fmtTime = s => `${String(Math.floor(s / 60)).padStart(2, '0')}:${String(s % 60).padStart(2, '0')}`;
return (
<div className="flex items-center gap-2 flex-wrap">
{!recording ? (
<button
onClick={onStart}
className="flex items-center gap-1.5 px-3 py-1.5 rounded border border-red-900 bg-red-950 text-red-400 hover:bg-red-900 text-xs font-bold transition-colors"
>
<span className="w-2 h-2 rounded-full bg-red-500" />
REC
</button>
) : (
<button
onClick={onStop}
className="flex items-center gap-1.5 px-3 py-1.5 rounded border border-red-600 bg-red-900 text-red-300 hover:bg-red-800 text-xs font-bold animate-pulse"
>
<span className="w-2 h-2 rounded bg-red-400" />
STOP {fmtTime(recSeconds)}
</button>
)}
<button
onClick={() => onSnapshot(overlayRef?.current)}
className="flex items-center gap-1 px-3 py-1.5 rounded border border-gray-700 bg-gray-900 text-gray-400 hover:border-cyan-700 hover:text-cyan-400 text-xs font-bold transition-colors"
>
📷 SNAP
</button>
{recording && (
<span className="text-xs text-red-500 animate-pulse font-mono">
RECORDING {fmtTime(recSeconds)}
</span>
)}
</div>
);
}
// Main component
export function CameraViewer({ subscribe }) {
const {
cameras, frames, fps,
activeId, setActiveId,
pipList, togglePip,
recording, recSeconds,
startRecording, stopRecording,
takeSnapshot,
} = useCamera({ subscribe });
// Detection state
const [faces, setFaces] = useState([]);
const [gestures, setGestures] = useState([]);
const [sceneObjects, setSceneObjects] = useState([]);
const [showOverlay, setShowOverlay] = useState(true);
const [overlayMode, setOverlayMode] = useState('all'); // 'all' | 'faces' | 'gestures' | 'objects' | 'off'
const overlayCanvasRef = useRef(null);
// Subscribe to detection topics
useEffect(() => {
if (!subscribe) return;
const u1 = subscribe('/social/faces/detections', 'saltybot_social_msgs/FaceDetectionArray', msg => {
setFaces(msg.faces ?? []);
});
const u2 = subscribe('/social/gestures', 'saltybot_social_msgs/GestureArray', msg => {
setGestures(msg.gestures ?? []);
});
const u3 = subscribe('/social/scene/objects', 'saltybot_scene_msgs/SceneObjectArray', msg => {
setSceneObjects(msg.objects ?? []);
});
return () => { u1?.(); u2?.(); u3?.(); };
}, [subscribe]);
const activeCam = CAMERA_BY_ID[activeId];
const activeFrame = frames[activeId];
// Filter overlay data based on mode
const visibleFaces = (overlayMode === 'all' || overlayMode === 'faces') ? faces : [];
const visibleGestures = (overlayMode === 'all' || overlayMode === 'gestures') ? gestures : [];
const visibleObjects = (overlayMode === 'all' || overlayMode === 'objects') ? sceneObjects : [];
// Container size tracking (for overlay canvas sizing)
const containerRef = useRef(null);
const [containerSize, setContainerSize] = useState({ w: 640, h: 480 });
useEffect(() => {
if (!containerRef.current) return;
const ro = new ResizeObserver(entries => {
const e = entries[0];
setContainerSize({ w: Math.round(e.contentRect.width), h: Math.round(e.contentRect.height) });
});
ro.observe(containerRef.current);
return () => ro.disconnect();
}, []);
// Quality badge
const camFps = fps[activeId] ?? 0;
const quality = camFps >= 13 ? 'FULL' : camFps >= 8 ? 'GOOD' : camFps > 0 ? 'LOW' : 'NO SIGNAL';
const qualColor = camFps >= 13 ? 'text-green-500' : camFps >= 8 ? 'text-amber-500' : camFps > 0 ? 'text-red-500' : 'text-gray-700';
return (
<div className="space-y-3">
{/* ── Camera strip ── */}
<div className="bg-gray-950 rounded-lg border border-cyan-950 p-3 space-y-2">
<div className="flex items-center justify-between">
<div className="text-cyan-700 text-xs font-bold tracking-widest">CAMERA SELECT</div>
<span className={`text-xs font-bold ${qualColor}`}>{quality} {camFps > 0 ? `${camFps}fps` : ''}</span>
</div>
<CameraStrip
cameras={cameras}
activeId={activeId}
pipList={pipList}
frames={frames}
fps={fps}
onSelect={setActiveId}
onTogglePip={togglePip}
/>
</div>
{/* ── Main viewer ── */}
<div className="bg-gray-950 rounded-lg border border-cyan-950 overflow-hidden">
{/* Viewer toolbar */}
<div className="flex items-center justify-between px-3 py-2 border-b border-cyan-950">
<div className="flex items-center gap-2">
<span className="text-cyan-400 text-xs font-bold">{activeCam?.label ?? '—'}</span>
{activeCam?.isDepth && (
<span className="text-xs text-gray-600 border border-gray-800 rounded px-1">DEPTH · greyscale</span>
)}
{activeCam?.isPanoramic && (
<span className="text-xs text-gray-600 border border-gray-800 rounded px-1">360° · drag to pan</span>
)}
</div>
{/* Overlay mode selector */}
<div className="flex items-center gap-1">
{['off','faces','gestures','objects','all'].map(mode => (
<button
key={mode}
onClick={() => setOverlayMode(mode)}
className={`px-2 py-0.5 rounded text-xs border transition-colors ${
overlayMode === mode
? 'border-cyan-600 bg-cyan-950 text-cyan-400'
: 'border-gray-800 text-gray-600 hover:border-gray-700 hover:text-gray-400'
}`}
>
{mode === 'all' ? 'ALL' : mode === 'off' ? 'OFF' : mode.slice(0,3).toUpperCase()}
</button>
))}
</div>
</div>
{/* Image + overlay */}
<div className="relative" ref={containerRef}>
{activeCam?.isPanoramic ? (
<PanoViewer frameUrl={activeFrame} />
) : activeFrame ? (
<img
src={activeFrame}
alt={activeCam?.label ?? 'camera'}
className="w-full object-contain block bg-black"
style={{ maxHeight: '480px' }}
/>
) : (
<div className="w-full bg-black flex items-center justify-center text-gray-800 text-sm font-mono"
style={{ height: '360px' }}>
<div className="text-center space-y-2">
<div className="text-2xl">📷</div>
<div>Waiting for {activeCam?.label ?? '—'}</div>
<div className="text-xs text-gray-700">{activeCam?.topic}</div>
</div>
</div>
)}
{/* Detection overlay canvas */}
{overlayMode !== 'off' && !activeCam?.isPanoramic && (
<OverlayCanvas
ref={overlayCanvasRef}
faces={visibleFaces}
gestures={visibleGestures}
sceneObjects={visibleObjects}
activeCam={activeCam}
containerW={containerSize.w}
containerH={containerSize.h}
/>
)}
{/* PiP windows */}
{pipList.map((id, idx) => {
const cam = CAMERA_BY_ID[id];
if (!cam) return null;
return (
<PiPWindow
key={id}
cam={cam}
frameUrl={frames[id]}
fps={fps[id] ?? 0}
index={idx}
onClose={() => togglePip(id)}
/>
);
})}
</div>
</div>
{/* ── Recording controls ── */}
<div className="bg-gray-950 rounded-lg border border-cyan-950 p-3">
<div className="flex items-center justify-between mb-2">
<div className="text-cyan-700 text-xs font-bold tracking-widest">CAPTURE</div>
</div>
<RecordingBar
recording={recording}
recSeconds={recSeconds}
onStart={startRecording}
onStop={stopRecording}
onSnapshot={takeSnapshot}
overlayRef={overlayCanvasRef}
/>
<div className="mt-2 text-xs text-gray-700">
Recording saves as MP4/WebM to your Downloads.
Snapshot includes detection overlay + timestamp.
</div>
</div>
{/* ── Detection status ── */}
<div className="grid grid-cols-3 gap-2 text-xs">
<div className="bg-gray-950 rounded border border-gray-800 p-2">
<div className="text-gray-600 mb-1">FACES</div>
<div className={`font-bold ${faces.length > 0 ? 'text-cyan-400' : 'text-gray-700'}`}>
{faces.length > 0 ? `${faces.length} detected` : 'none'}
</div>
{faces.slice(0, 2).map((f, i) => (
<div key={i} className="text-gray-600 truncate">
{f.person_name && f.person_name !== 'unknown'
? `${f.person_name}`
: `↳ unknown #${f.face_id}`}
</div>
))}
</div>
<div className="bg-gray-950 rounded border border-gray-800 p-2">
<div className="text-gray-600 mb-1">GESTURES</div>
<div className={`font-bold ${gestures.length > 0 ? 'text-amber-400' : 'text-gray-700'}`}>
{gestures.length > 0 ? `${gestures.length} active` : 'none'}
</div>
{gestures.slice(0, 2).map((g, i) => {
const icon = GESTURE_ICONS[g.gesture_type] ?? '?';
return (
<div key={i} className="text-gray-600 truncate">
{icon} {g.gesture_type} cam{g.camera_id}
</div>
);
})}
</div>
<div className="bg-gray-950 rounded border border-gray-800 p-2">
<div className="text-gray-600 mb-1">OBJECTS</div>
<div className={`font-bold ${sceneObjects.length > 0 ? 'text-green-400' : 'text-gray-700'}`}>
{sceneObjects.length > 0 ? `${sceneObjects.length} objects` : 'none'}
</div>
{sceneObjects
.filter(o => o.hazard_type > 0)
.slice(0, 2)
.map((o, i) => (
<div key={i} className="text-amber-700 truncate"> {o.class_name}</div>
))
}
{sceneObjects.filter(o => o.hazard_type === 0).slice(0, 2).map((o, i) => (
<div key={`ok${i}`} className="text-gray-600 truncate">
{o.class_name} {o.distance_m > 0 ? `${o.distance_m.toFixed(1)}m` : ''}
</div>
))}
</div>
</div>
{/* ── Legend ── */}
<div className="flex gap-4 text-xs text-gray-700 flex-wrap">
<div className="flex items-center gap-1">
<div className="w-3 h-3 rounded-sm border border-cyan-600" />
Known face
</div>
<div className="flex items-center gap-1">
<div className="w-3 h-3 rounded-sm border border-amber-600" />
Unknown face
</div>
<div className="flex items-center gap-1">
<span>👆</span> Gesture
</div>
<div className="flex items-center gap-1">
<div className="w-3 h-3 rounded-sm border border-green-700 border-dashed" />
Object
</div>
<div className="flex items-center gap-1">
<div className="w-3 h-3 rounded-sm border border-amber-600 border-dashed" />
Hazard
</div>
<div className="ml-auto text-gray-800 italic">
pin = PiP · overlay: {overlayMode}
</div>
</div>
</div>
);
}

View File

@ -0,0 +1,325 @@
/**
* useCamera.js Multi-camera stream manager (Issue #177).
*
* Subscribes to sensor_msgs/CompressedImage topics via rosbridge.
* Decodes base64 JPEG/PNG data URL for <img>/<canvas> display.
* Tracks per-camera FPS. Manages MediaRecorder for recording + snapshots.
*
* Camera sources:
* front / left / rear / right 4× CSI IMX219, 640×480
* topic: /camera/<name>/image_raw/compressed
* color D435i RGB, 640×480
* topic: /camera/color/image_raw/compressed
* depth D435i depth, 640×480 greyscale (PNG16)
* topic: /camera/depth/image_rect_raw/compressed
* panoramic equirect stitch 1920×960
* topic: /camera/panoramic/compressed
*/
import { useState, useEffect, useRef, useCallback } from 'react';
// ── Camera catalogue ──────────────────────────────────────────────────────────
export const CAMERAS = [
{
id: 'front',
label: 'Front',
shortLabel: 'F',
topic: '/camera/front/image_raw/compressed',
msgType: 'sensor_msgs/CompressedImage',
cameraId: 0, // matches gesture_node camera_id
width: 640, height: 480,
},
{
id: 'left',
label: 'Left',
shortLabel: 'L',
topic: '/camera/left/image_raw/compressed',
msgType: 'sensor_msgs/CompressedImage',
cameraId: 1,
width: 640, height: 480,
},
{
id: 'rear',
label: 'Rear',
shortLabel: 'R',
topic: '/camera/rear/image_raw/compressed',
msgType: 'sensor_msgs/CompressedImage',
cameraId: 2,
width: 640, height: 480,
},
{
id: 'right',
label: 'Right',
shortLabel: 'Rt',
topic: '/camera/right/image_raw/compressed',
msgType: 'sensor_msgs/CompressedImage',
cameraId: 3,
width: 640, height: 480,
},
{
id: 'color',
label: 'D435i RGB',
shortLabel: 'D',
topic: '/camera/color/image_raw/compressed',
msgType: 'sensor_msgs/CompressedImage',
cameraId: 4,
width: 640, height: 480,
},
{
id: 'depth',
label: 'Depth',
shortLabel: '≋',
topic: '/camera/depth/image_rect_raw/compressed',
msgType: 'sensor_msgs/CompressedImage',
cameraId: 5,
width: 640, height: 480,
isDepth: true,
},
{
id: 'panoramic',
label: 'Panoramic',
shortLabel: '360',
topic: '/camera/panoramic/compressed',
msgType: 'sensor_msgs/CompressedImage',
cameraId: -1,
width: 1920, height: 960,
isPanoramic: true,
},
];
export const CAMERA_BY_ID = Object.fromEntries(CAMERAS.map(c => [c.id, c]));
export const CAMERA_BY_ROS_ID = Object.fromEntries(
CAMERAS.filter(c => c.cameraId >= 0).map(c => [c.cameraId, c])
);
const TARGET_FPS = 15;
const FPS_INTERVAL = 1000; // ms between FPS counter resets
// ── Hook ──────────────────────────────────────────────────────────────────────
export function useCamera({ subscribe } = {}) {
const [frames, setFrames] = useState(() =>
Object.fromEntries(CAMERAS.map(c => [c.id, null]))
);
const [fps, setFps] = useState(() =>
Object.fromEntries(CAMERAS.map(c => [c.id, 0]))
);
const [activeId, setActiveId] = useState('front');
const [pipList, setPipList] = useState([]); // up to 3 extra camera ids
const [recording, setRecording] = useState(false);
const [recSeconds, setRecSeconds] = useState(0);
// ── Refs (not state — no re-render needed) ─────────────────────────────────
const countRef = useRef(Object.fromEntries(CAMERAS.map(c => [c.id, 0])));
const mediaRecRef = useRef(null);
const chunksRef = useRef([]);
const recTimerRef = useRef(null);
const recordCanvas = useRef(null); // hidden canvas used for recording
const recAnimRef = useRef(null); // rAF handle for record-canvas loop
const latestFrameRef = useRef(Object.fromEntries(CAMERAS.map(c => [c.id, null])));
const latestTsRef = useRef(Object.fromEntries(CAMERAS.map(c => [c.id, 0])));
// ── FPS counter ────────────────────────────────────────────────────────────
useEffect(() => {
const timer = setInterval(() => {
setFps({ ...countRef.current });
const reset = Object.fromEntries(CAMERAS.map(c => [c.id, 0]));
countRef.current = reset;
}, FPS_INTERVAL);
return () => clearInterval(timer);
}, []);
// ── Subscribe all camera topics ────────────────────────────────────────────
useEffect(() => {
if (!subscribe) return;
const unsubs = CAMERAS.map(cam => {
let lastTs = 0;
const interval = Math.floor(1000 / TARGET_FPS); // client-side 15fps gate
return subscribe(cam.topic, cam.msgType, (msg) => {
const now = Date.now();
if (now - lastTs < interval) return; // drop frames > 15fps
lastTs = now;
const fmt = msg.format || 'jpeg';
const mime = fmt.includes('png') || fmt.includes('16UC') ? 'image/png' : 'image/jpeg';
const dataUrl = `data:${mime};base64,${msg.data}`;
latestFrameRef.current[cam.id] = dataUrl;
latestTsRef.current[cam.id] = now;
countRef.current[cam.id] = (countRef.current[cam.id] ?? 0) + 1;
setFrames(prev => ({ ...prev, [cam.id]: dataUrl }));
});
});
return () => unsubs.forEach(fn => fn?.());
}, [subscribe]);
// ── Create hidden record canvas ────────────────────────────────────────────
useEffect(() => {
const c = document.createElement('canvas');
c.width = 640;
c.height = 480;
c.style.display = 'none';
document.body.appendChild(c);
recordCanvas.current = c;
return () => { c.remove(); };
}, []);
// ── Draw loop for record canvas ────────────────────────────────────────────
// Runs at TARGET_FPS when recording — draws active frame to hidden canvas
const startRecordLoop = useCallback(() => {
const canvas = recordCanvas.current;
if (!canvas) return;
const step = () => {
const cam = CAMERA_BY_ID[activeId];
const src = latestFrameRef.current[activeId];
const ctx = canvas.getContext('2d');
if (!cam || !src) {
recAnimRef.current = requestAnimationFrame(step);
return;
}
// Resize canvas to match source
if (canvas.width !== cam.width || canvas.height !== cam.height) {
canvas.width = cam.width;
canvas.height = cam.height;
}
const img = new Image();
img.onload = () => {
ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
};
img.src = src;
recAnimRef.current = setTimeout(step, Math.floor(1000 / TARGET_FPS));
};
recAnimRef.current = setTimeout(step, 0);
}, [activeId]);
const stopRecordLoop = useCallback(() => {
if (recAnimRef.current) {
clearTimeout(recAnimRef.current);
cancelAnimationFrame(recAnimRef.current);
recAnimRef.current = null;
}
}, []);
// ── Recording ──────────────────────────────────────────────────────────────
const startRecording = useCallback(() => {
const canvas = recordCanvas.current;
if (!canvas || recording) return;
startRecordLoop();
const stream = canvas.captureStream(TARGET_FPS);
const mimeType =
MediaRecorder.isTypeSupported('video/mp4') ? 'video/mp4' :
MediaRecorder.isTypeSupported('video/webm;codecs=vp9') ? 'video/webm;codecs=vp9' :
MediaRecorder.isTypeSupported('video/webm;codecs=vp8') ? 'video/webm;codecs=vp8' :
'video/webm';
chunksRef.current = [];
const mr = new MediaRecorder(stream, { mimeType, videoBitsPerSecond: 2_500_000 });
mr.ondataavailable = e => { if (e.data?.size > 0) chunksRef.current.push(e.data); };
mr.start(200);
mediaRecRef.current = mr;
setRecording(true);
setRecSeconds(0);
recTimerRef.current = setInterval(() => setRecSeconds(s => s + 1), 1000);
}, [recording, startRecordLoop]);
const stopRecording = useCallback(() => {
const mr = mediaRecRef.current;
if (!mr || mr.state === 'inactive') return;
mr.onstop = () => {
const ext = mr.mimeType.includes('mp4') ? 'mp4' : 'webm';
const blob = new Blob(chunksRef.current, { type: mr.mimeType });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = `saltybot-${activeId}-${Date.now()}.${ext}`;
a.click();
URL.revokeObjectURL(url);
};
mr.stop();
stopRecordLoop();
clearInterval(recTimerRef.current);
setRecording(false);
}, [activeId, stopRecordLoop]);
// ── Snapshot ───────────────────────────────────────────────────────────────
const takeSnapshot = useCallback((overlayCanvasEl) => {
const src = latestFrameRef.current[activeId];
if (!src) return;
const cam = CAMERA_BY_ID[activeId];
const canvas = document.createElement('canvas');
canvas.width = cam.width;
canvas.height = cam.height;
const ctx = canvas.getContext('2d');
const img = new Image();
img.onload = () => {
ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
// Composite detection overlay if provided
if (overlayCanvasEl) {
ctx.drawImage(overlayCanvasEl, 0, 0, canvas.width, canvas.height);
}
// Timestamp watermark
ctx.fillStyle = 'rgba(0,0,0,0.5)';
ctx.fillRect(0, canvas.height - 20, canvas.width, 20);
ctx.fillStyle = '#06b6d4';
ctx.font = '11px monospace';
ctx.fillText(`SALTYBOT ${cam.label} ${new Date().toISOString()}`, 8, canvas.height - 6);
canvas.toBlob(blob => {
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = `saltybot-snap-${activeId}-${Date.now()}.png`;
a.click();
URL.revokeObjectURL(url);
}, 'image/png');
};
img.src = src;
}, [activeId]);
// ── PiP management ─────────────────────────────────────────────────────────
const togglePip = useCallback(id => {
setPipList(prev => {
if (prev.includes(id)) return prev.filter(x => x !== id);
const next = [...prev, id].filter(x => x !== activeId);
return next.slice(-3); // max 3 PIPs
});
}, [activeId]);
// Remove PiP if it becomes the active camera
useEffect(() => {
setPipList(prev => prev.filter(id => id !== activeId));
}, [activeId]);
return {
cameras: CAMERAS,
frames,
fps,
activeId, setActiveId,
pipList, togglePip,
recording, recSeconds,
startRecording, stopRecording,
takeSnapshot,
};
}