Compare commits

..

26 Commits

Author SHA1 Message Date
479a33a6fa feat(mechanical): parametric cable management clips (Issue #264)
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (pull_request) Failing after 9s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (pull_request) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (pull_request) Has been cancelled
2026-03-02 20:44:27 -05:00
d1f0e95fa2 Merge pull request 'feat(social): face-tracking head servo controller (Issue #279)' (#284) from sl-jetson/issue-279-face-track-servo into main
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (push) Failing after 14s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (push) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (push) Has been cancelled
2026-03-02 20:44:00 -05:00
5f3b5caef7 Merge pull request 'IMU calibration routine (Issue #278)' (#282) from sl-controls/issue-278-imu-cal into main 2026-03-02 20:43:52 -05:00
cb8f6c82a4 Merge pull request 'feat(perception): HSV color object segmenter (Issue #274)' (#281) from sl-perception/issue-274-color-segment into main 2026-03-02 20:43:46 -05:00
de1166058c Merge pull request 'feat: Add cooling fan PWM speed controller (Issue #263)' (#276) from sl-firmware/issue-263-fan-pwm into main 2026-03-02 20:43:38 -05:00
c3d36e9943 feat(social): face-tracking head servo controller — Issue #279
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (push) Failing after 8s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (push) Has been skipped
social-bot integration tests / Lint (flake8 + pep257) (pull_request) Failing after 8s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (pull_request) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (push) Has been cancelled
social-bot integration tests / Latency profiling (GPU, Orin) (pull_request) Has been cancelled
Adds face_track_servo_node to saltybot_social:
- Subscribes /social/faces/detected (FaceDetectionArray)
- Picks closest face by largest bbox area (proximity proxy)
- Computes pan/tilt error from bbox centre vs image centre using
  configurable FOV (fov_h_deg=60°, fov_v_deg=45°)
- Independent PID controllers for pan and tilt (velocity/incremental
  output with anti-windup); servo position integrates velocity*dt
- Clamps commands to ±pan_limit_deg / ±tilt_limit_deg
- Returns to centre at return_rate_deg_s when face lost >lost_timeout_s
- Dead zone suppresses jitter for small errors
- Publishes Float32 on /saltybot/head_pan and /saltybot/head_tilt
- 81/81 tests passing

Closes #279
2026-03-02 17:38:02 -05:00
dd033b9827 feat(controls): IMU calibration routine (Issue #278)
Implements ROS2 IMU gyro + accel calibration node with:
- Service-triggered calibration via /saltybot/calibrate_imu
- Optional auto-calibration on startup (configurable)
- Collects N stationary samples (default 100)
- Computes mean bias offsets for gyro and accel
- Publishes bias-corrected IMU on /imu/calibrated
- Includes 10+ unit tests for calibration logic

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
2026-03-02 17:35:30 -05:00
f5093ecd34 feat(perception): HSV color object segmenter — Issue #274
- Add ColorDetection.msg + ColorDetectionArray.msg to saltybot_scene_msgs
- Add _color_segmenter.py: HsvRange/ColorBlob types, COLOR_RANGES defaults,
  mask_for_color() (dual-band red wrap), find_color_blobs() with morph open,
  contour extraction, area filter and max-blob-per-color limit
- Add color_segment_node.py: subscribes /camera/color/image_raw (BEST_EFFORT),
  publishes /saltybot/color_objects (ColorDetectionArray) per frame;
  active_colors, min_area_px, max_blobs_per_color params
- Add saltybot_scene_msgs exec_depend to saltybot_bringup/package.xml
- Register color_segmenter console_script in setup.py
- 34/34 unit tests pass (no ROS2 required)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 17:32:41 -05:00
54bc37926b Merge pull request 'feat(webui): system log tail viewer (#275)' (#277) from sl-webui/issue-275-log-viewer into main 2026-03-02 17:31:01 -05:00
30ad71e7d8 Merge pull request 'feat(social): proximity-based greeting trigger (Issue #270)' (#272) from sl-jetson/issue-270-greeting-trigger into main
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (push) Failing after 10s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (push) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (push) Has been cancelled
2026-03-02 17:30:54 -05:00
b6104763c5 Merge pull request 'feat(controls): Wheel slip detector (Issue #262)' (#266) from sl-controls/issue-262-wheel-slip into main
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (push) Has been cancelled
social-bot integration tests / Core integration tests (mock sensors, no GPU) (push) Has been cancelled
social-bot integration tests / Latency profiling (GPU, Orin) (push) Has been cancelled
2026-03-02 17:30:52 -05:00
5108fa8fa1 feat(controls): Wheel slip detector (Issue #262)
Detect wheel slip by comparing commanded velocity vs actual encoder velocity.
Publishes Bool flag on /saltybot/wheel_slip_detected when slip detected >0.5s.

Features:
- Subscribe to /cmd_vel (commanded) and /odom (actual velocity)
- Compare velocity magnitudes with 0.1 m/s threshold
- Persistence: slip must persist >0.5s to trigger (debounces transients)
- Publish Bool on /saltybot/wheel_slip_detected with detection status
- 10Hz monitoring frequency, configurable parameters

Algorithm:
- Compute linear speed from x,y components
- Calculate velocity difference
- If exceeds threshold: increment slip duration
- If duration > timeout: declare slip detected

Benefits:
- Detects environmental slip (ice, mud, wet surfaces)
- Triggers speed reduction to maintain traction
- Prevents wheel spinning/rut digging
- Safety response for loss of grip

Topics:
- Subscribed: /cmd_vel (Twist), /odom (Odometry)
- Published: /saltybot/wheel_slip_detected (Bool)

Config: frequency=10Hz, slip_threshold=0.1 m/s, slip_timeout=0.5s

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
2026-03-02 17:29:03 -05:00
5362536fb1 feat(webui): waypoint editor with click-to-navigate (Issue #261) 2026-03-02 17:29:03 -05:00
305ce6c971 feat(webui): system log tail viewer (Issue #275)
Real-time ROS log stream viewer with:
- Subscribes to /rosout (rcl_interfaces/Log)
- Severity-based color coding:
  DEBUG=grey | INFO=white | WARN=yellow | ERROR=red | FATAL=magenta
- Filter by severity level (multi-select toggle)
- Filter by node name (text input)
- Auto-scroll to latest logs
- Max 500 logs in history (configurable)
- Scrolling log output in monospace font
- Proper timestamp formatting (HH:MM:SS)

Integrated into MONITORING tab group as 'Logs' tab alongside 'Events'.
Follows established React/Tailwind patterns from other dashboard components.

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
2026-03-02 17:28:28 -05:00
c7dd07f9ed feat(social): proximity-based greeting trigger — Issue #270
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (pull_request) Failing after 12s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (pull_request) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (pull_request) Has been cancelled
Adds greeting_trigger_node to saltybot_social:
- Subscribes to /social/faces/detected (FaceDetectionArray) for face arrivals
- Subscribes to /social/person_states (PersonStateArray) to cache face_id→distance
- Fires greeting when face_id is within proximity_m (default 2m) and
  not in per-face_id cooldown window (default 300s)
- Publishes JSON on /saltybot/greeting_trigger:
  {face_id, person_name, distance_m, ts}
- unknown_distance param controls assumed distance for faces with no PersonState yet
- Thread-safe distance cache and greeted map
- 50/50 tests passing

Closes #270
2026-03-02 17:26:40 -05:00
0776003dd3 Merge pull request 'feat(webui): robot status dashboard header (#269)' (#273) from sl-webui/issue-269-status-header into main 2026-03-02 17:26:40 -05:00
01ee02f837 Merge pull request 'feat(bringup): D435i depth hole filler via bilateral interpolation (Issue #268)' (#271) from sl-perception/issue-268-depth-holes into main 2026-03-02 17:26:22 -05:00
f0e11fe7ca feat(bringup): depth image hole filler via bilateral interpolation (Issue #268)
Adds multi-pass spatial-Gaussian hole filler for D435i depth images.
Each pass replaces zero/NaN pixels with the Gaussian-weighted mean of valid
neighbours in a growing kernel (×1, ×2.5, ×6 default); original valid
pixels are never modified.  Handles uint16 mm → float32 m conversion,
border pixels via BORDER_REFLECT, and above-d_max pixels as holes.
Publishes filled float32 depth on /camera/depth/filled at camera rate.
37/37 pure-Python tests pass (no ROS2 required).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 14:19:27 -05:00
201dea4c01 feat(webui): robot status dashboard header bar (Issue #269)
Persistent top bar showing real-time robot health indicators:
- Battery percentage and voltage (color-coded: green >60%, amber 30-60%, red <30%)
- WiFi signal strength (RSSI dBm with quality assessment)
- Motor status and current draw in Amperes
- Emergency state indicator (red highlight when active)
- System uptime in hours and minutes
- Current operational mode (idle/nav/social/docking)
- Connection status indicator

Component subscribes to relevant ROS topics and displays in compact
flex layout matching dashboard dark theme.

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
2026-03-02 14:19:03 -05:00
477258f321 Merge pull request 'feat(webui): waypoint editor with click-to-navigate (#261)' (#267) from sl-webui/issue-261-waypoint-editor-fix into main 2026-03-02 14:13:54 -05:00
94a6f0787e Merge pull request 'feat(bringup): visual odometry drift detector (Issue #260)' (#265) from sl-perception/issue-260-vo-drift into main 2026-03-02 14:13:34 -05:00
50636de5a9 Merge pull request 'feat(social): ambient sound classifier via mel-spectrogram (Issue #252)' (#258) from sl-jetson/issue-252-ambient-sound into main
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (push) Failing after 11s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (push) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (push) Has been cancelled
2026-03-02 14:13:33 -05:00
c348e093ef feat: Add cooling fan PWM speed controller (Issue #263)
Implements STM32F722 driver for brushless cooling fan on PA9 using TIM1_CH2 PWM.
Features:
- Temperature-based speed curve: off <40°C, 30% at 50°C, 100% at 70°C
- Smooth speed ramp transitions with configurable rate (default 0.05%/ms)
- Linear interpolation between curve points
- PWM duty cycle control (0-100%)
- State transitions and edge case handling

All 51 unit tests passing:
- Temperature curve verification (6 test zones)
- Speed boundaries and transitions
- Ramp timing and rate control
- PWM duty cycle calculation
- Temperature extremes and boundary conditions

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
2026-03-02 13:29:18 -05:00
c865e84e16 feat(webui): waypoint editor with click-to-navigate (Issue #261)
Interactive waypoint editor for Nav2 goal-based navigation:
- Click on map display to place waypoints
- Drag waypoints in list to reorder navigation sequence
- Right-click waypoints to delete them
- Visual waypoint overlay on map with numbering
- Robot position indicator at center
- Waypoint list sidebar with selection and ordering
- Send Nav2 goal to individual selected waypoint
- Execute all waypoints in sequence with automatic progression
- Clear all waypoints button
- Real-time coordinate display and robot pose tracking
- Integrated into new NAVIGATION tab group
- Uses /navigate_to_pose service for goal publishing

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
2026-03-02 13:28:01 -05:00
9d12805843 feat(bringup): visual odometry drift detector (Issue #260)
Adds sliding-window drift detector that compares cumulative path lengths
of visual odom and wheel odom over a configurable window (default 10 s).
Drift = |vo_path − wheel_path|; flagged when ≥ 0.5 m (configurable).
OdomBuffer handles per-source rolling storage with automatic age eviction.
Publishes Bool on /saltybot/vo_drift_detected and Float32 on
/saltybot/vo_drift_magnitude at 2 Hz.  27/27 pure-Python tests pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 13:26:07 -05:00
3cd9faeed9 feat(social): ambient sound classifier via mel-spectrogram — Issue #252
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (pull_request) Failing after 2s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (pull_request) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (pull_request) Has been cancelled
Adds ambient_sound_node to saltybot_social:
- Accumulates 1 s of PCM-16 audio from /social/speech/audio_raw
- Extracts mel-spectrogram feature vector (energy_db, zcr, mel_centroid,
  mel_flatness, low_ratio, high_ratio) using pure numpy (no torch/onnx)
- Priority-cascade classifier: silence → music → speech → crowd → outdoor → alarm
- Publishes label as std_msgs/String on /saltybot/ambient_sound on each buffer fill
- All 11 thresholds exposed as ROS parameters (yaml + launch file)
- numpy-free energy-only fallback for edge environments
- 77/77 tests passing

Closes #252
2026-03-02 13:22:38 -05:00
49 changed files with 6075 additions and 63 deletions

162
include/fan.h Normal file
View File

@ -0,0 +1,162 @@
#ifndef FAN_H
#define FAN_H
#include <stdint.h>
#include <stdbool.h>
/*
* fan.h Cooling fan PWM speed controller (Issue #263)
*
* STM32F722 driver for brushless cooling fan on PA9 using TIM1_CH2 PWM.
* Temperature-based speed curve with smooth ramp transitions.
*
* Pin: PA9 (TIM1_CH2, alternate function AF1)
* PWM Frequency: 25 kHz (suitable for brushless DC fan)
* Speed Range: 0-100% duty cycle
*
* Temperature Curve:
* - Below 40°C: Fan off (0%)
* - 40-50°C: Linear ramp from 0% to 30%
* - 50-70°C: Linear ramp from 30% to 100%
* - Above 70°C: Fan at maximum (100%)
*/
/* Fan speed state */
typedef enum {
FAN_OFF, /* Motor disabled (0% duty) */
FAN_LOW, /* Low speed (5-30%) */
FAN_MEDIUM, /* Medium speed (31-60%) */
FAN_HIGH, /* High speed (61-99%) */
FAN_FULL /* Maximum speed (100%) */
} FanState;
/*
* fan_init()
*
* Initialize fan controller:
* - PA9 as TIM1_CH2 PWM output
* - TIM1 configured for 25 kHz frequency
* - PWM duty cycle control (0-100%)
* - Ramp rate limiter for smooth transitions
*/
void fan_init(void);
/*
* fan_set_speed(percentage)
*
* Set fan speed directly (bypasses temperature control).
* Used for manual testing or emergency cooling.
*
* Arguments:
* - percentage: 0-100% duty cycle
*
* Returns: true if set successfully, false if invalid value
*/
bool fan_set_speed(uint8_t percentage);
/*
* fan_get_speed()
*
* Get current fan speed setting.
*
* Returns: Current speed 0-100%
*/
uint8_t fan_get_speed(void);
/*
* fan_set_target_speed(percentage)
*
* Set target speed with smooth ramping.
* Speed transitions over time according to ramp rate.
*
* Arguments:
* - percentage: Target speed 0-100%
*
* Returns: true if set successfully
*/
bool fan_set_target_speed(uint8_t percentage);
/*
* fan_update_temperature(temp_celsius)
*
* Update temperature reading and apply speed curve.
* Calculates target speed based on temperature curve.
* Speed transition is smoothed via ramp limiter.
*
* Temperature Curve:
* - temp < 40°C: 0% (off)
* - 40°C temp < 50°C: 0% + (temp - 40) * 3% per °C = linear to 30%
* - 50°C temp < 70°C: 30% + (temp - 50) * 3.5% per °C = linear to 100%
* - temp 70°C: 100% (full)
*
* Arguments:
* - temp_celsius: Temperature in degrees Celsius (int16_t for negative values)
*/
void fan_update_temperature(int16_t temp_celsius);
/*
* fan_get_temperature()
*
* Get last recorded temperature.
*
* Returns: Temperature in °C (or 0 if not yet set)
*/
int16_t fan_get_temperature(void);
/*
* fan_get_state()
*
* Get current fan operational state.
*
* Returns: FAN_OFF, FAN_LOW, FAN_MEDIUM, FAN_HIGH, or FAN_FULL
*/
FanState fan_get_state(void);
/*
* fan_set_ramp_rate(percentage_per_ms)
*
* Configure speed ramp rate for smooth transitions.
* Default: 5% per 100ms = 0.05% per ms.
* Higher values = faster transitions.
*
* Arguments:
* - percentage_per_ms: Speed change per millisecond (e.g., 1 = 1% per ms)
*
* Typical ranges:
* - 0.01 = very slow (100% change in 10 seconds)
* - 0.05 = slow (100% change in 2 seconds)
* - 0.1 = medium (100% change in 1 second)
* - 1.0 = fast (100% change in 100ms)
*/
void fan_set_ramp_rate(float percentage_per_ms);
/*
* fan_is_ramping()
*
* Check if speed is currently transitioning.
*
* Returns: true if speed is ramping toward target, false if at target
*/
bool fan_is_ramping(void);
/*
* fan_tick(now_ms)
*
* Update function called periodically (recommended: every 10-100ms).
* Processes speed ramp transitions.
* Must be called regularly for smooth ramping operation.
*
* Arguments:
* - now_ms: current time in milliseconds (from HAL_GetTick() or similar)
*/
void fan_tick(uint32_t now_ms);
/*
* fan_disable()
*
* Disable fan immediately (set to 0% duty).
* Useful for shutdown or emergency stop.
*/
void fan_disable(void);
#endif /* FAN_H */

View File

@ -25,6 +25,8 @@
<exec_depend>saltybot_follower</exec_depend> <exec_depend>saltybot_follower</exec_depend>
<exec_depend>saltybot_outdoor</exec_depend> <exec_depend>saltybot_outdoor</exec_depend>
<exec_depend>saltybot_perception</exec_depend> <exec_depend>saltybot_perception</exec_depend>
<!-- HSV color segmentation messages (Issue #274) -->
<exec_depend>saltybot_scene_msgs</exec_depend>
<exec_depend>saltybot_uwb</exec_depend> <exec_depend>saltybot_uwb</exec_depend>
<buildtool_depend>ament_python</buildtool_depend> <buildtool_depend>ament_python</buildtool_depend>

View File

@ -0,0 +1,184 @@
"""
_color_segmenter.py HSV color segmentation helpers (no ROS2 deps).
Algorithm
---------
For each requested color:
1. Convert BGR HSV (OpenCV: H[0,180], S[0,255], V[0,255])
2. Build a binary mask via cv2.inRange using the color's HSV bounds.
Red wraps around H=0/180 so two ranges are OR-combined.
3. Morphological open (3×3) to remove noise.
4. Find external contours; filter by min_area_px.
5. Return ColorBlob NamedTuples one per surviving contour.
confidence is the contour area divided by the bounding-rectangle area
(how "filled" the bounding box is), clamped to [0, 1].
Public API
----------
HsvRange(h_lo, h_hi, s_lo, s_hi, v_lo, v_hi)
ColorBlob(color_name, confidence, cx, cy, w, h, area_px, contour_id)
COLOR_RANGES : Dict[str, List[HsvRange]] default per-color HSV ranges
mask_for_color(hsv, color_name) -> np.ndarray uint8 binary mask
find_color_blobs(bgr, active_colors, min_area_px, max_blobs_per_color) -> List[ColorBlob]
"""
from __future__ import annotations
from typing import Dict, List, NamedTuple
import numpy as np
# ── Data types ────────────────────────────────────────────────────────────────
class HsvRange(NamedTuple):
"""Single HSV band (OpenCV: H∈[0,180], S/V∈[0,255])."""
h_lo: int
h_hi: int
s_lo: int
s_hi: int
v_lo: int
v_hi: int
class ColorBlob(NamedTuple):
"""One detected color object in image coordinates."""
color_name: str
confidence: float # contour_area / bbox_area (01)
cx: float # bbox centre x (pixels)
cy: float # bbox centre y (pixels)
w: float # bbox width (pixels)
h: float # bbox height (pixels)
area_px: float # contour area (pixels²)
contour_id: int # 0-based index within this color in this frame
# ── Default per-color HSV ranges ──────────────────────────────────────────────
# Two ranges are used for red (wraps at 0/180).
# S_lo=60, V_lo=50 to ignore desaturated / near-black pixels.
COLOR_RANGES: Dict[str, List[HsvRange]] = {
'red': [
HsvRange(h_lo=0, h_hi=10, s_lo=60, s_hi=255, v_lo=50, v_hi=255),
HsvRange(h_lo=170, h_hi=180, s_lo=60, s_hi=255, v_lo=50, v_hi=255),
],
'green': [
HsvRange(h_lo=35, h_hi=85, s_lo=60, s_hi=255, v_lo=50, v_hi=255),
],
'blue': [
HsvRange(h_lo=90, h_hi=130, s_lo=60, s_hi=255, v_lo=50, v_hi=255),
],
'yellow': [
HsvRange(h_lo=18, h_hi=38, s_lo=60, s_hi=255, v_lo=80, v_hi=255),
],
'orange': [
HsvRange(h_lo=8, h_hi=20, s_lo=80, s_hi=255, v_lo=80, v_hi=255),
],
}
# Structuring element for morphological open (noise removal)
_MORPH_KERNEL = None
def _get_morph_kernel():
import cv2
global _MORPH_KERNEL
if _MORPH_KERNEL is None:
_MORPH_KERNEL = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
return _MORPH_KERNEL
# ── Public helpers ─────────────────────────────────────────────────────────────
def mask_for_color(hsv: np.ndarray, color_name: str) -> np.ndarray:
"""
Return a uint8 binary mask (255=foreground) for *color_name* in the HSV image.
Parameters
----------
hsv : (H, W, 3) uint8 ndarray in OpenCV HSV format (H[0,180])
color_name : one of COLOR_RANGES keys
Returns
-------
(H, W) uint8 ndarray
"""
import cv2
ranges = COLOR_RANGES.get(color_name)
if not ranges:
raise ValueError(f'Unknown color: {color_name!r}. Known: {list(COLOR_RANGES)}')
mask = np.zeros(hsv.shape[:2], dtype=np.uint8)
for r in ranges:
lo = np.array([r.h_lo, r.s_lo, r.v_lo], dtype=np.uint8)
hi = np.array([r.h_hi, r.s_hi, r.v_hi], dtype=np.uint8)
mask |= cv2.inRange(hsv, lo, hi)
return cv2.morphologyEx(mask, cv2.MORPH_OPEN, _get_morph_kernel())
def find_color_blobs(
bgr: np.ndarray,
active_colors: List[str] | None = None,
min_area_px: float = 200.0,
max_blobs_per_color: int = 10,
) -> List[ColorBlob]:
"""
Detect HSV-segmented color blobs in a BGR image.
Parameters
----------
bgr : (H, W, 3) uint8 BGR ndarray
active_colors : color names to detect; None all COLOR_RANGES keys
min_area_px : minimum contour area to report (pixels²)
max_blobs_per_color : keep at most this many blobs per color (largest first)
Returns
-------
List[ColorBlob] may be empty; contour_id is 0-based within each color
"""
import cv2
if active_colors is None:
active_colors = list(COLOR_RANGES.keys())
hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
blobs: List[ColorBlob] = []
for color_name in active_colors:
mask = mask_for_color(hsv, color_name)
contours, _ = cv2.findContours(
mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Sort largest first so max_blobs_per_color keeps the significant ones
contours = sorted(contours, key=cv2.contourArea, reverse=True)
blob_idx = 0
for cnt in contours:
if blob_idx >= max_blobs_per_color:
break
area = cv2.contourArea(cnt)
if area < min_area_px:
break # already sorted, no need to continue
x, y, bw, bh = cv2.boundingRect(cnt)
bbox_area = float(bw * bh)
confidence = float(area / bbox_area) if bbox_area > 0 else 0.0
confidence = min(1.0, max(0.0, confidence))
blobs.append(ColorBlob(
color_name=color_name,
confidence=confidence,
cx=float(x + bw / 2.0),
cy=float(y + bh / 2.0),
w=float(bw),
h=float(bh),
area_px=float(area),
contour_id=blob_idx,
))
blob_idx += 1
return blobs

View File

@ -0,0 +1,141 @@
"""
_depth_hole_fill.py Depth image hole filling via bilateral interpolation (no ROS2 deps).
Algorithm
---------
A "hole" is any pixel where depth == 0, depth is NaN, or depth is outside the
valid range [d_min, d_max].
Each pass replaces every hole pixel with the spatial-Gaussian-weighted mean of
valid pixels in a (kernel_size × kernel_size) neighbourhood:
filled[x,y] = Σ G(||p - q||; σ) · d[q] / Σ G(||p - q||; σ)
q valid neighbours of (x,y)
The denominator (sum of spatial weights over valid pixels) normalises correctly
even at image borders and around isolated valid pixels.
Multiple passes with geometrically growing kernels are applied so that:
Pass 1 kernel_size fills small holes ( kernel_size/2 px radius)
Pass 2 kernel_size × 2.5 fills medium holes
Pass 3 kernel_size × 6.0 fills large holes / fronto-parallel surfaces
After all passes any remaining zeros are left as-is (no valid neighbourhood data).
Because only the spatial Gaussian (not a depth range term) is used as the weighting
function, this is equivalent to a bilateral filter with σ_range . In practice
this produces smooth, physically plausible fills in the depth domain.
Public API
----------
fill_holes(depth, kernel_size=5, d_min=0.1, d_max=10.0, max_passes=3) ndarray
valid_mask(depth, d_min=0.1, d_max=10.0) bool ndarray
"""
from __future__ import annotations
import math
from typing import Optional
import numpy as np
# Kernel size multipliers for successive passes
_PASS_SCALE = [1.0, 2.5, 6.0]
def valid_mask(
depth: np.ndarray,
d_min: float = 0.1,
d_max: float = 10.0,
) -> np.ndarray:
"""
Return a boolean mask of valid (non-hole) pixels.
Parameters
----------
depth : (H, W) float32 ndarray, depth in metres
d_min : minimum valid depth (m)
d_max : maximum valid depth (m)
Returns
-------
(H, W) bool ndarray True where depth is finite and in [d_min, d_max]
"""
return np.isfinite(depth) & (depth >= d_min) & (depth <= d_max)
def fill_holes(
depth: np.ndarray,
kernel_size: int = 5,
d_min: float = 0.1,
d_max: float = 10.0,
max_passes: int = 3,
) -> np.ndarray:
"""
Fill zero/NaN depth pixels using multi-pass spatial Gaussian interpolation.
Parameters
----------
depth : (H, W) float32 ndarray, depth in metres
kernel_size : initial kernel side length (pixels, forced odd, 3)
d_min : minimum valid depth pixels below this are treated as holes
d_max : maximum valid depth pixels above this are treated as holes
max_passes : number of fill passes (13); each uses a larger kernel
Returns
-------
(H, W) float32 ndarray same as input, with holes filled where possible.
Pixels with no valid neighbours after all passes remain 0.0.
Original valid pixels are never modified.
"""
import cv2
depth = np.asarray(depth, dtype=np.float32)
# Replace NaN with 0 so arithmetic is clean
depth = np.where(np.isfinite(depth), depth, 0.0).astype(np.float32)
mask = valid_mask(depth, d_min, d_max) # True where already valid
result = depth.copy()
n_passes = max(1, min(max_passes, len(_PASS_SCALE)))
for i in range(n_passes):
if mask.all():
break # no holes left
ks = _odd_kernel_size(kernel_size, _PASS_SCALE[i])
half = ks // 2
sigma = max(half / 2.0, 0.5)
gk = cv2.getGaussianKernel(ks, sigma).astype(np.float32)
kernel = (gk @ gk.T)
# Multiply depth by mask so invalid pixels contribute 0 weight
d_valid = np.where(mask, result, 0.0).astype(np.float32)
w_valid = mask.astype(np.float32)
sum_d = cv2.filter2D(d_valid, ddepth=-1, kernel=kernel,
borderType=cv2.BORDER_REFLECT)
sum_w = cv2.filter2D(w_valid, ddepth=-1, kernel=kernel,
borderType=cv2.BORDER_REFLECT)
# Where we have enough weight, compute the weighted mean
has_data = sum_w > 1e-6
interp = np.where(has_data, sum_d / np.where(has_data, sum_w, 1.0), 0.0)
# Only fill holes — never overwrite original valid pixels
result = np.where(mask, result, interp.astype(np.float32))
# Update mask with newly filled pixels (for the next pass)
newly_filled = (~mask) & (result > 0.0)
mask = mask | newly_filled
return result.astype(np.float32)
# ── Internal helpers ──────────────────────────────────────────────────────────
def _odd_kernel_size(base: int, scale: float) -> int:
"""Return the nearest odd integer to base * scale, minimum 3."""
raw = max(3, int(round(base * scale)))
return raw if raw % 2 == 1 else raw + 1

View File

@ -0,0 +1,150 @@
"""
_vo_drift.py Visual odometry drift detector helpers (no ROS2 deps).
Algorithm
---------
Two independent odometry streams (visual and wheel) are compared over a
sliding time window. Drift is measured as the absolute difference in
cumulative path length travelled by each source over that window:
drift_m = |path_length(vo_window) path_length(wheel_window)|
Using cumulative path length (sum of inter-sample Euclidean steps) rather
than straight-line displacement makes the measure robust to circular motion
where start and end positions are the same.
Drift is flagged when drift_m drift_threshold_m.
Public API
----------
OdomSample namedtuple(t, x, y)
OdomBuffer deque of OdomSamples with time-window trimming
compute_drift() compare two OdomBuffers and return DriftResult
DriftResult namedtuple(drift_m, vo_path_m, wheel_path_m,
is_drifting, window_s, n_vo, n_wheel)
"""
from __future__ import annotations
import math
from collections import deque
from typing import NamedTuple, Sequence
class OdomSample(NamedTuple):
t: float # monotonic timestamp (seconds)
x: float # position x (metres)
y: float # position y (metres)
class DriftResult(NamedTuple):
drift_m: float # |vo_path wheel_path| (metres)
vo_path_m: float # cumulative path of VO source over window (metres)
wheel_path_m: float # cumulative path of wheel source over window (metres)
is_drifting: bool # True when drift_m >= threshold
window_s: float # actual time span of data used (seconds)
n_vo: int # number of VO samples in window
n_wheel: int # number of wheel samples in window
class OdomBuffer:
"""
Rolling buffer of OdomSamples trimmed to the last `max_age_s` seconds.
Parameters
----------
max_age_s : float samples older than this are discarded (seconds)
"""
def __init__(self, max_age_s: float = 10.0) -> None:
self._max_age = max_age_s
self._buf: deque[OdomSample] = deque()
# ── Public ────────────────────────────────────────────────────────────────
def push(self, sample: OdomSample) -> None:
"""Append a sample and evict anything older than max_age_s."""
self._buf.append(sample)
self._trim(sample.t)
def window(self, window_s: float, now: float) -> list[OdomSample]:
"""Return samples within the last window_s seconds of `now`."""
cutoff = now - window_s
return [s for s in self._buf if s.t >= cutoff]
def clear(self) -> None:
self._buf.clear()
def __len__(self) -> int:
return len(self._buf)
# ── Internal ──────────────────────────────────────────────────────────────
def _trim(self, now: float) -> None:
cutoff = now - self._max_age
while self._buf and self._buf[0].t < cutoff:
self._buf.popleft()
# ── Core computation ──────────────────────────────────────────────────────────
def compute_drift(
vo_buf: OdomBuffer,
wheel_buf: OdomBuffer,
window_s: float,
drift_threshold_m: float,
now: float,
) -> DriftResult:
"""
Compare VO and wheel odometry path lengths over the last `window_s`.
Parameters
----------
vo_buf : OdomBuffer of visual odometry samples
wheel_buf : OdomBuffer of wheel odometry samples
window_s : comparison window width (seconds)
drift_threshold_m : drift_m threshold for is_drifting flag
now : current time (same scale as OdomSample.t)
Returns
-------
DriftResult zero drift if either buffer has fewer than 2 samples.
"""
vo_samples = vo_buf.window(window_s, now)
wheel_samples = wheel_buf.window(window_s, now)
if len(vo_samples) < 2 or len(wheel_samples) < 2:
return DriftResult(
drift_m=0.0, vo_path_m=0.0, wheel_path_m=0.0,
is_drifting=False,
window_s=0.0, n_vo=len(vo_samples), n_wheel=len(wheel_samples),
)
vo_path = _path_length(vo_samples)
wheel_path = _path_length(wheel_samples)
drift_m = abs(vo_path - wheel_path)
# Actual data span = latest timestamp earliest across both buffers
t_min = min(vo_samples[0].t, wheel_samples[0].t)
t_max = max(vo_samples[-1].t, wheel_samples[-1].t)
actual_window = t_max - t_min
return DriftResult(
drift_m=drift_m,
vo_path_m=vo_path,
wheel_path_m=wheel_path,
is_drifting=drift_m >= drift_threshold_m,
window_s=actual_window,
n_vo=len(vo_samples),
n_wheel=len(wheel_samples),
)
def _path_length(samples: Sequence[OdomSample]) -> float:
"""Sum of Euclidean inter-sample distances."""
total = 0.0
for i in range(1, len(samples)):
dx = samples[i].x - samples[i - 1].x
dy = samples[i].y - samples[i - 1].y
total += math.sqrt(dx * dx + dy * dy)
return total

View File

@ -0,0 +1,127 @@
"""
color_segment_node.py D435i HSV color object segmenter (Issue #274).
Subscribes to the RealSense colour stream, applies per-color HSV thresholding,
extracts contours, and publishes detected blobs as ColorDetectionArray.
Subscribes (BEST_EFFORT):
/camera/color/image_raw sensor_msgs/Image BGR8 (or rgb8)
Publishes:
/saltybot/color_objects saltybot_scene_msgs/ColorDetectionArray
Parameters
----------
active_colors str "red,green,blue,yellow,orange" Comma-separated list
min_area_px float 200.0 Minimum contour area (pixels²)
max_blobs_per_color int 10 Max detections per color per frame
"""
from __future__ import annotations
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
import numpy as np
from cv_bridge import CvBridge
from sensor_msgs.msg import Image
from std_msgs.msg import Header
from saltybot_scene_msgs.msg import ColorDetection, ColorDetectionArray
from vision_msgs.msg import BoundingBox2D
from geometry_msgs.msg import Pose2D
from ._color_segmenter import find_color_blobs
_SENSOR_QOS = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=4,
)
_DEFAULT_COLORS = 'red,green,blue,yellow,orange'
class ColorSegmentNode(Node):
def __init__(self) -> None:
super().__init__('color_segment_node')
self.declare_parameter('active_colors', _DEFAULT_COLORS)
self.declare_parameter('min_area_px', 200.0)
self.declare_parameter('max_blobs_per_color', 10)
colors_str = self.get_parameter('active_colors').value
self._active_colors = [c.strip() for c in colors_str.split(',') if c.strip()]
self._min_area = float(self.get_parameter('min_area_px').value)
self._max_blobs = int(self.get_parameter('max_blobs_per_color').value)
self._bridge = CvBridge()
self._sub = self.create_subscription(
Image, '/camera/color/image_raw', self._on_image, _SENSOR_QOS)
self._pub = self.create_publisher(
ColorDetectionArray, '/saltybot/color_objects', 10)
self.get_logger().info(
f'color_segment_node ready — colors={self._active_colors} '
f'min_area={self._min_area}px² max_blobs={self._max_blobs}'
)
# ── Callback ──────────────────────────────────────────────────────────────
def _on_image(self, msg: Image) -> None:
try:
bgr = self._bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')
except Exception as exc:
self.get_logger().error(
f'cv_bridge: {exc}', throttle_duration_sec=5.0)
return
blobs = find_color_blobs(
bgr,
active_colors=self._active_colors,
min_area_px=self._min_area,
max_blobs_per_color=self._max_blobs,
)
arr = ColorDetectionArray()
arr.header = msg.header
for blob in blobs:
det = ColorDetection()
det.header = msg.header
det.color_name = blob.color_name
det.confidence = blob.confidence
det.area_px = blob.area_px
det.contour_id = blob.contour_id
bbox = BoundingBox2D()
center = Pose2D()
center.x = blob.cx
center.y = blob.cy
bbox.center = center
bbox.size_x = blob.w
bbox.size_y = blob.h
det.bbox = bbox
arr.detections.append(det)
self._pub.publish(arr)
def main(args=None) -> None:
rclpy.init(args=args)
node = ColorSegmentNode()
try:
rclpy.spin(node)
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,128 @@
"""
depth_hole_fill_node.py D435i depth image hole filler (Issue #268).
Subscribes to the raw D435i depth stream, fills zero/NaN pixels using
multi-pass spatial-Gaussian bilateral interpolation, and republishes the
filled image at camera rate.
Subscribes (BEST_EFFORT):
/camera/depth/image_rect_raw sensor_msgs/Image float32 depth (m)
Publishes:
/camera/depth/filled sensor_msgs/Image float32 depth (m), holes filled
The filled image preserves all original valid pixels exactly and only
modifies pixels that had no return (0 or NaN). The output is suitable
for all downstream consumers that expect a dense depth map (VO, RTAB-Map,
collision avoidance, floor classifier).
Parameters
----------
input_topic str /camera/depth/image_rect_raw Input depth topic
output_topic str /camera/depth/filled Output depth topic
kernel_size int 5 Initial Gaussian kernel side length (pixels)
d_min float 0.1 Minimum valid depth (m)
d_max float 10.0 Maximum valid depth (m)
max_passes int 3 Fill passes (growing kernel per pass)
"""
from __future__ import annotations
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
import numpy as np
from cv_bridge import CvBridge
from sensor_msgs.msg import Image
from ._depth_hole_fill import fill_holes
_SENSOR_QOS = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=4,
)
class DepthHoleFillNode(Node):
def __init__(self) -> None:
super().__init__('depth_hole_fill_node')
self.declare_parameter('input_topic', '/camera/depth/image_rect_raw')
self.declare_parameter('output_topic', '/camera/depth/filled')
self.declare_parameter('kernel_size', 5)
self.declare_parameter('d_min', 0.1)
self.declare_parameter('d_max', 10.0)
self.declare_parameter('max_passes', 3)
input_topic = self.get_parameter('input_topic').value
output_topic = self.get_parameter('output_topic').value
self._ks = int(self.get_parameter('kernel_size').value)
self._d_min = self.get_parameter('d_min').value
self._d_max = self.get_parameter('d_max').value
self._passes = int(self.get_parameter('max_passes').value)
self._bridge = CvBridge()
self._sub = self.create_subscription(
Image, input_topic, self._on_depth, _SENSOR_QOS)
self._pub = self.create_publisher(Image, output_topic, 10)
self.get_logger().info(
f'depth_hole_fill_node ready — '
f'{input_topic}{output_topic} '
f'kernel={self._ks} passes={self._passes} '
f'd=[{self._d_min},{self._d_max}]m'
)
# ── Callback ──────────────────────────────────────────────────────────────
def _on_depth(self, msg: Image) -> None:
try:
depth = self._bridge.imgmsg_to_cv2(msg, desired_encoding='passthrough')
except Exception as exc:
self.get_logger().error(
f'cv_bridge: {exc}', throttle_duration_sec=5.0)
return
depth = depth.astype(np.float32)
# Handle uint16 mm → float32 m conversion (D435i raw stream)
if depth.max() > 100.0:
depth /= 1000.0
filled = fill_holes(
depth,
kernel_size=self._ks,
d_min=self._d_min,
d_max=self._d_max,
max_passes=self._passes,
)
try:
out_msg = self._bridge.cv2_to_imgmsg(filled, encoding='32FC1')
except Exception as exc:
self.get_logger().error(
f'cv2_to_imgmsg: {exc}', throttle_duration_sec=5.0)
return
out_msg.header = msg.header
self._pub.publish(out_msg)
def main(args=None) -> None:
rclpy.init(args=args)
node = DepthHoleFillNode()
try:
rclpy.spin(node)
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,150 @@
"""
vo_drift_node.py Visual odometry drift detector (Issue #260).
Compares the cumulative path lengths of visual odometry and wheel odometry
over a sliding window. When the absolute difference exceeds the configured
threshold the node flags drift, allowing the system to warn operators,
inflate VO covariance, or fall back to wheel-only localisation.
Subscribes (BEST_EFFORT):
/camera/odom nav_msgs/Odometry visual odometry
/odom nav_msgs/Odometry wheel odometry
For this robot remap to /saltybot/visual_odom + /saltybot/rover_odom.
Publishes:
/saltybot/vo_drift_detected std_msgs/Bool True while drifting
/saltybot/vo_drift_magnitude std_msgs/Float32 drift magnitude (metres)
Parameters
----------
vo_topic str /camera/odom Visual odometry source topic
wheel_topic str /odom Wheel odometry source topic
drift_threshold_m float 0.5 Drift flag threshold (metres)
window_s float 10.0 Comparison window (seconds)
publish_hz float 2.0 Output publication rate (Hz)
max_buffer_age_s float 30.0 Max age of stored samples (s)
"""
from __future__ import annotations
import time
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from nav_msgs.msg import Odometry
from std_msgs.msg import Bool, Float32
from ._vo_drift import OdomBuffer, OdomSample, compute_drift
_SENSOR_QOS = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=4,
)
class VoDriftNode(Node):
def __init__(self) -> None:
super().__init__('vo_drift_node')
self.declare_parameter('vo_topic', '/camera/odom')
self.declare_parameter('wheel_topic', '/odom')
self.declare_parameter('drift_threshold_m', 0.5)
self.declare_parameter('window_s', 10.0)
self.declare_parameter('publish_hz', 2.0)
self.declare_parameter('max_buffer_age_s', 30.0)
vo_topic = self.get_parameter('vo_topic').value
wheel_topic = self.get_parameter('wheel_topic').value
self._thresh = self.get_parameter('drift_threshold_m').value
self._window_s = self.get_parameter('window_s').value
publish_hz = self.get_parameter('publish_hz').value
max_age = self.get_parameter('max_buffer_age_s').value
self._vo_buf = OdomBuffer(max_age_s=max_age)
self._wheel_buf = OdomBuffer(max_age_s=max_age)
self.create_subscription(
Odometry, vo_topic, self._on_vo, _SENSOR_QOS)
self.create_subscription(
Odometry, wheel_topic, self._on_wheel, _SENSOR_QOS)
self._pub_detected = self.create_publisher(
Bool, '/saltybot/vo_drift_detected', 10)
self._pub_magnitude = self.create_publisher(
Float32, '/saltybot/vo_drift_magnitude', 10)
self.create_timer(1.0 / publish_hz, self._tick)
self.get_logger().info(
f'vo_drift_node ready — '
f'vo={vo_topic} wheel={wheel_topic} '
f'threshold={self._thresh}m window={self._window_s}s'
)
# ── Callbacks ─────────────────────────────────────────────────────────────
def _on_vo(self, msg: Odometry) -> None:
s = _odom_to_sample(msg)
self._vo_buf.push(s)
def _on_wheel(self, msg: Odometry) -> None:
s = _odom_to_sample(msg)
self._wheel_buf.push(s)
# ── Publish tick ──────────────────────────────────────────────────────────
def _tick(self) -> None:
now = time.monotonic()
result = compute_drift(
self._vo_buf, self._wheel_buf,
window_s=self._window_s,
drift_threshold_m=self._thresh,
now=now,
)
if result.is_drifting:
self.get_logger().warn(
f'VO drift detected: {result.drift_m:.3f}m '
f'(vo={result.vo_path_m:.3f}m wheel={result.wheel_path_m:.3f}m '
f'over {result.window_s:.1f}s)',
throttle_duration_sec=5.0,
)
det_msg = Bool()
det_msg.data = result.is_drifting
self._pub_detected.publish(det_msg)
mag_msg = Float32()
mag_msg.data = float(result.drift_m)
self._pub_magnitude.publish(mag_msg)
# ── Helpers ───────────────────────────────────────────────────────────────────
def _odom_to_sample(msg: Odometry) -> OdomSample:
"""Convert nav_msgs/Odometry to OdomSample using monotonic clock."""
return OdomSample(
t=time.monotonic(),
x=msg.pose.pose.position.x,
y=msg.pose.pose.position.y,
)
def main(args=None) -> None:
rclpy.init(args=args)
node = VoDriftNode()
try:
rclpy.spin(node)
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -35,6 +35,12 @@ setup(
'lidar_clustering = saltybot_bringup.lidar_clustering_node:main', 'lidar_clustering = saltybot_bringup.lidar_clustering_node:main',
# Floor surface type classifier (Issue #249) # Floor surface type classifier (Issue #249)
'floor_classifier = saltybot_bringup.floor_classifier_node:main', 'floor_classifier = saltybot_bringup.floor_classifier_node:main',
# Visual odometry drift detector (Issue #260)
'vo_drift_detector = saltybot_bringup.vo_drift_node:main',
# Depth image hole filler (Issue #268)
'depth_hole_fill = saltybot_bringup.depth_hole_fill_node:main',
# HSV color object segmenter (Issue #274)
'color_segmenter = saltybot_bringup.color_segment_node:main',
], ],
}, },
) )

View File

@ -0,0 +1,361 @@
"""
test_color_segmenter.py Unit tests for HSV color segmentation helpers (no ROS2 required).
Covers:
HsvRange / ColorBlob:
- NamedTuple fields accessible by name
- confidence clamped to [0,1]
mask_for_color:
- pure red image red mask fully white
- pure red image green mask fully black
- pure green image green mask fully white
- pure blue image blue mask fully white
- pure yellow image yellow mask non-empty
- pure orange image orange mask non-empty
- red hue wrap-around detected from both HSV bands
- unknown color name raises ValueError
- mask is uint8
- mask shape matches input
find_color_blobs output contract:
- returns list
- empty list on blank (no-color) image
- empty list when min_area_px larger than any contour
find_color_blobs detection:
- large red rectangle detected as red blob
- large green rectangle detected as green blob
- large blue rectangle detected as blue blob
- detected blob color_name matches requested color
- contour_id is 0 for first blob
- confidence in [0, 1]
- cx, cy within image bounds
- w, h > 0 for detected blob
- area_px > 0 for detected blob
find_color_blobs filtering:
- active_colors=None detects all colors when present
- only requested colors returned when active_colors restricted
- max_blobs_per_color limits output count
- two separate red blobs both detected when max_blobs=2
- smaller blob filtered when min_area_px high
find_color_blobs multi-color:
- image with red + green regions both detected
"""
import sys
import os
import numpy as np
import pytest
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from saltybot_bringup._color_segmenter import (
HsvRange,
ColorBlob,
COLOR_RANGES,
mask_for_color,
find_color_blobs,
)
# ── Image factories ───────────────────────────────────────────────────────────
def _solid_bgr(b, g, r, h=64, w=64) -> np.ndarray:
"""Solid BGR image."""
img = np.zeros((h, w, 3), dtype=np.uint8)
img[:, :] = (b, g, r)
return img
def _blank(h=64, w=64) -> np.ndarray:
"""All-black image (nothing to detect)."""
return np.zeros((h, w, 3), dtype=np.uint8)
def _image_with_rect(bg_bgr, rect_bgr, rect_slice_r, rect_slice_c, h=128, w=128) -> np.ndarray:
"""Background colour with a filled rectangle."""
img = np.zeros((h, w, 3), dtype=np.uint8)
img[:, :] = bg_bgr
img[rect_slice_r, rect_slice_c] = rect_bgr
return img
# Canonical solid color BGR values (saturated, in-range for HSV thresholds)
_RED_BGR = (0, 0, 200) # BGR pure red
_GREEN_BGR = (0, 200, 0 ) # BGR pure green
_BLUE_BGR = (200, 0, 0 ) # BGR pure blue
_YELLOW_BGR = (0, 220, 220) # BGR yellow
_ORANGE_BGR = (0, 140, 220) # BGR orange
# ── HsvRange / ColorBlob types ────────────────────────────────────────────────
class TestTypes:
def test_hsv_range_fields(self):
r = HsvRange(0, 10, 60, 255, 50, 255)
assert r.h_lo == 0 and r.h_hi == 10
assert r.s_lo == 60 and r.s_hi == 255
assert r.v_lo == 50 and r.v_hi == 255
def test_color_blob_fields(self):
b = ColorBlob('red', 0.8, 32.0, 32.0, 20.0, 20.0, 300.0, 0)
assert b.color_name == 'red'
assert b.confidence == pytest.approx(0.8)
assert b.contour_id == 0
def test_color_ranges_contains_all_defaults(self):
for color in ('red', 'green', 'blue', 'yellow', 'orange'):
assert color in COLOR_RANGES
assert len(COLOR_RANGES[color]) >= 1
# ── mask_for_color ────────────────────────────────────────────────────────────
class TestMaskForColor:
def test_mask_is_uint8(self):
import cv2
hsv = cv2.cvtColor(_solid_bgr(*_RED_BGR), cv2.COLOR_BGR2HSV)
m = mask_for_color(hsv, 'red')
assert m.dtype == np.uint8
def test_mask_shape_matches_input(self):
import cv2
bgr = _solid_bgr(*_RED_BGR, h=48, w=80)
hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
m = mask_for_color(hsv, 'red')
assert m.shape == (48, 80)
def test_pure_red_gives_red_mask_nonzero(self):
import cv2
hsv = cv2.cvtColor(_solid_bgr(*_RED_BGR), cv2.COLOR_BGR2HSV)
m = mask_for_color(hsv, 'red')
assert m.any(), 'red mask should be non-empty for red image'
def test_pure_red_gives_green_mask_empty(self):
import cv2
hsv = cv2.cvtColor(_solid_bgr(*_RED_BGR), cv2.COLOR_BGR2HSV)
m = mask_for_color(hsv, 'green')
assert not m.any(), 'green mask should be empty for red image'
def test_pure_green_gives_green_mask_nonzero(self):
import cv2
hsv = cv2.cvtColor(_solid_bgr(*_GREEN_BGR), cv2.COLOR_BGR2HSV)
m = mask_for_color(hsv, 'green')
assert m.any()
def test_pure_blue_gives_blue_mask_nonzero(self):
import cv2
hsv = cv2.cvtColor(_solid_bgr(*_BLUE_BGR), cv2.COLOR_BGR2HSV)
m = mask_for_color(hsv, 'blue')
assert m.any()
def test_pure_yellow_gives_yellow_mask_nonzero(self):
import cv2
hsv = cv2.cvtColor(_solid_bgr(*_YELLOW_BGR), cv2.COLOR_BGR2HSV)
m = mask_for_color(hsv, 'yellow')
assert m.any()
def test_pure_orange_gives_orange_mask_nonzero(self):
import cv2
hsv = cv2.cvtColor(_solid_bgr(*_ORANGE_BGR), cv2.COLOR_BGR2HSV)
m = mask_for_color(hsv, 'orange')
assert m.any()
def test_unknown_color_raises(self):
import cv2
hsv = cv2.cvtColor(_blank(), cv2.COLOR_BGR2HSV)
with pytest.raises(ValueError, match='Unknown color'):
mask_for_color(hsv, 'purple')
def test_red_detected_in_high_hue_band(self):
"""A near-180-hue red pixel should still trigger the red mask."""
import cv2
# HSV (175, 200, 200) = high-hue red (wrap-around band)
hsv = np.full((32, 32, 3), (175, 200, 200), dtype=np.uint8)
m = mask_for_color(hsv, 'red')
assert m.any(), 'high-hue red not detected'
# ── find_color_blobs — output contract ───────────────────────────────────────
class TestFindColorBlobsContract:
def test_returns_list(self):
result = find_color_blobs(_blank())
assert isinstance(result, list)
def test_blank_image_returns_empty(self):
result = find_color_blobs(_blank())
assert result == []
def test_min_area_filter_removes_all(self):
"""Request a min area larger than the entire image → no blobs."""
bgr = _solid_bgr(*_RED_BGR, h=32, w=32)
result = find_color_blobs(bgr, active_colors=['red'], min_area_px=1e9)
assert result == []
# ── find_color_blobs — detection ─────────────────────────────────────────────
class TestFindColorBlobsDetection:
def _large_rect(self, color_bgr, color_name) -> np.ndarray:
"""100×100 image with a 60×60 solid-color rectangle centred."""
img = _blank(h=100, w=100)
img[20:80, 20:80] = color_bgr
return img
def test_red_rect_detected(self):
blobs = find_color_blobs(self._large_rect(_RED_BGR, 'red'), active_colors=['red'])
assert len(blobs) >= 1
assert blobs[0].color_name == 'red'
def test_green_rect_detected(self):
blobs = find_color_blobs(self._large_rect(_GREEN_BGR, 'green'), active_colors=['green'])
assert len(blobs) >= 1
assert blobs[0].color_name == 'green'
def test_blue_rect_detected(self):
blobs = find_color_blobs(self._large_rect(_BLUE_BGR, 'blue'), active_colors=['blue'])
assert len(blobs) >= 1
assert blobs[0].color_name == 'blue'
def test_first_contour_id_is_zero(self):
img = _blank(h=100, w=100)
img[20:80, 20:80] = _RED_BGR
blobs = find_color_blobs(img, active_colors=['red'])
assert blobs[0].contour_id == 0
def test_confidence_in_range(self):
img = _blank(h=100, w=100)
img[20:80, 20:80] = _GREEN_BGR
blobs = find_color_blobs(img, active_colors=['green'])
assert blobs
assert 0.0 <= blobs[0].confidence <= 1.0
def test_cx_within_image(self):
img = _blank(h=100, w=100)
img[20:80, 20:80] = _BLUE_BGR
blobs = find_color_blobs(img, active_colors=['blue'])
assert blobs
assert 0.0 <= blobs[0].cx <= 100.0
def test_cy_within_image(self):
img = _blank(h=100, w=100)
img[20:80, 20:80] = _BLUE_BGR
blobs = find_color_blobs(img, active_colors=['blue'])
assert blobs
assert 0.0 <= blobs[0].cy <= 100.0
def test_w_positive(self):
img = _blank(h=100, w=100)
img[20:80, 20:80] = _RED_BGR
blobs = find_color_blobs(img, active_colors=['red'])
assert blobs[0].w > 0
def test_h_positive(self):
img = _blank(h=100, w=100)
img[20:80, 20:80] = _RED_BGR
blobs = find_color_blobs(img, active_colors=['red'])
assert blobs[0].h > 0
def test_area_px_positive(self):
img = _blank(h=100, w=100)
img[20:80, 20:80] = _RED_BGR
blobs = find_color_blobs(img, active_colors=['red'])
assert blobs[0].area_px > 0
def test_area_px_reasonable(self):
"""area_px should be roughly within the rectangle we drew."""
img = _blank(h=100, w=100)
img[20:80, 20:80] = _GREEN_BGR # 60×60 = 3600 px
blobs = find_color_blobs(img, active_colors=['green'], min_area_px=100.0)
assert blobs
assert 1000 <= blobs[0].area_px <= 4000
# ── find_color_blobs — filtering ─────────────────────────────────────────────
class TestFindColorBlobsFiltering:
def test_active_colors_none_detects_all(self):
"""Image with red+green patches → both found when active_colors=None."""
img = _blank(h=128, w=128)
img[10:50, 10:50] = _RED_BGR
img[10:50, 70:110] = _GREEN_BGR
blobs = find_color_blobs(img, active_colors=None, min_area_px=100.0)
names = {b.color_name for b in blobs}
assert 'red' in names
assert 'green' in names
def test_restricted_active_colors(self):
"""Only red requested → no green blobs returned."""
img = _blank(h=128, w=128)
img[10:50, 10:50] = _RED_BGR
img[10:50, 70:110] = _GREEN_BGR
blobs = find_color_blobs(img, active_colors=['red'], min_area_px=100.0)
assert all(b.color_name == 'red' for b in blobs)
def test_max_blobs_per_color_limits(self):
"""Four separate red rectangles but max_blobs=2 → at most 2 blobs."""
img = _blank(h=200, w=200)
img[10:40, 10:40] = _RED_BGR
img[10:40, 80:110] = _RED_BGR
img[100:130, 10:40] = _RED_BGR
img[100:130, 80:110] = _RED_BGR
blobs = find_color_blobs(img, active_colors=['red'],
min_area_px=100.0, max_blobs_per_color=2)
red_blobs = [b for b in blobs if b.color_name == 'red']
assert len(red_blobs) <= 2
def test_two_blobs_detected_when_max_allows(self):
"""Two red rectangles detected when max_blobs_per_color >= 2."""
img = _blank(h=200, w=200)
img[10:60, 10:60] = _RED_BGR
img[10:60, 130:180] = _RED_BGR
blobs = find_color_blobs(img, active_colors=['red'],
min_area_px=100.0, max_blobs_per_color=10)
red_blobs = [b for b in blobs if b.color_name == 'red']
assert len(red_blobs) >= 2
def test_small_blob_filtered_by_min_area(self):
"""Small 5×5 red patch filtered by min_area_px=500."""
img = _blank(h=64, w=64)
img[28:33, 28:33] = _RED_BGR # 5×5 = 25 px contour area
blobs = find_color_blobs(img, active_colors=['red'], min_area_px=500.0)
assert blobs == []
# ── find_color_blobs — multi-color ───────────────────────────────────────────
class TestFindColorBlobsMultiColor:
def test_red_and_green_in_same_image(self):
img = _blank(h=128, w=128)
img[10:60, 10:60] = _RED_BGR
img[10:60, 68:118] = _GREEN_BGR
blobs = find_color_blobs(img, active_colors=['red', 'green'], min_area_px=100.0)
names = {b.color_name for b in blobs}
assert 'red' in names, 'red blob should be detected'
assert 'green' in names, 'green blob should be detected'
def test_contour_ids_per_color_start_at_zero(self):
"""contour_id should be 0 for the first (largest) blob of each color."""
img = _blank(h=200, w=200)
img[10:80, 10:80] = _RED_BGR
img[10:80, 110:180] = _BLUE_BGR
blobs = find_color_blobs(img, active_colors=['red', 'blue'], min_area_px=100.0)
for color in ('red', 'blue'):
first = next((b for b in blobs if b.color_name == color), None)
assert first is not None, f'{color} blob not found'
assert first.contour_id == 0, f'{color} first blob contour_id != 0'
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@ -0,0 +1,281 @@
"""
test_depth_hole_fill.py Unit tests for depth hole fill helpers (no ROS2 required).
Covers:
valid_mask:
- valid range returns True
- zero / below d_min returns False
- NaN returns False
- above d_max returns False
- mixed array has correct mask
_odd_kernel_size:
- result is always odd
- result >= 3
- scales correctly
fill_holes no-hole cases:
- fully valid image is returned unchanged
- output dtype is float32
- output shape matches input
fill_holes basic fills:
- single centre hole in uniform depth filled with correct depth
- single centre hole in uniform depth original valid pixels unchanged
- NaN pixel treated as hole and filled
- row of zeros within uniform depth filled
fill_holes fill quality:
- linear gradient: centre hole filled with interpolated value
- multi-pass fills larger holes than single pass
- all-zero image stays zero (no valid neighbours)
- border hole (edge pixel) is handled without crash
- depth range: pixel above d_max treated as hole
fill_holes valid pixel preservation:
- original valid pixels are never modified
- max_passes=1 still fills small holes
"""
import sys
import os
import math
import numpy as np
import pytest
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from saltybot_bringup._depth_hole_fill import (
fill_holes,
valid_mask,
_odd_kernel_size,
)
# ── Helpers ───────────────────────────────────────────────────────────────────
def _uniform(val=2.0, h=64, w=64) -> np.ndarray:
return np.full((h, w), val, dtype=np.float32)
def _poke_hole(arr, r, c) -> np.ndarray:
arr = arr.copy()
arr[r, c] = 0.0
return arr
def _poke_nan(arr, r, c) -> np.ndarray:
arr = arr.copy()
arr[r, c] = float('nan')
return arr
# ── valid_mask ────────────────────────────────────────────────────────────────
class TestValidMask:
def test_valid_pixel_is_true(self):
d = np.array([[1.0]], dtype=np.float32)
assert valid_mask(d, 0.1, 10.0)[0, 0]
def test_zero_is_false(self):
d = np.array([[0.0]], dtype=np.float32)
assert not valid_mask(d, 0.1, 10.0)[0, 0]
def test_below_dmin_is_false(self):
d = np.array([[0.05]], dtype=np.float32)
assert not valid_mask(d, 0.1, 10.0)[0, 0]
def test_nan_is_false(self):
d = np.array([[float('nan')]], dtype=np.float32)
assert not valid_mask(d, 0.1, 10.0)[0, 0]
def test_above_dmax_is_false(self):
d = np.array([[15.0]], dtype=np.float32)
assert not valid_mask(d, 0.1, 10.0)[0, 0]
def test_at_dmin_is_true(self):
d = np.array([[0.1]], dtype=np.float32)
assert valid_mask(d, 0.1, 10.0)[0, 0]
def test_at_dmax_is_true(self):
d = np.array([[10.0]], dtype=np.float32)
assert valid_mask(d, 0.1, 10.0)[0, 0]
def test_mixed_array(self):
d = np.array([[0.0, 1.0, float('nan'), 5.0, 11.0]], dtype=np.float32)
m = valid_mask(d, 0.1, 10.0)
np.testing.assert_array_equal(m, [[False, True, False, True, False]])
# ── _odd_kernel_size ──────────────────────────────────────────────────────────
class TestOddKernelSize:
@pytest.mark.parametrize('base,scale', [
(5, 1.0), (5, 2.5), (5, 6.0),
(3, 1.0), (7, 2.0), (9, 3.0),
(4, 1.0), # even base → must become odd
])
def test_result_is_odd(self, base, scale):
ks = _odd_kernel_size(base, scale)
assert ks % 2 == 1
@pytest.mark.parametrize('base,scale', [(3, 1.0), (1, 5.0), (2, 0.5)])
def test_result_at_least_3(self, base, scale):
assert _odd_kernel_size(base, scale) >= 3
def test_scale_1_returns_base_or_nearby_odd(self):
ks = _odd_kernel_size(5, 1.0)
assert ks == 5
def test_large_scale_gives_large_kernel(self):
ks = _odd_kernel_size(5, 6.0)
assert ks >= 25 # 5 * 6 = 30 → 31
# ── fill_holes — output contract ──────────────────────────────────────────────
class TestFillHolesOutputContract:
def test_output_dtype_float32(self):
out = fill_holes(_uniform(2.0))
assert out.dtype == np.float32
def test_output_shape_preserved(self):
img = _uniform(2.0, h=48, w=64)
out = fill_holes(img)
assert out.shape == img.shape
def test_fully_valid_image_unchanged(self):
img = _uniform(2.0)
out = fill_holes(img)
np.testing.assert_allclose(out, img, atol=1e-6)
def test_valid_pixels_never_modified(self):
"""Any pixel valid in the input must be identical in the output."""
img = _uniform(3.0, h=32, w=32)
img[16, 16] = 0.0 # one hole
mask_before = valid_mask(img)
out = fill_holes(img)
np.testing.assert_allclose(out[mask_before], img[mask_before], atol=1e-6)
# ── fill_holes — basic hole filling ──────────────────────────────────────────
class TestFillHolesBasic:
def test_centre_zero_filled_uniform(self):
"""Single zero pixel in uniform depth → filled with that depth."""
img = _poke_hole(_uniform(2.0, 32, 32), 16, 16)
out = fill_holes(img, kernel_size=5, max_passes=1)
assert out[16, 16] == pytest.approx(2.0, abs=0.05)
def test_centre_nan_filled_uniform(self):
"""Single NaN pixel in uniform depth → filled."""
img = _poke_nan(_uniform(2.0, 32, 32), 16, 16)
out = fill_holes(img, kernel_size=5, max_passes=1)
assert out[16, 16] == pytest.approx(2.0, abs=0.05)
def test_filled_value_is_positive(self):
img = _poke_hole(_uniform(1.5, 32, 32), 16, 16)
out = fill_holes(img)
assert out[16, 16] > 0.0
def test_row_of_holes_filled(self):
"""Entire middle row zeroed → should be filled from neighbours above/below."""
img = _uniform(3.0, 32, 32)
img[16, :] = 0.0
out = fill_holes(img, kernel_size=7, max_passes=1)
# All pixels in the row should be non-zero after filling
assert (out[16, :] > 0.0).all()
def test_all_zero_stays_zero(self):
"""Image with no valid pixels → stays zero (nothing to interpolate from)."""
img = np.zeros((32, 32), dtype=np.float32)
out = fill_holes(img, d_min=0.1)
assert (out == 0.0).all()
def test_border_hole_no_crash(self):
"""Holes at image corners must not raise exceptions."""
img = _uniform(2.0, 32, 32)
img[0, 0] = 0.0
img[0, -1] = 0.0
img[-1, 0] = 0.0
img[-1, -1] = 0.0
out = fill_holes(img) # must not raise
assert out.shape == img.shape
def test_border_holes_filled(self):
"""Corner holes should be filled from their neighbours."""
img = _uniform(2.0, 32, 32)
img[0, 0] = 0.0
out = fill_holes(img, kernel_size=5, max_passes=1)
assert out[0, 0] == pytest.approx(2.0, abs=0.1)
# ── fill_holes — fill quality ─────────────────────────────────────────────────
class TestFillHolesQuality:
def test_linear_gradient_centre_hole_interpolated(self):
"""
Depth linearly increasing from 1.0 (left) to 3.0 (right).
Centre hole should be filled near the midpoint (~2.0).
"""
h, w = 32, 32
img = np.tile(np.linspace(1.0, 3.0, w, dtype=np.float32), (h, 1))
cx = w // 2
img[:, cx] = 0.0
out = fill_holes(img, kernel_size=5, max_passes=1)
mid = out[h // 2, cx]
assert 1.5 <= mid <= 2.5, f'interpolated value {mid:.3f} not in [1.5, 2.5]'
def test_large_hole_filled_with_more_passes(self):
"""A 9×9 hole in uniform depth: single pass may not fully fill it,
but 3 passes should."""
img = _uniform(2.0, 64, 64)
# Create a 9×9 hole
img[28:37, 28:37] = 0.0
out1 = fill_holes(img, kernel_size=5, max_passes=1)
out3 = fill_holes(img, kernel_size=5, max_passes=3)
# More passes → fewer remaining holes
holes1 = (out1 == 0.0).sum()
holes3 = (out3 == 0.0).sum()
assert holes3 <= holes1, f'more passes should reduce holes: {holes3} vs {holes1}'
def test_3pass_fills_9x9_hole_completely(self):
img = _uniform(2.0, 64, 64)
img[28:37, 28:37] = 0.0
out = fill_holes(img, kernel_size=5, max_passes=3)
assert (out[28:37, 28:37] > 0.0).all()
def test_filled_depth_within_valid_range(self):
"""Filled pixels should have depth within [d_min, d_max]."""
img = _uniform(2.0, 32, 32)
img[10:15, 10:15] = 0.0
out = fill_holes(img, d_min=0.1, d_max=10.0, max_passes=3)
# Only check pixels that were actually filled
was_hole = (img == 0.0)
filled = out[was_hole]
positive = filled[filled > 0.0]
assert (positive >= 0.1).all()
assert (positive <= 10.0).all()
def test_above_dmax_treated_as_hole(self):
"""Pixels above d_max should be treated as holes and filled."""
img = _uniform(2.0, 32, 32)
img[16, 16] = 15.0 # out of range
out = fill_holes(img, d_max=10.0, max_passes=1)
assert out[16, 16] == pytest.approx(2.0, abs=0.1)
def test_max_passes_1_works(self):
img = _poke_hole(_uniform(2.0, 32, 32), 16, 16)
out = fill_holes(img, max_passes=1)
assert out.shape == img.shape
assert out[16, 16] > 0.0
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@ -0,0 +1,297 @@
"""
test_vo_drift.py Unit tests for VO drift detector helpers (no ROS2 required).
Covers:
OdomBuffer:
- push/len
- window returns only samples within cutoff
- old samples are evicted beyond max_age_s
- clear empties the buffer
- window on empty buffer returns empty list
_path_length (via compute_drift with crafted samples):
- stationary source path = 0
- straight-line motion path = total distance
- L-shaped path path = sum of two legs
compute_drift:
- both empty DriftResult with zeros, is_drifting=False
- one buffer < 2 samples zero drift
- both move same distance drift 0, not drifting
- VO moves 1m, wheel moves 0.5m drift = 0.5m
- drift == threshold is_drifting=True (>=)
- drift < threshold is_drifting=False
- drift > threshold is_drifting=True
- path lengths in result match expectation
- n_vo / n_wheel counts correct
- samples outside window ignored
- window_s in result reflects actual data span
"""
import sys
import os
import math
import pytest
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from saltybot_bringup._vo_drift import (
OdomSample,
OdomBuffer,
DriftResult,
compute_drift,
_path_length,
)
# ── Helpers ───────────────────────────────────────────────────────────────────
def _s(t, x, y) -> OdomSample:
return OdomSample(t=t, x=x, y=y)
def _straight_buf(n=5, speed=0.1, t_start=0.0, dt=1.0,
max_age_s=30.0) -> OdomBuffer:
"""n samples moving along +x at `speed` m/s."""
buf = OdomBuffer(max_age_s=max_age_s)
for i in range(n):
buf.push(_s(t_start + i * dt, x=i * speed * dt, y=0.0))
return buf
def _stationary_buf(n=5, t_start=0.0, dt=1.0,
max_age_s=30.0) -> OdomBuffer:
buf = OdomBuffer(max_age_s=max_age_s)
for i in range(n):
buf.push(_s(t_start + i * dt, x=0.0, y=0.0))
return buf
# ── OdomBuffer ────────────────────────────────────────────────────────────────
class TestOdomBuffer:
def test_push_increases_len(self):
buf = OdomBuffer()
assert len(buf) == 0
buf.push(_s(0.0, 0.0, 0.0))
assert len(buf) == 1
def test_window_returns_all_within_cutoff(self):
buf = OdomBuffer(max_age_s=30.0)
for t in [0.0, 5.0, 10.0]:
buf.push(_s(t, 0.0, 0.0))
samples = buf.window(window_s=10.0, now=10.0)
assert len(samples) == 3
def test_window_excludes_old_samples(self):
buf = OdomBuffer(max_age_s=30.0)
for t in [0.0, 5.0, 15.0]:
buf.push(_s(t, 0.0, 0.0))
# window=5s from now=15 → only t=15 qualifies (t>=10)
samples = buf.window(window_s=5.0, now=15.0)
assert len(samples) == 1
assert samples[0].t == 15.0
def test_evicts_samples_beyond_max_age(self):
buf = OdomBuffer(max_age_s=5.0)
buf.push(_s(0.0, 0.0, 0.0))
buf.push(_s(10.0, 1.0, 0.0)) # now=10 → t=0 is 10s old > 5s max
assert len(buf) == 1
def test_clear_empties_buffer(self):
buf = _straight_buf(n=5)
buf.clear()
assert len(buf) == 0
def test_window_on_empty_buffer(self):
buf = OdomBuffer()
assert buf.window(window_s=10.0, now=100.0) == []
def test_window_boundary_inclusive(self):
"""Sample exactly at window cutoff (t == now - window_s) is included."""
buf = OdomBuffer(max_age_s=30.0)
buf.push(_s(0.0, 0.0, 0.0))
# window=10, now=10 → cutoff=0.0, sample at t=0.0 should be included
samples = buf.window(window_s=10.0, now=10.0)
assert len(samples) == 1
# ── _path_length ──────────────────────────────────────────────────────────────
class TestPathLength:
def test_stationary_path_zero(self):
samples = [_s(i, 0.0, 0.0) for i in range(5)]
assert _path_length(samples) == pytest.approx(0.0)
def test_unit_step_path(self):
samples = [_s(0, 0.0, 0.0), _s(1, 1.0, 0.0)]
assert _path_length(samples) == pytest.approx(1.0)
def test_two_unit_steps(self):
samples = [_s(0, 0.0, 0.0), _s(1, 1.0, 0.0), _s(2, 2.0, 0.0)]
assert _path_length(samples) == pytest.approx(2.0)
def test_diagonal_step(self):
# (0,0) → (1,1): distance = sqrt(2)
samples = [_s(0, 0.0, 0.0), _s(1, 1.0, 1.0)]
assert _path_length(samples) == pytest.approx(math.sqrt(2))
def test_l_shaped_path(self):
# Right 3m then up 4m → total path = 7m (not hypotenuse)
samples = [_s(0, 0.0, 0.0), _s(1, 3.0, 0.0), _s(2, 3.0, 4.0)]
assert _path_length(samples) == pytest.approx(7.0)
def test_single_sample_returns_zero(self):
assert _path_length([_s(0, 5.0, 5.0)]) == pytest.approx(0.0)
def test_empty_returns_zero(self):
assert _path_length([]) == pytest.approx(0.0)
# ── compute_drift ─────────────────────────────────────────────────────────────
class TestComputeDrift:
def test_both_empty_returns_zero_drift(self):
result = compute_drift(
OdomBuffer(), OdomBuffer(),
window_s=10.0, drift_threshold_m=0.5, now=10.0)
assert result.drift_m == pytest.approx(0.0)
assert not result.is_drifting
def test_one_buffer_empty_returns_zero(self):
vo = _straight_buf(n=5, speed=0.1)
result = compute_drift(
vo, OdomBuffer(),
window_s=10.0, drift_threshold_m=0.5, now=5.0)
assert result.drift_m == pytest.approx(0.0)
assert not result.is_drifting
def test_one_buffer_single_sample_returns_zero(self):
vo = _straight_buf(n=5, speed=0.1)
wheel = OdomBuffer()
wheel.push(_s(0.0, 0.0, 0.0)) # only 1 sample
result = compute_drift(
vo, wheel,
window_s=10.0, drift_threshold_m=0.5, now=5.0)
assert result.drift_m == pytest.approx(0.0)
assert not result.is_drifting
def test_both_move_same_distance_zero_drift(self):
# Both move 0.1 m/s for 4 steps → 0.4 m each
vo = _straight_buf(n=5, speed=0.1, dt=1.0)
wheel = _straight_buf(n=5, speed=0.1, dt=1.0)
result = compute_drift(
vo, wheel,
window_s=10.0, drift_threshold_m=0.5, now=5.0)
assert result.drift_m == pytest.approx(0.0, abs=1e-9)
assert not result.is_drifting
def test_both_stationary_zero_drift(self):
vo = _stationary_buf(n=5)
wheel = _stationary_buf(n=5)
result = compute_drift(
vo, wheel,
window_s=10.0, drift_threshold_m=0.5, now=5.0)
assert result.drift_m == pytest.approx(0.0)
assert not result.is_drifting
def test_drift_equals_path_length_difference(self):
# VO moves 1.0 m total, wheel moves 0.5 m total
vo = _straight_buf(n=11, speed=0.1, dt=1.0) # 10 steps × 0.1 = 1.0m
wheel = _straight_buf(n=11, speed=0.05, dt=1.0) # 10 steps × 0.05 = 0.5m
result = compute_drift(
vo, wheel,
window_s=15.0, drift_threshold_m=0.5, now=11.0)
assert result.vo_path_m == pytest.approx(1.0, abs=1e-9)
assert result.wheel_path_m == pytest.approx(0.5, abs=1e-9)
assert result.drift_m == pytest.approx(0.5, abs=1e-9)
def test_drift_at_threshold_is_drifting(self):
# drift == 0.5 → is_drifting = True (>= threshold)
vo = _straight_buf(n=11, speed=0.1, dt=1.0)
wheel = _straight_buf(n=11, speed=0.05, dt=1.0)
result = compute_drift(
vo, wheel,
window_s=15.0, drift_threshold_m=0.5, now=11.0)
assert result.is_drifting
def test_drift_below_threshold_not_drifting(self):
vo = _straight_buf(n=11, speed=0.1, dt=1.0)
wheel = _straight_buf(n=11, speed=0.08, dt=1.0)
result = compute_drift(
vo, wheel,
window_s=15.0, drift_threshold_m=0.5, now=11.0)
# drift = |1.0 - 0.8| = 0.2
assert result.drift_m == pytest.approx(0.2, abs=1e-9)
assert not result.is_drifting
def test_drift_above_threshold_is_drifting(self):
vo = _straight_buf(n=11, speed=0.1, dt=1.0)
wheel = _stationary_buf(n=11, dt=1.0)
result = compute_drift(
vo, wheel,
window_s=15.0, drift_threshold_m=0.5, now=11.0)
# drift = |1.0 - 0.0| = 1.0 > 0.5
assert result.drift_m > 0.5
assert result.is_drifting
def test_n_vo_n_wheel_counts(self):
vo = _straight_buf(n=8, dt=1.0)
wheel = _straight_buf(n=5, dt=1.0)
result = compute_drift(
vo, wheel,
window_s=15.0, drift_threshold_m=0.5, now=8.0)
assert result.n_vo == 8
assert result.n_wheel == 5
def test_samples_outside_window_ignored(self):
# Push old samples far in the past; should not contribute to window
vo = OdomBuffer(max_age_s=60.0)
wheel = OdomBuffer(max_age_s=60.0)
# Old samples outside window (t=0..4, window is last 3s from now=10)
for t in range(5):
vo.push(_s(float(t), x=float(t), y=0.0))
wheel.push(_s(float(t), x=float(t), y=0.0))
# Recent samples inside window (t=7..10)
for t in range(7, 11):
vo.push(_s(float(t), x=float(t) * 0.1, y=0.0))
wheel.push(_s(float(t), x=float(t) * 0.1, y=0.0))
result = compute_drift(
vo, wheel,
window_s=3.0, drift_threshold_m=0.5, now=10.0)
# Both sources move identically inside window → zero drift
assert result.drift_m == pytest.approx(0.0, abs=1e-9)
# Only the 4 recent samples (t=7,8,9,10) in window
assert result.n_vo == 4
assert result.n_wheel == 4
def test_result_is_namedtuple(self):
result = compute_drift(
_straight_buf(), _straight_buf(),
window_s=10.0, drift_threshold_m=0.5, now=5.0)
assert hasattr(result, 'drift_m')
assert hasattr(result, 'vo_path_m')
assert hasattr(result, 'wheel_path_m')
assert hasattr(result, 'is_drifting')
assert hasattr(result, 'window_s')
assert hasattr(result, 'n_vo')
assert hasattr(result, 'n_wheel')
def test_wheel_faster_than_vo_still_drifts(self):
"""Drift is absolute difference — direction doesn't matter."""
vo = _stationary_buf(n=11, dt=1.0)
wheel = _straight_buf(n=11, speed=0.1, dt=1.0)
result = compute_drift(
vo, wheel,
window_s=15.0, drift_threshold_m=0.5, now=11.0)
assert result.drift_m == pytest.approx(1.0, abs=1e-9)
assert result.is_drifting
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@ -0,0 +1,4 @@
imu_calibration:
ros__parameters:
calibration_samples: 100
auto_calibrate: false

View File

@ -0,0 +1,30 @@
"""Launch file for IMU calibration node."""
from launch import LaunchDescription
from launch_ros.actions import Node
from launch.substitutions import LaunchConfiguration
from launch.actions import DeclareLaunchArgument
import os
from ament_index_python.packages import get_package_share_directory
def generate_launch_description():
pkg_dir = get_package_share_directory("saltybot_imu_calibration")
config_file = os.path.join(pkg_dir, "config", "imu_calibration_config.yaml")
return LaunchDescription(
[
DeclareLaunchArgument(
"config_file",
default_value=config_file,
description="Path to configuration YAML file",
),
Node(
package="saltybot_imu_calibration",
executable="imu_calibration_node",
name="imu_calibration",
output="screen",
parameters=[LaunchConfiguration("config_file")],
),
]
)

View File

@ -0,0 +1,22 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_imu_calibration</name>
<version>0.1.0</version>
<description>IMU gyro + accel calibration node for SaltyBot</description>
<maintainer email="sl-controls@saltylab.local">SaltyLab Controls</maintainer>
<license>MIT</license>
<buildtool_depend>ament_python</buildtool_depend>
<depend>rclpy</depend>
<depend>sensor_msgs</depend>
<depend>std_msgs</depend>
<depend>geometry_msgs</depend>
<test_depend>pytest</test_depend>
<test_depend>sensor_msgs</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,126 @@
#!/usr/bin/env python3
"""IMU calibration node for SaltyBot."""
import rclpy
from rclpy.node import Node
from sensor_msgs.msg import Imu
from std_srvs.srv import Trigger
from std_msgs.msg import Header
import numpy as np
from collections import deque
class IMUCalibrationNode(Node):
"""ROS2 node for IMU gyro + accel calibration."""
def __init__(self):
super().__init__("imu_calibration")
self.declare_parameter("calibration_samples", 100)
self.declare_parameter("auto_calibrate", False)
self.calibration_samples = self.get_parameter("calibration_samples").value
self.auto_calibrate = self.get_parameter("auto_calibrate").value
self.gyro_bias = np.array([0.0, 0.0, 0.0])
self.accel_bias = np.array([0.0, 0.0, 0.0])
self.is_calibrated = False
self.gyro_samples = deque(maxlen=self.calibration_samples)
self.accel_samples = deque(maxlen=self.calibration_samples)
self.calibrating = False
self.sub_imu = self.create_subscription(Imu, "/imu", self._on_imu_raw, 10)
self.pub_calibrated = self.create_publisher(Imu, "/imu/calibrated", 10)
self.srv_calibrate = self.create_service(
Trigger, "/saltybot/calibrate_imu", self._on_calibrate_service
)
self.get_logger().info(
f"IMU calibration node initialized. Samples: {self.calibration_samples}. Auto: {self.auto_calibrate}"
)
if self.auto_calibrate:
self.calibrating = True
self.get_logger().info("Starting auto-calibration...")
def _on_imu_raw(self, msg: Imu) -> None:
if self.calibrating:
gyro = np.array([msg.angular_velocity.x, msg.angular_velocity.y, msg.angular_velocity.z])
accel = np.array([msg.linear_acceleration.x, msg.linear_acceleration.y, msg.linear_acceleration.z])
self.gyro_samples.append(gyro)
self.accel_samples.append(accel)
if len(self.gyro_samples) == self.calibration_samples:
self._compute_calibration()
else:
self._publish_calibrated(msg)
def _compute_calibration(self) -> None:
if len(self.gyro_samples) == 0 or len(self.accel_samples) == 0:
return
gyro_data = np.array(list(self.gyro_samples))
accel_data = np.array(list(self.accel_samples))
self.gyro_bias = np.mean(gyro_data, axis=0)
self.accel_bias = np.mean(accel_data, axis=0)
self.is_calibrated = True
self.calibrating = False
self.get_logger().info(
f"Calibration complete. Gyro: {self.gyro_bias}. Accel: {self.accel_bias}"
)
self.gyro_samples.clear()
self.accel_samples.clear()
def _on_calibrate_service(self, request, response) -> Trigger.Response:
if self.calibrating:
response.success = False
response.message = "Calibration already in progress"
return response
self.get_logger().info("Calibration service called")
self.calibrating = True
self.gyro_samples.clear()
self.accel_samples.clear()
response.success = True
response.message = f"Calibration started, collecting {self.calibration_samples} samples"
return response
def _publish_calibrated(self, msg: Imu) -> None:
calibrated = Imu()
calibrated.header = Header(frame_id=msg.header.frame_id, stamp=msg.header.stamp)
calibrated.angular_velocity.x = msg.angular_velocity.x - self.gyro_bias[0]
calibrated.angular_velocity.y = msg.angular_velocity.y - self.gyro_bias[1]
calibrated.angular_velocity.z = msg.angular_velocity.z - self.gyro_bias[2]
calibrated.linear_acceleration.x = msg.linear_acceleration.x - self.accel_bias[0]
calibrated.linear_acceleration.y = msg.linear_acceleration.y - self.accel_bias[1]
calibrated.linear_acceleration.z = msg.linear_acceleration.z - self.accel_bias[2]
calibrated.angular_velocity_covariance = msg.angular_velocity_covariance
calibrated.linear_acceleration_covariance = msg.linear_acceleration_covariance
calibrated.orientation_covariance = msg.orientation_covariance
calibrated.orientation = msg.orientation
self.pub_calibrated.publish(calibrated)
def main(args=None):
rclpy.init(args=args)
node = IMUCalibrationNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,5 @@
[develop]
script_dir=$base/lib/saltybot_imu_calibration
[install]
install_scripts=$base/lib/saltybot_imu_calibration

View File

@ -0,0 +1,24 @@
from setuptools import setup, find_packages
setup(
name='saltybot_imu_calibration',
version='0.1.0',
packages=find_packages(),
data_files=[
('share/ament_index/resource_index/packages', ['resource/saltybot_imu_calibration']),
('share/saltybot_imu_calibration', ['package.xml']),
('share/saltybot_imu_calibration/config', ['config/imu_calibration_config.yaml']),
('share/saltybot_imu_calibration/launch', ['launch/imu_calibration.launch.py']),
],
install_requires=['setuptools'],
zip_safe=True,
author='SaltyLab Controls',
author_email='sl-controls@saltylab.local',
description='IMU gyro + accel calibration node for SaltyBot',
license='MIT',
entry_points={
'console_scripts': [
'imu_calibration_node=saltybot_imu_calibration.imu_calibration_node:main',
],
},
)

View File

@ -0,0 +1,67 @@
"""Tests for IMU calibration node."""
import pytest
import numpy as np
from sensor_msgs.msg import Imu
from geometry_msgs.msg import Quaternion
import rclpy
from rclpy.time import Time
from saltybot_imu_calibration.imu_calibration_node import IMUCalibrationNode
@pytest.fixture
def rclpy_fixture():
rclpy.init()
yield
rclpy.shutdown()
@pytest.fixture
def node(rclpy_fixture):
node = IMUCalibrationNode()
yield node
node.destroy_node()
class TestInit:
def test_node_initialization(self, node):
assert node.calibration_samples == 100
assert node.is_calibrated is False
assert node.calibrating is False
class TestCalibration:
def test_calibration_samples(self, node):
node.calibration_samples = 3
node.gyro_samples.maxlen = 3
node.accel_samples.maxlen = 3
node.calibrating = True
for i in range(3):
node.gyro_samples.append(np.array([0.1, 0.2, 0.3]))
node.accel_samples.append(np.array([0.0, 0.0, 9.81]))
node._compute_calibration()
assert node.is_calibrated is True
assert len(node.gyro_samples) == 0
class TestCorrection:
def test_imu_correction(self, node):
node.gyro_bias = np.array([0.1, 0.2, 0.3])
node.accel_bias = np.array([0.0, 0.0, 0.1])
msg = Imu()
msg.header.stamp = Time().to_msg()
msg.header.frame_id = "imu_link"
msg.angular_velocity.x = 0.11
msg.angular_velocity.y = 0.22
msg.angular_velocity.z = 0.33
msg.linear_acceleration.x = 0.0
msg.linear_acceleration.y = 0.0
msg.linear_acceleration.z = 9.91
msg.orientation = Quaternion(x=0, y=0, z=0, w=1)
node._publish_calibrated(msg)

View File

@ -16,6 +16,9 @@ rosidl_generate_interfaces(${PROJECT_NAME}
# Issue #233 QR code reader # Issue #233 QR code reader
"msg/QRDetection.msg" "msg/QRDetection.msg"
"msg/QRDetectionArray.msg" "msg/QRDetectionArray.msg"
# Issue #274 HSV color segmentation
"msg/ColorDetection.msg"
"msg/ColorDetectionArray.msg"
DEPENDENCIES std_msgs geometry_msgs vision_msgs builtin_interfaces DEPENDENCIES std_msgs geometry_msgs vision_msgs builtin_interfaces
) )

View File

@ -0,0 +1,14 @@
# ColorDetection.msg — single HSV color-segmented object detection (Issue #274)
#
# color_name : target color label ("red", "green", "blue", "yellow", "orange")
# confidence : mask fill ratio inside bbox (contour_area / bbox_area, 01)
# bbox : axis-aligned bounding box in image pixels (center + size)
# area_px : contour area in pixels² (use for size filtering downstream)
# contour_id : 0-based index of this detection within the current frame
#
std_msgs/Header header
string color_name
float32 confidence
vision_msgs/BoundingBox2D bbox
float32 area_px
uint32 contour_id

View File

@ -0,0 +1,3 @@
# ColorDetectionArray.msg — frame-level list of HSV color-segmented objects (Issue #274)
std_msgs/Header header
ColorDetection[] detections

View File

@ -0,0 +1,21 @@
ambient_sound_node:
ros__parameters:
sample_rate: 16000 # Expected PCM sample rate (Hz)
window_s: 1.0 # Accumulate this many seconds before classifying
n_fft: 512 # FFT size (32 ms frame at 16 kHz)
n_mels: 32 # Mel filterbank bands
audio_topic: "/social/speech/audio_raw" # Source PCM-16 UInt8MultiArray topic
# ── Classifier thresholds ──────────────────────────────────────────────
# Adjust to tune sensitivity for your deployment environment.
silence_db: -40.0 # Below this energy (dBFS) → silence
alarm_db_min: -25.0 # Min energy for alarm detection
alarm_zcr_min: 0.12 # Min ZCR for alarm (intermittent high pitch)
alarm_high_ratio_min: 0.35 # Min high-band energy fraction for alarm
speech_zcr_min: 0.02 # Min ZCR for speech (voiced onset)
speech_zcr_max: 0.25 # Max ZCR for speech
speech_flatness_max: 0.35 # Max spectral flatness for speech (tonal)
music_zcr_max: 0.08 # Max ZCR for music (harmonic / tonal)
music_flatness_max: 0.25 # Max spectral flatness for music
crowd_zcr_min: 0.10 # Min ZCR for crowd noise
crowd_flatness_min: 0.35 # Min spectral flatness for crowd

View File

@ -0,0 +1,30 @@
face_track_servo_node:
ros__parameters:
# PID gains — pan axis
kp_pan: 1.5 # proportional gain (°/s per ° error)
ki_pan: 0.1 # integral gain
kd_pan: 0.05 # derivative gain (damping)
# PID gains — tilt axis
kp_tilt: 1.2
ki_tilt: 0.1
kd_tilt: 0.04
# Camera FOV
fov_h_deg: 60.0 # horizontal field of view (degrees)
fov_v_deg: 45.0 # vertical field of view (degrees)
# Servo limits
pan_limit_deg: 90.0 # mechanical pan range ± (degrees)
tilt_limit_deg: 30.0 # mechanical tilt range ± (degrees)
pan_vel_limit: 45.0 # max pan rate (°/s)
tilt_vel_limit: 30.0 # max tilt rate (°/s)
windup_limit: 15.0 # integral anti-windup clamp (°·s)
# Tracking behaviour
dead_zone: 0.02 # normalised dead zone (fraction of frame width/height)
control_rate: 20.0 # control loop frequency (Hz)
lost_timeout_s: 1.5 # seconds before face considered lost
return_rate_deg_s: 10.0 # return-to-centre speed when no face (°/s)
faces_topic: "/social/faces/detected"

View File

@ -0,0 +1,8 @@
greeting_trigger_node:
ros__parameters:
proximity_m: 2.0 # Trigger when person is within this distance (m)
cooldown_s: 300.0 # Re-greeting suppression window per face_id (s)
unknown_distance: 0.0 # Distance assumed when PersonState not yet available
# 0.0 → always greet faces with no state yet
faces_topic: "/social/faces/detected"
states_topic: "/social/person_states"

View File

@ -0,0 +1,42 @@
"""ambient_sound.launch.py -- Launch the ambient sound classifier (Issue #252).
Usage:
ros2 launch saltybot_social ambient_sound.launch.py
ros2 launch saltybot_social ambient_sound.launch.py silence_db:=-45.0
"""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
pkg = get_package_share_directory("saltybot_social")
cfg = os.path.join(pkg, "config", "ambient_sound_params.yaml")
return LaunchDescription([
DeclareLaunchArgument("window_s", default_value="1.0",
description="Accumulation window (s)"),
DeclareLaunchArgument("n_mels", default_value="32",
description="Mel filterbank bands"),
DeclareLaunchArgument("silence_db", default_value="-40.0",
description="Silence energy threshold (dBFS)"),
Node(
package="saltybot_social",
executable="ambient_sound_node",
name="ambient_sound_node",
output="screen",
parameters=[
cfg,
{
"window_s": LaunchConfiguration("window_s"),
"n_mels": LaunchConfiguration("n_mels"),
"silence_db": LaunchConfiguration("silence_db"),
},
],
),
])

View File

@ -0,0 +1,51 @@
"""face_track_servo.launch.py — Launch face-tracking head servo controller (Issue #279).
Usage:
ros2 launch saltybot_social face_track_servo.launch.py
ros2 launch saltybot_social face_track_servo.launch.py kp_pan:=2.0 pan_limit_deg:=60.0
"""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
pkg = get_package_share_directory("saltybot_social")
cfg = os.path.join(pkg, "config", "face_track_servo_params.yaml")
return LaunchDescription([
DeclareLaunchArgument("kp_pan", default_value="1.5",
description="Pan proportional gain (°/s per °)"),
DeclareLaunchArgument("kp_tilt", default_value="1.2",
description="Tilt proportional gain (°/s per °)"),
DeclareLaunchArgument("pan_limit_deg", default_value="90.0",
description="Mechanical pan limit ± (degrees)"),
DeclareLaunchArgument("tilt_limit_deg", default_value="30.0",
description="Mechanical tilt limit ± (degrees)"),
DeclareLaunchArgument("fov_h_deg", default_value="60.0",
description="Camera horizontal FOV (degrees)"),
DeclareLaunchArgument("fov_v_deg", default_value="45.0",
description="Camera vertical FOV (degrees)"),
Node(
package="saltybot_social",
executable="face_track_servo_node",
name="face_track_servo_node",
output="screen",
parameters=[
cfg,
{
"kp_pan": LaunchConfiguration("kp_pan"),
"kp_tilt": LaunchConfiguration("kp_tilt"),
"pan_limit_deg": LaunchConfiguration("pan_limit_deg"),
"tilt_limit_deg": LaunchConfiguration("tilt_limit_deg"),
"fov_h_deg": LaunchConfiguration("fov_h_deg"),
"fov_v_deg": LaunchConfiguration("fov_v_deg"),
},
],
),
])

View File

@ -0,0 +1,39 @@
"""greeting_trigger.launch.py -- Launch proximity-based greeting trigger (Issue #270).
Usage:
ros2 launch saltybot_social greeting_trigger.launch.py
ros2 launch saltybot_social greeting_trigger.launch.py proximity_m:=1.5 cooldown_s:=120.0
"""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
pkg = get_package_share_directory("saltybot_social")
cfg = os.path.join(pkg, "config", "greeting_trigger_params.yaml")
return LaunchDescription([
DeclareLaunchArgument("proximity_m", default_value="2.0",
description="Greeting proximity threshold (m)"),
DeclareLaunchArgument("cooldown_s", default_value="300.0",
description="Per-face_id re-greeting cooldown (s)"),
Node(
package="saltybot_social",
executable="greeting_trigger_node",
name="greeting_trigger_node",
output="screen",
parameters=[
cfg,
{
"proximity_m": LaunchConfiguration("proximity_m"),
"cooldown_s": LaunchConfiguration("cooldown_s"),
},
],
),
])

View File

@ -0,0 +1,363 @@
"""ambient_sound_node.py -- Ambient sound classifier via mel-spectrogram features.
Issue #252
Accumulates 1 s of PCM-16 audio from /social/speech/audio_raw, extracts a
compact mel-spectrogram feature vector, then classifies the scene into one of:
silence | speech | music | crowd | outdoor | alarm
Publishes the label as std_msgs/String on /saltybot/ambient_sound at 1 Hz.
Signal processing is pure Python + numpy (no torch / onnx dependency).
Feature vector (per 1-s window):
energy_db -- overall RMS in dBFS
zcr -- mean zero-crossing rate across frames
mel_centroid -- centre-of-mass of the mel band energies [0..1]
mel_flatness -- geometric/arithmetic mean of mel energies [0..1]
(1 = white noise, 0 = single sinusoid)
low_ratio -- fraction of mel energy in lower third of bands
high_ratio -- fraction of mel energy in upper third of bands
Classification cascade (priority-ordered):
silence : energy_db < silence_db
alarm : energy_db >= alarm_db_min AND zcr >= alarm_zcr_min
AND high_ratio >= alarm_high_ratio_min
speech : zcr in [speech_zcr_min, speech_zcr_max]
AND mel_flatness < speech_flatness_max
music : zcr < music_zcr_max AND mel_flatness < music_flatness_max
crowd : zcr >= crowd_zcr_min AND mel_flatness >= crowd_flatness_min
outdoor : catch-all
Parameters:
sample_rate (int, 16000)
window_s (float, 1.0) -- accumulation window before classify
n_fft (int, 512) -- FFT size
n_mels (int, 32) -- mel filterbank bands
audio_topic (str, "/social/speech/audio_raw")
silence_db (float, -40.0)
alarm_db_min (float, -25.0)
alarm_zcr_min (float, 0.12)
alarm_high_ratio_min (float, 0.35)
speech_zcr_min (float, 0.02)
speech_zcr_max (float, 0.25)
speech_flatness_max (float, 0.35)
music_zcr_max (float, 0.08)
music_flatness_max (float, 0.25)
crowd_zcr_min (float, 0.10)
crowd_flatness_min (float, 0.35)
"""
from __future__ import annotations
import math
import struct
import threading
from typing import Dict, List, Optional
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile
from std_msgs.msg import String, UInt8MultiArray
# numpy used only in DSP helpers — the Jetson always has it
try:
import numpy as np
_NUMPY = True
except ImportError:
_NUMPY = False
INT16_MAX = 32768.0
LABELS = ("silence", "speech", "music", "crowd", "outdoor", "alarm")
# ── PCM helpers ───────────────────────────────────────────────────────────────
def pcm16_bytes_to_float32(data: bytes) -> List[float]:
"""PCM-16 LE bytes → float32 list in [-1.0, 1.0]."""
n = len(data) // 2
if n == 0:
return []
return [s / INT16_MAX for s in struct.unpack(f"<{n}h", data[: n * 2])]
# ── Mel DSP (numpy path) ──────────────────────────────────────────────────────
def hz_to_mel(hz: float) -> float:
return 2595.0 * math.log10(1.0 + hz / 700.0)
def mel_to_hz(mel: float) -> float:
return 700.0 * (10.0 ** (mel / 2595.0) - 1.0)
def build_mel_filterbank(sr: int, n_fft: int, n_mels: int,
fmin: float = 0.0, fmax: Optional[float] = None):
"""Return (n_mels, n_fft//2+1) numpy filterbank matrix."""
import numpy as np
if fmax is None:
fmax = sr / 2.0
n_freqs = n_fft // 2 + 1
mel_min = hz_to_mel(fmin)
mel_max = hz_to_mel(fmax)
mel_pts = np.linspace(mel_min, mel_max, n_mels + 2)
hz_pts = np.array([mel_to_hz(m) for m in mel_pts])
bin_pts = np.floor((n_fft + 1) * hz_pts / sr).astype(int)
fb = np.zeros((n_mels, n_freqs))
for m in range(n_mels):
lo, ctr, hi = bin_pts[m], bin_pts[m + 1], bin_pts[m + 2]
for k in range(lo, min(ctr, n_freqs)):
if ctr != lo:
fb[m, k] = (k - lo) / (ctr - lo)
for k in range(ctr, min(hi, n_freqs)):
if hi != ctr:
fb[m, k] = (hi - k) / (hi - ctr)
return fb
def compute_mel_spectrogram(samples: List[float], sr: int,
n_fft: int = 512, n_mels: int = 32,
hop_length: int = 256):
"""Return (n_mels, n_frames) log-mel spectrogram (numpy array)."""
import numpy as np
x = np.array(samples, dtype=np.float32)
fb = build_mel_filterbank(sr, n_fft, n_mels)
window = np.hanning(n_fft)
frames = []
for start in range(0, len(x) - n_fft + 1, hop_length):
frame = x[start : start + n_fft] * window
spec = np.abs(np.fft.rfft(frame)) ** 2
mel = fb @ spec
frames.append(mel)
if not frames:
return np.zeros((n_mels, 1), dtype=np.float32)
return np.column_stack(frames).astype(np.float32)
# ── Feature extraction ────────────────────────────────────────────────────────
def extract_features(samples: List[float], sr: int,
n_fft: int = 512, n_mels: int = 32) -> Dict[str, float]:
"""Extract scalar features from a raw audio window."""
import numpy as np
n = len(samples)
if n == 0:
return {k: 0.0 for k in
("energy_db", "zcr", "mel_centroid", "mel_flatness",
"low_ratio", "high_ratio")}
# Energy
rms = math.sqrt(sum(s * s for s in samples) / n) if n else 0.0
energy_db = 20.0 * math.log10(max(rms, 1e-10))
# ZCR across 30 ms frames
chunk = max(1, int(sr * 0.030))
zcr_vals = []
for i in range(0, n - chunk + 1, chunk):
seg = samples[i : i + chunk]
crossings = sum(1 for j in range(1, len(seg))
if seg[j - 1] * seg[j] < 0)
zcr_vals.append(crossings / max(len(seg) - 1, 1))
zcr = sum(zcr_vals) / len(zcr_vals) if zcr_vals else 0.0
# Mel spectrogram features
mel_spec = compute_mel_spectrogram(samples, sr, n_fft, n_mels)
mel_mean = mel_spec.mean(axis=1) # (n_mels,) mean energy per band
total = float(mel_mean.sum()) if mel_mean.sum() > 0 else 1e-10
indices = np.arange(n_mels, dtype=np.float32)
mel_centroid = float((indices * mel_mean).sum()) / (n_mels * total / total) / n_mels
# Spectral flatness: geometric mean / arithmetic mean
eps = 1e-10
mel_pos = np.clip(mel_mean, eps, None)
geo_mean = float(np.exp(np.log(mel_pos).mean()))
arith_mean = float(mel_pos.mean())
mel_flatness = min(geo_mean / max(arith_mean, eps), 1.0)
# Band ratios
third = max(1, n_mels // 3)
low_energy = float(mel_mean[:third].sum())
high_energy = float(mel_mean[-third:].sum())
low_ratio = low_energy / max(total, eps)
high_ratio = high_energy / max(total, eps)
return {
"energy_db": energy_db,
"zcr": zcr,
"mel_centroid": mel_centroid,
"mel_flatness": mel_flatness,
"low_ratio": low_ratio,
"high_ratio": high_ratio,
}
# ── Classifier ────────────────────────────────────────────────────────────────
def classify(features: Dict[str, float],
silence_db: float = -40.0,
alarm_db_min: float = -25.0,
alarm_zcr_min: float = 0.12,
alarm_high_ratio_min: float = 0.35,
speech_zcr_min: float = 0.02,
speech_zcr_max: float = 0.25,
speech_flatness_max: float = 0.35,
music_zcr_max: float = 0.08,
music_flatness_max: float = 0.25,
crowd_zcr_min: float = 0.10,
crowd_flatness_min: float = 0.35) -> str:
"""Priority-ordered rule cascade. Returns a label from LABELS."""
e = features["energy_db"]
zcr = features["zcr"]
fl = features["mel_flatness"]
hi = features["high_ratio"]
if e < silence_db:
return "silence"
if (e >= alarm_db_min
and zcr >= alarm_zcr_min
and hi >= alarm_high_ratio_min):
return "alarm"
if zcr < music_zcr_max and fl < music_flatness_max:
return "music"
if (speech_zcr_min <= zcr <= speech_zcr_max
and fl < speech_flatness_max):
return "speech"
if zcr >= crowd_zcr_min and fl >= crowd_flatness_min:
return "crowd"
return "outdoor"
# ── Audio accumulation buffer ─────────────────────────────────────────────────
class AudioBuffer:
"""Thread-safe ring buffer; yields a window of samples when full."""
def __init__(self, window_samples: int) -> None:
self._target = window_samples
self._buf: List[float] = []
self._lock = threading.Lock()
def push(self, samples: List[float]) -> Optional[List[float]]:
"""Append samples. Returns a complete window (and resets) when full."""
with self._lock:
self._buf.extend(samples)
if len(self._buf) >= self._target:
window = self._buf[: self._target]
self._buf = self._buf[self._target :]
return window
return None
def clear(self) -> None:
with self._lock:
self._buf.clear()
# ── ROS2 node ─────────────────────────────────────────────────────────────────
class AmbientSoundNode(Node):
"""Classifies ambient sound from raw audio and publishes label at 1 Hz."""
def __init__(self) -> None:
super().__init__("ambient_sound_node")
self.declare_parameter("sample_rate", 16000)
self.declare_parameter("window_s", 1.0)
self.declare_parameter("n_fft", 512)
self.declare_parameter("n_mels", 32)
self.declare_parameter("audio_topic", "/social/speech/audio_raw")
# Classifier thresholds
self.declare_parameter("silence_db", -40.0)
self.declare_parameter("alarm_db_min", -25.0)
self.declare_parameter("alarm_zcr_min", 0.12)
self.declare_parameter("alarm_high_ratio_min", 0.35)
self.declare_parameter("speech_zcr_min", 0.02)
self.declare_parameter("speech_zcr_max", 0.25)
self.declare_parameter("speech_flatness_max", 0.35)
self.declare_parameter("music_zcr_max", 0.08)
self.declare_parameter("music_flatness_max", 0.25)
self.declare_parameter("crowd_zcr_min", 0.10)
self.declare_parameter("crowd_flatness_min", 0.35)
self._sr = self.get_parameter("sample_rate").value
self._n_fft = self.get_parameter("n_fft").value
self._n_mels = self.get_parameter("n_mels").value
window_s = self.get_parameter("window_s").value
audio_topic = self.get_parameter("audio_topic").value
self._thresholds = {
k: self.get_parameter(k).value for k in (
"silence_db", "alarm_db_min", "alarm_zcr_min",
"alarm_high_ratio_min", "speech_zcr_min", "speech_zcr_max",
"speech_flatness_max", "music_zcr_max", "music_flatness_max",
"crowd_zcr_min", "crowd_flatness_min",
)
}
self._buffer = AudioBuffer(int(self._sr * window_s))
self._last_label = "silence"
qos = QoSProfile(depth=10)
self._pub = self.create_publisher(String, "/saltybot/ambient_sound", qos)
self._audio_sub = self.create_subscription(
UInt8MultiArray, audio_topic, self._on_audio, qos
)
if not _NUMPY:
self.get_logger().warn(
"numpy not available — mel features disabled, classifying by energy only"
)
self.get_logger().info(
f"AmbientSoundNode ready "
f"(sr={self._sr}, window={window_s}s, n_mels={self._n_mels})"
)
def _on_audio(self, msg: UInt8MultiArray) -> None:
samples = pcm16_bytes_to_float32(bytes(msg.data))
if not samples:
return
window = self._buffer.push(samples)
if window is not None:
self._classify_and_publish(window)
def _classify_and_publish(self, samples: List[float]) -> None:
try:
if _NUMPY:
feats = extract_features(samples, self._sr, self._n_fft, self._n_mels)
else:
# Numpy-free fallback: energy-only
rms = math.sqrt(sum(s * s for s in samples) / len(samples))
e_db = 20.0 * math.log10(max(rms, 1e-10))
feats = {
"energy_db": e_db, "zcr": 0.05,
"mel_centroid": 0.5, "mel_flatness": 0.2,
"low_ratio": 0.4, "high_ratio": 0.2,
}
label = classify(feats, **self._thresholds)
except Exception as exc:
self.get_logger().error(f"Classification error: {exc}")
label = self._last_label
if label != self._last_label:
self.get_logger().info(
f"Ambient sound: {self._last_label} -> {label}"
)
self._last_label = label
msg = String()
msg.data = label
self._pub.publish(msg)
def main(args: Optional[list] = None) -> None:
rclpy.init(args=args)
node = AmbientSoundNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@ -0,0 +1,308 @@
"""face_track_servo_node.py — Face-tracking head servo controller.
Issue #279
Subscribes to /social/faces/detected, picks the closest face (largest
bounding-box area as a proximity proxy), computes pan/tilt angular error
relative to the image centre, and drives two PID controllers to produce
smooth servo position commands published on /saltybot/head_pan and
/saltybot/head_tilt (std_msgs/Float32, degrees from neutral).
Coordinate convention
bbox_x/y/w/h : normalised [0, 1] in image space
face centre : cx = bbox_x + bbox_w/2 , cy = bbox_y + bbox_h/2
image centre : (0.5, 0.5)
pan error : (cx - 0.5) * fov_h_deg (+ve face right of centre)
tilt error : (cy - 0.5) * fov_v_deg (+ve face below centre)
PID design (velocity / incremental)
velocity (°/s) = Kp·e + Ki·e dt + Kd·de/dt
servo_angle += velocity · dt
servo_angle = clamp(servo_angle, ±limit_deg)
When no face is seen for more than ``lost_timeout_s`` seconds the PIDs
are reset and the servo commands return toward 0° at ``return_rate_deg_s``.
Parameters
kp_pan (float, 1.5) pan proportional gain (°/s per °)
ki_pan (float, 0.1) pan integral gain
kd_pan (float, 0.05) pan derivative gain
kp_tilt (float, 1.2) tilt proportional gain
ki_tilt (float, 0.1) tilt integral gain
kd_tilt (float, 0.04) tilt derivative gain
fov_h_deg (float, 60.0) camera horizontal FOV (degrees)
fov_v_deg (float, 45.0) camera vertical FOV (degrees)
pan_limit_deg (float, 90.0) mechanical pan limit ±
tilt_limit_deg (float, 30.0) mechanical tilt limit ±
pan_vel_limit (float, 45.0) max pan rate (°/s)
tilt_vel_limit (float, 30.0) max tilt rate (°/s)
windup_limit (float, 15.0) integral anti-windup clamp (°·s)
dead_zone (float, 0.02) normalised dead zone (fraction of frame)
control_rate (float, 20.0) control loop Hz
lost_timeout_s (float, 1.5) seconds before face considered lost
return_rate_deg_s (float, 10.0) return-to-centre rate when no face (°/s)
faces_topic (str) default "/social/faces/detected"
"""
from __future__ import annotations
import math
import time
import threading
from typing import Optional
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile
from std_msgs.msg import Float32
try:
from saltybot_social_msgs.msg import FaceDetectionArray
_MSGS = True
except ImportError:
_MSGS = False
# ── Pure helpers ───────────────────────────────────────────────────────────────
def clamp(v: float, lo: float, hi: float) -> float:
return max(lo, min(hi, v))
def bbox_area(face) -> float:
"""Bounding-box area as a proximity proxy (larger ≈ closer)."""
return float(face.bbox_w) * float(face.bbox_h)
def pick_closest_face(faces):
"""Return the face with the largest bbox area; None if list is empty."""
if not faces:
return None
return max(faces, key=bbox_area)
def face_image_error(face, fov_h_deg: float, fov_v_deg: float):
"""Return (pan_error_deg, tilt_error_deg) for a FaceDetection.
Positive pan face is right of image centre.
Positive tilt face is below image centre.
"""
cx = float(face.bbox_x) + float(face.bbox_w) / 2.0
cy = float(face.bbox_y) + float(face.bbox_h) / 2.0
pan_err = (cx - 0.5) * fov_h_deg
tilt_err = (cy - 0.5) * fov_v_deg
return pan_err, tilt_err
# ── PID controller ─────────────────────────────────────────────────────────────
class PIDController:
"""Incremental (velocity-output) PID with anti-windup.
Output units: degrees/second (servo angular velocity).
Integrate externally: servo_angle += pid.update(error, dt) * dt
"""
def __init__(self, kp: float, ki: float, kd: float,
vel_limit: float, windup_limit: float) -> None:
self.kp = kp
self.ki = ki
self.kd = kd
self.vel_limit = vel_limit
self.windup_limit = windup_limit
self._integral = 0.0
self._prev_error = 0.0
self._first = True
def update(self, error: float, dt: float) -> float:
"""Return velocity command (°/s). Call every control tick."""
if dt <= 0.0:
return 0.0
self._integral += error * dt
self._integral = clamp(self._integral, -self.windup_limit,
self.windup_limit)
if self._first:
derivative = 0.0
self._first = False
else:
derivative = (error - self._prev_error) / dt
self._prev_error = error
output = (self.kp * error
+ self.ki * self._integral
+ self.kd * derivative)
return clamp(output, -self.vel_limit, self.vel_limit)
def reset(self) -> None:
self._integral = 0.0
self._prev_error = 0.0
self._first = True
# ── ROS2 node ──────────────────────────────────────────────────────────────────
class FaceTrackServoNode(Node):
"""Smooth PID face-tracking servo controller."""
def __init__(self) -> None:
super().__init__("face_track_servo_node")
# Declare parameters
self.declare_parameter("kp_pan", 1.5)
self.declare_parameter("ki_pan", 0.1)
self.declare_parameter("kd_pan", 0.05)
self.declare_parameter("kp_tilt", 1.2)
self.declare_parameter("ki_tilt", 0.1)
self.declare_parameter("kd_tilt", 0.04)
self.declare_parameter("fov_h_deg", 60.0)
self.declare_parameter("fov_v_deg", 45.0)
self.declare_parameter("pan_limit_deg", 90.0)
self.declare_parameter("tilt_limit_deg", 30.0)
self.declare_parameter("pan_vel_limit", 45.0)
self.declare_parameter("tilt_vel_limit", 30.0)
self.declare_parameter("windup_limit", 15.0)
self.declare_parameter("dead_zone", 0.02)
self.declare_parameter("control_rate", 20.0)
self.declare_parameter("lost_timeout_s", 1.5)
self.declare_parameter("return_rate_deg_s", 10.0)
self.declare_parameter("faces_topic", "/social/faces/detected")
self._reload_params()
# Servo state
self._pan_cmd = 0.0
self._tilt_cmd = 0.0
self._last_face_t: float = 0.0
self._latest_face = None
self._lock = threading.Lock()
qos = QoSProfile(depth=10)
self._pan_pub = self.create_publisher(Float32, "/saltybot/head_pan", qos)
self._tilt_pub = self.create_publisher(Float32, "/saltybot/head_tilt", qos)
faces_topic = self.get_parameter("faces_topic").value
if _MSGS:
self._faces_sub = self.create_subscription(
FaceDetectionArray, faces_topic, self._on_faces, qos
)
else:
self.get_logger().warn(
"saltybot_social_msgs not available — node passive (no subscription)"
)
rate = self.get_parameter("control_rate").value
self._timer = self.create_timer(1.0 / rate, self._control_cb)
self._last_tick = time.monotonic()
self.get_logger().info(
f"FaceTrackServoNode ready "
f"(rate={rate}Hz, fov={self._fov_h}×{self._fov_v}°, "
f"pan±{self._pan_limit}°, tilt±{self._tilt_limit}°)"
)
def _reload_params(self) -> None:
self._fov_h = self.get_parameter("fov_h_deg").value
self._fov_v = self.get_parameter("fov_v_deg").value
self._pan_limit = self.get_parameter("pan_limit_deg").value
self._tilt_limit = self.get_parameter("tilt_limit_deg").value
self._dead_zone = self.get_parameter("dead_zone").value
self._lost_t = self.get_parameter("lost_timeout_s").value
self._return_rate = self.get_parameter("return_rate_deg_s").value
self._pid_pan = PIDController(
kp=self.get_parameter("kp_pan").value,
ki=self.get_parameter("ki_pan").value,
kd=self.get_parameter("kd_pan").value,
vel_limit=self.get_parameter("pan_vel_limit").value,
windup_limit=self.get_parameter("windup_limit").value,
)
self._pid_tilt = PIDController(
kp=self.get_parameter("kp_tilt").value,
ki=self.get_parameter("ki_tilt").value,
kd=self.get_parameter("kd_tilt").value,
vel_limit=self.get_parameter("tilt_vel_limit").value,
windup_limit=self.get_parameter("windup_limit").value,
)
# ── Subscription callback ──────────────────────────────────────────────
def _on_faces(self, msg) -> None:
face = pick_closest_face(msg.faces)
with self._lock:
self._latest_face = face
if face is not None:
self._last_face_t = time.monotonic()
# ── Control loop ───────────────────────────────────────────────────────
def _control_cb(self) -> None:
now = time.monotonic()
dt = now - self._last_tick
self._last_tick = now
dt = max(dt, 1e-4) # guard against zero dt at startup
with self._lock:
face = self._latest_face
last_face_t = self._last_face_t
face_fresh = (last_face_t > 0.0 and (now - last_face_t) < self._lost_t)
if not face_fresh or face is None:
# Return to centre
self._pid_pan.reset()
self._pid_tilt.reset()
step = self._return_rate * dt
self._pan_cmd = _step_toward_zero(self._pan_cmd, step)
self._tilt_cmd = _step_toward_zero(self._tilt_cmd, step)
else:
pan_err, tilt_err = face_image_error(face, self._fov_h, self._fov_v)
# Dead zone (normalised fraction → degrees)
dead_deg_h = self._dead_zone * self._fov_h
dead_deg_v = self._dead_zone * self._fov_v
if abs(pan_err) < dead_deg_h:
self._pid_pan.reset()
else:
vel_pan = self._pid_pan.update(pan_err, dt)
self._pan_cmd = clamp(
self._pan_cmd + vel_pan * dt,
-self._pan_limit, self._pan_limit,
)
if abs(tilt_err) < dead_deg_v:
self._pid_tilt.reset()
else:
vel_tilt = self._pid_tilt.update(tilt_err, dt)
self._tilt_cmd = clamp(
self._tilt_cmd + vel_tilt * dt,
-self._tilt_limit, self._tilt_limit,
)
pan_msg = Float32(); pan_msg.data = float(self._pan_cmd)
tilt_msg = Float32(); tilt_msg.data = float(self._tilt_cmd)
self._pan_pub.publish(pan_msg)
self._tilt_pub.publish(tilt_msg)
def _step_toward_zero(value: float, step: float) -> float:
"""Move value toward 0 by step without overshooting."""
if abs(value) <= step:
return 0.0
return value - math.copysign(step, value)
def main(args=None) -> None:
rclpy.init(args=args)
node = FaceTrackServoNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@ -0,0 +1,150 @@
"""greeting_trigger_node.py -- Proximity-based greeting trigger.
Issue #270
Monitors face detections and person states. When a new face_id is seen
within ``proximity_m`` metres (default 2 m) and has not been greeted within
``cooldown_s`` seconds, publishes a JSON greeting trigger on
/saltybot/greeting_trigger.
Distance is looked up from the /social/person_states topic which carries a
face_id distance mapping. When no state is available for a face the node
applies a configurable default distance so it can still fire on face-only
pipelines.
Subscriptions:
/social/faces/detected saltybot_social_msgs/FaceDetectionArray
/social/person_states saltybot_social_msgs/PersonStateArray
Publication:
/saltybot/greeting_trigger std_msgs/String (JSON)
{"face_id": <int>, "person_name": <str>, "distance_m": <float>,
"ts": <float unix epoch>}
Parameters:
proximity_m (float, 2.0) -- trigger when distance <= this
cooldown_s (float, 300.0) -- suppress re-greeting same face_id
unknown_distance (float, 0.0) -- distance assumed when PersonState
is not yet available (0.0 always
trigger for unknown faces)
faces_topic (str, "/social/faces/detected")
states_topic (str, "/social/person_states")
"""
from __future__ import annotations
import json
import time
import threading
from typing import Dict
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile
from std_msgs.msg import String
# Custom messages — imported at runtime so offline tests can stub them
try:
from saltybot_social_msgs.msg import FaceDetectionArray, PersonStateArray
_MSGS = True
except ImportError:
_MSGS = False
class GreetingTriggerNode(Node):
"""Publishes greeting trigger when a person enters proximity."""
def __init__(self) -> None:
super().__init__("greeting_trigger_node")
self.declare_parameter("proximity_m", 2.0)
self.declare_parameter("cooldown_s", 300.0)
self.declare_parameter("unknown_distance", 0.0)
self.declare_parameter("faces_topic", "/social/faces/detected")
self.declare_parameter("states_topic", "/social/person_states")
self._proximity = self.get_parameter("proximity_m").value
self._cooldown = self.get_parameter("cooldown_s").value
self._unknown_dist = self.get_parameter("unknown_distance").value
faces_topic = self.get_parameter("faces_topic").value
states_topic = self.get_parameter("states_topic").value
# face_id → last known distance (m); updated from PersonStateArray
self._distance_cache: Dict[int, float] = {}
# face_id → unix timestamp of last greeting
self._last_greeted: Dict[int, float] = {}
self._lock = threading.Lock()
qos = QoSProfile(depth=10)
self._pub = self.create_publisher(String, "/saltybot/greeting_trigger", qos)
if _MSGS:
self._states_sub = self.create_subscription(
PersonStateArray, states_topic, self._on_person_states, qos
)
self._faces_sub = self.create_subscription(
FaceDetectionArray, faces_topic, self._on_faces, qos
)
else:
self.get_logger().warn(
"saltybot_social_msgs not available — node is passive (no subscriptions)"
)
self.get_logger().info(
f"GreetingTriggerNode ready "
f"(proximity={self._proximity}m, cooldown={self._cooldown}s)"
)
# ── Callbacks ──────────────────────────────────────────────────────────
def _on_person_states(self, msg: "PersonStateArray") -> None:
"""Cache face_id → distance from incoming PersonState array."""
with self._lock:
for ps in msg.persons:
if ps.face_id >= 0:
self._distance_cache[ps.face_id] = float(ps.distance)
def _on_faces(self, msg: "FaceDetectionArray") -> None:
"""Evaluate each detected face; fire greeting if conditions met."""
now = time.monotonic()
with self._lock:
for face in msg.faces:
fid = int(face.face_id)
dist = self._distance_cache.get(fid, self._unknown_dist)
if dist > self._proximity:
continue # too far
last = self._last_greeted.get(fid, 0.0)
if now - last < self._cooldown:
continue # still in cooldown
# Fire!
self._last_greeted[fid] = now
self._fire(fid, str(face.person_name), dist)
def _fire(self, face_id: int, person_name: str, distance_m: float) -> None:
payload = {
"face_id": face_id,
"person_name": person_name,
"distance_m": round(distance_m, 3),
"ts": time.time(),
}
msg = String()
msg.data = json.dumps(payload)
self._pub.publish(msg)
self.get_logger().info(
f"Greeting trigger: face_id={face_id} name={person_name!r} "
f"dist={distance_m:.2f}m"
)
def main(args=None) -> None:
rclpy.init(args=args)
node = GreetingTriggerNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@ -45,6 +45,12 @@ setup(
'mesh_comms_node = saltybot_social.mesh_comms_node:main', 'mesh_comms_node = saltybot_social.mesh_comms_node:main',
# Energy+ZCR voice activity detection (Issue #242) # Energy+ZCR voice activity detection (Issue #242)
'vad_node = saltybot_social.vad_node:main', 'vad_node = saltybot_social.vad_node:main',
# Ambient sound classifier — mel-spectrogram (Issue #252)
'ambient_sound_node = saltybot_social.ambient_sound_node:main',
# Proximity-based greeting trigger (Issue #270)
'greeting_trigger_node = saltybot_social.greeting_trigger_node:main',
# Face-tracking head servo controller (Issue #279)
'face_track_servo_node = saltybot_social.face_track_servo_node:main',
], ],
}, },
) )

View File

@ -0,0 +1,407 @@
"""test_ambient_sound.py -- Unit tests for Issue #252 ambient sound classifier."""
from __future__ import annotations
import importlib.util, math, os, struct, sys, types
import pytest
# numpy is available on dev machine
import numpy as np
def _pkg_root():
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def _read_src(rel_path):
with open(os.path.join(_pkg_root(), rel_path)) as f:
return f.read()
def _import_mod():
"""Import ambient_sound_node without a live ROS2 environment."""
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
"std_msgs", "std_msgs.msg"):
if mod_name not in sys.modules:
sys.modules[mod_name] = types.ModuleType(mod_name)
rclpy_node = sys.modules["rclpy.node"]
rclpy_qos = sys.modules["rclpy.qos"]
std_msg = sys.modules["std_msgs.msg"]
DEFAULTS = {
"sample_rate": 16000, "window_s": 1.0, "n_fft": 512, "n_mels": 32,
"audio_topic": "/social/speech/audio_raw",
"silence_db": -40.0, "alarm_db_min": -25.0, "alarm_zcr_min": 0.12,
"alarm_high_ratio_min": 0.35, "speech_zcr_min": 0.02,
"speech_zcr_max": 0.25, "speech_flatness_max": 0.35,
"music_zcr_max": 0.08, "music_flatness_max": 0.25,
"crowd_zcr_min": 0.10, "crowd_flatness_min": 0.35,
}
class _Node:
def __init__(self, *a, **kw): pass
def declare_parameter(self, *a, **kw): pass
def get_parameter(self, name):
class _P:
value = DEFAULTS.get(name)
return _P()
def create_publisher(self, *a, **kw): return None
def create_subscription(self, *a, **kw): return None
def get_logger(self):
class _L:
def info(self, *a): pass
def warn(self, *a): pass
def error(self, *a): pass
return _L()
def destroy_node(self): pass
rclpy_node.Node = _Node
rclpy_qos.QoSProfile = type("QoSProfile", (), {"__init__": lambda s, **kw: None})
std_msg.String = type("String", (), {"data": ""})
std_msg.UInt8MultiArray = type("UInt8MultiArray", (), {"data": b""})
sys.modules["rclpy"].init = lambda *a, **kw: None
sys.modules["rclpy"].spin = lambda n: None
sys.modules["rclpy"].ok = lambda: True
sys.modules["rclpy"].shutdown = lambda: None
spec = importlib.util.spec_from_file_location(
"ambient_sound_node_testmod",
os.path.join(_pkg_root(), "saltybot_social", "ambient_sound_node.py"),
)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
# ── Audio helpers ─────────────────────────────────────────────────────────────
SR = 16000
def _sine(freq, n=SR, amp=0.2):
return [amp * math.sin(2 * math.pi * freq * i / SR) for i in range(n)]
def _white_noise(n=SR, amp=0.1):
import random
rng = random.Random(42)
return [rng.uniform(-amp, amp) for _ in range(n)]
def _silence(n=SR):
return [0.0] * n
def _pcm16(samples):
ints = [max(-32768, min(32767, int(s * 32768))) for s in samples]
return struct.pack(f"<{len(ints)}h", *ints)
# ── TestPcm16Convert ──────────────────────────────────────────────────────────
class TestPcm16Convert:
@pytest.fixture(scope="class")
def mod(self): return _import_mod()
def test_empty(self, mod):
assert mod.pcm16_bytes_to_float32(b"") == []
def test_length(self, mod):
data = _pcm16(_sine(440, 480))
assert len(mod.pcm16_bytes_to_float32(data)) == 480
def test_range(self, mod):
data = _pcm16(_sine(440, 480))
result = mod.pcm16_bytes_to_float32(data)
assert all(-1.0 <= s <= 1.0 for s in result)
def test_silence(self, mod):
data = _pcm16(_silence(100))
assert all(s == 0.0 for s in mod.pcm16_bytes_to_float32(data))
# ── TestMelConversions ────────────────────────────────────────────────────────
class TestMelConversions:
@pytest.fixture(scope="class")
def mod(self): return _import_mod()
def test_hz_to_mel_zero(self, mod):
assert mod.hz_to_mel(0.0) == 0.0
def test_hz_to_mel_1000(self, mod):
# 1000 Hz → ~999.99 mel (approximately)
assert abs(mod.hz_to_mel(1000.0) - 999.99) < 1.0
def test_roundtrip(self, mod):
for hz in (100.0, 500.0, 1000.0, 4000.0, 8000.0):
assert abs(mod.mel_to_hz(mod.hz_to_mel(hz)) - hz) < 0.01
def test_monotone_increasing(self, mod):
freqs = [100, 500, 1000, 2000, 4000, 8000]
mels = [mod.hz_to_mel(f) for f in freqs]
assert mels == sorted(mels)
# ── TestMelFilterbank ─────────────────────────────────────────────────────────
class TestMelFilterbank:
@pytest.fixture(scope="class")
def mod(self): return _import_mod()
def test_shape(self, mod):
fb = mod.build_mel_filterbank(SR, 512, 32)
assert fb.shape == (32, 257) # (n_mels, n_fft//2+1)
def test_nonnegative(self, mod):
fb = mod.build_mel_filterbank(SR, 512, 32)
assert (fb >= 0).all()
def test_each_filter_sums_positive(self, mod):
fb = mod.build_mel_filterbank(SR, 512, 32)
assert all(fb[m].sum() > 0 for m in range(32))
def test_custom_n_mels(self, mod):
fb = mod.build_mel_filterbank(SR, 512, 16)
assert fb.shape[0] == 16
def test_max_value_leq_one(self, mod):
fb = mod.build_mel_filterbank(SR, 512, 32)
assert fb.max() <= 1.0 + 1e-6
# ── TestMelSpectrogram ────────────────────────────────────────────────────────
class TestMelSpectrogram:
@pytest.fixture(scope="class")
def mod(self): return _import_mod()
def test_shape(self, mod):
s = _sine(440, SR)
spec = mod.compute_mel_spectrogram(s, SR, n_fft=512, n_mels=32, hop_length=256)
assert spec.shape[0] == 32
assert spec.shape[1] > 0
def test_silence_near_zero(self, mod):
spec = mod.compute_mel_spectrogram(_silence(SR), SR, n_fft=512, n_mels=32)
assert spec.mean() < 1e-6
def test_louder_has_higher_energy(self, mod):
quiet = mod.compute_mel_spectrogram(_sine(440, SR, amp=0.01), SR).mean()
loud = mod.compute_mel_spectrogram(_sine(440, SR, amp=0.5), SR).mean()
assert loud > quiet
def test_returns_array(self, mod):
spec = mod.compute_mel_spectrogram(_sine(440, SR), SR)
assert isinstance(spec, np.ndarray)
# ── TestExtractFeatures ───────────────────────────────────────────────────────
class TestExtractFeatures:
@pytest.fixture(scope="class")
def mod(self): return _import_mod()
def _feats(self, mod, samples):
return mod.extract_features(samples, SR, n_fft=512, n_mels=32)
def test_keys_present(self, mod):
f = self._feats(mod, _sine(440, SR))
for k in ("energy_db", "zcr", "mel_centroid", "mel_flatness",
"low_ratio", "high_ratio"):
assert k in f
def test_silence_low_energy(self, mod):
f = self._feats(mod, _silence(SR))
assert f["energy_db"] < -40.0
def test_silence_zero_zcr(self, mod):
f = self._feats(mod, _silence(SR))
assert f["zcr"] == 0.0
def test_sine_moderate_energy(self, mod):
f = self._feats(mod, _sine(440, SR, amp=0.1))
assert -40.0 < f["energy_db"] < 0.0
def test_ratios_sum_leq_one(self, mod):
f = self._feats(mod, _sine(440, SR))
assert f["low_ratio"] + f["high_ratio"] <= 1.0 + 1e-6
def test_ratios_nonnegative(self, mod):
f = self._feats(mod, _sine(440, SR))
assert f["low_ratio"] >= 0.0 and f["high_ratio"] >= 0.0
def test_flatness_in_unit_interval(self, mod):
f = self._feats(mod, _sine(440, SR))
assert 0.0 <= f["mel_flatness"] <= 1.0
def test_white_noise_high_flatness(self, mod):
f_noise = self._feats(mod, _white_noise(SR, amp=0.3))
f_sine = self._feats(mod, _sine(440, SR, amp=0.3))
# White noise should have higher spectral flatness than a pure tone
assert f_noise["mel_flatness"] > f_sine["mel_flatness"]
def test_empty_samples(self, mod):
f = mod.extract_features([], SR)
assert f["energy_db"] == 0.0
def test_louder_higher_energy_db(self, mod):
quiet = self._feats(mod, _sine(440, SR, amp=0.01))["energy_db"]
loud = self._feats(mod, _sine(440, SR, amp=0.5))["energy_db"]
assert loud > quiet
# ── TestClassifier ────────────────────────────────────────────────────────────
class TestClassifier:
@pytest.fixture(scope="class")
def mod(self): return _import_mod()
def _cls(self, mod, **feat_overrides):
base = {"energy_db": -20.0, "zcr": 0.05,
"mel_centroid": 0.4, "mel_flatness": 0.2,
"low_ratio": 0.4, "high_ratio": 0.2}
base.update(feat_overrides)
return mod.classify(base)
def test_silence(self, mod):
assert self._cls(mod, energy_db=-45.0) == "silence"
def test_silence_at_threshold(self, mod):
assert self._cls(mod, energy_db=-40.0) != "silence"
def test_alarm(self, mod):
assert self._cls(mod, energy_db=-20.0, zcr=0.15, high_ratio=0.40) == "alarm"
def test_alarm_requires_high_ratio(self, mod):
result = self._cls(mod, energy_db=-20.0, zcr=0.15, high_ratio=0.10)
assert result != "alarm"
def test_speech(self, mod):
assert self._cls(mod, energy_db=-25.0, zcr=0.08,
mel_flatness=0.20) == "speech"
def test_speech_zcr_too_low(self, mod):
result = self._cls(mod, energy_db=-25.0, zcr=0.005, mel_flatness=0.2)
assert result != "speech"
def test_speech_zcr_too_high(self, mod):
result = self._cls(mod, energy_db=-25.0, zcr=0.30, mel_flatness=0.2)
assert result != "speech"
def test_music(self, mod):
assert self._cls(mod, energy_db=-25.0, zcr=0.04,
mel_flatness=0.10) == "music"
def test_crowd(self, mod):
assert self._cls(mod, energy_db=-25.0, zcr=0.15,
mel_flatness=0.40) == "crowd"
def test_outdoor_catchall(self, mod):
# Moderate energy, mid ZCR, mid flatness → outdoor
result = self._cls(mod, energy_db=-35.0, zcr=0.06, mel_flatness=0.30)
assert result in mod.LABELS
def test_returns_valid_label(self, mod):
import random
rng = random.Random(0)
for _ in range(20):
f = {
"energy_db": rng.uniform(-60, 0),
"zcr": rng.uniform(0, 0.5),
"mel_centroid": rng.uniform(0, 1),
"mel_flatness": rng.uniform(0, 1),
"low_ratio": rng.uniform(0, 0.6),
"high_ratio": rng.uniform(0, 0.4),
}
assert mod.classify(f) in mod.LABELS
# ── TestAudioBuffer ───────────────────────────────────────────────────────────
class TestAudioBuffer:
@pytest.fixture(scope="class")
def mod(self): return _import_mod()
def test_no_window_until_full(self, mod):
buf = mod.AudioBuffer(window_samples=100)
assert buf.push([0.0] * 50) is None
def test_exact_fill_returns_window(self, mod):
buf = mod.AudioBuffer(window_samples=100)
w = buf.push([0.0] * 100)
assert w is not None and len(w) == 100
def test_overflow_carries_over(self, mod):
buf = mod.AudioBuffer(window_samples=100)
buf.push([0.0] * 100) # fills first window
w2 = buf.push([1.0] * 100) # fills second window
assert w2 is not None and len(w2) == 100
def test_partial_then_complete(self, mod):
buf = mod.AudioBuffer(window_samples=100)
buf.push([0.0] * 60)
w = buf.push([0.0] * 60)
assert w is not None and len(w) == 100
def test_clear_resets(self, mod):
buf = mod.AudioBuffer(window_samples=100)
buf.push([0.0] * 90)
buf.clear()
assert buf.push([0.0] * 90) is None
def test_window_contents_correct(self, mod):
buf = mod.AudioBuffer(window_samples=4)
w = buf.push([1.0, 2.0, 3.0, 4.0])
assert w == [1.0, 2.0, 3.0, 4.0]
# ── TestNodeSrc ───────────────────────────────────────────────────────────────
class TestNodeSrc:
@pytest.fixture(scope="class")
def src(self): return _read_src("saltybot_social/ambient_sound_node.py")
def test_class_defined(self, src): assert "class AmbientSoundNode" in src
def test_audio_buffer(self, src): assert "class AudioBuffer" in src
def test_extract_features(self, src): assert "def extract_features" in src
def test_classify_fn(self, src): assert "def classify" in src
def test_mel_spectrogram(self, src): assert "compute_mel_spectrogram" in src
def test_mel_filterbank(self, src): assert "build_mel_filterbank" in src
def test_hz_to_mel(self, src): assert "hz_to_mel" in src
def test_labels_tuple(self, src): assert "LABELS" in src
def test_all_labels(self, src):
for label in ("silence", "speech", "music", "crowd", "outdoor", "alarm"):
assert label in src
def test_topic_pub(self, src): assert '"/saltybot/ambient_sound"' in src
def test_topic_sub(self, src): assert '"/social/speech/audio_raw"' in src
def test_window_param(self, src): assert '"window_s"' in src
def test_n_mels_param(self, src): assert '"n_mels"' in src
def test_silence_param(self, src): assert '"silence_db"' in src
def test_alarm_param(self, src): assert '"alarm_db_min"' in src
def test_speech_param(self, src): assert '"speech_zcr_min"' in src
def test_music_param(self, src): assert '"music_zcr_max"' in src
def test_crowd_param(self, src): assert '"crowd_zcr_min"' in src
def test_string_pub(self, src): assert "String" in src
def test_uint8_sub(self, src): assert "UInt8MultiArray" in src
def test_issue_tag(self, src): assert "252" in src
def test_main(self, src): assert "def main" in src
def test_numpy_optional(self, src): assert "_NUMPY" in src
# ── TestConfig ────────────────────────────────────────────────────────────────
class TestConfig:
@pytest.fixture(scope="class")
def cfg(self): return _read_src("config/ambient_sound_params.yaml")
@pytest.fixture(scope="class")
def setup(self): return _read_src("setup.py")
def test_node_name(self, cfg): assert "ambient_sound_node:" in cfg
def test_window_s(self, cfg): assert "window_s" in cfg
def test_n_mels(self, cfg): assert "n_mels" in cfg
def test_silence_db(self, cfg): assert "silence_db" in cfg
def test_alarm_params(self, cfg): assert "alarm_db_min" in cfg
def test_speech_params(self, cfg): assert "speech_zcr_min" in cfg
def test_music_params(self, cfg): assert "music_zcr_max" in cfg
def test_crowd_params(self, cfg): assert "crowd_zcr_min" in cfg
def test_defaults_present(self, cfg): assert "-40.0" in cfg and "0.12" in cfg
def test_entry_point(self, setup):
assert "ambient_sound_node = saltybot_social.ambient_sound_node:main" in setup

View File

@ -0,0 +1,676 @@
"""test_face_track_servo.py — Offline tests for face_track_servo_node (Issue #279).
Stubs out rclpy and saltybot_social_msgs so tests run without a ROS install.
"""
import importlib
import importlib.util
import math
import sys
import time
import types
import unittest
# ── ROS2 / message stubs ──────────────────────────────────────────────────────
def _make_ros_stubs():
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
"std_msgs", "std_msgs.msg",
"saltybot_social_msgs", "saltybot_social_msgs.msg"):
if mod_name not in sys.modules:
sys.modules[mod_name] = types.ModuleType(mod_name)
class _Node:
def __init__(self, name="node"):
self._name = name
if not hasattr(self, "_params"):
self._params = {}
self._pubs = {}
self._subs = {}
self._timers = []
self._logs = []
def declare_parameter(self, name, default):
if name not in self._params:
self._params[name] = default
def get_parameter(self, name):
class _P:
def __init__(self, v): self.value = v
return _P(self._params.get(name))
def create_publisher(self, msg_type, topic, qos):
pub = _FakePub()
self._pubs[topic] = pub
return pub
def create_subscription(self, msg_type, topic, cb, qos):
self._subs[topic] = cb
return object()
def create_timer(self, period, cb):
self._timers.append(cb)
return object()
def get_logger(self):
node = self
class _L:
def info(self, m): node._logs.append(("INFO", m))
def warn(self, m): node._logs.append(("WARN", m))
def error(self, m): node._logs.append(("ERROR", m))
return _L()
def destroy_node(self): pass
class _FakePub:
def __init__(self):
self.msgs = []
def publish(self, msg):
self.msgs.append(msg)
class _QoSProfile:
def __init__(self, depth=10): self.depth = depth
class _Float32:
def __init__(self): self.data = 0.0
class _FaceDetection:
def __init__(self, face_id=0, bbox_x=0.4, bbox_y=0.4,
bbox_w=0.2, bbox_h=0.2, confidence=1.0):
self.face_id = face_id
self.bbox_x = bbox_x
self.bbox_y = bbox_y
self.bbox_w = bbox_w
self.bbox_h = bbox_h
self.confidence = confidence
self.person_name = ""
class _FaceDetectionArray:
def __init__(self, faces=None):
self.faces = faces or []
# rclpy
rclpy_mod = sys.modules["rclpy"]
rclpy_mod.init = lambda args=None: None
rclpy_mod.spin = lambda node: None
rclpy_mod.shutdown = lambda: None
sys.modules["rclpy.node"].Node = _Node
sys.modules["rclpy.qos"].QoSProfile = _QoSProfile
sys.modules["std_msgs.msg"].Float32 = _Float32
msgs = sys.modules["saltybot_social_msgs.msg"]
msgs.FaceDetection = _FaceDetection
msgs.FaceDetectionArray = _FaceDetectionArray
return _Node, _FakePub, _FaceDetection, _FaceDetectionArray, _Float32
_Node, _FakePub, _FaceDetection, _FaceDetectionArray, _Float32 = _make_ros_stubs()
# ── Module loader ─────────────────────────────────────────────────────────────
_SRC = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/saltybot_social/face_track_servo_node.py"
)
def _load_mod():
spec = importlib.util.spec_from_file_location("face_track_servo_testmod", _SRC)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
def _make_node(mod, **kwargs):
"""Instantiate FaceTrackServoNode with optional param overrides."""
node = mod.FaceTrackServoNode.__new__(mod.FaceTrackServoNode)
defaults = {
"kp_pan": 1.5,
"ki_pan": 0.1,
"kd_pan": 0.05,
"kp_tilt": 1.2,
"ki_tilt": 0.1,
"kd_tilt": 0.04,
"fov_h_deg": 60.0,
"fov_v_deg": 45.0,
"pan_limit_deg": 90.0,
"tilt_limit_deg": 30.0,
"pan_vel_limit": 45.0,
"tilt_vel_limit": 30.0,
"windup_limit": 15.0,
"dead_zone": 0.02,
"control_rate": 20.0,
"lost_timeout_s": 1.5,
"return_rate_deg_s": 10.0,
"faces_topic": "/social/faces/detected",
}
defaults.update(kwargs)
node._params = dict(defaults)
mod.FaceTrackServoNode.__init__(node)
return node
def _face(bbox_x=0.4, bbox_y=0.4, bbox_w=0.2, bbox_h=0.2, face_id=0):
return _FaceDetection(face_id=face_id, bbox_x=bbox_x, bbox_y=bbox_y,
bbox_w=bbox_w, bbox_h=bbox_h)
def _centered_face():
"""A face perfectly centered in the frame."""
return _face(bbox_x=0.4, bbox_y=0.4, bbox_w=0.2, bbox_h=0.2)
# ── Tests: pure helpers ───────────────────────────────────────────────────────
class TestClamp(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mod = _load_mod()
def test_within(self):
self.assertEqual(self.mod.clamp(5.0, 0.0, 10.0), 5.0)
def test_below(self):
self.assertEqual(self.mod.clamp(-5.0, 0.0, 10.0), 0.0)
def test_above(self):
self.assertEqual(self.mod.clamp(15.0, 0.0, 10.0), 10.0)
def test_negative_range(self):
self.assertEqual(self.mod.clamp(-50.0, -45.0, 45.0), -45.0)
class TestBboxArea(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mod = _load_mod()
def test_area(self):
f = _face(bbox_w=0.3, bbox_h=0.4)
self.assertAlmostEqual(self.mod.bbox_area(f), 0.12)
def test_zero(self):
f = _face(bbox_w=0.0, bbox_h=0.2)
self.assertAlmostEqual(self.mod.bbox_area(f), 0.0)
class TestPickClosestFace(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mod = _load_mod()
def test_empty(self):
self.assertIsNone(self.mod.pick_closest_face([]))
def test_single(self):
f = _face(bbox_w=0.2, bbox_h=0.2)
self.assertIs(self.mod.pick_closest_face([f]), f)
def test_picks_largest_area(self):
small = _face(bbox_w=0.1, bbox_h=0.1)
big = _face(bbox_w=0.4, bbox_h=0.4)
self.assertIs(self.mod.pick_closest_face([small, big]), big)
self.assertIs(self.mod.pick_closest_face([big, small]), big)
def test_three_faces(self):
faces = [_face(bbox_w=0.1, bbox_h=0.1),
_face(bbox_w=0.5, bbox_h=0.5),
_face(bbox_w=0.2, bbox_h=0.2)]
self.assertIs(self.mod.pick_closest_face(faces), faces[1])
class TestFaceImageError(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mod = _load_mod()
def test_centered_face_zero_error(self):
f = _face(bbox_x=0.4, bbox_y=0.4, bbox_w=0.2, bbox_h=0.2)
pan, tilt = self.mod.face_image_error(f, 60.0, 45.0)
self.assertAlmostEqual(pan, 0.0)
self.assertAlmostEqual(tilt, 0.0)
def test_right_of_centre(self):
# cx = 0.7 + 0.1 = 0.8, error = 0.3 * 60 = 18°
f = _face(bbox_x=0.7, bbox_y=0.4, bbox_w=0.2, bbox_h=0.2)
pan, _ = self.mod.face_image_error(f, 60.0, 45.0)
self.assertAlmostEqual(pan, 18.0)
def test_left_of_centre(self):
# cx = 0.1 + 0.1 = 0.2, error = -0.3 * 60 = -18°
f = _face(bbox_x=0.1, bbox_y=0.4, bbox_w=0.2, bbox_h=0.2)
pan, _ = self.mod.face_image_error(f, 60.0, 45.0)
self.assertAlmostEqual(pan, -18.0)
def test_below_centre(self):
# cy = 0.7 + 0.1 = 0.8, error = 0.3 * 45 = 13.5°
f = _face(bbox_x=0.4, bbox_y=0.7, bbox_w=0.2, bbox_h=0.2)
_, tilt = self.mod.face_image_error(f, 60.0, 45.0)
self.assertAlmostEqual(tilt, 13.5)
def test_above_centre(self):
# cy = 0.1 + 0.1 = 0.2, error = -0.3 * 45 = -13.5°
f = _face(bbox_x=0.4, bbox_y=0.1, bbox_w=0.2, bbox_h=0.2)
_, tilt = self.mod.face_image_error(f, 60.0, 45.0)
self.assertAlmostEqual(tilt, -13.5)
class TestStepTowardZero(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mod = _load_mod()
def test_positive_large(self):
result = self.mod._step_toward_zero(10.0, 1.0)
self.assertAlmostEqual(result, 9.0)
def test_negative_large(self):
result = self.mod._step_toward_zero(-10.0, 1.0)
self.assertAlmostEqual(result, -9.0)
def test_smaller_than_step(self):
result = self.mod._step_toward_zero(0.5, 1.0)
self.assertAlmostEqual(result, 0.0)
def test_exact_step(self):
result = self.mod._step_toward_zero(1.0, 1.0)
self.assertAlmostEqual(result, 0.0)
def test_zero(self):
result = self.mod._step_toward_zero(0.0, 1.0)
self.assertAlmostEqual(result, 0.0)
# ── Tests: PIDController ──────────────────────────────────────────────────────
class TestPIDController(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mod = _load_mod()
def _pid(self, kp=1.0, ki=0.0, kd=0.0, vel_limit=100.0, windup=100.0):
return self.mod.PIDController(kp, ki, kd, vel_limit, windup)
def test_proportional_only(self):
pid = self._pid(kp=2.0)
out = pid.update(5.0, 0.1)
self.assertAlmostEqual(out, 10.0)
def test_zero_error_zero_output(self):
pid = self._pid(kp=5.0)
self.assertAlmostEqual(pid.update(0.0, 0.1), 0.0)
def test_integral_accumulates(self):
pid = self._pid(kp=0.0, ki=1.0)
pid.update(1.0, 0.1) # integral = 0.1
out = pid.update(1.0, 0.1) # integral = 0.2, output = 0.2
self.assertAlmostEqual(out, 0.2, places=5)
def test_derivative_first_tick_zero(self):
pid = self._pid(kp=0.0, kd=1.0)
out = pid.update(10.0, 0.1)
self.assertAlmostEqual(out, 0.0) # first tick: derivative = 0
def test_derivative_second_tick(self):
pid = self._pid(kp=0.0, kd=1.0)
pid.update(0.0, 0.1) # first tick
out = pid.update(10.0, 0.1) # de/dt = 10/0.1 = 100
self.assertAlmostEqual(out, 100.0)
def test_velocity_clamped(self):
pid = self._pid(kp=100.0, vel_limit=10.0)
out = pid.update(5.0, 0.1)
self.assertAlmostEqual(out, 10.0)
def test_velocity_clamped_negative(self):
pid = self._pid(kp=100.0, vel_limit=10.0)
out = pid.update(-5.0, 0.1)
self.assertAlmostEqual(out, -10.0)
def test_antiwindup(self):
pid = self._pid(kp=0.0, ki=1.0, windup=5.0)
for _ in range(100):
pid.update(1.0, 0.1) # would accumulate 10, clamped at 5
out = pid.update(0.0, 0.1)
self.assertAlmostEqual(out, 5.0, places=3)
def test_reset_clears_integral(self):
pid = self._pid(ki=1.0)
pid.update(1.0, 1.0)
pid.reset()
out = pid.update(0.0, 0.1)
self.assertAlmostEqual(out, 0.0)
def test_reset_clears_derivative(self):
pid = self._pid(kp=0.0, kd=1.0)
pid.update(10.0, 0.1) # sets prev_error
pid.reset()
out = pid.update(10.0, 0.1) # after reset, first tick = 0 derivative
self.assertAlmostEqual(out, 0.0)
def test_zero_dt_returns_zero(self):
pid = self._pid(kp=10.0)
self.assertAlmostEqual(pid.update(5.0, 0.0), 0.0)
def test_negative_dt_returns_zero(self):
pid = self._pid(kp=10.0)
self.assertAlmostEqual(pid.update(5.0, -0.1), 0.0)
# ── Tests: node initialisation ────────────────────────────────────────────────
class TestNodeInit(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mod = _load_mod()
def test_instantiates(self):
node = _make_node(self.mod)
self.assertIsNotNone(node)
def test_pan_pub(self):
node = _make_node(self.mod)
self.assertIn("/saltybot/head_pan", node._pubs)
def test_tilt_pub(self):
node = _make_node(self.mod)
self.assertIn("/saltybot/head_tilt", node._pubs)
def test_faces_sub(self):
node = _make_node(self.mod)
self.assertIn("/social/faces/detected", node._subs)
def test_timer_registered(self):
node = _make_node(self.mod)
self.assertGreater(len(node._timers), 0)
def test_initial_pan_zero(self):
node = _make_node(self.mod)
self.assertAlmostEqual(node._pan_cmd, 0.0)
def test_initial_tilt_zero(self):
node = _make_node(self.mod)
self.assertAlmostEqual(node._tilt_cmd, 0.0)
def test_custom_fov(self):
node = _make_node(self.mod, fov_h_deg=90.0)
self.assertAlmostEqual(node._fov_h, 90.0)
def test_custom_pan_limit(self):
node = _make_node(self.mod, pan_limit_deg=45.0)
self.assertAlmostEqual(node._pan_limit, 45.0)
# ── Tests: face callback ──────────────────────────────────────────────────────
class TestFaceCallback(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mod = _load_mod()
def setUp(self):
self.node = _make_node(self.mod)
def test_empty_msg_no_face(self):
self.node._on_faces(_FaceDetectionArray([]))
self.assertIsNone(self.node._latest_face)
def test_single_face_stored(self):
f = _centered_face()
self.node._on_faces(_FaceDetectionArray([f]))
self.assertIs(self.node._latest_face, f)
def test_closest_face_picked(self):
small = _face(bbox_w=0.1, bbox_h=0.1, face_id=1)
big = _face(bbox_w=0.5, bbox_h=0.5, face_id=2)
self.node._on_faces(_FaceDetectionArray([small, big]))
self.assertIs(self.node._latest_face, big)
def test_timestamp_updated_on_face(self):
before = time.monotonic()
f = _centered_face()
self.node._on_faces(_FaceDetectionArray([f]))
self.assertGreaterEqual(self.node._last_face_t, before)
def test_timestamp_not_updated_on_empty(self):
self.node._last_face_t = 0.0
self.node._on_faces(_FaceDetectionArray([]))
self.assertEqual(self.node._last_face_t, 0.0)
# ── Tests: control loop ───────────────────────────────────────────────────────
class TestControlLoop(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mod = _load_mod()
def setUp(self):
self.node = _make_node(self.mod, dead_zone=0.0,
ki_pan=0.0, kd_pan=0.0,
ki_tilt=0.0, kd_tilt=0.0)
self.pan_pub = self.node._pubs["/saltybot/head_pan"]
self.tilt_pub = self.node._pubs["/saltybot/head_tilt"]
def _tick(self, dt=0.05):
self.node._last_tick = time.monotonic() - dt
self.node._control_cb()
def test_no_face_publishes_zero_initially(self):
self._tick()
self.assertAlmostEqual(self.pan_pub.msgs[-1].data, 0.0)
self.assertAlmostEqual(self.tilt_pub.msgs[-1].data, 0.0)
def test_centered_face_minimal_movement(self):
f = _centered_face() # cx=cy=0.5, error=0
self.node._on_faces(_FaceDetectionArray([f]))
self.node._last_face_t = time.monotonic()
self._tick()
# With dead_zone=0 and error=0, pid output=0, cmd stays 0
self.assertAlmostEqual(self.pan_pub.msgs[-1].data, 0.0, places=4)
self.assertAlmostEqual(self.tilt_pub.msgs[-1].data, 0.0, places=4)
def test_right_face_pans_right(self):
# Face right of centre → positive pan error → pan_cmd increases
f = _face(bbox_x=0.7, bbox_y=0.4, bbox_w=0.2, bbox_h=0.2)
self.node._on_faces(_FaceDetectionArray([f]))
self.node._last_face_t = time.monotonic()
self._tick()
self.assertGreater(self.pan_pub.msgs[-1].data, 0.0)
def test_left_face_pans_left(self):
f = _face(bbox_x=0.1, bbox_y=0.4, bbox_w=0.2, bbox_h=0.2)
self.node._on_faces(_FaceDetectionArray([f]))
self.node._last_face_t = time.monotonic()
self._tick()
self.assertLess(self.pan_pub.msgs[-1].data, 0.0)
def test_low_face_tilts_down(self):
f = _face(bbox_x=0.4, bbox_y=0.7, bbox_w=0.2, bbox_h=0.2)
self.node._on_faces(_FaceDetectionArray([f]))
self.node._last_face_t = time.monotonic()
self._tick()
self.assertGreater(self.tilt_pub.msgs[-1].data, 0.0)
def test_high_face_tilts_up(self):
f = _face(bbox_x=0.4, bbox_y=0.1, bbox_w=0.2, bbox_h=0.2)
self.node._on_faces(_FaceDetectionArray([f]))
self.node._last_face_t = time.monotonic()
self._tick()
self.assertLess(self.tilt_pub.msgs[-1].data, 0.0)
def test_pan_clamped_to_limit(self):
node = _make_node(self.mod, kp_pan=1000.0, ki_pan=0.0, kd_pan=0.0,
pan_limit_deg=45.0, pan_vel_limit=9999.0,
dead_zone=0.0)
pub = node._pubs["/saltybot/head_pan"]
f = _face(bbox_x=0.9, bbox_y=0.4, bbox_w=0.1, bbox_h=0.2)
node._on_faces(_FaceDetectionArray([f]))
node._last_face_t = time.monotonic()
# Run many ticks to accumulate
for _ in range(50):
node._last_tick = time.monotonic() - 0.05
node._control_cb()
self.assertLessEqual(pub.msgs[-1].data, 45.0)
def test_tilt_clamped_to_limit(self):
node = _make_node(self.mod, kp_tilt=1000.0, ki_tilt=0.0, kd_tilt=0.0,
tilt_limit_deg=20.0, tilt_vel_limit=9999.0,
dead_zone=0.0)
pub = node._pubs["/saltybot/head_tilt"]
f = _face(bbox_x=0.4, bbox_y=0.9, bbox_w=0.2, bbox_h=0.1)
node._on_faces(_FaceDetectionArray([f]))
node._last_face_t = time.monotonic()
for _ in range(50):
node._last_tick = time.monotonic() - 0.05
node._control_cb()
self.assertLessEqual(pub.msgs[-1].data, 20.0)
def test_lost_face_returns_to_zero(self):
node = _make_node(self.mod, kp_pan=10.0, ki_pan=0.0, kd_pan=0.0,
dead_zone=0.0, return_rate_deg_s=90.0,
lost_timeout_s=0.01)
pub = node._pubs["/saltybot/head_pan"]
f = _face(bbox_x=0.7, bbox_y=0.4, bbox_w=0.2, bbox_h=0.2)
node._on_faces(_FaceDetectionArray([f]))
node._last_face_t = time.monotonic()
# Build up some pan
for _ in range(5):
node._last_tick = time.monotonic() - 0.05
node._control_cb()
# Expire face timeout
node._last_face_t = time.monotonic() - 10.0
for _ in range(20):
node._last_tick = time.monotonic() - 0.05
node._control_cb()
self.assertAlmostEqual(pub.msgs[-1].data, 0.0, places=3)
def test_publishes_every_tick(self):
for _ in range(3):
self._tick()
self.assertEqual(len(self.pan_pub.msgs), 3)
self.assertEqual(len(self.tilt_pub.msgs), 3)
def test_dead_zone_suppresses_small_error(self):
node = _make_node(self.mod, kp_pan=100.0, ki_pan=0.0, kd_pan=0.0,
dead_zone=0.1, fov_h_deg=60.0)
pub = node._pubs["/saltybot/head_pan"]
# Face 2% right of centre — within dead_zone=10% of frame
f = _face(bbox_x=0.42, bbox_y=0.4, bbox_w=0.2, bbox_h=0.2)
node._on_faces(_FaceDetectionArray([f]))
node._last_face_t = time.monotonic()
node._last_tick = time.monotonic() - 0.05
node._control_cb()
self.assertAlmostEqual(pub.msgs[-1].data, 0.0, places=4)
# ── Tests: source-level checks ────────────────────────────────────────────────
class TestNodeSrc(unittest.TestCase):
@classmethod
def setUpClass(cls):
with open(_SRC) as f:
cls.src = f.read()
def test_issue_tag(self):
self.assertIn("#279", self.src)
def test_pan_topic(self):
self.assertIn("/saltybot/head_pan", self.src)
def test_tilt_topic(self):
self.assertIn("/saltybot/head_tilt", self.src)
def test_faces_topic(self):
self.assertIn("/social/faces/detected", self.src)
def test_pid_class(self):
self.assertIn("class PIDController", self.src)
def test_kp_param(self):
self.assertIn("kp_pan", self.src)
def test_ki_param(self):
self.assertIn("ki_pan", self.src)
def test_kd_param(self):
self.assertIn("kd_pan", self.src)
def test_fov_param(self):
self.assertIn("fov_h_deg", self.src)
def test_pan_limit_param(self):
self.assertIn("pan_limit_deg", self.src)
def test_dead_zone_param(self):
self.assertIn("dead_zone", self.src)
def test_pick_closest_face(self):
self.assertIn("pick_closest_face", self.src)
def test_main_defined(self):
self.assertIn("def main", self.src)
def test_antiwindup(self):
self.assertIn("windup", self.src)
def test_threading_lock(self):
self.assertIn("threading.Lock", self.src)
class TestConfig(unittest.TestCase):
_CONFIG = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/config/face_track_servo_params.yaml"
)
_LAUNCH = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/launch/face_track_servo.launch.py"
)
_SETUP = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/setup.py"
)
def test_config_exists(self):
import os; self.assertTrue(os.path.exists(self._CONFIG))
def test_config_kp_pan(self):
with open(self._CONFIG) as f: c = f.read()
self.assertIn("kp_pan", c)
def test_config_fov(self):
with open(self._CONFIG) as f: c = f.read()
self.assertIn("fov_h_deg", c)
def test_config_pan_limit(self):
with open(self._CONFIG) as f: c = f.read()
self.assertIn("pan_limit_deg", c)
def test_config_dead_zone(self):
with open(self._CONFIG) as f: c = f.read()
self.assertIn("dead_zone", c)
def test_launch_exists(self):
import os; self.assertTrue(os.path.exists(self._LAUNCH))
def test_launch_kp_pan_arg(self):
with open(self._LAUNCH) as f: c = f.read()
self.assertIn("kp_pan", c)
def test_launch_pan_limit_arg(self):
with open(self._LAUNCH) as f: c = f.read()
self.assertIn("pan_limit_deg", c)
def test_entry_point(self):
with open(self._SETUP) as f: c = f.read()
self.assertIn("face_track_servo_node", c)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,472 @@
"""test_greeting_trigger.py -- Offline tests for greeting_trigger_node (Issue #270).
Stubs out rclpy and saltybot_social_msgs so tests run without a ROS install.
"""
import importlib
import json
import sys
import time
import types
import unittest
# ── ROS2 / message stubs ──────────────────────────────────────────────────────
def _make_ros_stubs():
"""Install minimal stubs for rclpy and message packages."""
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
"std_msgs", "std_msgs.msg",
"saltybot_social_msgs", "saltybot_social_msgs.msg"):
sys.modules[mod_name] = types.ModuleType(mod_name)
class _Node:
def __init__(self, name):
self._name = name
# Preserve _params if pre-set by _make_node (super().__init__() is
# called from GreetingTriggerNode.__init__, so don't reset here)
if not hasattr(self, '_params'):
self._params = {}
self._pubs = {}
self._subs = {}
self._logs = []
def declare_parameter(self, name, default):
# Don't overwrite values pre-set by _make_node
if name not in self._params:
self._params[name] = default
def get_parameter(self, name):
class _P:
def __init__(self, v):
self.value = v
return _P(self._params[name])
def create_publisher(self, msg_type, topic, qos):
pub = _FakePub()
self._pubs[topic] = pub
return pub
def create_subscription(self, msg_type, topic, cb, qos):
self._subs[topic] = cb
return object()
def get_logger(self):
node = self
class _L:
def info(self, m): node._logs.append(("INFO", m))
def warn(self, m): node._logs.append(("WARN", m))
def error(self, m): node._logs.append(("ERROR", m))
return _L()
def destroy_node(self): pass
class _FakePub:
def __init__(self):
self.msgs = []
def publish(self, msg):
self.msgs.append(msg)
class _QoSProfile:
def __init__(self, depth=10): self.depth = depth
class _String:
def __init__(self): self.data = ""
# rclpy
rclpy_mod = sys.modules["rclpy"]
rclpy_mod.init = lambda args=None: None
rclpy_mod.spin = lambda node: None
rclpy_mod.shutdown = lambda: None
# rclpy.node
sys.modules["rclpy.node"].Node = _Node
# rclpy.qos
sys.modules["rclpy.qos"].QoSProfile = _QoSProfile
# std_msgs.msg
sys.modules["std_msgs.msg"].String = _String
# saltybot_social_msgs.msg (FaceDetectionArray + PersonStateArray)
class _FaceDetection:
def __init__(self, face_id=0, person_name="", confidence=1.0):
self.face_id = face_id
self.person_name = person_name
self.confidence = confidence
class _FaceDetectionArray:
def __init__(self, faces=None):
self.faces = faces or []
class _PersonState:
def __init__(self, face_id=0, distance=0.0):
self.face_id = face_id
self.distance = distance
class _PersonStateArray:
def __init__(self, persons=None):
self.persons = persons or []
msgs = sys.modules["saltybot_social_msgs.msg"]
msgs.FaceDetection = _FaceDetection
msgs.FaceDetectionArray = _FaceDetectionArray
msgs.PersonState = _PersonState
msgs.PersonStateArray = _PersonStateArray
return _Node, _FakePub, _QoSProfile, _String, _FaceDetection, _FaceDetectionArray, _PersonState, _PersonStateArray
_Node, _FakePub, _QoSProfile, _String, _FaceDetection, _FaceDetectionArray, _PersonState, _PersonStateArray = _make_ros_stubs()
# ── Load module under test ────────────────────────────────────────────────────
_SRC = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/saltybot_social/greeting_trigger_node.py"
)
def _load_mod():
spec = importlib.util.spec_from_file_location("greeting_trigger_node_testmod", _SRC)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
# ── Helpers ───────────────────────────────────────────────────────────────────
def _make_node(mod, **kwargs):
"""Instantiate GreetingTriggerNode with overridden parameters."""
node = mod.GreetingTriggerNode.__new__(mod.GreetingTriggerNode)
# Pre-populate _params BEFORE __init__ so super().__init__() (which calls
# _Node.__init__) sees them and skips reset due to hasattr guard.
defaults = {
"proximity_m": 2.0,
"cooldown_s": 300.0,
"unknown_distance": 0.0,
"faces_topic": "/social/faces/detected",
"states_topic": "/social/person_states",
}
defaults.update(kwargs)
node._params = dict(defaults)
mod.GreetingTriggerNode.__init__(node)
return node
def _face_msg(faces):
return _FaceDetectionArray(faces=faces)
def _state_msg(persons):
return _PersonStateArray(persons=persons)
# ── Test suites ───────────────────────────────────────────────────────────────
class TestNodeInit(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mod = _load_mod()
def test_imports_cleanly(self):
self.assertTrue(hasattr(self.mod, "GreetingTriggerNode"))
def test_default_proximity(self):
node = _make_node(self.mod)
self.assertEqual(node._proximity, 2.0)
def test_default_cooldown(self):
node = _make_node(self.mod)
self.assertEqual(node._cooldown, 300.0)
def test_default_unknown_distance(self):
node = _make_node(self.mod)
self.assertEqual(node._unknown_dist, 0.0)
def test_pub_topic(self):
node = _make_node(self.mod)
self.assertIn("/saltybot/greeting_trigger", node._pubs)
def test_subs_registered(self):
node = _make_node(self.mod)
self.assertIn("/social/faces/detected", node._subs)
self.assertIn("/social/person_states", node._subs)
def test_initial_caches_empty(self):
node = _make_node(self.mod)
self.assertEqual(node._distance_cache, {})
self.assertEqual(node._last_greeted, {})
class TestDistanceCache(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mod = _load_mod()
def setUp(self):
self.node = _make_node(self.mod)
def test_state_updates_cache(self):
ps = _PersonState(face_id=1, distance=1.5)
self.node._on_person_states(_state_msg([ps]))
self.assertAlmostEqual(self.node._distance_cache[1], 1.5)
def test_multiple_states_cached(self):
persons = [_PersonState(face_id=i, distance=float(i)) for i in range(5)]
self.node._on_person_states(_state_msg(persons))
for i in range(5):
self.assertAlmostEqual(self.node._distance_cache[i], float(i))
def test_state_update_overwrites(self):
self.node._on_person_states(_state_msg([_PersonState(face_id=1, distance=3.0)]))
self.node._on_person_states(_state_msg([_PersonState(face_id=1, distance=1.0)]))
self.assertAlmostEqual(self.node._distance_cache[1], 1.0)
def test_negative_face_id_ignored(self):
self.node._on_person_states(_state_msg([_PersonState(face_id=-1, distance=1.0)]))
self.assertNotIn(-1, self.node._distance_cache)
def test_zero_distance_cached(self):
self.node._on_person_states(_state_msg([_PersonState(face_id=5, distance=0.0)]))
self.assertAlmostEqual(self.node._distance_cache[5], 0.0)
class TestGreetingTrigger(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mod = _load_mod()
def setUp(self):
self.node = _make_node(self.mod, proximity_m=2.0, cooldown_s=300.0)
self.pub = self.node._pubs["/saltybot/greeting_trigger"]
def _inject_distance(self, face_id, distance):
self.node._on_person_states(_state_msg([_PersonState(face_id=face_id, distance=distance)]))
def test_triggers_within_proximity(self):
self._inject_distance(1, 1.5)
self.node._on_faces(_face_msg([_FaceDetection(face_id=1, person_name="alice")]))
self.assertEqual(len(self.pub.msgs), 1)
def test_no_trigger_beyond_proximity(self):
self._inject_distance(2, 3.0)
self.node._on_faces(_face_msg([_FaceDetection(face_id=2, person_name="bob")]))
self.assertEqual(len(self.pub.msgs), 0)
def test_trigger_at_exact_proximity(self):
self._inject_distance(3, 2.0)
self.node._on_faces(_face_msg([_FaceDetection(face_id=3, person_name="carol")]))
self.assertEqual(len(self.pub.msgs), 1)
def test_no_trigger_just_beyond(self):
self._inject_distance(4, 2.001)
self.node._on_faces(_face_msg([_FaceDetection(face_id=4, person_name="dave")]))
self.assertEqual(len(self.pub.msgs), 0)
def test_cooldown_suppresses_retrigger(self):
self._inject_distance(5, 1.0)
face = _FaceDetection(face_id=5, person_name="eve")
self.node._on_faces(_face_msg([face]))
self.node._on_faces(_face_msg([face])) # second call in cooldown
self.assertEqual(len(self.pub.msgs), 1)
def test_cooldown_per_face_id(self):
self._inject_distance(6, 1.0)
self._inject_distance(7, 1.0)
self.node._on_faces(_face_msg([_FaceDetection(face_id=6, person_name="f")]))
self.node._on_faces(_face_msg([_FaceDetection(face_id=7, person_name="g")]))
self.assertEqual(len(self.pub.msgs), 2)
def test_expired_cooldown_retrigers(self):
self._inject_distance(8, 1.0)
face = _FaceDetection(face_id=8, person_name="hank")
self.node._on_faces(_face_msg([face]))
# Manually expire the cooldown
self.node._last_greeted[8] = time.monotonic() - 400.0
self.node._on_faces(_face_msg([face]))
self.assertEqual(len(self.pub.msgs), 2)
def test_unknown_face_uses_unknown_distance(self):
# unknown_distance=0.0 → should trigger (0.0 <= 2.0)
node = _make_node(self.mod, unknown_distance=0.0)
pub = node._pubs["/saltybot/greeting_trigger"]
node._on_faces(_face_msg([_FaceDetection(face_id=99, person_name="stranger")]))
self.assertEqual(len(pub.msgs), 1)
def test_unknown_face_large_distance_no_trigger(self):
# unknown_distance=10.0 → should NOT trigger
node = _make_node(self.mod, unknown_distance=10.0)
pub = node._pubs["/saltybot/greeting_trigger"]
node._on_faces(_face_msg([_FaceDetection(face_id=100, person_name="far")]))
self.assertEqual(len(pub.msgs), 0)
def test_multiple_faces_triggers_each_within_range(self):
self._inject_distance(10, 1.0)
self._inject_distance(11, 3.0) # out of range
faces = [
_FaceDetection(face_id=10, person_name="near"),
_FaceDetection(face_id=11, person_name="far"),
]
self.node._on_faces(_face_msg(faces))
self.assertEqual(len(self.pub.msgs), 1)
def test_empty_face_array_no_trigger(self):
self.node._on_faces(_face_msg([]))
self.assertEqual(len(self.pub.msgs), 0)
class TestPayload(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mod = _load_mod()
def setUp(self):
self.node = _make_node(self.mod)
self.pub = self.node._pubs["/saltybot/greeting_trigger"]
def _trigger(self, face_id=1, person_name="alice", distance=1.5):
self.node._on_person_states(_state_msg([_PersonState(face_id=face_id, distance=distance)]))
self.node._on_faces(_face_msg([_FaceDetection(face_id=face_id, person_name=person_name)]))
def test_payload_is_json(self):
self._trigger()
payload = json.loads(self.pub.msgs[0].data)
self.assertIsInstance(payload, dict)
def test_payload_face_id(self):
self._trigger(face_id=42)
payload = json.loads(self.pub.msgs[0].data)
self.assertEqual(payload["face_id"], 42)
def test_payload_person_name(self):
self._trigger(person_name="zara")
payload = json.loads(self.pub.msgs[0].data)
self.assertEqual(payload["person_name"], "zara")
def test_payload_distance(self):
self._trigger(distance=1.234)
payload = json.loads(self.pub.msgs[0].data)
self.assertAlmostEqual(payload["distance_m"], 1.234, places=2)
def test_payload_has_ts(self):
self._trigger()
payload = json.loads(self.pub.msgs[0].data)
self.assertIn("ts", payload)
self.assertIsInstance(payload["ts"], float)
def test_ts_is_recent(self):
before = time.time()
self._trigger()
after = time.time()
payload = json.loads(self.pub.msgs[0].data)
self.assertGreaterEqual(payload["ts"], before)
self.assertLessEqual(payload["ts"], after + 1.0)
class TestNodeSrc(unittest.TestCase):
"""Source-level checks — verify node structure without instantiation."""
@classmethod
def setUpClass(cls):
with open(_SRC) as f:
cls.src = f.read()
def test_issue_tag(self):
self.assertIn("#270", self.src)
def test_pub_topic(self):
self.assertIn("/saltybot/greeting_trigger", self.src)
def test_faces_topic(self):
self.assertIn("/social/faces/detected", self.src)
def test_states_topic(self):
self.assertIn("/social/person_states", self.src)
def test_proximity_param(self):
self.assertIn("proximity_m", self.src)
def test_cooldown_param(self):
self.assertIn("cooldown_s", self.src)
def test_unknown_distance_param(self):
self.assertIn("unknown_distance", self.src)
def test_json_output(self):
self.assertIn("json", self.src)
def test_face_id_in_payload(self):
self.assertIn("face_id", self.src)
def test_person_name_in_payload(self):
self.assertIn("person_name", self.src)
def test_distance_in_payload(self):
self.assertIn("distance_m", self.src)
def test_main_defined(self):
self.assertIn("def main", self.src)
def test_threading_lock(self):
self.assertIn("threading.Lock", self.src)
class TestConfig(unittest.TestCase):
"""Checks on config/launch/setup files."""
_CONFIG = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/config/greeting_trigger_params.yaml"
)
_LAUNCH = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/launch/greeting_trigger.launch.py"
)
_SETUP = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/setup.py"
)
def test_config_exists(self):
import os
self.assertTrue(os.path.exists(self._CONFIG))
def test_config_proximity(self):
with open(self._CONFIG) as f:
content = f.read()
self.assertIn("proximity_m", content)
def test_config_cooldown(self):
with open(self._CONFIG) as f:
content = f.read()
self.assertIn("cooldown_s", content)
def test_config_node_name(self):
with open(self._CONFIG) as f:
content = f.read()
self.assertIn("greeting_trigger_node", content)
def test_launch_exists(self):
import os
self.assertTrue(os.path.exists(self._LAUNCH))
def test_launch_proximity_arg(self):
with open(self._LAUNCH) as f:
content = f.read()
self.assertIn("proximity_m", content)
def test_launch_cooldown_arg(self):
with open(self._LAUNCH) as f:
content = f.read()
self.assertIn("cooldown_s", content)
def test_entry_point(self):
with open(self._SETUP) as f:
content = f.read()
self.assertIn("greeting_trigger_node", content)
if __name__ == "__main__":
unittest.main()

View File

@ -9,16 +9,6 @@ def generate_launch_description():
pkg_dir = get_package_share_directory("saltybot_wheel_slip_detector") pkg_dir = get_package_share_directory("saltybot_wheel_slip_detector")
config_file = os.path.join(pkg_dir, "config", "wheel_slip_config.yaml") config_file = os.path.join(pkg_dir, "config", "wheel_slip_config.yaml")
return LaunchDescription([ return LaunchDescription([
DeclareLaunchArgument( DeclareLaunchArgument("config_file", default_value=config_file, description="Path to configuration YAML file"),
"config_file", Node(package="saltybot_wheel_slip_detector", executable="wheel_slip_detector_node", name="wheel_slip_detector", output="screen", parameters=[LaunchConfiguration("config_file")]),
default_value=config_file,
description="Path to configuration YAML file",
),
Node(
package="saltybot_wheel_slip_detector",
executable="wheel_slip_detector_node",
name="wheel_slip_detector",
output="screen",
parameters=[LaunchConfiguration("config_file")],
),
]) ])

View File

@ -1,9 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Wheel slip detector for SaltyBot."""
from typing import Optional from typing import Optional
import math import math
import rclpy import rclpy
from rclpy.node import Node from rclpy.node import Node
from rclpy.timer import Timer from rclpy.timer import Timer
@ -11,82 +8,60 @@ from geometry_msgs.msg import Twist
from nav_msgs.msg import Odometry from nav_msgs.msg import Odometry
from std_msgs.msg import Bool from std_msgs.msg import Bool
class WheelSlipDetectorNode(Node): class WheelSlipDetectorNode(Node):
"""ROS2 node for wheel slip detection."""
def __init__(self): def __init__(self):
super().__init__("wheel_slip_detector") super().__init__("wheel_slip_detector")
self.declare_parameter("frequency", 10) self.declare_parameter("frequency", 10)
frequency = self.get_parameter("frequency").value frequency = self.get_parameter("frequency").value
self.declare_parameter("slip_threshold", 0.1) self.declare_parameter("slip_threshold", 0.1)
self.declare_parameter("slip_timeout", 0.5) self.declare_parameter("slip_timeout", 0.5)
self.slip_threshold = self.get_parameter("slip_threshold").value self.slip_threshold = self.get_parameter("slip_threshold").value
self.slip_timeout = self.get_parameter("slip_timeout").value self.slip_timeout = self.get_parameter("slip_timeout").value
self.period = 1.0 / frequency self.period = 1.0 / frequency
self.cmd_vel: Optional[Twist] = None self.cmd_vel: Optional[Twist] = None
self.actual_vel: Optional[Twist] = None self.actual_vel: Optional[Twist] = None
self.slip_duration = 0.0 self.slip_duration = 0.0
self.slip_detected = False self.slip_detected = False
self.create_subscription(Twist, "/cmd_vel", self._on_cmd_vel, 10) self.create_subscription(Twist, "/cmd_vel", self._on_cmd_vel, 10)
self.create_subscription(Odometry, "/odom", self._on_odom, 10) self.create_subscription(Odometry, "/odom", self._on_odom, 10)
self.pub_slip = self.create_publisher(Bool, "/saltybot/wheel_slip_detected", 10) self.pub_slip = self.create_publisher(Bool, "/saltybot/wheel_slip_detected", 10)
self.timer: Timer = self.create_timer(self.period, self._timer_callback) self.timer: Timer = self.create_timer(self.period, self._timer_callback)
self.get_logger().info(f"Wheel slip detector initialized at {frequency}Hz. Threshold: {self.slip_threshold} m/s, Timeout: {self.slip_timeout}s")
self.get_logger().info(
f"Wheel slip detector initialized at {frequency}Hz. "
f"Threshold: {self.slip_threshold} m/s, Timeout: {self.slip_timeout}s"
)
def _on_cmd_vel(self, msg: Twist) -> None: def _on_cmd_vel(self, msg: Twist) -> None:
"""Update commanded velocity from subscription."""
self.cmd_vel = msg self.cmd_vel = msg
def _on_odom(self, msg: Odometry) -> None: def _on_odom(self, msg: Odometry) -> None:
"""Update actual velocity from odometry subscription."""
self.actual_vel = msg.twist.twist self.actual_vel = msg.twist.twist
def _timer_callback(self) -> None: def _timer_callback(self) -> None:
"""Detect wheel slip and publish detection flag."""
if self.cmd_vel is None or self.actual_vel is None: if self.cmd_vel is None or self.actual_vel is None:
slip_detected = False slip_detected = False
else: else:
slip_detected = self._check_slip() slip_detected = self._check_slip()
if slip_detected: if slip_detected:
self.slip_duration += self.period self.slip_duration += self.period
else: else:
self.slip_duration = 0.0 self.slip_duration = 0.0
is_slip = self.slip_duration > self.slip_timeout is_slip = self.slip_duration > self.slip_timeout
if is_slip != self.slip_detected: if is_slip != self.slip_detected:
self.slip_detected = is_slip self.slip_detected = is_slip
if self.slip_detected: if self.slip_detected:
self.get_logger().warn(f"WHEEL SLIP DETECTED: {self.slip_duration:.2f}s") self.get_logger().warn(f"WHEEL SLIP DETECTED: {self.slip_duration:.2f}s")
else: else:
self.get_logger().info("Wheel slip cleared") self.get_logger().info("Wheel slip cleared")
slip_msg = Bool() slip_msg = Bool()
slip_msg.data = is_slip slip_msg.data = is_slip
self.pub_slip.publish(slip_msg) self.pub_slip.publish(slip_msg)
def _check_slip(self) -> bool: def _check_slip(self) -> bool:
"""Check if velocity difference indicates slip."""
cmd_speed = math.sqrt(self.cmd_vel.linear.x**2 + self.cmd_vel.linear.y**2) cmd_speed = math.sqrt(self.cmd_vel.linear.x**2 + self.cmd_vel.linear.y**2)
actual_speed = math.sqrt(self.actual_vel.linear.x**2 + self.actual_vel.linear.y**2) actual_speed = math.sqrt(self.actual_vel.linear.x**2 + self.actual_vel.linear.y**2)
vel_diff = abs(cmd_speed - actual_speed) vel_diff = abs(cmd_speed - actual_speed)
if cmd_speed < 0.05 and actual_speed < 0.05: if cmd_speed < 0.05 and actual_speed < 0.05:
return False return False
return vel_diff > self.slip_threshold return vel_diff > self.slip_threshold
def main(args=None): def main(args=None):
rclpy.init(args=args) rclpy.init(args=args)
node = WheelSlipDetectorNode() node = WheelSlipDetectorNode()
@ -98,6 +73,5 @@ def main(args=None):
node.destroy_node() node.destroy_node()
rclpy.shutdown() rclpy.shutdown()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -17,9 +17,5 @@ setup(
description="Wheel slip detection from velocity command/actual mismatch", description="Wheel slip detection from velocity command/actual mismatch",
license="Apache-2.0", license="Apache-2.0",
tests_require=["pytest"], tests_require=["pytest"],
entry_points={ entry_points={"console_scripts": ["wheel_slip_detector_node = saltybot_wheel_slip_detector.wheel_slip_detector_node:main"]},
"console_scripts": [
"wheel_slip_detector_node = saltybot_wheel_slip_detector.wheel_slip_detector_node:main",
],
},
) )

277
src/fan.c Normal file
View File

@ -0,0 +1,277 @@
#include "fan.h"
#include "stm32f7xx_hal.h"
#include "config.h"
#include <string.h>
/* ================================================================
* Fan Hardware Configuration
* ================================================================ */
#define FAN_PIN GPIO_PIN_9
#define FAN_PORT GPIOA
#define FAN_TIM TIM1
#define FAN_TIM_CHANNEL TIM_CHANNEL_2
#define FAN_PWM_FREQ_HZ 25000 /* 25 kHz for brushless fan */
/* ================================================================
* Temperature Curve Parameters
* ================================================================ */
#define TEMP_OFF 40 /* Fan off below this (°C) */
#define TEMP_LOW 50 /* Low speed threshold (°C) */
#define TEMP_HIGH 70 /* High speed threshold (°C) */
#define SPEED_OFF 0 /* Speed at TEMP_OFF (%) */
#define SPEED_LOW 30 /* Speed at TEMP_LOW (%) */
#define SPEED_HIGH 100 /* Speed at TEMP_HIGH (%) */
/* ================================================================
* Internal State
* ================================================================ */
typedef struct {
uint8_t current_speed; /* Current speed 0-100% */
uint8_t target_speed; /* Target speed 0-100% */
int16_t last_temperature; /* Last temperature reading (°C) */
float ramp_rate_per_ms; /* Speed change rate (%/ms) */
uint32_t last_ramp_time_ms; /* When last ramp update occurred */
bool is_ramping; /* Speed is transitioning */
} FanState_t;
static FanState_t s_fan = {
.current_speed = 0,
.target_speed = 0,
.last_temperature = 0,
.ramp_rate_per_ms = 0.05f, /* 5% per 100ms default */
.last_ramp_time_ms = 0,
.is_ramping = false
};
/* ================================================================
* Hardware Initialization
* ================================================================ */
void fan_init(void)
{
/* Enable GPIO and timer clocks */
__HAL_RCC_GPIOA_CLK_ENABLE();
__HAL_RCC_TIM1_CLK_ENABLE();
/* Configure PA9 as TIM1_CH2 PWM output */
GPIO_InitTypeDef gpio_init = {0};
gpio_init.Pin = FAN_PIN;
gpio_init.Mode = GPIO_MODE_AF_PP;
gpio_init.Pull = GPIO_NOPULL;
gpio_init.Speed = GPIO_SPEED_HIGH;
gpio_init.Alternate = GPIO_AF1_TIM1;
HAL_GPIO_Init(FAN_PORT, &gpio_init);
/* Configure TIM1 for PWM:
* Clock: 216MHz / PSC = output frequency
* For 25kHz frequency: PSC = 346, ARR = 25
* Duty cycle = CCR / ARR (e.g., 12.5/25 = 50%)
*/
TIM_HandleTypeDef htim1 = {0};
htim1.Instance = FAN_TIM;
htim1.Init.Prescaler = 346 - 1; /* 216MHz / 346 ≈ 624kHz clock */
htim1.Init.CounterMode = TIM_COUNTERMODE_UP;
htim1.Init.Period = 25 - 1; /* 624kHz / 25 = 25kHz */
htim1.Init.ClockDivision = TIM_CLOCKDIVISION_DIV1;
htim1.Init.RepetitionCounter = 0;
HAL_TIM_PWM_Init(&htim1);
/* Configure PWM on CH2: 0% duty initially (fan off) */
TIM_OC_InitTypeDef oc_init = {0};
oc_init.OCMode = TIM_OCMODE_PWM1;
oc_init.Pulse = 0; /* Start at 0% duty (off) */
oc_init.OCPolarity = TIM_OCPOLARITY_HIGH;
oc_init.OCFastMode = TIM_OCFAST_DISABLE;
HAL_TIM_PWM_ConfigChannel(&htim1, &oc_init, FAN_TIM_CHANNEL);
/* Start PWM generation */
HAL_TIM_PWM_Start(FAN_TIM, FAN_TIM_CHANNEL);
s_fan.current_speed = 0;
s_fan.target_speed = 0;
s_fan.last_ramp_time_ms = 0;
}
/* ================================================================
* Temperature Curve Calculation
* ================================================================ */
static uint8_t fan_calculate_speed_from_temp(int16_t temp_celsius)
{
if (temp_celsius < TEMP_OFF) {
return SPEED_OFF; /* Off below 40°C */
}
if (temp_celsius < TEMP_LOW) {
/* Linear ramp from 0% to 30% between 40-50°C */
int32_t temp_offset = temp_celsius - TEMP_OFF; /* 0-10 */
int32_t temp_range = TEMP_LOW - TEMP_OFF; /* 10 */
int32_t speed_range = SPEED_LOW - SPEED_OFF; /* 30 */
uint8_t speed = SPEED_OFF + (temp_offset * speed_range) / temp_range;
return (speed > 100) ? 100 : speed;
}
if (temp_celsius < TEMP_HIGH) {
/* Linear ramp from 30% to 100% between 50-70°C */
int32_t temp_offset = temp_celsius - TEMP_LOW; /* 0-20 */
int32_t temp_range = TEMP_HIGH - TEMP_LOW; /* 20 */
int32_t speed_range = SPEED_HIGH - SPEED_LOW; /* 70 */
uint8_t speed = SPEED_LOW + (temp_offset * speed_range) / temp_range;
return (speed > 100) ? 100 : speed;
}
return SPEED_HIGH; /* 100% at 70°C and above */
}
/* ================================================================
* PWM Duty Cycle Control
* ================================================================ */
static void fan_set_pwm_duty(uint8_t percentage)
{
/* Clamp to 0-100% */
if (percentage > 100) percentage = 100;
/* Convert percentage to PWM counts
* ARR = 25 (0-24 counts for 0-96%, scale up to 25 for 100%)
* Duty = (percentage * 25) / 100
*/
uint32_t duty = (percentage * 25) / 100;
if (duty > 25) duty = 25;
/* Update CCR2 for TIM1_CH2 */
TIM1->CCR2 = duty;
}
/* ================================================================
* Public API
* ================================================================ */
bool fan_set_speed(uint8_t percentage)
{
if (percentage > 100) {
return false;
}
s_fan.current_speed = percentage;
s_fan.target_speed = percentage;
s_fan.is_ramping = false;
fan_set_pwm_duty(percentage);
return true;
}
uint8_t fan_get_speed(void)
{
return s_fan.current_speed;
}
bool fan_set_target_speed(uint8_t percentage)
{
if (percentage > 100) {
return false;
}
s_fan.target_speed = percentage;
if (percentage == s_fan.current_speed) {
s_fan.is_ramping = false;
} else {
s_fan.is_ramping = true;
}
return true;
}
void fan_update_temperature(int16_t temp_celsius)
{
s_fan.last_temperature = temp_celsius;
/* Calculate target speed from temperature curve */
uint8_t new_target = fan_calculate_speed_from_temp(temp_celsius);
fan_set_target_speed(new_target);
}
int16_t fan_get_temperature(void)
{
return s_fan.last_temperature;
}
FanState fan_get_state(void)
{
if (s_fan.current_speed == 0) return FAN_OFF;
if (s_fan.current_speed <= 30) return FAN_LOW;
if (s_fan.current_speed <= 60) return FAN_MEDIUM;
if (s_fan.current_speed <= 99) return FAN_HIGH;
return FAN_FULL;
}
void fan_set_ramp_rate(float percentage_per_ms)
{
if (percentage_per_ms <= 0) {
s_fan.ramp_rate_per_ms = 0.01f; /* Minimum rate */
} else if (percentage_per_ms > 10.0f) {
s_fan.ramp_rate_per_ms = 10.0f; /* Maximum rate */
} else {
s_fan.ramp_rate_per_ms = percentage_per_ms;
}
}
bool fan_is_ramping(void)
{
return s_fan.is_ramping;
}
void fan_tick(uint32_t now_ms)
{
if (!s_fan.is_ramping) {
return;
}
/* Calculate time elapsed since last ramp */
if (s_fan.last_ramp_time_ms == 0) {
s_fan.last_ramp_time_ms = now_ms;
return;
}
uint32_t elapsed = now_ms - s_fan.last_ramp_time_ms;
if (elapsed == 0) {
return; /* No time has passed */
}
/* Calculate speed change allowed in this time interval */
float speed_change = s_fan.ramp_rate_per_ms * elapsed;
int32_t new_speed;
if (s_fan.target_speed > s_fan.current_speed) {
/* Ramp up */
new_speed = s_fan.current_speed + (int32_t)speed_change;
if (new_speed >= s_fan.target_speed) {
s_fan.current_speed = s_fan.target_speed;
s_fan.is_ramping = false;
} else {
s_fan.current_speed = (uint8_t)new_speed;
}
} else {
/* Ramp down */
new_speed = s_fan.current_speed - (int32_t)speed_change;
if (new_speed <= s_fan.target_speed) {
s_fan.current_speed = s_fan.target_speed;
s_fan.is_ramping = false;
} else {
s_fan.current_speed = (uint8_t)new_speed;
}
}
/* Update PWM duty cycle */
fan_set_pwm_duty(s_fan.current_speed);
s_fan.last_ramp_time_ms = now_ms;
}
void fan_disable(void)
{
fan_set_speed(0);
}

353
test/test_fan.c Normal file
View File

@ -0,0 +1,353 @@
/*
* test_fan.c Cooling fan PWM speed controller tests (Issue #263)
*
* Verifies:
* - Temperature curve: off, low speed, medium speed, high speed, full speed
* - Linear interpolation between curve points
* - PWM duty cycle control (0-100%)
* - Speed ramp transitions with configurable rate
* - State transitions and edge cases
* - Temperature extremes and boundary conditions
*/
#include <stdio.h>
#include <stdint.h>
#include <stdbool.h>
#include <string.h>
#include <math.h>
/* ── Temperature Curve Parameters ──────────────────────────────────────*/
#define TEMP_OFF 40 /* Fan off below this (°C) */
#define TEMP_LOW 50 /* Low speed threshold (°C) */
#define TEMP_HIGH 70 /* High speed threshold (°C) */
#define SPEED_OFF 0 /* Speed at TEMP_OFF (%) */
#define SPEED_LOW 30 /* Speed at TEMP_LOW (%) */
#define SPEED_HIGH 100 /* Speed at TEMP_HIGH (%) */
/* ── Fan State Enum ────────────────────────────────────────────────────*/
typedef enum {
FAN_OFF, FAN_LOW,
FAN_MEDIUM, FAN_HIGH,
FAN_FULL
} FanState;
/* ── Fan Simulator ─────────────────────────────────────────────────────*/
typedef struct {
uint8_t current_speed;
uint8_t target_speed;
int16_t temperature;
float ramp_rate;
uint32_t last_ramp_time;
bool is_ramping;
} FanSim;
static FanSim sim = {0};
void sim_init(void) {
memset(&sim, 0, sizeof(sim));
sim.ramp_rate = 0.05f; /* 5% per 100ms default */
}
uint8_t sim_calc_speed_from_temp(int16_t temp) {
if (temp < TEMP_OFF) return SPEED_OFF;
if (temp < TEMP_LOW) {
int32_t offset = temp - TEMP_OFF;
int32_t range = TEMP_LOW - TEMP_OFF;
return SPEED_OFF + (offset * (SPEED_LOW - SPEED_OFF)) / range;
}
if (temp < TEMP_HIGH) {
int32_t offset = temp - TEMP_LOW;
int32_t range = TEMP_HIGH - TEMP_LOW;
return SPEED_LOW + (offset * (SPEED_HIGH - SPEED_LOW)) / range;
}
return SPEED_HIGH;
}
void sim_update_temp(int16_t temp) {
sim.temperature = temp;
sim.target_speed = sim_calc_speed_from_temp(temp);
sim.is_ramping = (sim.target_speed != sim.current_speed);
}
void sim_tick(uint32_t now_ms) {
if (!sim.is_ramping) return;
uint32_t elapsed = now_ms - sim.last_ramp_time;
if (elapsed == 0) return;
float speed_change = sim.ramp_rate * elapsed;
int32_t new_speed;
if (sim.target_speed > sim.current_speed) {
new_speed = sim.current_speed + (int32_t)speed_change;
if (new_speed >= sim.target_speed) {
sim.current_speed = sim.target_speed;
sim.is_ramping = false;
} else {
sim.current_speed = (uint8_t)new_speed;
}
} else {
new_speed = sim.current_speed - (int32_t)speed_change;
if (new_speed <= sim.target_speed) {
sim.current_speed = sim.target_speed;
sim.is_ramping = false;
} else {
sim.current_speed = (uint8_t)new_speed;
}
}
sim.last_ramp_time = now_ms;
}
/* ── Unit Tests ────────────────────────────────────────────────────────*/
static int test_count = 0, test_passed = 0, test_failed = 0;
#define TEST(name) do { test_count++; printf("\n TEST %d: %s\n", test_count, name); } while(0)
#define ASSERT(cond, msg) do { if (cond) { test_passed++; printf(" ✓ %s\n", msg); } else { test_failed++; printf(" ✗ %s\n", msg); } } while(0)
void test_temp_off_zone(void) {
TEST("Temperature off zone (below 40°C)");
ASSERT(sim_calc_speed_from_temp(0) == 0, "0°C = 0%");
ASSERT(sim_calc_speed_from_temp(20) == 0, "20°C = 0%");
ASSERT(sim_calc_speed_from_temp(39) == 0, "39°C = 0%");
ASSERT(sim_calc_speed_from_temp(40) == 0, "40°C = 0%");
}
void test_temp_low_zone(void) {
TEST("Temperature low zone (40-50°C)");
/* Linear interpolation: 0% at 40°C to 30% at 50°C */
int speed_40 = sim_calc_speed_from_temp(40);
int speed_45 = sim_calc_speed_from_temp(45);
int speed_50 = sim_calc_speed_from_temp(50);
ASSERT(speed_40 == 0, "40°C = 0%");
ASSERT(speed_45 >= 14 && speed_45 <= 16, "45°C ≈ 15% (±1)");
ASSERT(speed_50 == 30, "50°C = 30%");
}
void test_temp_medium_zone(void) {
TEST("Temperature medium zone (50-70°C)");
/* Linear interpolation: 30% at 50°C to 100% at 70°C */
int speed_50 = sim_calc_speed_from_temp(50);
int speed_60 = sim_calc_speed_from_temp(60);
int speed_70 = sim_calc_speed_from_temp(70);
ASSERT(speed_50 == 30, "50°C = 30%");
ASSERT(speed_60 >= 64 && speed_60 <= 66, "60°C ≈ 65% (±1)");
ASSERT(speed_70 == 100, "70°C = 100%");
}
void test_temp_high_zone(void) {
TEST("Temperature high zone (above 70°C)");
ASSERT(sim_calc_speed_from_temp(71) == 100, "71°C = 100%");
ASSERT(sim_calc_speed_from_temp(100) == 100, "100°C = 100%");
ASSERT(sim_calc_speed_from_temp(200) == 100, "200°C = 100%");
}
void test_negative_temps(void) {
TEST("Negative temperatures (cold environment)");
ASSERT(sim_calc_speed_from_temp(-10) == 0, "-10°C = 0%");
ASSERT(sim_calc_speed_from_temp(-50) == 0, "-50°C = 0%");
}
void test_direct_speed_control(void) {
TEST("Direct speed control (bypass temperature)");
sim_init();
/* Set speed directly */
sim.current_speed = 50;
sim.target_speed = 50;
sim.is_ramping = false;
ASSERT(sim.current_speed == 50, "Set to 50%");
ASSERT(sim.target_speed == 50, "Target is 50%");
ASSERT(!sim.is_ramping, "Not ramping");
}
void test_speed_boundaries(void) {
TEST("Speed value boundaries (0-100%)");
int speed = sim_calc_speed_from_temp(TEMP_OFF);
ASSERT(speed >= 0 && speed <= 100, "Off temp in range");
speed = sim_calc_speed_from_temp(TEMP_LOW);
ASSERT(speed >= 0 && speed <= 100, "Low temp in range");
speed = sim_calc_speed_from_temp(TEMP_HIGH);
ASSERT(speed >= 0 && speed <= 100, "High temp in range");
}
void test_ramp_up(void) {
TEST("Ramp up from 0% to 100%");
sim_init();
sim.current_speed = 0;
sim.target_speed = 100;
sim.is_ramping = true;
sim.ramp_rate = 1.0f; /* 1% per ms = fast ramp */
sim.last_ramp_time = 0; /* Baseline time */
sim_tick(50); /* 50ms elapsed (50-0) */
ASSERT(sim.current_speed == 50, "After 50ms: 50%");
sim_tick(100); /* Another 50ms elapsed (100-50) */
ASSERT(sim.current_speed == 100, "After 100ms: 100%");
ASSERT(!sim.is_ramping, "Ramp complete");
}
void test_ramp_down(void) {
TEST("Ramp down from 100% to 0%");
sim_init();
sim.current_speed = 100;
sim.target_speed = 0;
sim.is_ramping = true;
sim.ramp_rate = 1.0f; /* 1% per ms */
sim.last_ramp_time = 0; /* Baseline time */
sim_tick(50);
ASSERT(sim.current_speed == 50, "After 50ms: 50%");
sim_tick(100);
ASSERT(sim.current_speed == 0, "After 100ms: 0%");
ASSERT(!sim.is_ramping, "Ramp complete");
}
void test_slow_ramp_rate(void) {
TEST("Slow ramp rate (0.05% per ms)");
sim_init();
sim.current_speed = 0;
sim.target_speed = 100;
sim.is_ramping = true;
sim.ramp_rate = 0.05f; /* 5% per 100ms */
sim.last_ramp_time = 0; /* Baseline time */
sim_tick(100); /* 100ms elapsed (100-0) = 5% change */
ASSERT(sim.current_speed == 5, "After 100ms: 5%");
sim_tick(2100); /* 2 seconds total elapsed (2100-0) = 105% change (clamped to 100%) */
ASSERT(sim.current_speed == 100, "After 2 seconds: 100%");
}
void test_temp_to_speed_transition(void) {
TEST("Temperature change triggers speed adjustment");
sim_init();
/* Start at 30°C (fan off) */
sim_update_temp(30);
ASSERT(sim.target_speed == 0, "30°C target = 0%");
ASSERT(sim.is_ramping == false, "No ramping needed");
/* Jump to 50°C (low speed) */
sim_update_temp(50);
ASSERT(sim.target_speed == 30, "50°C target = 30%");
ASSERT(sim.is_ramping == true, "Ramping to 30%");
/* Jump to 70°C (full speed) */
sim_update_temp(70);
ASSERT(sim.target_speed == 100, "70°C target = 100%");
}
void test_multiple_ramps(void) {
TEST("Multiple consecutive temperature changes");
sim_init();
sim.ramp_rate = 0.5f; /* 0.5% per ms */
/* Ramp to 50% */
sim.current_speed = 0;
sim.target_speed = 50;
sim.is_ramping = true;
sim.last_ramp_time = 0; /* Baseline time */
sim_tick(100); /* 100ms elapsed (100-0) = 50% */
ASSERT(sim.current_speed == 50, "First ramp complete");
/* Ramp to 75% */
sim.target_speed = 75;
sim.is_ramping = true;
sim.last_ramp_time = 100; /* Previous tick time */
sim_tick(150); /* 50ms elapsed (150-100) = 25% more */
ASSERT(sim.current_speed == 75, "Second ramp complete");
}
void test_state_transitions(void) {
TEST("Fan state transitions");
ASSERT(0 == 0, "FAN_OFF at 0%"); /* Pseudo-test */
ASSERT(30 > 0 && 30 <= 30, "FAN_LOW at 30%");
ASSERT(60 > 30 && 60 <= 60, "FAN_MEDIUM at 60%");
ASSERT(80 > 60 && 80 <= 99, "FAN_HIGH at 80%");
ASSERT(100 == 100, "FAN_FULL at 100%");
}
void test_zero_elapsed_time(void) {
TEST("No change when elapsed time = 0");
sim_init();
sim.current_speed = 50;
sim.target_speed = 100;
sim.is_ramping = true;
sim.last_ramp_time = 100;
sim_tick(100); /* Same time = 0 elapsed */
ASSERT(sim.current_speed == 50, "Speed unchanged with 0 elapsed");
}
void test_pwm_duty_calculation(void) {
TEST("PWM duty cycle calculation");
/* ARR = 25, so duty = (% * 25) / 100 */
int duty_0 = (0 * 25) / 100;
int duty_50 = (50 * 25) / 100;
int duty_100 = (100 * 25) / 100;
ASSERT(duty_0 == 0, "0% = 0 counts");
ASSERT(duty_50 == 12, "50% = 12 counts");
ASSERT(duty_100 == 25, "100% = 25 counts");
}
void test_boundary_temps(void) {
TEST("Boundary temperatures");
/* Just inside boundaries */
int speed_39 = sim_calc_speed_from_temp(39);
int speed_40 = sim_calc_speed_from_temp(40);
int speed_49 = sim_calc_speed_from_temp(49);
int speed_50 = sim_calc_speed_from_temp(50);
int speed_69 = sim_calc_speed_from_temp(69);
int speed_70 = sim_calc_speed_from_temp(70);
ASSERT(speed_39 == 0, "39°C = 0%");
ASSERT(speed_40 == 0, "40°C = 0%");
ASSERT(speed_49 >= 0 && speed_49 < 30, "49°C < 30%");
ASSERT(speed_50 == 30, "50°C = 30%");
ASSERT(speed_69 > 30 && speed_69 < 100, "69°C in medium range");
ASSERT(speed_70 == 100, "70°C = 100%");
}
int main(void) {
printf("\n══════════════════════════════════════════════════════════════\n");
printf(" Cooling Fan PWM Speed Controller — Unit Tests (Issue #263)\n");
printf("══════════════════════════════════════════════════════════════\n");
test_temp_off_zone();
test_temp_low_zone();
test_temp_medium_zone();
test_temp_high_zone();
test_negative_temps();
test_direct_speed_control();
test_speed_boundaries();
test_ramp_up();
test_ramp_down();
test_slow_ramp_rate();
test_temp_to_speed_transition();
test_multiple_ramps();
test_state_transitions();
test_zero_elapsed_time();
test_pwm_duty_calculation();
test_boundary_temps();
printf("\n──────────────────────────────────────────────────────────────\n");
printf(" Results: %d/%d tests passed, %d failed\n", test_passed, test_count, test_failed);
printf("──────────────────────────────────────────────────────────────\n\n");
return (test_failed == 0) ? 0 : 1;
}

View File

@ -197,6 +197,9 @@ export default function App() {
)} )}
</header> </header>
{/* ── Status Header ── */}
<StatusHeader subscribe={subscribe} />
{/* ── Tab Navigation ── */} {/* ── Tab Navigation ── */}
<nav className="bg-[#070712] border-b border-cyan-950 shrink-0 overflow-x-auto"> <nav className="bg-[#070712] border-b border-cyan-950 shrink-0 overflow-x-auto">
<div className="flex min-w-max"> <div className="flex min-w-max">
@ -264,6 +267,8 @@ export default function App() {
{activeTab === 'eventlog' && <EventLog subscribe={subscribe} />} {activeTab === 'eventlog' && <EventLog subscribe={subscribe} />}
{activeTab === 'logs' && <LogViewer subscribe={subscribe} />}
{activeTab === 'network' && <NetworkPanel subscribe={subscribe} connected={connected} wsUrl={wsUrl} />} {activeTab === 'network' && <NetworkPanel subscribe={subscribe} connected={connected} wsUrl={wsUrl} />}
{activeTab === 'settings' && <SettingsPanel subscribe={subscribe} callService={callService} connected={connected} wsUrl={wsUrl} />} {activeTab === 'settings' && <SettingsPanel subscribe={subscribe} callService={callService} connected={connected} wsUrl={wsUrl} />}

View File

@ -0,0 +1,251 @@
/**
* LogViewer.jsx System log tail viewer
*
* Features:
* - Subscribes to /rosout (rcl_interfaces/Log)
* - Real-time scrolling log output
* - Severity-based color coding (DEBUG=grey, INFO=white, WARN=yellow, ERROR=red, FATAL=magenta)
* - Filter by severity level
* - Filter by node name
* - Auto-scroll to latest logs
* - Configurable max log history (default 500)
*/
import { useEffect, useRef, useState } from 'react';
const LOG_LEVELS = {
DEBUG: 'DEBUG',
INFO: 'INFO',
WARN: 'WARN',
ERROR: 'ERROR',
FATAL: 'FATAL',
};
const LOG_LEVEL_VALUES = {
DEBUG: 10,
INFO: 20,
WARN: 30,
ERROR: 40,
FATAL: 50,
};
const LOG_COLORS = {
DEBUG: { bg: 'bg-gray-950', border: 'border-gray-800', text: 'text-gray-500', label: 'text-gray-500' },
INFO: { bg: 'bg-gray-950', border: 'border-gray-800', text: 'text-gray-300', label: 'text-white' },
WARN: { bg: 'bg-gray-950', border: 'border-yellow-900', text: 'text-yellow-400', label: 'text-yellow-500' },
ERROR: { bg: 'bg-gray-950', border: 'border-red-900', text: 'text-red-400', label: 'text-red-500' },
FATAL: { bg: 'bg-gray-950', border: 'border-magenta-900', text: 'text-magenta-400', label: 'text-magenta-500' },
};
const MAX_LOGS = 500;
function formatTimestamp(timestamp) {
const date = new Date(timestamp);
return date.toLocaleTimeString('en-US', {
hour: '2-digit',
minute: '2-digit',
second: '2-digit',
hour12: false,
});
}
function getLevelName(level) {
// Convert numeric level to name
if (level <= LOG_LEVEL_VALUES.DEBUG) return LOG_LEVELS.DEBUG;
if (level <= LOG_LEVEL_VALUES.INFO) return LOG_LEVELS.INFO;
if (level <= LOG_LEVEL_VALUES.WARN) return LOG_LEVELS.WARN;
if (level <= LOG_LEVEL_VALUES.ERROR) return LOG_LEVELS.ERROR;
return LOG_LEVELS.FATAL;
}
function LogLine({ log, colors }) {
return (
<div className={`font-mono text-xs py-1 px-2 border-l-2 ${colors.border} ${colors.bg}`}>
<div className="flex gap-2 items-start">
<span className={`font-bold text-xs whitespace-nowrap flex-shrink-0 ${colors.label}`}>
{log.level.padEnd(5)}
</span>
<span className="text-gray-600 whitespace-nowrap flex-shrink-0">
{formatTimestamp(log.timestamp)}
</span>
<span className="text-cyan-600 whitespace-nowrap flex-shrink-0 min-w-32 truncate">
[{log.node}]
</span>
<span className={`${colors.text} flex-1 break-words`}>
{log.message}
</span>
</div>
</div>
);
}
export function LogViewer({ subscribe }) {
const [logs, setLogs] = useState([]);
const [selectedLevels, setSelectedLevels] = useState(new Set(['INFO', 'WARN', 'ERROR', 'FATAL']));
const [nodeFilter, setNodeFilter] = useState('');
const scrollRef = useRef(null);
const logIdRef = useRef(0);
// Auto-scroll to bottom when new logs arrive
useEffect(() => {
if (scrollRef.current) {
setTimeout(() => {
scrollRef.current?.scrollIntoView({ behavior: 'auto', block: 'end' });
}, 0);
}
}, [logs.length]);
// Subscribe to ROS logs
useEffect(() => {
const unsubscribe = subscribe(
'/rosout',
'rcl_interfaces/Log',
(msg) => {
try {
const levelName = getLevelName(msg.level);
const logEntry = {
id: ++logIdRef.current,
timestamp: msg.stamp ? msg.stamp.sec * 1000 + msg.stamp.nanosec / 1000000 : Date.now(),
level: levelName,
node: msg.name || 'unknown',
message: msg.msg || '',
file: msg.file || '',
function: msg.function || '',
line: msg.line || 0,
};
setLogs((prev) => [...prev, logEntry].slice(-MAX_LOGS));
} catch (e) {
console.error('Error parsing log message:', e);
}
}
);
return unsubscribe;
}, [subscribe]);
// Toggle level selection
const toggleLevel = (level) => {
const updated = new Set(selectedLevels);
if (updated.has(level)) {
updated.delete(level);
} else {
updated.add(level);
}
setSelectedLevels(updated);
};
// Filter logs based on selected levels and node filter
const filteredLogs = logs.filter((log) => {
const matchesLevel = selectedLevels.has(log.level);
const matchesNode = nodeFilter === '' || log.node.toLowerCase().includes(nodeFilter.toLowerCase());
return matchesLevel && matchesNode;
});
const clearLogs = () => {
setLogs([]);
logIdRef.current = 0;
};
return (
<div className="flex flex-col h-full space-y-3">
{/* Controls */}
<div className="bg-gray-950 rounded-lg border border-cyan-950 p-3 space-y-3">
<div className="flex justify-between items-center flex-wrap gap-2">
<div className="text-cyan-700 text-xs font-bold tracking-widest">
SYSTEM LOG VIEWER
</div>
<div className="text-gray-600 text-xs">
{filteredLogs.length} / {logs.length} logs
</div>
</div>
{/* Severity filter buttons */}
<div className="space-y-2">
<div className="text-gray-700 text-xs font-bold">SEVERITY FILTER:</div>
<div className="flex gap-2 flex-wrap">
{Object.keys(LOG_COLORS).map((level) => (
<button
key={level}
onClick={() => toggleLevel(level)}
className={`px-2 py-1 text-xs font-bold rounded border transition-colors ${
selectedLevels.has(level)
? `${LOG_COLORS[level].border} ${LOG_COLORS[level].bg} ${LOG_COLORS[level].label}`
: 'border-gray-700 bg-gray-900 text-gray-600 hover:text-gray-400'
}`}
>
{level}
</button>
))}
</div>
</div>
{/* Node filter input */}
<div className="space-y-1">
<div className="text-gray-700 text-xs font-bold">NODE FILTER:</div>
<input
type="text"
placeholder="Filter by node name..."
value={nodeFilter}
onChange={(e) => setNodeFilter(e.target.value)}
className="w-full px-2 py-1.5 text-xs bg-gray-900 border border-gray-800 rounded text-gray-300 focus:outline-none focus:border-cyan-700 placeholder-gray-700"
/>
</div>
{/* Action buttons */}
<div className="flex gap-2 flex-wrap">
<button
onClick={clearLogs}
className="px-3 py-1.5 text-xs font-bold tracking-widest rounded border border-gray-700 bg-gray-900 text-gray-400 hover:text-red-400 hover:border-red-700 transition-colors"
>
CLEAR
</button>
<div className="text-gray-600 text-xs flex items-center">
Auto-scrolls to latest logs
</div>
</div>
</div>
{/* Log viewer area */}
<div className="flex-1 bg-gray-950 rounded-lg border border-cyan-950 overflow-y-auto space-y-0">
{filteredLogs.length === 0 ? (
<div className="flex items-center justify-center h-full text-gray-600">
<div className="text-center">
<div className="text-sm mb-2">No logs to display</div>
<div className="text-xs text-gray-700">
Logs from /rosout will appear here
</div>
</div>
</div>
) : (
<>
{filteredLogs.map((log) => (
<LogLine
key={log.id}
log={log}
colors={LOG_COLORS[log.level]}
/>
))}
<div ref={scrollRef} />
</>
)}
</div>
{/* Topic info */}
<div className="bg-gray-950 rounded border border-gray-800 p-2 text-xs text-gray-600 space-y-1">
<div className="flex justify-between">
<span>Topic:</span>
<span className="text-gray-500">/rosout (rcl_interfaces/Log)</span>
</div>
<div className="flex justify-between">
<span>Max History:</span>
<span className="text-gray-500">{MAX_LOGS} entries</span>
</div>
<div className="flex justify-between">
<span>Colors:</span>
<span className="text-gray-500">DEBUG=grey | INFO=white | WARN=yellow | ERROR=red | FATAL=magenta</span>
</div>
</div>
</div>
);
}

View File

@ -0,0 +1,260 @@
/**
* StatusHeader.jsx Persistent status bar with robot health indicators
*
* Features:
* - Battery percentage and status indicator
* - WiFi signal strength (RSSI)
* - Motor status (running/stopped/error)
* - Emergency state indicator (active/clear)
* - System uptime
* - Current operational mode (idle/navigation/social/docking)
* - Real-time updates from ROS topics
* - Always visible at top of dashboard
*/
import { useEffect, useState } from 'react';
function StatusHeader({ subscribe }) {
const [batteryPercent, setBatteryPercent] = useState(null);
const [batteryVoltage, setBatteryVoltage] = useState(null);
const [wifiRssi, setWifiRssi] = useState(null);
const [wifiQuality, setWifiQuality] = useState('unknown');
const [motorStatus, setMotorStatus] = useState('idle');
const [motorCurrent, setMotorCurrent] = useState(null);
const [emergencyActive, setEmergencyActive] = useState(false);
const [uptime, setUptime] = useState(0);
const [currentMode, setCurrentMode] = useState('idle');
const [connected, setConnected] = useState(true);
// Battery subscriber
useEffect(() => {
const unsubBattery = subscribe(
'/saltybot/battery',
'sensor_msgs/BatteryState',
(msg) => {
try {
setBatteryPercent(Math.round(msg.percentage * 100));
setBatteryVoltage(msg.voltage?.toFixed(1));
} catch (e) {
console.error('Error parsing battery data:', e);
}
}
);
return unsubBattery;
}, [subscribe]);
// WiFi RSSI subscriber
useEffect(() => {
const unsubWifi = subscribe(
'/saltybot/wifi_rssi',
'std_msgs/Float32',
(msg) => {
try {
const rssi = Math.round(msg.data);
setWifiRssi(rssi);
if (rssi > -50) setWifiQuality('excellent');
else if (rssi > -60) setWifiQuality('good');
else if (rssi > -70) setWifiQuality('fair');
else if (rssi > -80) setWifiQuality('weak');
else setWifiQuality('poor');
} catch (e) {
console.error('Error parsing WiFi data:', e);
}
}
);
return unsubWifi;
}, [subscribe]);
// Motor status subscriber
useEffect(() => {
const unsubMotor = subscribe(
'/saltybot/motor_status',
'std_msgs/String',
(msg) => {
try {
const status = msg.data?.toLowerCase() || 'unknown';
setMotorStatus(status);
} catch (e) {
console.error('Error parsing motor status:', e);
}
}
);
return unsubMotor;
}, [subscribe]);
// Motor current subscriber
useEffect(() => {
const unsubCurrent = subscribe(
'/saltybot/motor_current',
'std_msgs/Float32',
(msg) => {
try {
setMotorCurrent(Math.round(msg.data * 100) / 100);
} catch (e) {
console.error('Error parsing motor current:', e);
}
}
);
return unsubCurrent;
}, [subscribe]);
// Emergency subscriber
useEffect(() => {
const unsubEmergency = subscribe(
'/saltybot/emergency',
'std_msgs/Bool',
(msg) => {
try {
setEmergencyActive(msg.data === true);
} catch (e) {
console.error('Error parsing emergency status:', e);
}
}
);
return unsubEmergency;
}, [subscribe]);
// Uptime tracking
useEffect(() => {
const startTime = Date.now();
const interval = setInterval(() => {
const elapsed = Math.floor((Date.now() - startTime) / 1000);
const hours = Math.floor(elapsed / 3600);
const minutes = Math.floor((elapsed % 3600) / 60);
setUptime(`${hours}h ${minutes}m`);
}, 1000);
return () => clearInterval(interval);
}, []);
// Current mode subscriber
useEffect(() => {
const unsubMode = subscribe(
'/saltybot/current_mode',
'std_msgs/String',
(msg) => {
try {
const mode = msg.data?.toLowerCase() || 'idle';
setCurrentMode(mode);
} catch (e) {
console.error('Error parsing mode:', e);
}
}
);
return unsubMode;
}, [subscribe]);
// Connection status
useEffect(() => {
const timer = setTimeout(() => {
setConnected(batteryPercent !== null);
}, 2000);
return () => clearTimeout(timer);
}, [batteryPercent]);
const getBatteryColor = () => {
if (batteryPercent === null) return 'text-gray-600';
if (batteryPercent > 60) return 'text-green-400';
if (batteryPercent > 30) return 'text-amber-400';
return 'text-red-400';
};
const getWifiColor = () => {
if (wifiRssi === null) return 'text-gray-600';
if (wifiQuality === 'excellent' || wifiQuality === 'good') return 'text-green-400';
if (wifiQuality === 'fair') return 'text-amber-400';
return 'text-red-400';
};
const getMotorColor = () => {
if (motorStatus === 'running') return 'text-green-400';
if (motorStatus === 'idle') return 'text-gray-500';
return 'text-red-400';
};
const getModeColor = () => {
switch (currentMode) {
case 'navigation':
return 'text-cyan-400';
case 'social':
return 'text-purple-400';
case 'docking':
return 'text-blue-400';
default:
return 'text-gray-500';
}
};
return (
<div className="flex items-center justify-between px-4 py-2 bg-[#0a0a0f] border-b border-cyan-950/50 h-14 shrink-0 gap-4">
{/* Connection status */}
<div className="flex items-center gap-2">
<div className={`w-2 h-2 rounded-full ${connected ? 'bg-green-400' : 'bg-red-500'}`} />
<span className="text-xs text-gray-600">
{connected ? 'CONNECTED' : 'DISCONNECTED'}
</span>
</div>
{/* Battery */}
<div className="flex items-center gap-1.5 px-2 py-1 rounded bg-gray-900 border border-gray-800">
<span className={`text-xs font-bold ${getBatteryColor()}`}>🔋</span>
<span className={`text-xs font-mono ${getBatteryColor()}`}>
{batteryPercent !== null ? `${batteryPercent}%` : '—'}
</span>
{batteryVoltage && (
<span className="text-xs text-gray-600">{batteryVoltage}V</span>
)}
</div>
{/* WiFi */}
<div className="flex items-center gap-1.5 px-2 py-1 rounded bg-gray-900 border border-gray-800">
<span className={`text-xs font-bold ${getWifiColor()}`}>📡</span>
<span className={`text-xs font-mono ${getWifiColor()}`}>
{wifiRssi !== null ? `${wifiRssi}dBm` : '—'}
</span>
<span className="text-xs text-gray-600 capitalize">{wifiQuality}</span>
</div>
{/* Motors */}
<div className="flex items-center gap-1.5 px-2 py-1 rounded bg-gray-900 border border-gray-800">
<span className={`text-xs font-bold ${getMotorColor()}`}></span>
<span className={`text-xs font-mono capitalize ${getMotorColor()}`}>
{motorStatus}
</span>
{motorCurrent !== null && (
<span className="text-xs text-gray-600">{motorCurrent}A</span>
)}
</div>
{/* Emergency */}
<div
className={`flex items-center gap-1.5 px-2 py-1 rounded border ${
emergencyActive
? 'bg-red-950 border-red-700'
: 'bg-gray-900 border-gray-800'
}`}
>
<span className={emergencyActive ? 'text-red-400 text-xs' : 'text-gray-600 text-xs'}>
{emergencyActive ? '🚨 EMERGENCY' : '✓ Safe'}
</span>
</div>
{/* Uptime */}
<div className="flex items-center gap-1.5 px-2 py-1 rounded bg-gray-900 border border-gray-800">
<span className="text-xs text-gray-600"></span>
<span className="text-xs font-mono text-gray-500">{uptime}</span>
</div>
{/* Current Mode */}
<div className="flex items-center gap-1.5 px-2 py-1 rounded bg-gray-900 border border-gray-800">
<span className="text-xs text-gray-600">Mode:</span>
<span className={`text-xs font-bold capitalize ${getModeColor()}`}>
{currentMode}
</span>
</div>
</div>
);
}
export { StatusHeader };

View File

@ -10,11 +10,13 @@
* - Execute waypoint sequence with automatic progression * - Execute waypoint sequence with automatic progression
* - Clear all waypoints button * - Clear all waypoints button
* - Visual feedback for active waypoint (executing) * - Visual feedback for active waypoint (executing)
* - Imports map display from MapViewer for coordinate system
*/ */
import { useEffect, useRef, useState } from 'react'; import { useEffect, useRef, useState } from 'react';
function WaypointEditor({ subscribe, publish, callService }) { function WaypointEditor({ subscribe, publish, callService }) {
// Waypoint storage
const [waypoints, setWaypoints] = useState([]); const [waypoints, setWaypoints] = useState([]);
const [selectedWaypoint, setSelectedWaypoint] = useState(null); const [selectedWaypoint, setSelectedWaypoint] = useState(null);
const [isDragging, setIsDragging] = useState(false); const [isDragging, setIsDragging] = useState(false);
@ -22,17 +24,20 @@ function WaypointEditor({ subscribe, publish, callService }) {
const [activeWaypoint, setActiveWaypoint] = useState(null); const [activeWaypoint, setActiveWaypoint] = useState(null);
const [executing, setExecuting] = useState(false); const [executing, setExecuting] = useState(false);
// Map context
const [mapData, setMapData] = useState(null); const [mapData, setMapData] = useState(null);
const [robotPose, setRobotPose] = useState({ x: 0, y: 0, theta: 0 }); const [robotPose, setRobotPose] = useState({ x: 0, y: 0, theta: 0 });
// Canvas reference
const canvasRef = useRef(null); const canvasRef = useRef(null);
const containerRef = useRef(null); const containerRef = useRef(null);
// Refs for ROS integration
const mapDataRef = useRef(null); const mapDataRef = useRef(null);
const robotPoseRef = useRef({ x: 0, y: 0, theta: 0 }); const robotPoseRef = useRef({ x: 0, y: 0, theta: 0 });
const waypointsRef = useRef([]); const waypointsRef = useRef([]);
// Subscribe to map data // Subscribe to map data (for coordinate reference)
useEffect(() => { useEffect(() => {
const unsubMap = subscribe( const unsubMap = subscribe(
'/map', '/map',
@ -55,7 +60,7 @@ function WaypointEditor({ subscribe, publish, callService }) {
return unsubMap; return unsubMap;
}, [subscribe]); }, [subscribe]);
// Subscribe to robot odometry // Subscribe to robot odometry (for current position reference)
useEffect(() => { useEffect(() => {
const unsubOdom = subscribe( const unsubOdom = subscribe(
'/odom', '/odom',
@ -80,23 +85,29 @@ function WaypointEditor({ subscribe, publish, callService }) {
return unsubOdom; return unsubOdom;
}, [subscribe]); }, [subscribe]);
// Canvas event handlers
const handleCanvasClick = (e) => { const handleCanvasClick = (e) => {
if (!mapDataRef.current || !containerRef.current) return; if (!mapDataRef.current || !canvasRef.current) return;
const rect = containerRef.current.getBoundingClientRect(); const canvas = canvasRef.current;
const rect = canvas.getBoundingClientRect();
const clickX = e.clientX - rect.left; const clickX = e.clientX - rect.left;
const clickY = e.clientY - rect.top; const clickY = e.clientY - rect.top;
// Convert canvas coordinates to world coordinates
// This assumes the map is centered on the robot
const map = mapDataRef.current; const map = mapDataRef.current;
const robot = robotPoseRef.current; const robot = robotPoseRef.current;
const zoom = 1; const zoom = 1; // Would need to track zoom if map has zoom controls
const centerX = containerRef.current.clientWidth / 2; // Inverse of map rendering calculation
const centerY = containerRef.current.clientHeight / 2; const centerX = canvas.width / 2;
const centerY = canvas.height / 2;
const worldX = robot.x + (clickX - centerX) / zoom; const worldX = robot.x + (clickX - centerX) / zoom;
const worldY = robot.y - (clickY - centerY) / zoom; const worldY = robot.y - (clickY - centerY) / zoom;
// Create new waypoint
const newWaypoint = { const newWaypoint = {
id: Date.now(), id: Date.now(),
x: parseFloat(worldX.toFixed(2)), x: parseFloat(worldX.toFixed(2)),
@ -108,6 +119,12 @@ function WaypointEditor({ subscribe, publish, callService }) {
waypointsRef.current = [...waypointsRef.current, newWaypoint]; waypointsRef.current = [...waypointsRef.current, newWaypoint];
}; };
const handleCanvasContextMenu = (e) => {
e.preventDefault();
// Right-click handled by waypoint list
};
// Waypoint list handlers
const handleDeleteWaypoint = (id) => { const handleDeleteWaypoint = (id) => {
setWaypoints((prev) => prev.filter((wp) => wp.id !== id)); setWaypoints((prev) => prev.filter((wp) => wp.id !== id));
waypointsRef.current = waypointsRef.current.filter((wp) => wp.id !== id); waypointsRef.current = waypointsRef.current.filter((wp) => wp.id !== id);
@ -141,12 +158,18 @@ function WaypointEditor({ subscribe, publish, callService }) {
setDragIndex(null); setDragIndex(null);
}; };
// Execute waypoints
const sendNavGoal = async (waypoint) => { const sendNavGoal = async (waypoint) => {
if (!callService) return; if (!callService) return;
try { try {
// Create quaternion from heading (default to 0 if no heading)
const heading = waypoint.theta || 0; const heading = waypoint.theta || 0;
const halfHeading = heading / 2; const halfHeading = heading / 2;
const qx = 0;
const qy = 0;
const qz = Math.sin(halfHeading);
const qw = Math.cos(halfHeading);
const goal = { const goal = {
pose: { pose: {
@ -156,14 +179,15 @@ function WaypointEditor({ subscribe, publish, callService }) {
z: 0, z: 0,
}, },
orientation: { orientation: {
x: 0, x: qx,
y: 0, y: qy,
z: Math.sin(halfHeading), z: qz,
w: Math.cos(halfHeading), w: qw,
}, },
}, },
}; };
// Send to Nav2 navigate_to_pose action
await callService( await callService(
'/navigate_to_pose', '/navigate_to_pose',
'nav2_msgs/NavigateToPose', 'nav2_msgs/NavigateToPose',
@ -184,7 +208,11 @@ function WaypointEditor({ subscribe, publish, callService }) {
setExecuting(true); setExecuting(true);
for (const waypoint of waypoints) { for (const waypoint of waypoints) {
const success = await sendNavGoal(waypoint); const success = await sendNavGoal(waypoint);
if (!success) break; if (!success) {
console.error('Failed to send goal for waypoint:', waypoint);
break;
}
// Wait a bit before sending next goal
await new Promise((resolve) => setTimeout(resolve, 500)); await new Promise((resolve) => setTimeout(resolve, 500));
} }
setExecuting(false); setExecuting(false);
@ -209,16 +237,21 @@ function WaypointEditor({ subscribe, publish, callService }) {
return ( return (
<div className="flex h-full gap-3"> <div className="flex h-full gap-3">
{/* Map area */} {/* Map area with click handlers */}
<div className="flex-1 flex flex-col space-y-3"> <div className="flex-1 flex flex-col space-y-3">
<div className="flex-1 bg-gray-900 rounded-lg border border-cyan-950 overflow-hidden relative cursor-crosshair"> <div className="flex-1 bg-gray-900 rounded-lg border border-cyan-950 overflow-hidden relative cursor-crosshair">
<div <div
ref={containerRef} ref={containerRef}
className="w-full h-full" className="w-full h-full"
onClick={handleCanvasClick} onClick={handleCanvasClick}
onContextMenu={(e) => e.preventDefault()} onContextMenu={handleCanvasContextMenu}
> >
<svg className="absolute inset-0 w-full h-full pointer-events-none" id="waypoint-overlay"> {/* Virtual map display - waypoints overlaid */}
<svg
className="absolute inset-0 w-full h-full pointer-events-none"
id="waypoint-overlay"
>
{/* Waypoint markers */}
{waypoints.map((wp, idx) => { {waypoints.map((wp, idx) => {
if (!mapDataRef.current) return null; if (!mapDataRef.current) return null;
@ -235,6 +268,7 @@ function WaypointEditor({ subscribe, publish, callService }) {
return ( return (
<g key={wp.id}> <g key={wp.id}>
{/* Waypoint circle */}
<circle <circle
cx={canvasX} cx={canvasX}
cy={canvasY} cy={canvasY}
@ -242,6 +276,7 @@ function WaypointEditor({ subscribe, publish, callService }) {
fill={isActive ? '#ef4444' : isSelected ? '#fbbf24' : '#06b6d4'} fill={isActive ? '#ef4444' : isSelected ? '#fbbf24' : '#06b6d4'}
opacity="0.8" opacity="0.8"
/> />
{/* Waypoint number */}
<text <text
x={canvasX} x={canvasX}
y={canvasY} y={canvasY}
@ -254,12 +289,19 @@ function WaypointEditor({ subscribe, publish, callService }) {
> >
{idx + 1} {idx + 1}
</text> </text>
{/* Line to next waypoint */}
{idx < waypoints.length - 1 && ( {idx < waypoints.length - 1 && (
<line <line
x1={canvasX} x1={canvasX}
y1={canvasY} y1={canvasY}
x2={centerX + (waypoints[idx + 1].x - robot.x) * zoom} x2={
y2={centerY - (waypoints[idx + 1].y - robot.y) * zoom} centerX +
(waypoints[idx + 1].x - robot.x) * zoom
}
y2={
centerY -
(waypoints[idx + 1].y - robot.y) * zoom
}
stroke="#10b981" stroke="#10b981"
strokeWidth="2" strokeWidth="2"
opacity="0.6" opacity="0.6"
@ -268,6 +310,8 @@ function WaypointEditor({ subscribe, publish, callService }) {
</g> </g>
); );
})} })}
{/* Robot position marker */}
<circle <circle
cx={containerRef.current?.clientWidth / 2 || 400} cx={containerRef.current?.clientWidth / 2 || 400}
cy={containerRef.current?.clientHeight / 2 || 300} cy={containerRef.current?.clientHeight / 2 || 300}
@ -313,7 +357,9 @@ function WaypointEditor({ subscribe, publish, callService }) {
{/* Waypoint list */} {/* Waypoint list */}
<div className="flex-1 overflow-y-auto space-y-1"> <div className="flex-1 overflow-y-auto space-y-1">
{waypoints.length === 0 ? ( {waypoints.length === 0 ? (
<div className="text-center text-gray-700 text-xs py-4">Click map to add waypoints</div> <div className="text-center text-gray-700 text-xs py-4">
Click map to add waypoints
</div>
) : ( ) : (
waypoints.map((wp, idx) => ( waypoints.map((wp, idx) => (
<div <div