Compare commits
41 Commits
0207c29d8c
...
c0d946f858
| Author | SHA1 | Date | |
|---|---|---|---|
| c0d946f858 | |||
| b04fd916ff | |||
| a8a9771ec7 | |||
| 042c0529a1 | |||
| 46fc2db8e6 | |||
| 6592b58f65 | |||
| 45d456049a | |||
| 631282b95f | |||
| 0ecf341c57 | |||
| 94d12159b4 | |||
| eac203ecf4 | |||
| c620dc51a7 | |||
| bcf848109b | |||
| 672120bb50 | |||
| f7f89403d5 | |||
| ae76697a1c | |||
| 677e6eb75e | |||
| 0af4441120 | |||
| ddb93bec20 | |||
| 358c1ab6f9 | |||
| 7966eb5187 | |||
| 2a9b03dd76 | |||
| 93028dc847 | |||
| 3bee8f3cb4 | |||
| 813d6f2529 | |||
| 9b538395c0 | |||
| a0f3677732 | |||
| 3fce9bf577 | |||
| 1729e43964 | |||
| 347449ed95 | |||
| 5156100197 | |||
| a7d9531537 | |||
| bd9cb6da35 | |||
| c50899f000 | |||
| 79505579b1 | |||
| a310c8afc9 | |||
| 2f76d1d0d5 | |||
| eb61207532 | |||
| 1c8430e68a | |||
| ba39e9ba26 | |||
| 23e05a634f |
@ -13,6 +13,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
|
||||||
/* Initialise ADC3 for single-channel Vbat reading on PC1. */
|
/* Initialise ADC3 for single-channel Vbat reading on PC1. */
|
||||||
void battery_init(void);
|
void battery_init(void);
|
||||||
|
|||||||
@ -14,4 +14,7 @@ extern I2C_HandleTypeDef hi2c1;
|
|||||||
|
|
||||||
int i2c1_init(void);
|
int i2c1_init(void);
|
||||||
|
|
||||||
|
int i2c1_write(uint8_t addr, const uint8_t *data, int len);
|
||||||
|
int i2c1_read(uint8_t addr, uint8_t *data, int len);
|
||||||
|
|
||||||
#endif /* I2C1_H */
|
#endif /* I2C1_H */
|
||||||
|
|||||||
@ -0,0 +1,20 @@
|
|||||||
|
# PulseAudio Configuration for MageDok HDMI Audio
|
||||||
|
# Routes HDMI audio from DisplayPort adapter to internal speaker output
|
||||||
|
|
||||||
|
# Detect and load HDMI output module
|
||||||
|
load-module module-alsa-sink device=hw:0,3 sink_name=hdmi_stereo sink_properties="device.description='HDMI Audio'"
|
||||||
|
|
||||||
|
# Detect and configure internal speaker (fallback)
|
||||||
|
load-module module-alsa-sink device=hw:0,0 sink_name=speaker_mono sink_properties="device.description='Speaker'"
|
||||||
|
|
||||||
|
# Set HDMI as default output sink
|
||||||
|
set-default-sink hdmi_stereo
|
||||||
|
|
||||||
|
# Enable volume control
|
||||||
|
load-module module-volume-restore
|
||||||
|
|
||||||
|
# Auto-switch to HDMI when connected
|
||||||
|
load-module module-switch-on-connect
|
||||||
|
|
||||||
|
# Log sink configuration
|
||||||
|
.load-if-exists /etc/pulse/magedok-routing.conf
|
||||||
33
jetson/ros2_ws/src/saltybot_bringup/config/xorg-magedok.conf
Normal file
33
jetson/ros2_ws/src/saltybot_bringup/config/xorg-magedok.conf
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
# X11 Configuration for MageDok 7" Display
|
||||||
|
# Resolution: 1024×600 @ 60Hz
|
||||||
|
# Output: HDMI via DisplayPort adapter
|
||||||
|
|
||||||
|
Section "Monitor"
|
||||||
|
Identifier "MageDok"
|
||||||
|
Option "PreferredMode" "1024x600_60.00"
|
||||||
|
Option "Position" "0 0"
|
||||||
|
Option "Primary" "true"
|
||||||
|
EndSection
|
||||||
|
|
||||||
|
Section "Screen"
|
||||||
|
Identifier "Screen0"
|
||||||
|
Monitor "MageDok"
|
||||||
|
DefaultDepth 24
|
||||||
|
SubSection "Display"
|
||||||
|
Depth 24
|
||||||
|
Modes "1024x600" "1024x768" "800x600" "640x480"
|
||||||
|
EndSubSection
|
||||||
|
EndSection
|
||||||
|
|
||||||
|
Section "Device"
|
||||||
|
Identifier "NVIDIA Tegra"
|
||||||
|
Driver "nvidia"
|
||||||
|
BusID "PCI:0:0:0"
|
||||||
|
Option "RegistryDwords" "EnableBrightnessControl=1"
|
||||||
|
Option "ConnectedMonitor" "HDMI-0"
|
||||||
|
EndSection
|
||||||
|
|
||||||
|
Section "ServerLayout"
|
||||||
|
Identifier "Default"
|
||||||
|
Screen "Screen0"
|
||||||
|
EndSection
|
||||||
@ -0,0 +1,218 @@
|
|||||||
|
# MageDok 7" Touchscreen Display Setup
|
||||||
|
|
||||||
|
Issue #369: Display setup for MageDok 7" IPS touchscreen on Jetson Orin Nano.
|
||||||
|
|
||||||
|
## Hardware Setup
|
||||||
|
|
||||||
|
### Connections
|
||||||
|
- **Video**: DisplayPort → HDMI cable from Orin DP 1.2 connector to MageDok HDMI input
|
||||||
|
- **Touch**: USB 3.0 cable from Orin USB-A to MageDok USB-C connector
|
||||||
|
- **Audio**: HDMI carries embedded audio from DisplayPort (no separate audio cable needed)
|
||||||
|
|
||||||
|
### Display Specs
|
||||||
|
- **Resolution**: 1024×600 @ 60Hz
|
||||||
|
- **Panel Type**: 7" IPS (In-Plane Switching) - wide viewing angles
|
||||||
|
- **Sunlight Readable**: Yes, with high brightness
|
||||||
|
- **Built-in Speakers**: Yes (via HDMI audio)
|
||||||
|
|
||||||
|
## Installation Steps
|
||||||
|
|
||||||
|
### 1. Kernel and Display Driver Configuration
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Update display mode database (if needed)
|
||||||
|
sudo apt-get update && sudo apt-get install -y xrandr x11-utils edid-decode
|
||||||
|
|
||||||
|
# Verify X11 is running
|
||||||
|
echo $DISPLAY # Should show :0 or :1
|
||||||
|
|
||||||
|
# Check connected displays
|
||||||
|
xrandr --query
|
||||||
|
```
|
||||||
|
|
||||||
|
**Expected output**: HDMI-1 connected at 1024x600 resolution
|
||||||
|
|
||||||
|
### 2. Install udev Rules for Touch Input
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Copy udev rules
|
||||||
|
sudo cp jetson/ros2_ws/src/saltybot_bringup/udev/90-magedok-touch.rules \
|
||||||
|
/etc/udev/rules.d/
|
||||||
|
|
||||||
|
# Reload udev
|
||||||
|
sudo udevadm control --reload-rules
|
||||||
|
sudo udevadm trigger
|
||||||
|
|
||||||
|
# Verify touch device
|
||||||
|
ls -l /dev/magedok-touch
|
||||||
|
# Or check input devices
|
||||||
|
cat /proc/bus/input/devices | grep -i "eGTouch\|EETI"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. X11 Display Configuration
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Backup original X11 config
|
||||||
|
sudo cp /etc/X11/xorg.conf /etc/X11/xorg.conf.backup
|
||||||
|
|
||||||
|
# Apply MageDok X11 config
|
||||||
|
sudo cp jetson/ros2_ws/src/saltybot_bringup/config/xorg-magedok.conf \
|
||||||
|
/etc/X11/xorg.conf
|
||||||
|
|
||||||
|
# Restart X11 (or reboot)
|
||||||
|
sudo systemctl restart gdm3 # or startx if using console
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. PulseAudio Audio Routing
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Check current audio sinks
|
||||||
|
pactl list sinks | grep Name
|
||||||
|
|
||||||
|
# Find HDMI sink (typically contains "hdmi" in name)
|
||||||
|
pactl set-default-sink <hdmi-sink-name>
|
||||||
|
|
||||||
|
# Verify routing
|
||||||
|
pactl get-default-sink
|
||||||
|
|
||||||
|
# Optional: Set volume
|
||||||
|
pactl set-sink-volume <sink-name> 70%
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. ROS2 Launch Configuration
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build the saltybot_bringup package
|
||||||
|
cd jetson/ros2_ws
|
||||||
|
colcon build --packages-select saltybot_bringup
|
||||||
|
|
||||||
|
# Source workspace
|
||||||
|
source install/setup.bash
|
||||||
|
|
||||||
|
# Launch display setup
|
||||||
|
ros2 launch saltybot_bringup magedok_display.launch.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. Enable Auto-Start on Boot
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Copy systemd service
|
||||||
|
sudo cp jetson/ros2_ws/src/saltybot_bringup/systemd/magedok-display.service \
|
||||||
|
/etc/systemd/system/
|
||||||
|
|
||||||
|
# Enable service
|
||||||
|
sudo systemctl daemon-reload
|
||||||
|
sudo systemctl enable magedok-display.service
|
||||||
|
|
||||||
|
# Start service
|
||||||
|
sudo systemctl start magedok-display.service
|
||||||
|
|
||||||
|
# Check status
|
||||||
|
sudo systemctl status magedok-display.service
|
||||||
|
sudo journalctl -u magedok-display -f # Follow logs
|
||||||
|
```
|
||||||
|
|
||||||
|
## Verification
|
||||||
|
|
||||||
|
### Display Resolution
|
||||||
|
```bash
|
||||||
|
# Check actual resolution
|
||||||
|
xdotool getactivewindow getwindowgeometry
|
||||||
|
|
||||||
|
# Verify with xrandr
|
||||||
|
xrandr | grep "1024x600"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Expected**: `1024x600_60.00 +0+0` or similar
|
||||||
|
|
||||||
|
### Touch Input
|
||||||
|
```bash
|
||||||
|
# List input devices
|
||||||
|
xinput list
|
||||||
|
|
||||||
|
# Should show "MageDok Touch" or "eGTouch Controller"
|
||||||
|
# Test touch by clicking on display - cursor should move
|
||||||
|
```
|
||||||
|
|
||||||
|
### Audio
|
||||||
|
```bash
|
||||||
|
# Test HDMI audio
|
||||||
|
speaker-test -c 2 -l 1 -s 1 -t sine
|
||||||
|
|
||||||
|
# Verify volume level
|
||||||
|
pactl list sinks | grep -A 10 RUNNING
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Display Not Detected
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Check EDID data
|
||||||
|
edid-decode /sys/class/drm/card0-HDMI-A-1/edid
|
||||||
|
|
||||||
|
# Force resolution
|
||||||
|
xrandr --output HDMI-1 --mode 1024x600 --rate 60
|
||||||
|
|
||||||
|
# Check kernel logs
|
||||||
|
dmesg | grep -i "drm\|HDMI\|dp"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Touch Not Working
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Check USB connection
|
||||||
|
lsusb | grep -i "eGTouch\|EETI"
|
||||||
|
|
||||||
|
# Verify udev rules applied
|
||||||
|
cat /etc/udev/rules.d/90-magedok-touch.rules
|
||||||
|
|
||||||
|
# Test touch device directly
|
||||||
|
evtest /dev/magedok-touch # Or /dev/input/eventX
|
||||||
|
```
|
||||||
|
|
||||||
|
### Audio Not Routing
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Check PulseAudio daemon
|
||||||
|
pulseaudio --version
|
||||||
|
systemctl status pulseaudio
|
||||||
|
|
||||||
|
# Restart PulseAudio
|
||||||
|
systemctl --user restart pulseaudio
|
||||||
|
|
||||||
|
# Monitor audio stream
|
||||||
|
pactl list sink-inputs
|
||||||
|
```
|
||||||
|
|
||||||
|
### Display Disconnection (Headless Fallback)
|
||||||
|
|
||||||
|
The system should continue operating normally with display disconnected:
|
||||||
|
- ROS2 services remain accessible via network
|
||||||
|
- Robot commands via `/cmd_vel` continue working
|
||||||
|
- Data logging and telemetry unaffected
|
||||||
|
- Dashboard accessible via SSH/webui from other machine
|
||||||
|
|
||||||
|
## Testing Checklist
|
||||||
|
|
||||||
|
- [ ] Display shows 1024×600 resolution
|
||||||
|
- [ ] Touch input registers in xinput (test by moving cursor)
|
||||||
|
- [ ] Audio plays through display speakers
|
||||||
|
- [ ] System boots without login prompt (if using auto-start)
|
||||||
|
- [ ] All ROS2 nodes launch correctly with display
|
||||||
|
- [ ] System operates normally when display is disconnected
|
||||||
|
- [ ] `/magedok/touch_status` topic shows true (ROS2 verify script)
|
||||||
|
- [ ] `/magedok/audio_status` topic shows HDMI sink (ROS2 audio router)
|
||||||
|
|
||||||
|
## Related Issues
|
||||||
|
|
||||||
|
- **#368**: Salty Face UI (depends on this display setup)
|
||||||
|
- **#370**: Animated expression UI
|
||||||
|
- **#371**: Deaf/accessibility mode with touch keyboard
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- MageDok 7" Specs: [HDMI, 1024×600, USB Touch, Built-in Speakers]
|
||||||
|
- Jetson Orin Nano DisplayPort Output: Requires active adapter (no DP Alt Mode on USB-C)
|
||||||
|
- PulseAudio: HDMI audio sink routing via ALSA
|
||||||
|
- X11/Xrandr: Display mode configuration
|
||||||
@ -0,0 +1,59 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
MageDok 7" Display Launch Configuration
|
||||||
|
- Video: DisplayPort → HDMI (1024×600)
|
||||||
|
- Touch: USB HID
|
||||||
|
- Audio: HDMI → internal speakers via PulseAudio
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from launch import LaunchDescription
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
from launch.actions import ExecuteProcess
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launch_description():
|
||||||
|
return LaunchDescription([
|
||||||
|
# Log startup
|
||||||
|
ExecuteProcess(
|
||||||
|
cmd=['echo', '[MageDok] Display setup starting...'],
|
||||||
|
shell=True,
|
||||||
|
),
|
||||||
|
|
||||||
|
# Verify display resolution
|
||||||
|
Node(
|
||||||
|
package='saltybot_bringup',
|
||||||
|
executable='verify_display.py',
|
||||||
|
name='display_verifier',
|
||||||
|
parameters=[
|
||||||
|
{'target_width': 1024},
|
||||||
|
{'target_height': 600},
|
||||||
|
{'target_refresh': 60},
|
||||||
|
],
|
||||||
|
output='screen',
|
||||||
|
),
|
||||||
|
|
||||||
|
# Monitor touch input
|
||||||
|
Node(
|
||||||
|
package='saltybot_bringup',
|
||||||
|
executable='touch_monitor.py',
|
||||||
|
name='touch_monitor',
|
||||||
|
parameters=[
|
||||||
|
{'device_name': 'MageDok Touch'},
|
||||||
|
{'poll_interval': 0.1},
|
||||||
|
],
|
||||||
|
output='screen',
|
||||||
|
),
|
||||||
|
|
||||||
|
# Audio routing (PulseAudio sink redirection)
|
||||||
|
Node(
|
||||||
|
package='saltybot_bringup',
|
||||||
|
executable='audio_router.py',
|
||||||
|
name='audio_router',
|
||||||
|
parameters=[
|
||||||
|
{'hdmi_sink': 'alsa_output.pci-0000_00_1d.0.hdmi-stereo'},
|
||||||
|
{'default_sink': True},
|
||||||
|
],
|
||||||
|
output='screen',
|
||||||
|
),
|
||||||
|
])
|
||||||
@ -0,0 +1,326 @@
|
|||||||
|
"""
|
||||||
|
_audio_scene.py — MFCC + spectral nearest-centroid audio scene classifier (no ROS2 deps).
|
||||||
|
|
||||||
|
Classifies 1-second audio clips into one of four environment labels:
|
||||||
|
'indoor' — enclosed space (room tone ~440 Hz, low flux)
|
||||||
|
'outdoor' — open ambient (wind/ambient noise, ~1500 Hz band)
|
||||||
|
'traffic' — road/engine (low-frequency fundamental ~80 Hz)
|
||||||
|
'park' — natural space (bird-like high-frequency ~3–5 kHz)
|
||||||
|
|
||||||
|
Feature vector (16 dimensions)
|
||||||
|
--------------------------------
|
||||||
|
[0..12] 13 MFCC coefficients (mean across frames, from 26-mel filterbank)
|
||||||
|
[13] Spectral centroid (Hz, mean across frames)
|
||||||
|
[14] Spectral rolloff (Hz at 85 % cumulative energy, mean across frames)
|
||||||
|
[15] Zero-crossing rate (crossings / sample, mean across frames)
|
||||||
|
|
||||||
|
Classifier
|
||||||
|
----------
|
||||||
|
Nearest-centroid: class whose prototype centroid has the smallest L2 distance
|
||||||
|
to the (per-dimension normalised) query feature vector.
|
||||||
|
Centroids are computed once at import from seeded synthetic prototype signals,
|
||||||
|
so they are always consistent with the feature extractor.
|
||||||
|
|
||||||
|
Confidence
|
||||||
|
----------
|
||||||
|
conf = 1 / (1 + dist_to_nearest_centroid) in normalised feature space.
|
||||||
|
|
||||||
|
Public API
|
||||||
|
----------
|
||||||
|
SCENE_LABELS tuple[str, ...] ordered class names
|
||||||
|
AudioSceneResult NamedTuple single-clip result
|
||||||
|
NearestCentroidClassifier class
|
||||||
|
.predict(features) → (label, conf)
|
||||||
|
extract_features(samples, sr) → np.ndarray shape (16,)
|
||||||
|
classify_audio(samples, sr) → AudioSceneResult
|
||||||
|
|
||||||
|
_CLASSIFIER module-level singleton
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import NamedTuple, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# ── Constants ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
SCENE_LABELS: Tuple[str, ...] = ('indoor', 'outdoor', 'traffic', 'park')
|
||||||
|
|
||||||
|
_SR_DEFAULT = 16_000 # Hz
|
||||||
|
_N_MFCC = 13
|
||||||
|
_N_MELS = 26
|
||||||
|
_N_FFT = 512
|
||||||
|
_WIN_LENGTH = 400 # 25 ms @ 16 kHz
|
||||||
|
_HOP_LENGTH = 160 # 10 ms @ 16 kHz
|
||||||
|
_N_FEATURES = _N_MFCC + 3 # 16 total
|
||||||
|
|
||||||
|
|
||||||
|
# ── Result type ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class AudioSceneResult(NamedTuple):
|
||||||
|
label: str # one of SCENE_LABELS
|
||||||
|
confidence: float # 0.0–1.0
|
||||||
|
features: np.ndarray # shape (16,) raw feature vector
|
||||||
|
|
||||||
|
|
||||||
|
# ── Low-level DSP helpers ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _frame(x: np.ndarray, frame_length: int, hop_length: int) -> np.ndarray:
|
||||||
|
"""Slice a 1-D signal into overlapping frames. Returns (n_frames, frame_length)."""
|
||||||
|
n_frames = max(1, 1 + (len(x) - frame_length) // hop_length)
|
||||||
|
idx = (
|
||||||
|
np.arange(frame_length)[np.newaxis, :]
|
||||||
|
+ np.arange(n_frames)[:, np.newaxis] * hop_length
|
||||||
|
)
|
||||||
|
# Pad x if the last index exceeds the signal length
|
||||||
|
needed = int(idx.max()) + 1
|
||||||
|
if needed > len(x):
|
||||||
|
x = np.pad(x, (0, needed - len(x)))
|
||||||
|
return x[idx]
|
||||||
|
|
||||||
|
|
||||||
|
def _mel_filterbank(
|
||||||
|
sr: int,
|
||||||
|
n_fft: int,
|
||||||
|
n_mels: int,
|
||||||
|
f_min: float = 0.0,
|
||||||
|
f_max: float | None = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Build (n_mels, n_fft//2+1) triangular mel filterbank matrix."""
|
||||||
|
if f_max is None:
|
||||||
|
f_max = sr / 2.0
|
||||||
|
|
||||||
|
def _hz2mel(f: float) -> float:
|
||||||
|
return 2595.0 * math.log10(1.0 + f / 700.0)
|
||||||
|
|
||||||
|
def _mel2hz(m: float) -> float:
|
||||||
|
return 700.0 * (10.0 ** (m / 2595.0) - 1.0)
|
||||||
|
|
||||||
|
mel_pts = np.linspace(_hz2mel(f_min), _hz2mel(f_max), n_mels + 2)
|
||||||
|
hz_pts = np.array([_mel2hz(m) for m in mel_pts])
|
||||||
|
|
||||||
|
freqs = np.fft.rfftfreq(n_fft, d=1.0 / sr) # (F,)
|
||||||
|
f = freqs[:, np.newaxis] # (F, 1)
|
||||||
|
f_left = hz_pts[:-2] # (n_mels,)
|
||||||
|
f_ctr = hz_pts[1:-1]
|
||||||
|
f_right = hz_pts[2:]
|
||||||
|
|
||||||
|
left_slope = (f - f_left) / np.maximum(f_ctr - f_left, 1e-10)
|
||||||
|
right_slope = (f_right - f) / np.maximum(f_right - f_ctr, 1e-10)
|
||||||
|
|
||||||
|
fbank = np.maximum(0.0, np.minimum(left_slope, right_slope)).T # (n_mels, F)
|
||||||
|
return fbank
|
||||||
|
|
||||||
|
|
||||||
|
def _dct2(x: np.ndarray) -> np.ndarray:
|
||||||
|
"""Type-II DCT of each row. x: (n_frames, N) → (n_frames, N)."""
|
||||||
|
N = x.shape[1]
|
||||||
|
n = np.arange(N, dtype=np.float64)
|
||||||
|
k = np.arange(N, dtype=np.float64)[:, np.newaxis]
|
||||||
|
D = np.cos(math.pi * k * (2.0 * n + 1.0) / (2.0 * N)) # (N, N)
|
||||||
|
return x @ D.T # (n_frames, N)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Feature extraction ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Pre-build filterbank once (module-level, shared across all calls at default SR)
|
||||||
|
_FBANK: np.ndarray = _mel_filterbank(_SR_DEFAULT, _N_FFT, _N_MELS)
|
||||||
|
_WINDOW: np.ndarray = np.hamming(_WIN_LENGTH)
|
||||||
|
_FREQS: np.ndarray = np.fft.rfftfreq(_N_FFT, d=1.0 / _SR_DEFAULT)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_features(
|
||||||
|
samples: np.ndarray,
|
||||||
|
sr: int = _SR_DEFAULT,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Extract a 16-d feature vector from a mono float audio clip.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
samples : 1-D array, float32 or float64, range roughly [-1, 1]
|
||||||
|
sr : sample rate (Hz); must match the audio
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
features : shape (16,) — [MFCC_0..12, centroid_hz, rolloff_hz, zcr]
|
||||||
|
"""
|
||||||
|
x = np.asarray(samples, dtype=np.float64)
|
||||||
|
if len(x) == 0:
|
||||||
|
return np.zeros(_N_FEATURES, dtype=np.float64)
|
||||||
|
|
||||||
|
# Rebuild filterbank/freqs if non-default sample rate
|
||||||
|
if sr != _SR_DEFAULT:
|
||||||
|
fbank = _mel_filterbank(sr, _N_FFT, _N_MELS)
|
||||||
|
freqs = np.fft.rfftfreq(_N_FFT, d=1.0 / sr)
|
||||||
|
else:
|
||||||
|
fbank = _FBANK
|
||||||
|
freqs = _FREQS
|
||||||
|
|
||||||
|
# Frame + window
|
||||||
|
frames = _frame(x, _WIN_LENGTH, _HOP_LENGTH) # (T, W)
|
||||||
|
frames = frames * _WINDOW # Hamming
|
||||||
|
|
||||||
|
# Magnitude spectrum
|
||||||
|
mag = np.abs(np.fft.rfft(frames, n=_N_FFT, axis=1)) # (T, F)
|
||||||
|
|
||||||
|
# ── MFCC ────────────────────────────────────────────────────────────────
|
||||||
|
mel_e = mag @ fbank.T # (T, n_mels)
|
||||||
|
log_me = np.log(np.maximum(mel_e, 1e-10))
|
||||||
|
mfccs = _dct2(log_me)[:, :_N_MFCC] # (T, 13)
|
||||||
|
mfcc_mean = mfccs.mean(axis=0) # (13,)
|
||||||
|
|
||||||
|
# ── Spectral centroid ───────────────────────────────────────────────────
|
||||||
|
power = mag ** 2 # (T, F)
|
||||||
|
p_sum = power.sum(axis=1, keepdims=True)
|
||||||
|
p_sum = np.maximum(p_sum, 1e-20)
|
||||||
|
sc_hz = (power @ freqs) / p_sum.squeeze(1) # (T,) weighted mean freq
|
||||||
|
centroid_mean = float(sc_hz.mean())
|
||||||
|
|
||||||
|
# ── Spectral rolloff (85 %) ─────────────────────────────────────────────
|
||||||
|
cumsum = np.cumsum(power, axis=1) # (T, F)
|
||||||
|
thresh = 0.85 * p_sum # (T, 1)
|
||||||
|
rolloff_idx = np.argmax(cumsum >= thresh, axis=1) # (T,) first bin ≥ 85 %
|
||||||
|
rolloff_hz = freqs[rolloff_idx] # (T,)
|
||||||
|
rolloff_mean = float(rolloff_hz.mean())
|
||||||
|
|
||||||
|
# ── Zero-crossing rate ──────────────────────────────────────────────────
|
||||||
|
signs = np.sign(x)
|
||||||
|
zcr = np.mean(np.abs(np.diff(signs)) / 2.0) # crossings / sample
|
||||||
|
zcr_val = float(zcr)
|
||||||
|
|
||||||
|
features = np.concatenate([
|
||||||
|
mfcc_mean,
|
||||||
|
[centroid_mean, rolloff_mean, zcr_val],
|
||||||
|
]).astype(np.float64)
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
# ── Nearest-centroid classifier ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
class NearestCentroidClassifier:
|
||||||
|
"""
|
||||||
|
Classify a feature vector by nearest L2 distance in normalised feature space.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
centroids : (n_classes, n_features) array — one row per class
|
||||||
|
labels : sequence of class name strings, same order as centroids rows
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
centroids: np.ndarray,
|
||||||
|
labels: Tuple[str, ...],
|
||||||
|
) -> None:
|
||||||
|
self.centroids = np.asarray(centroids, dtype=np.float64)
|
||||||
|
self.labels = tuple(labels)
|
||||||
|
# Per-dimension normalisation derived from centroid spread
|
||||||
|
self._min = self.centroids.min(axis=0)
|
||||||
|
self._range = np.maximum(
|
||||||
|
self.centroids.max(axis=0) - self._min, 1e-6
|
||||||
|
)
|
||||||
|
|
||||||
|
def predict(self, features: np.ndarray) -> Tuple[str, float]:
|
||||||
|
"""
|
||||||
|
Return (label, confidence) for the given feature vector.
|
||||||
|
|
||||||
|
confidence = 1 / (1 + min_distance) in normalised feature space.
|
||||||
|
"""
|
||||||
|
norm_feat = (features - self._min) / self._range
|
||||||
|
norm_cent = (self.centroids - self._min) / self._range
|
||||||
|
dists = np.linalg.norm(norm_cent - norm_feat, axis=1) # (n_classes,)
|
||||||
|
best = int(np.argmin(dists))
|
||||||
|
conf = float(1.0 / (1.0 + dists[best]))
|
||||||
|
return self.labels[best], conf
|
||||||
|
|
||||||
|
|
||||||
|
# ── Prototype centroid construction ──────────────────────────────────────────
|
||||||
|
|
||||||
|
def _make_prototype(
|
||||||
|
label: str,
|
||||||
|
sr: int = _SR_DEFAULT,
|
||||||
|
rng: np.random.RandomState | None = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Generate a 1-second synthetic prototype signal for each scene class."""
|
||||||
|
if rng is None:
|
||||||
|
rng = np.random.RandomState(42)
|
||||||
|
n = sr
|
||||||
|
t = np.linspace(0.0, 1.0, n, endpoint=False)
|
||||||
|
|
||||||
|
if label == 'indoor':
|
||||||
|
# Room tone: dominant 440 Hz + light broadband noise
|
||||||
|
sig = 0.5 * np.sin(2 * math.pi * 440.0 * t) + 0.05 * rng.randn(n)
|
||||||
|
|
||||||
|
elif label == 'outdoor':
|
||||||
|
# Wind/ambient: bandpass noise centred ~1500 Hz (800–2500 Hz)
|
||||||
|
noise = rng.randn(n)
|
||||||
|
# Band-pass via FFT zeroing
|
||||||
|
spec = np.fft.rfft(noise)
|
||||||
|
freqs = np.fft.rfftfreq(n, d=1.0 / sr)
|
||||||
|
spec[(freqs < 800) | (freqs > 2500)] = 0.0
|
||||||
|
sig = np.fft.irfft(spec, n=n).real * 0.5
|
||||||
|
|
||||||
|
elif label == 'traffic':
|
||||||
|
# Engine / road noise: 80 Hz fundamental + 2nd harmonic + light noise
|
||||||
|
sig = (
|
||||||
|
0.5 * np.sin(2 * math.pi * 80.0 * t)
|
||||||
|
+ 0.25 * np.sin(2 * math.pi * 160.0 * t)
|
||||||
|
+ 0.1 * rng.randn(n)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif label == 'park':
|
||||||
|
# Birdsong: high-frequency components at 3200 Hz + 4800 Hz
|
||||||
|
sig = (
|
||||||
|
0.4 * np.sin(2 * math.pi * 3200.0 * t)
|
||||||
|
+ 0.3 * np.sin(2 * math.pi * 4800.0 * t)
|
||||||
|
+ 0.05 * rng.randn(n)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unknown scene label: {label!r}')
|
||||||
|
|
||||||
|
return sig.astype(np.float64)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_centroids(sr: int = _SR_DEFAULT) -> np.ndarray:
|
||||||
|
"""Compute (4, 16) centroid matrix from prototype signals."""
|
||||||
|
rng = np.random.RandomState(42)
|
||||||
|
rows = []
|
||||||
|
for label in SCENE_LABELS:
|
||||||
|
proto = _make_prototype(label, sr=sr, rng=rng)
|
||||||
|
rows.append(extract_features(proto, sr=sr))
|
||||||
|
return np.stack(rows, axis=0) # (4, 16)
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton — built once at import
|
||||||
|
_CENTROIDS: np.ndarray = _build_centroids()
|
||||||
|
_CLASSIFIER: NearestCentroidClassifier = NearestCentroidClassifier(
|
||||||
|
_CENTROIDS, SCENE_LABELS
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Main entry point ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def classify_audio(
|
||||||
|
samples: np.ndarray,
|
||||||
|
sr: int = _SR_DEFAULT,
|
||||||
|
) -> AudioSceneResult:
|
||||||
|
"""
|
||||||
|
Classify a mono audio clip into one of four scene categories.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
samples : 1-D float array (any length ≥ 1 sample)
|
||||||
|
sr : sample rate in Hz
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSceneResult with .label, .confidence, .features
|
||||||
|
"""
|
||||||
|
feat = extract_features(samples, sr=sr)
|
||||||
|
label, conf = _CLASSIFIER.predict(feat)
|
||||||
|
return AudioSceneResult(label=label, confidence=conf, features=feat)
|
||||||
@ -0,0 +1,329 @@
|
|||||||
|
"""
|
||||||
|
_camera_power_manager.py — Adaptive camera power mode FSM (Issue #375).
|
||||||
|
|
||||||
|
Pure-Python library (no ROS2 / hardware dependencies) for full unit-test
|
||||||
|
coverage without hardware.
|
||||||
|
|
||||||
|
Five Power Modes
|
||||||
|
----------------
|
||||||
|
SLEEP (0) — charging / idle.
|
||||||
|
Sensors: none (or 1 CSI at 1 fps as wake trigger).
|
||||||
|
~150 MB RAM.
|
||||||
|
|
||||||
|
SOCIAL (1) — parked / socialising.
|
||||||
|
Sensors: C920 webcam + face UI.
|
||||||
|
~400 MB RAM.
|
||||||
|
|
||||||
|
AWARE (2) — indoor, slow walking, <5 km/h.
|
||||||
|
Sensors: front CSI + RealSense + LIDAR.
|
||||||
|
~850 MB RAM.
|
||||||
|
|
||||||
|
ACTIVE (3) — sidewalk / bike path, 5–15 km/h.
|
||||||
|
Sensors: front+rear CSI + RealSense + LIDAR + UWB.
|
||||||
|
~1.15 GB RAM.
|
||||||
|
|
||||||
|
FULL (4) — street / high-speed >15 km/h, or crossing.
|
||||||
|
Sensors: all 4 CSI + RealSense + LIDAR + UWB.
|
||||||
|
~1.55 GB RAM.
|
||||||
|
|
||||||
|
Transition Logic
|
||||||
|
----------------
|
||||||
|
Automatic transitions are speed-driven with hysteresis to prevent flapping:
|
||||||
|
|
||||||
|
Upgrade thresholds (instantaneous):
|
||||||
|
SOCIAL → AWARE : speed ≥ 0.3 m/s (~1 km/h, any motion)
|
||||||
|
AWARE → ACTIVE : speed ≥ 1.4 m/s (~5 km/h)
|
||||||
|
ACTIVE → FULL : speed ≥ 4.2 m/s (~15 km/h)
|
||||||
|
|
||||||
|
Downgrade thresholds (held for downgrade_hold_s before applying):
|
||||||
|
FULL → ACTIVE : speed < 3.6 m/s (~13 km/h, 2 km/h hysteresis)
|
||||||
|
ACTIVE → AWARE : speed < 1.1 m/s (~4 km/h)
|
||||||
|
AWARE → SOCIAL : speed < 0.1 m/s and idle ≥ idle_to_social_s
|
||||||
|
|
||||||
|
Scenario overrides (bypass hysteresis, instant):
|
||||||
|
CROSSING → FULL immediately and held until scenario clears
|
||||||
|
EMERGENCY → FULL immediately (also signals speed reduction upstream)
|
||||||
|
INDOOR → cap at AWARE (never ACTIVE or FULL indoors)
|
||||||
|
PARKED → SOCIAL immediately
|
||||||
|
OUTDOOR → normal speed-based logic
|
||||||
|
|
||||||
|
Battery low override:
|
||||||
|
battery_pct < battery_low_pct → cap at AWARE to save power
|
||||||
|
|
||||||
|
Safety invariant:
|
||||||
|
Rear CSI is a safety sensor at speed — it is always on in ACTIVE and FULL.
|
||||||
|
Any mode downgrade during CROSSING is blocked.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import IntEnum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
# ── Mode enum ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class CameraMode(IntEnum):
|
||||||
|
SLEEP = 0
|
||||||
|
SOCIAL = 1
|
||||||
|
AWARE = 2
|
||||||
|
ACTIVE = 3
|
||||||
|
FULL = 4
|
||||||
|
|
||||||
|
@property
|
||||||
|
def label(self) -> str:
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
|
# ── Scenario enum ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class Scenario(str):
|
||||||
|
"""
|
||||||
|
Known scenario strings. The FSM accepts any string; these are the
|
||||||
|
canonical values that trigger special behaviour.
|
||||||
|
"""
|
||||||
|
UNKNOWN = 'unknown'
|
||||||
|
PARKED = 'parked'
|
||||||
|
INDOOR = 'indoor'
|
||||||
|
OUTDOOR = 'outdoor'
|
||||||
|
CROSSING = 'crossing'
|
||||||
|
EMERGENCY = 'emergency'
|
||||||
|
|
||||||
|
|
||||||
|
# ── Sensor configuration per mode ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ActiveSensors:
|
||||||
|
csi_front: bool = False
|
||||||
|
csi_rear: bool = False
|
||||||
|
csi_left: bool = False
|
||||||
|
csi_right: bool = False
|
||||||
|
realsense: bool = False
|
||||||
|
lidar: bool = False
|
||||||
|
uwb: bool = False
|
||||||
|
webcam: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_count(self) -> int:
|
||||||
|
return sum([
|
||||||
|
self.csi_front, self.csi_rear, self.csi_left, self.csi_right,
|
||||||
|
self.realsense, self.lidar, self.uwb, self.webcam,
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
# Canonical sensor set for each mode
|
||||||
|
MODE_SENSORS: dict[CameraMode, ActiveSensors] = {
|
||||||
|
CameraMode.SLEEP: ActiveSensors(),
|
||||||
|
CameraMode.SOCIAL: ActiveSensors(webcam=True),
|
||||||
|
CameraMode.AWARE: ActiveSensors(csi_front=True, realsense=True, lidar=True),
|
||||||
|
CameraMode.ACTIVE: ActiveSensors(csi_front=True, csi_rear=True,
|
||||||
|
realsense=True, lidar=True, uwb=True),
|
||||||
|
CameraMode.FULL: ActiveSensors(csi_front=True, csi_rear=True,
|
||||||
|
csi_left=True, csi_right=True,
|
||||||
|
realsense=True, lidar=True, uwb=True),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Speed thresholds (m/s)
|
||||||
|
_SPD_MOTION = 0.3 # ~1 km/h — any meaningful motion
|
||||||
|
_SPD_ACTIVE_UP = 5.0 / 3.6 # 5 km/h upgrade to ACTIVE
|
||||||
|
_SPD_FULL_UP = 15.0 / 3.6 # 15 km/h upgrade to FULL
|
||||||
|
_SPD_ACTIVE_DOWN = 4.0 / 3.6 # 4 km/h downgrade from ACTIVE (hysteresis gap)
|
||||||
|
_SPD_FULL_DOWN = 13.0 / 3.6 # 13 km/h downgrade from FULL
|
||||||
|
|
||||||
|
# Scenario strings that force FULL
|
||||||
|
_FORCE_FULL = frozenset({Scenario.CROSSING, Scenario.EMERGENCY})
|
||||||
|
# Scenarios that cap the mode
|
||||||
|
_CAP_AWARE = frozenset({Scenario.INDOOR})
|
||||||
|
_CAP_SOCIAL = frozenset({Scenario.PARKED})
|
||||||
|
|
||||||
|
|
||||||
|
# ── FSM result ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModeDecision:
|
||||||
|
mode: CameraMode
|
||||||
|
sensors: ActiveSensors
|
||||||
|
trigger_speed_mps: float
|
||||||
|
trigger_scenario: str
|
||||||
|
scenario_override: bool
|
||||||
|
|
||||||
|
|
||||||
|
# ── FSM ───────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class CameraPowerFSM:
|
||||||
|
"""
|
||||||
|
Finite state machine for camera power mode management.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
downgrade_hold_s : seconds a downgrade condition must persist before
|
||||||
|
the mode drops (hysteresis). Default 5.0 s.
|
||||||
|
idle_to_social_s : seconds of near-zero speed before AWARE → SOCIAL.
|
||||||
|
Default 30.0 s.
|
||||||
|
battery_low_pct : battery level below which mode is capped at AWARE.
|
||||||
|
Default 20.0 %.
|
||||||
|
clock : callable() → float for monotonic time (injectable
|
||||||
|
for tests; defaults to time.monotonic).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
downgrade_hold_s: float = 5.0,
|
||||||
|
idle_to_social_s: float = 30.0,
|
||||||
|
battery_low_pct: float = 20.0,
|
||||||
|
clock=None,
|
||||||
|
) -> None:
|
||||||
|
self._downgrade_hold = downgrade_hold_s
|
||||||
|
self._idle_to_social = idle_to_social_s
|
||||||
|
self._battery_low_pct = battery_low_pct
|
||||||
|
self._clock = clock or time.monotonic
|
||||||
|
|
||||||
|
self._mode: CameraMode = CameraMode.SLEEP
|
||||||
|
self._downgrade_pending_since: Optional[float] = None
|
||||||
|
self._idle_since: Optional[float] = None
|
||||||
|
|
||||||
|
# Last input values (for reporting)
|
||||||
|
self._last_speed: float = 0.0
|
||||||
|
self._last_scenario: str = Scenario.UNKNOWN
|
||||||
|
|
||||||
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mode(self) -> CameraMode:
|
||||||
|
return self._mode
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
speed_mps: float,
|
||||||
|
scenario: str = Scenario.UNKNOWN,
|
||||||
|
battery_pct: float = 100.0,
|
||||||
|
) -> ModeDecision:
|
||||||
|
"""
|
||||||
|
Evaluate inputs and return the current mode decision.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
speed_mps : current speed in m/s (non-negative)
|
||||||
|
scenario : current operating scenario string
|
||||||
|
battery_pct : battery charge level 0–100
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ModeDecision with the resolved mode and active sensor set.
|
||||||
|
"""
|
||||||
|
now = self._clock()
|
||||||
|
speed_mps = max(0.0, float(speed_mps))
|
||||||
|
self._last_speed = speed_mps
|
||||||
|
self._last_scenario = scenario
|
||||||
|
|
||||||
|
scenario_override = False
|
||||||
|
|
||||||
|
# ── 1. Hard overrides (no hysteresis) ─────────────────────────────
|
||||||
|
if scenario in _FORCE_FULL:
|
||||||
|
self._mode = CameraMode.FULL
|
||||||
|
self._downgrade_pending_since = None
|
||||||
|
self._idle_since = None
|
||||||
|
scenario_override = True
|
||||||
|
return self._decision(scenario, scenario_override)
|
||||||
|
|
||||||
|
if scenario in _CAP_SOCIAL:
|
||||||
|
self._mode = CameraMode.SOCIAL
|
||||||
|
self._downgrade_pending_since = None
|
||||||
|
self._idle_since = None
|
||||||
|
scenario_override = True
|
||||||
|
return self._decision(scenario, scenario_override)
|
||||||
|
|
||||||
|
# ── 2. Compute desired mode from speed ────────────────────────────
|
||||||
|
desired = self._speed_to_mode(speed_mps)
|
||||||
|
|
||||||
|
# ── 3. Apply scenario caps ────────────────────────────────────────
|
||||||
|
if scenario in _CAP_AWARE:
|
||||||
|
desired = min(desired, CameraMode.AWARE)
|
||||||
|
|
||||||
|
# ── 4. Apply battery low cap ──────────────────────────────────────
|
||||||
|
if battery_pct < self._battery_low_pct:
|
||||||
|
desired = min(desired, CameraMode.AWARE)
|
||||||
|
|
||||||
|
# ── 5. Idle timer: delay AWARE → SOCIAL when briefly stopped ─────────
|
||||||
|
# _speed_to_mode already returns SOCIAL for speed < _SPD_MOTION.
|
||||||
|
# When in AWARE (or above) and nearly stopped, hold at AWARE until
|
||||||
|
# the idle timer expires to avoid flapping at traffic lights.
|
||||||
|
if speed_mps < 0.1 and self._mode >= CameraMode.AWARE:
|
||||||
|
if self._idle_since is None:
|
||||||
|
self._idle_since = now
|
||||||
|
if now - self._idle_since < self._idle_to_social:
|
||||||
|
# Timer not yet expired — enforce minimum AWARE
|
||||||
|
desired = max(desired, CameraMode.AWARE)
|
||||||
|
# else: timer expired — let desired remain SOCIAL naturally
|
||||||
|
else:
|
||||||
|
self._idle_since = None
|
||||||
|
|
||||||
|
# ── 6. Apply hysteresis for downgrades ────────────────────────────
|
||||||
|
if desired < self._mode:
|
||||||
|
# Downgrade requested — start hold timer on first detection, then
|
||||||
|
# check on every call (including the first, so hold=0 is instant).
|
||||||
|
if self._downgrade_pending_since is None:
|
||||||
|
self._downgrade_pending_since = now
|
||||||
|
if now - self._downgrade_pending_since >= self._downgrade_hold:
|
||||||
|
self._mode = desired
|
||||||
|
self._downgrade_pending_since = None
|
||||||
|
# else: hold not expired — keep current mode
|
||||||
|
else:
|
||||||
|
# Upgrade or no change — apply immediately, cancel any pending downgrade
|
||||||
|
self._downgrade_pending_since = None
|
||||||
|
self._mode = desired
|
||||||
|
|
||||||
|
return self._decision(scenario, scenario_override)
|
||||||
|
|
||||||
|
def reset(self, mode: CameraMode = CameraMode.SLEEP) -> None:
|
||||||
|
"""Reset FSM state (e.g. on node restart)."""
|
||||||
|
self._mode = mode
|
||||||
|
self._downgrade_pending_since = None
|
||||||
|
self._idle_since = None
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _speed_to_mode(self, speed_mps: float) -> CameraMode:
|
||||||
|
"""Map speed to desired mode using hysteresis thresholds."""
|
||||||
|
if self._mode == CameraMode.FULL:
|
||||||
|
# Downgrade from FULL uses lower threshold
|
||||||
|
if speed_mps >= _SPD_FULL_DOWN:
|
||||||
|
return CameraMode.FULL
|
||||||
|
elif speed_mps >= _SPD_ACTIVE_DOWN:
|
||||||
|
return CameraMode.ACTIVE
|
||||||
|
elif speed_mps >= _SPD_MOTION:
|
||||||
|
return CameraMode.AWARE
|
||||||
|
else:
|
||||||
|
return CameraMode.AWARE # idle handled separately
|
||||||
|
|
||||||
|
elif self._mode == CameraMode.ACTIVE:
|
||||||
|
if speed_mps >= _SPD_FULL_UP:
|
||||||
|
return CameraMode.FULL
|
||||||
|
elif speed_mps >= _SPD_ACTIVE_DOWN:
|
||||||
|
return CameraMode.ACTIVE
|
||||||
|
elif speed_mps >= _SPD_MOTION:
|
||||||
|
return CameraMode.AWARE
|
||||||
|
else:
|
||||||
|
return CameraMode.AWARE
|
||||||
|
|
||||||
|
else:
|
||||||
|
# SLEEP / SOCIAL / AWARE — use upgrade thresholds
|
||||||
|
if speed_mps >= _SPD_FULL_UP:
|
||||||
|
return CameraMode.FULL
|
||||||
|
elif speed_mps >= _SPD_ACTIVE_UP:
|
||||||
|
return CameraMode.ACTIVE
|
||||||
|
elif speed_mps >= _SPD_MOTION:
|
||||||
|
return CameraMode.AWARE
|
||||||
|
else:
|
||||||
|
return CameraMode.SOCIAL
|
||||||
|
|
||||||
|
def _decision(self, scenario: str, override: bool) -> ModeDecision:
|
||||||
|
return ModeDecision(
|
||||||
|
mode = self._mode,
|
||||||
|
sensors = MODE_SENSORS[self._mode],
|
||||||
|
trigger_speed_mps = self._last_speed,
|
||||||
|
trigger_scenario = scenario,
|
||||||
|
scenario_override = override,
|
||||||
|
)
|
||||||
@ -0,0 +1,316 @@
|
|||||||
|
"""
|
||||||
|
_face_emotion.py — Geometric face emotion classifier (no ROS2, no ML deps).
|
||||||
|
|
||||||
|
Classifies facial expressions using geometric rules derived from facial
|
||||||
|
landmark positions. Input is a set of key 2-D landmark coordinates in
|
||||||
|
normalised image space (x, y ∈ [0, 1]) as produced by MediaPipe Face Mesh.
|
||||||
|
|
||||||
|
Emotions
|
||||||
|
--------
|
||||||
|
neutral — baseline, no strong geometric signal
|
||||||
|
happy — lip corners raised relative to lip midpoint (smile)
|
||||||
|
surprised — raised eyebrows + wide eyes + open mouth
|
||||||
|
angry — furrowed inner brows (inner brow below outer brow level)
|
||||||
|
sad — lip corners depressed relative to lip midpoint (frown)
|
||||||
|
|
||||||
|
Classification priority: surprised > happy > angry > sad > neutral.
|
||||||
|
|
||||||
|
Coordinate convention
|
||||||
|
---------------------
|
||||||
|
x : 0 = left edge of image, 1 = right edge
|
||||||
|
y : 0 = top edge of image, 1 = bottom edge (y increases downward)
|
||||||
|
|
||||||
|
MediaPipe Face Mesh key indices used
|
||||||
|
-------------------------------------
|
||||||
|
MOUTH_UPPER = 13 inner upper-lip centre
|
||||||
|
MOUTH_LOWER = 14 inner lower-lip centre
|
||||||
|
MOUTH_LEFT = 61 left mouth corner
|
||||||
|
MOUTH_RIGHT = 291 right mouth corner
|
||||||
|
L_EYE_TOP = 159 left eye upper lid centre
|
||||||
|
L_EYE_BOT = 145 left eye lower lid centre
|
||||||
|
R_EYE_TOP = 386 right eye upper lid centre
|
||||||
|
R_EYE_BOT = 374 right eye lower lid centre
|
||||||
|
L_BROW_INNER = 107 left inner eyebrow
|
||||||
|
L_BROW_OUTER = 46 left outer eyebrow
|
||||||
|
R_BROW_INNER = 336 right inner eyebrow
|
||||||
|
R_BROW_OUTER = 276 right outer eyebrow
|
||||||
|
CHIN = 152 chin tip
|
||||||
|
FOREHEAD = 10 midpoint between eyes (forehead anchor)
|
||||||
|
|
||||||
|
Public API
|
||||||
|
----------
|
||||||
|
EMOTION_LABELS tuple[str, ...]
|
||||||
|
FaceLandmarks dataclass — key points only
|
||||||
|
EmotionFeatures NamedTuple — 5 normalised geometric scores
|
||||||
|
EmotionResult NamedTuple — (emotion, confidence, features)
|
||||||
|
|
||||||
|
from_mediapipe(lms) → FaceLandmarks
|
||||||
|
compute_features(fl) → EmotionFeatures
|
||||||
|
classify_emotion(features) → (emotion, confidence)
|
||||||
|
detect_emotion(fl) → EmotionResult
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import NamedTuple, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
# ── Public constants ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
EMOTION_LABELS: Tuple[str, ...] = (
|
||||||
|
'neutral', 'happy', 'surprised', 'angry', 'sad'
|
||||||
|
)
|
||||||
|
|
||||||
|
# MediaPipe Face Mesh 468-vertex indices for key facial geometry
|
||||||
|
MOUTH_UPPER = 13
|
||||||
|
MOUTH_LOWER = 14
|
||||||
|
MOUTH_LEFT = 61
|
||||||
|
MOUTH_RIGHT = 291
|
||||||
|
L_EYE_TOP = 159
|
||||||
|
L_EYE_BOT = 145
|
||||||
|
R_EYE_TOP = 386
|
||||||
|
R_EYE_BOT = 374
|
||||||
|
L_EYE_LEFT = 33
|
||||||
|
L_EYE_RIGHT = 133
|
||||||
|
R_EYE_LEFT = 362
|
||||||
|
R_EYE_RIGHT = 263
|
||||||
|
L_BROW_INNER = 107
|
||||||
|
L_BROW_OUTER = 46
|
||||||
|
R_BROW_INNER = 336
|
||||||
|
R_BROW_OUTER = 276
|
||||||
|
CHIN = 152
|
||||||
|
FOREHEAD = 10
|
||||||
|
|
||||||
|
|
||||||
|
# ── Data types ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FaceLandmarks:
|
||||||
|
"""
|
||||||
|
Key geometric points extracted from a MediaPipe Face Mesh result.
|
||||||
|
|
||||||
|
All coordinates are normalised to [0, 1] image space:
|
||||||
|
x = 0 (left) … 1 (right)
|
||||||
|
y = 0 (top) … 1 (bottom)
|
||||||
|
"""
|
||||||
|
mouth_upper: Tuple[float, float]
|
||||||
|
mouth_lower: Tuple[float, float]
|
||||||
|
mouth_left: Tuple[float, float] # left corner (image-left = face-right)
|
||||||
|
mouth_right: Tuple[float, float]
|
||||||
|
l_eye_top: Tuple[float, float]
|
||||||
|
l_eye_bot: Tuple[float, float]
|
||||||
|
r_eye_top: Tuple[float, float]
|
||||||
|
r_eye_bot: Tuple[float, float]
|
||||||
|
l_eye_left: Tuple[float, float]
|
||||||
|
l_eye_right: Tuple[float, float]
|
||||||
|
r_eye_left: Tuple[float, float]
|
||||||
|
r_eye_right: Tuple[float, float]
|
||||||
|
l_brow_inner: Tuple[float, float]
|
||||||
|
l_brow_outer: Tuple[float, float]
|
||||||
|
r_brow_inner: Tuple[float, float]
|
||||||
|
r_brow_outer: Tuple[float, float]
|
||||||
|
chin: Tuple[float, float]
|
||||||
|
forehead: Tuple[float, float]
|
||||||
|
|
||||||
|
|
||||||
|
class EmotionFeatures(NamedTuple):
|
||||||
|
"""
|
||||||
|
Five normalised geometric features derived from FaceLandmarks.
|
||||||
|
|
||||||
|
All features are relative to face_height; positive = expressive direction.
|
||||||
|
"""
|
||||||
|
mouth_open: float # mouth height / face_height (≥ 0)
|
||||||
|
smile: float # (lip-mid y − avg corner y) / face_h (+ = smile)
|
||||||
|
brow_raise: float # (eye-top y − inner-brow y) / face_h (+ = raised)
|
||||||
|
brow_furl: float # (inner-brow y − outer-brow y) / face_h (+ = angry)
|
||||||
|
eye_open: float # eye height / face_height (≥ 0)
|
||||||
|
face_height: float # forehead-to-chin distance in image coordinates
|
||||||
|
|
||||||
|
|
||||||
|
class EmotionResult(NamedTuple):
|
||||||
|
"""Result of detect_emotion()."""
|
||||||
|
emotion: str # one of EMOTION_LABELS
|
||||||
|
confidence: float # 0.0–1.0
|
||||||
|
features: EmotionFeatures
|
||||||
|
|
||||||
|
|
||||||
|
# ── MediaPipe adapter ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def from_mediapipe(lms) -> FaceLandmarks:
|
||||||
|
"""
|
||||||
|
Build a FaceLandmarks from a single MediaPipe NormalizedLandmarkList.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
lms : mediapipe.framework.formats.landmark_pb2.NormalizedLandmarkList
|
||||||
|
(element of FaceMeshResults.multi_face_landmarks)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
FaceLandmarks with normalised (x, y) coordinates.
|
||||||
|
"""
|
||||||
|
def _pt(idx: int) -> Tuple[float, float]:
|
||||||
|
lm = lms.landmark[idx]
|
||||||
|
return float(lm.x), float(lm.y)
|
||||||
|
|
||||||
|
return FaceLandmarks(
|
||||||
|
mouth_upper = _pt(MOUTH_UPPER),
|
||||||
|
mouth_lower = _pt(MOUTH_LOWER),
|
||||||
|
mouth_left = _pt(MOUTH_LEFT),
|
||||||
|
mouth_right = _pt(MOUTH_RIGHT),
|
||||||
|
l_eye_top = _pt(L_EYE_TOP),
|
||||||
|
l_eye_bot = _pt(L_EYE_BOT),
|
||||||
|
r_eye_top = _pt(R_EYE_TOP),
|
||||||
|
r_eye_bot = _pt(R_EYE_BOT),
|
||||||
|
l_eye_left = _pt(L_EYE_LEFT),
|
||||||
|
l_eye_right = _pt(L_EYE_RIGHT),
|
||||||
|
r_eye_left = _pt(R_EYE_LEFT),
|
||||||
|
r_eye_right = _pt(R_EYE_RIGHT),
|
||||||
|
l_brow_inner = _pt(L_BROW_INNER),
|
||||||
|
l_brow_outer = _pt(L_BROW_OUTER),
|
||||||
|
r_brow_inner = _pt(R_BROW_INNER),
|
||||||
|
r_brow_outer = _pt(R_BROW_OUTER),
|
||||||
|
chin = _pt(CHIN),
|
||||||
|
forehead = _pt(FOREHEAD),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Feature extraction ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def compute_features(fl: FaceLandmarks) -> EmotionFeatures:
|
||||||
|
"""
|
||||||
|
Derive five normalised geometric features from a FaceLandmarks object.
|
||||||
|
|
||||||
|
Coordinate notes
|
||||||
|
----------------
|
||||||
|
Image y increases downward, so:
|
||||||
|
• mouth_open > 0 when lower lip is below upper lip (mouth open)
|
||||||
|
• smile > 0 when corners are above the lip midpoint (raised)
|
||||||
|
• brow_raise > 0 when inner brow is above the eye-top (raised brow)
|
||||||
|
• brow_furl > 0 when inner brow is lower than outer brow (furrowed)
|
||||||
|
• eye_open > 0 when lower lid is below upper lid (open eye)
|
||||||
|
"""
|
||||||
|
face_h = max(fl.chin[1] - fl.forehead[1], 1e-4)
|
||||||
|
|
||||||
|
# ── Mouth openness ────────────────────────────────────────────────────
|
||||||
|
mouth_h = fl.mouth_lower[1] - fl.mouth_upper[1]
|
||||||
|
mouth_open = max(0.0, mouth_h) / face_h
|
||||||
|
|
||||||
|
# ── Smile / frown ─────────────────────────────────────────────────────
|
||||||
|
# lip midpoint y
|
||||||
|
mid_y = (fl.mouth_upper[1] + fl.mouth_lower[1]) / 2.0
|
||||||
|
# average corner y
|
||||||
|
corner_y = (fl.mouth_left[1] + fl.mouth_right[1]) / 2.0
|
||||||
|
# positive when corners are above midpoint (y_corner < y_mid)
|
||||||
|
smile = (mid_y - corner_y) / face_h
|
||||||
|
|
||||||
|
# ── Eyebrow raise ─────────────────────────────────────────────────────
|
||||||
|
# Distance from inner brow (above) to eye top (below) — larger = more raised
|
||||||
|
l_raise = fl.l_eye_top[1] - fl.l_brow_inner[1]
|
||||||
|
r_raise = fl.r_eye_top[1] - fl.r_brow_inner[1]
|
||||||
|
brow_raise = max(0.0, (l_raise + r_raise) / 2.0) / face_h
|
||||||
|
|
||||||
|
# ── Brow furl (angry) ─────────────────────────────────────────────────
|
||||||
|
# Inner brow below outer brow (in image y) → inner y > outer y → positive
|
||||||
|
l_furl = (fl.l_brow_inner[1] - fl.l_brow_outer[1]) / face_h
|
||||||
|
r_furl = (fl.r_brow_inner[1] - fl.r_brow_outer[1]) / face_h
|
||||||
|
brow_furl = (l_furl + r_furl) / 2.0
|
||||||
|
|
||||||
|
# ── Eye openness ──────────────────────────────────────────────────────
|
||||||
|
l_eye_h = fl.l_eye_bot[1] - fl.l_eye_top[1]
|
||||||
|
r_eye_h = fl.r_eye_bot[1] - fl.r_eye_top[1]
|
||||||
|
eye_open = max(0.0, (l_eye_h + r_eye_h) / 2.0) / face_h
|
||||||
|
|
||||||
|
return EmotionFeatures(
|
||||||
|
mouth_open = float(mouth_open),
|
||||||
|
smile = float(smile),
|
||||||
|
brow_raise = float(brow_raise),
|
||||||
|
brow_furl = float(brow_furl),
|
||||||
|
eye_open = float(eye_open),
|
||||||
|
face_height = float(face_h),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Classification rules ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Thresholds are relative to face_height (all features normalised by face_h).
|
||||||
|
# Chosen to match realistic facial proportions:
|
||||||
|
# brow_raise: 0.06–0.10 neutral, > 0.12 raised
|
||||||
|
# eye_open: 0.04–0.06 normal, > 0.07 wide
|
||||||
|
# mouth_open: 0.0–0.02 closed, > 0.07 open
|
||||||
|
# smile: -0.02–0.02 neutral, > 0.025 smile, < -0.025 frown
|
||||||
|
# brow_furl: -0.02–0.01 neutral, > 0.02 furrowed
|
||||||
|
|
||||||
|
_T_SURPRISED_BROW = 0.12 # brow_raise threshold for surprised
|
||||||
|
_T_SURPRISED_EYE = 0.07 # eye_open threshold for surprised
|
||||||
|
_T_SURPRISED_MOUTH = 0.07 # mouth_open threshold for surprised
|
||||||
|
_T_HAPPY_SMILE = 0.025 # smile threshold for happy
|
||||||
|
_T_ANGRY_FURL = 0.02 # brow_furl threshold for angry
|
||||||
|
_T_ANGRY_NO_SMILE = 0.01 # smile must be below this for angry
|
||||||
|
_T_SAD_FROWN = 0.025 # -smile threshold for sad
|
||||||
|
_T_SAD_NO_FURL = 0.015 # brow_furl must be below this for sad
|
||||||
|
|
||||||
|
|
||||||
|
def classify_emotion(features: EmotionFeatures) -> Tuple[str, float]:
|
||||||
|
"""
|
||||||
|
Apply geometric rules to classify emotion from EmotionFeatures.
|
||||||
|
|
||||||
|
Priority order: surprised → happy → angry → sad → neutral.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
(emotion, confidence) where confidence ∈ (0, 1].
|
||||||
|
"""
|
||||||
|
f = features
|
||||||
|
|
||||||
|
# ── Surprised ────────────────────────────────────────────────────────
|
||||||
|
if (f.brow_raise > _T_SURPRISED_BROW
|
||||||
|
and f.eye_open > _T_SURPRISED_EYE
|
||||||
|
and f.mouth_open > _T_SURPRISED_MOUTH):
|
||||||
|
# Confidence from excess over each threshold
|
||||||
|
br_exc = (f.brow_raise - _T_SURPRISED_BROW) / 0.05
|
||||||
|
ey_exc = (f.eye_open - _T_SURPRISED_EYE) / 0.03
|
||||||
|
mo_exc = (f.mouth_open - _T_SURPRISED_MOUTH) / 0.06
|
||||||
|
conf = min(1.0, 0.5 + (br_exc + ey_exc + mo_exc) / 6.0)
|
||||||
|
return 'surprised', conf
|
||||||
|
|
||||||
|
# ── Happy ─────────────────────────────────────────────────────────────
|
||||||
|
if f.smile > _T_HAPPY_SMILE:
|
||||||
|
conf = min(1.0, 0.5 + (f.smile - _T_HAPPY_SMILE) / 0.05)
|
||||||
|
return 'happy', conf
|
||||||
|
|
||||||
|
# ── Angry ─────────────────────────────────────────────────────────────
|
||||||
|
if f.brow_furl > _T_ANGRY_FURL and f.smile < _T_ANGRY_NO_SMILE:
|
||||||
|
conf = min(1.0, 0.5 + (f.brow_furl - _T_ANGRY_FURL) / 0.04)
|
||||||
|
return 'angry', conf
|
||||||
|
|
||||||
|
# ── Sad ───────────────────────────────────────────────────────────────
|
||||||
|
if f.smile < -_T_SAD_FROWN and f.brow_furl < _T_SAD_NO_FURL:
|
||||||
|
conf = min(1.0, 0.5 + (-f.smile - _T_SAD_FROWN) / 0.04)
|
||||||
|
return 'sad', conf
|
||||||
|
|
||||||
|
return 'neutral', 0.8
|
||||||
|
|
||||||
|
|
||||||
|
# ── Main entry point ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def detect_emotion(fl: FaceLandmarks) -> EmotionResult:
|
||||||
|
"""
|
||||||
|
Classify emotion from face landmark geometry.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
fl : FaceLandmarks — normalised key point positions
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
EmotionResult(emotion, confidence, features)
|
||||||
|
"""
|
||||||
|
features = compute_features(fl)
|
||||||
|
emotion, confidence = classify_emotion(features)
|
||||||
|
return EmotionResult(
|
||||||
|
emotion = emotion,
|
||||||
|
confidence = float(confidence),
|
||||||
|
features = features,
|
||||||
|
)
|
||||||
@ -0,0 +1,361 @@
|
|||||||
|
"""
|
||||||
|
_obstacle_size.py — Depth-based obstacle size estimation by projecting LIDAR
|
||||||
|
clusters into the D435i depth image (no ROS2 deps).
|
||||||
|
|
||||||
|
Algorithm
|
||||||
|
---------
|
||||||
|
1. Transform each LIDAR cluster centroid from LIDAR frame to camera frame
|
||||||
|
using a configurable translation-only extrinsic (typical co-mounted setup).
|
||||||
|
2. Project to depth-image pixel coordinates via the pinhole camera model.
|
||||||
|
3. Sample the D435i depth image in a window around the projected pixel to
|
||||||
|
obtain a robust Z estimate (uint16 mm → metres).
|
||||||
|
4. Derive horizontal width (metres) directly from the LIDAR cluster bounding
|
||||||
|
box, which is more reliable than image-based re-measurement.
|
||||||
|
5. Derive vertical height (metres) by scanning a vertical strip in the depth
|
||||||
|
image: collect pixels whose depth is within `z_tol` of the sampled Z, then
|
||||||
|
compute the row extent and back-project using the camera fy.
|
||||||
|
6. Return an ObstacleSizeEstimate per cluster.
|
||||||
|
|
||||||
|
Coordinate frames
|
||||||
|
-----------------
|
||||||
|
LIDAR frame (+X forward, +Y left, +Z up — 2-D scan at z=0):
|
||||||
|
x_lidar = forward distance (metres)
|
||||||
|
y_lidar = lateral distance (metres, positive = left)
|
||||||
|
z_lidar = 0 (2-D LIDAR always lies in the horizontal plane)
|
||||||
|
|
||||||
|
Camera frame (+X right, +Y down, +Z forward — standard pinhole):
|
||||||
|
x_cam = −y_lidar + ex
|
||||||
|
y_cam = −z_lidar + ey = ey (z_lidar=0 for 2-D LIDAR)
|
||||||
|
z_cam = x_lidar + ez
|
||||||
|
|
||||||
|
Extrinsic parameters (CameraParams.ex/ey/ez) give the LIDAR origin position
|
||||||
|
in the camera coordinate frame:
|
||||||
|
ex — lateral offset (positive = LIDAR origin is to the right of camera)
|
||||||
|
ey — vertical offset (positive = LIDAR origin is below camera)
|
||||||
|
ez — forward offset (positive = LIDAR origin is in front of camera)
|
||||||
|
|
||||||
|
Public API
|
||||||
|
----------
|
||||||
|
CameraParams dataclass — camera intrinsics + co-mount extrinsics
|
||||||
|
ObstacleSizeEstimate NamedTuple — per-cluster result
|
||||||
|
|
||||||
|
lidar_to_camera(x_l, y_l, params) → (x_cam, y_cam, z_cam)
|
||||||
|
project_to_pixel(x_cam, y_cam, z_cam, p) → (u, v) or None if out-of-bounds
|
||||||
|
sample_depth_median(depth_u16, u, v, win, scale) → (depth_m, n_valid)
|
||||||
|
estimate_height(depth_u16, u_c, v_c, z_ref, params,
|
||||||
|
search_rows, col_hw, z_tol) → height_m
|
||||||
|
estimate_cluster_size(cluster, depth_u16, params,
|
||||||
|
depth_window, search_rows, col_hw, z_tol,
|
||||||
|
cluster_id) → ObstacleSizeEstimate
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import NamedTuple, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# ── Camera parameters ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CameraParams:
|
||||||
|
"""Pinhole intrinsics and LIDAR→camera extrinsics for the D435i.
|
||||||
|
|
||||||
|
Intrinsics (D435i 640×480 depth stream defaults — override from CameraInfo):
|
||||||
|
fx, fy : focal lengths in pixels
|
||||||
|
cx, cy : principal point in pixels
|
||||||
|
width, height : image dimensions in pixels
|
||||||
|
depth_scale : multiply raw uint16 value by this to get metres
|
||||||
|
D435i default = 0.001 (raw value is millimetres)
|
||||||
|
|
||||||
|
Extrinsics (LIDAR origin position in camera frame):
|
||||||
|
ex, ey, ez : translation (metres) — see module docstring
|
||||||
|
"""
|
||||||
|
fx: float = 383.0 # D435i 640×480 depth approx
|
||||||
|
fy: float = 383.0
|
||||||
|
cx: float = 320.0
|
||||||
|
cy: float = 240.0
|
||||||
|
width: int = 640
|
||||||
|
height: int = 480
|
||||||
|
depth_scale: float = 0.001 # mm → m
|
||||||
|
|
||||||
|
# LIDAR→camera extrinsics: LIDAR origin position in camera frame
|
||||||
|
ex: float = 0.0 # lateral (m); 0 = LIDAR directly in front of camera
|
||||||
|
ey: float = 0.05 # vertical (m); +0.05 = LIDAR 5 cm below camera
|
||||||
|
ez: float = 0.0 # forward (m); 0 = LIDAR at same depth as camera
|
||||||
|
|
||||||
|
|
||||||
|
# ── Result type ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class ObstacleSizeEstimate(NamedTuple):
|
||||||
|
"""Size estimate for one LIDAR cluster projected into the depth image."""
|
||||||
|
obstacle_id: int # user-supplied cluster index (0-based) or track ID
|
||||||
|
centroid_x: float # LIDAR-frame forward distance (metres)
|
||||||
|
centroid_y: float # LIDAR-frame lateral distance (metres, +left)
|
||||||
|
depth_z: float # D435i sampled depth at projected centroid (metres)
|
||||||
|
width_m: float # horizontal size, LIDAR-derived (metres)
|
||||||
|
height_m: float # vertical size, depth-image-derived (metres)
|
||||||
|
pixel_u: int # projected centroid column (-1 = out of image)
|
||||||
|
pixel_v: int # projected centroid row (-1 = out of image)
|
||||||
|
lidar_range: float # range from LIDAR origin to centroid (metres)
|
||||||
|
confidence: float # 0.0–1.0; based on depth sample quality
|
||||||
|
|
||||||
|
|
||||||
|
# ── Coordinate transform ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def lidar_to_camera(
|
||||||
|
x_lidar: float,
|
||||||
|
y_lidar: float,
|
||||||
|
params: CameraParams,
|
||||||
|
) -> Tuple[float, float, float]:
|
||||||
|
"""
|
||||||
|
Transform a 2-D LIDAR point to the camera coordinate frame.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x_lidar : forward distance in LIDAR frame (metres)
|
||||||
|
y_lidar : lateral distance in LIDAR frame (metres, +left)
|
||||||
|
params : CameraParams with extrinsic translation ex/ey/ez
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
(x_cam, y_cam, z_cam) in camera frame (metres)
|
||||||
|
x_cam = right, y_cam = down, z_cam = forward
|
||||||
|
"""
|
||||||
|
x_cam = -y_lidar + params.ex
|
||||||
|
y_cam = params.ey # LIDAR z=0 for 2-D scan
|
||||||
|
z_cam = x_lidar + params.ez
|
||||||
|
return float(x_cam), float(y_cam), float(z_cam)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Pinhole projection ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def project_to_pixel(
|
||||||
|
x_cam: float,
|
||||||
|
y_cam: float,
|
||||||
|
z_cam: float,
|
||||||
|
params: CameraParams,
|
||||||
|
) -> Optional[Tuple[int, int]]:
|
||||||
|
"""
|
||||||
|
Project a 3-D camera-frame point to image pixel coordinates.
|
||||||
|
|
||||||
|
Returns None if z_cam ≤ 0 or the projection falls outside the image.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
(u, v) integer pixel coordinates (column, row), or None.
|
||||||
|
"""
|
||||||
|
if z_cam <= 0.0:
|
||||||
|
return None
|
||||||
|
u = params.fx * x_cam / z_cam + params.cx
|
||||||
|
v = params.fy * y_cam / z_cam + params.cy
|
||||||
|
ui, vi = int(round(u)), int(round(v))
|
||||||
|
if 0 <= ui < params.width and 0 <= vi < params.height:
|
||||||
|
return ui, vi
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Depth sampling ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def sample_depth_median(
|
||||||
|
depth_u16: np.ndarray,
|
||||||
|
u: int,
|
||||||
|
v: int,
|
||||||
|
window_px: int = 5,
|
||||||
|
depth_scale: float = 0.001,
|
||||||
|
) -> Tuple[float, int]:
|
||||||
|
"""
|
||||||
|
Median depth in a square window centred on (u, v) in a uint16 depth image.
|
||||||
|
|
||||||
|
Zeros (invalid readings) are excluded from the median.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
depth_u16 : (H, W) uint16 depth image (raw units, e.g. mm for D435i)
|
||||||
|
u, v : centre pixel (column, row)
|
||||||
|
window_px : half-side of the sampling square (window of 2*window_px+1)
|
||||||
|
depth_scale : scale factor to convert raw units to metres (0.001 for mm)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
(depth_m, n_valid) — median depth in metres and count of valid pixels.
|
||||||
|
depth_m = 0.0 when n_valid == 0.
|
||||||
|
"""
|
||||||
|
h, w = depth_u16.shape
|
||||||
|
r0 = max(0, v - window_px)
|
||||||
|
r1 = min(h, v + window_px + 1)
|
||||||
|
c0 = max(0, u - window_px)
|
||||||
|
c1 = min(w, u + window_px + 1)
|
||||||
|
|
||||||
|
patch = depth_u16[r0:r1, c0:c1].astype(np.float64)
|
||||||
|
valid = patch[patch > 0]
|
||||||
|
if len(valid) == 0:
|
||||||
|
return 0.0, 0
|
||||||
|
return float(np.median(valid) * depth_scale), int(len(valid))
|
||||||
|
|
||||||
|
|
||||||
|
# ── Height estimation ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def estimate_height(
|
||||||
|
depth_u16: np.ndarray,
|
||||||
|
u_c: int,
|
||||||
|
v_c: int,
|
||||||
|
z_ref: float,
|
||||||
|
params: CameraParams,
|
||||||
|
search_rows: int = 120,
|
||||||
|
col_hw: int = 10,
|
||||||
|
z_tol: float = 0.30,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Estimate the vertical height (metres) of an obstacle in the depth image.
|
||||||
|
|
||||||
|
Scans a vertical strip centred on (u_c, v_c) and finds the row extent of
|
||||||
|
pixels whose depth is within `z_tol` metres of `z_ref`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
depth_u16 : (H, W) uint16 depth image
|
||||||
|
u_c, v_c : projected centroid pixel
|
||||||
|
z_ref : reference depth (metres) — from sample_depth_median
|
||||||
|
params : CameraParams (needs fy, depth_scale)
|
||||||
|
search_rows : half-height of the search strip (rows above/below v_c)
|
||||||
|
col_hw : half-width of the search strip (columns left/right of u_c)
|
||||||
|
z_tol : depth tolerance — pixels within z_ref ± z_tol are included
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
height_m — vertical extent in metres; 0.0 if fewer than 2 valid rows found.
|
||||||
|
"""
|
||||||
|
if z_ref <= 0.0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
h, w = depth_u16.shape
|
||||||
|
r0 = max(0, v_c - search_rows)
|
||||||
|
r1 = min(h, v_c + search_rows + 1)
|
||||||
|
c0 = max(0, u_c - col_hw)
|
||||||
|
c1 = min(w, u_c + col_hw + 1)
|
||||||
|
|
||||||
|
strip = depth_u16[r0:r1, c0:c1].astype(np.float64) * params.depth_scale
|
||||||
|
|
||||||
|
# For each row in the strip, check if any valid pixel is within z_tol of z_ref
|
||||||
|
z_min = z_ref - z_tol
|
||||||
|
z_max = z_ref + z_tol
|
||||||
|
valid_mask = (strip > 0) & (strip >= z_min) & (strip <= z_max)
|
||||||
|
valid_rows = np.any(valid_mask, axis=1) # (n_rows,) bool
|
||||||
|
|
||||||
|
if valid_rows.sum() < 2:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
row_indices = np.where(valid_rows)[0] # rows within the strip
|
||||||
|
v_top_strip = int(row_indices[0])
|
||||||
|
v_bottom_strip = int(row_indices[-1])
|
||||||
|
|
||||||
|
# Convert to absolute image rows
|
||||||
|
v_top = r0 + v_top_strip
|
||||||
|
v_bottom = r0 + v_bottom_strip
|
||||||
|
|
||||||
|
row_extent = v_bottom - v_top
|
||||||
|
if row_extent <= 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Back-project row extent to metres at depth z_ref
|
||||||
|
height_m = float(row_extent) * z_ref / params.fy
|
||||||
|
return float(height_m)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Main entry point ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def estimate_cluster_size(
|
||||||
|
cluster, # _lidar_clustering.Cluster NamedTuple
|
||||||
|
depth_u16: np.ndarray, # (H, W) uint16 depth image
|
||||||
|
params: CameraParams,
|
||||||
|
depth_window: int = 5, # half-side for depth median sampling
|
||||||
|
search_rows: int = 120, # half-height of height search strip
|
||||||
|
col_hw: int = 10, # half-width of height search strip
|
||||||
|
z_tol: float = 0.30, # depth tolerance for height estimation
|
||||||
|
obstacle_id: int = 0, # caller-assigned ID
|
||||||
|
) -> ObstacleSizeEstimate:
|
||||||
|
"""
|
||||||
|
Estimate the 3-D size of one LIDAR cluster using the D435i depth image.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
cluster : Cluster from _lidar_clustering.cluster_points()
|
||||||
|
depth_u16 : uint16 depth image (H × W), values in millimetres
|
||||||
|
params : CameraParams with intrinsics + extrinsics
|
||||||
|
depth_window : half-side of median depth sampling window (pixels)
|
||||||
|
search_rows : half-height for vertical height search (rows)
|
||||||
|
col_hw : half-width for vertical height search (columns)
|
||||||
|
z_tol : depth tolerance when collecting obstacle pixels (metres)
|
||||||
|
obstacle_id : arbitrary integer ID for this estimate
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ObstacleSizeEstimate
|
||||||
|
"""
|
||||||
|
cx, cy = float(cluster.centroid[0]), float(cluster.centroid[1])
|
||||||
|
lidar_range = float(math.sqrt(cx * cx + cy * cy))
|
||||||
|
|
||||||
|
# Transform centroid to camera frame
|
||||||
|
x_cam, y_cam, z_cam = lidar_to_camera(cx, cy, params)
|
||||||
|
|
||||||
|
# Project to depth image pixel
|
||||||
|
px = project_to_pixel(x_cam, y_cam, z_cam, params)
|
||||||
|
if px is None:
|
||||||
|
return ObstacleSizeEstimate(
|
||||||
|
obstacle_id = obstacle_id,
|
||||||
|
centroid_x = cx,
|
||||||
|
centroid_y = cy,
|
||||||
|
depth_z = 0.0,
|
||||||
|
width_m = float(cluster.width_m),
|
||||||
|
height_m = 0.0,
|
||||||
|
pixel_u = -1,
|
||||||
|
pixel_v = -1,
|
||||||
|
lidar_range = lidar_range,
|
||||||
|
confidence = 0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
u, v = px
|
||||||
|
|
||||||
|
# Sample depth
|
||||||
|
depth_m, n_valid = sample_depth_median(
|
||||||
|
depth_u16, u, v, window_px=depth_window,
|
||||||
|
depth_scale=params.depth_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use LIDAR range as fallback when depth image has no valid reading
|
||||||
|
if depth_m <= 0.0:
|
||||||
|
depth_m = z_cam
|
||||||
|
depth_conf = 0.0
|
||||||
|
else:
|
||||||
|
# Confidence: ratio of valid pixels in window + how close to expected Z
|
||||||
|
max_valid = (2 * depth_window + 1) ** 2
|
||||||
|
fill_ratio = min(1.0, n_valid / max(max_valid * 0.25, 1.0))
|
||||||
|
z_err = abs(depth_m - z_cam)
|
||||||
|
z_conf = max(0.0, 1.0 - z_err / max(z_cam, 0.1))
|
||||||
|
depth_conf = float(0.6 * fill_ratio + 0.4 * z_conf)
|
||||||
|
|
||||||
|
# Width: LIDAR gives reliable horizontal measurement
|
||||||
|
width_m = float(cluster.width_m)
|
||||||
|
|
||||||
|
# Height: from depth image vertical strip
|
||||||
|
height_m = estimate_height(
|
||||||
|
depth_u16, u, v, depth_m, params,
|
||||||
|
search_rows=search_rows, col_hw=col_hw, z_tol=z_tol,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ObstacleSizeEstimate(
|
||||||
|
obstacle_id = obstacle_id,
|
||||||
|
centroid_x = cx,
|
||||||
|
centroid_y = cy,
|
||||||
|
depth_z = float(depth_m),
|
||||||
|
width_m = width_m,
|
||||||
|
height_m = float(height_m),
|
||||||
|
pixel_u = int(u),
|
||||||
|
pixel_v = int(v),
|
||||||
|
lidar_range = lidar_range,
|
||||||
|
confidence = float(depth_conf),
|
||||||
|
)
|
||||||
@ -0,0 +1,375 @@
|
|||||||
|
"""
|
||||||
|
_obstacle_velocity.py — Dynamic obstacle velocity estimation via Kalman
|
||||||
|
filtering of LIDAR cluster centroids (no ROS2 deps).
|
||||||
|
|
||||||
|
Algorithm
|
||||||
|
---------
|
||||||
|
Each detected cluster centroid is tracked by a constant-velocity 2-D Kalman
|
||||||
|
filter:
|
||||||
|
|
||||||
|
State : [x, y, vx, vy]^T
|
||||||
|
Obs : [x, y]^T (centroid position)
|
||||||
|
|
||||||
|
Predict : x_pred = F(dt) @ x; P_pred = F @ P @ F^T + Q(dt)
|
||||||
|
Update : S = H @ P_pred @ H^T + R
|
||||||
|
K = P_pred @ H^T @ inv(S)
|
||||||
|
x = x_pred + K @ (z - H @ x_pred)
|
||||||
|
P = (I - K @ H) @ P_pred
|
||||||
|
|
||||||
|
Process noise Q uses the white-noise-acceleration approximation:
|
||||||
|
Q = diag([q_pos·dt², q_pos·dt², q_vel·dt, q_vel·dt])
|
||||||
|
|
||||||
|
Data association uses a greedy nearest-centroid approach: scan all
|
||||||
|
(track, cluster) pairs in order of Euclidean distance and accept matches
|
||||||
|
below `max_association_dist_m`, one-to-one.
|
||||||
|
|
||||||
|
Track lifecycle
|
||||||
|
---------------
|
||||||
|
created : first appearance of a cluster with no nearby track
|
||||||
|
confident : after `n_init_frames` consecutive updates (confidence=1.0)
|
||||||
|
coasting : cluster not matched for up to `max_coasting_frames` calls
|
||||||
|
deleted : coasting count exceeds `max_coasting_frames`
|
||||||
|
|
||||||
|
Public API
|
||||||
|
----------
|
||||||
|
KalmanTrack one tracked obstacle
|
||||||
|
.predict(dt) advance state by dt seconds
|
||||||
|
.update(centroid, width, depth, n) correct with new measurement
|
||||||
|
.position → np.ndarray (2,)
|
||||||
|
.velocity → np.ndarray (2,)
|
||||||
|
.speed → float (m/s)
|
||||||
|
.confidence → float 0–1
|
||||||
|
.alive → bool
|
||||||
|
|
||||||
|
associate(positions, centroids, max_dist)
|
||||||
|
→ (matches, unmatched_t, unmatched_c)
|
||||||
|
|
||||||
|
ObstacleTracker
|
||||||
|
.update(centroids, timestamp, widths, depths, point_counts)
|
||||||
|
→ List[KalmanTrack]
|
||||||
|
.tracks → Dict[int, KalmanTrack]
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# ── Kalman filter constants ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_H = np.array([[1.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 1.0, 0.0, 0.0]], dtype=np.float64)
|
||||||
|
|
||||||
|
_I4 = np.eye(4, dtype=np.float64)
|
||||||
|
|
||||||
|
|
||||||
|
# ── KalmanTrack ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class KalmanTrack:
|
||||||
|
"""
|
||||||
|
Constant-velocity 2-D Kalman filter for one LIDAR cluster centroid.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
track_id : unique integer ID
|
||||||
|
centroid : initial (x, y) position in metres
|
||||||
|
q_pos : position process noise density (m²/s²)
|
||||||
|
q_vel : velocity process noise density (m²/s³)
|
||||||
|
r_pos : measurement noise std-dev (metres)
|
||||||
|
n_init_frames : consecutive updates needed for confidence=1
|
||||||
|
max_coasting : missed updates before track is deleted
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
track_id: int,
|
||||||
|
centroid: np.ndarray,
|
||||||
|
q_pos: float = 0.05,
|
||||||
|
q_vel: float = 0.50,
|
||||||
|
r_pos: float = 0.10,
|
||||||
|
n_init_frames: int = 3,
|
||||||
|
max_coasting: int = 5,
|
||||||
|
) -> None:
|
||||||
|
self._id = track_id
|
||||||
|
self._n_init = max(n_init_frames, 1)
|
||||||
|
self._max_coast = max_coasting
|
||||||
|
self._q_pos = float(q_pos)
|
||||||
|
self._q_vel = float(q_vel)
|
||||||
|
self._R = np.diag([r_pos ** 2, r_pos ** 2])
|
||||||
|
self._age = 0 # successful updates
|
||||||
|
self._coasting = 0 # consecutive missed updates
|
||||||
|
self._alive = True
|
||||||
|
|
||||||
|
# State: [x, y, vx, vy]^T — initial velocity is zero
|
||||||
|
self._x = np.array(
|
||||||
|
[float(centroid[0]), float(centroid[1]), 0.0, 0.0],
|
||||||
|
dtype=np.float64,
|
||||||
|
)
|
||||||
|
# High initial velocity uncertainty, low position uncertainty
|
||||||
|
self._P = np.diag([r_pos ** 2, r_pos ** 2, 10.0, 10.0])
|
||||||
|
|
||||||
|
# Last-known cluster metadata (stored, not filtered)
|
||||||
|
self.last_width: float = 0.0
|
||||||
|
self.last_depth: float = 0.0
|
||||||
|
self.last_point_count: int = 0
|
||||||
|
|
||||||
|
# ── Properties ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@property
|
||||||
|
def track_id(self) -> int:
|
||||||
|
return self._id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def position(self) -> np.ndarray:
|
||||||
|
return self._x[:2].copy()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def velocity(self) -> np.ndarray:
|
||||||
|
return self._x[2:].copy()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def speed(self) -> float:
|
||||||
|
return float(np.linalg.norm(self._x[2:]))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def confidence(self) -> float:
|
||||||
|
return min(1.0, self._age / self._n_init)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def alive(self) -> bool:
|
||||||
|
return self._alive
|
||||||
|
|
||||||
|
@property
|
||||||
|
def age(self) -> int:
|
||||||
|
return self._age
|
||||||
|
|
||||||
|
@property
|
||||||
|
def coasting(self) -> int:
|
||||||
|
return self._coasting
|
||||||
|
|
||||||
|
# ── Kalman steps ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def predict(self, dt: float) -> None:
|
||||||
|
"""
|
||||||
|
Advance the filter state by `dt` seconds (constant-velocity model).
|
||||||
|
|
||||||
|
dt must be ≥ 0; negative values are clamped to 0.
|
||||||
|
Increments the coasting counter; marks the track dead when the
|
||||||
|
counter exceeds max_coasting_frames.
|
||||||
|
"""
|
||||||
|
dt = max(0.0, float(dt))
|
||||||
|
|
||||||
|
F = np.array([
|
||||||
|
[1.0, 0.0, dt, 0.0],
|
||||||
|
[0.0, 1.0, 0.0, dt],
|
||||||
|
[0.0, 0.0, 1.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 1.0],
|
||||||
|
], dtype=np.float64)
|
||||||
|
|
||||||
|
Q = np.diag([
|
||||||
|
self._q_pos * dt * dt,
|
||||||
|
self._q_pos * dt * dt,
|
||||||
|
self._q_vel * dt,
|
||||||
|
self._q_vel * dt,
|
||||||
|
])
|
||||||
|
|
||||||
|
self._x = F @ self._x
|
||||||
|
self._P = F @ self._P @ F.T + Q
|
||||||
|
|
||||||
|
self._coasting += 1
|
||||||
|
if self._coasting > self._max_coast:
|
||||||
|
self._alive = False
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
centroid: np.ndarray,
|
||||||
|
width: float = 0.0,
|
||||||
|
depth: float = 0.0,
|
||||||
|
point_count: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Correct state with a new centroid measurement.
|
||||||
|
|
||||||
|
Resets the coasting counter and increments the age counter.
|
||||||
|
"""
|
||||||
|
z = np.asarray(centroid, dtype=np.float64)[:2]
|
||||||
|
|
||||||
|
S = _H @ self._P @ _H.T + self._R # innovation covariance
|
||||||
|
K = self._P @ _H.T @ np.linalg.inv(S) # Kalman gain
|
||||||
|
|
||||||
|
self._x = self._x + K @ (z - _H @ self._x)
|
||||||
|
self._P = (_I4 - K @ _H) @ self._P
|
||||||
|
|
||||||
|
self._age += 1
|
||||||
|
self._coasting = 0
|
||||||
|
|
||||||
|
self.last_width = float(width)
|
||||||
|
self.last_depth = float(depth)
|
||||||
|
self.last_point_count = int(point_count)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Data association ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def associate(
|
||||||
|
positions: np.ndarray, # (N, 2) predicted track positions
|
||||||
|
centroids: np.ndarray, # (M, 2) new cluster centroids
|
||||||
|
max_dist: float,
|
||||||
|
) -> Tuple[List[Tuple[int, int]], List[int], List[int]]:
|
||||||
|
"""
|
||||||
|
Greedy nearest-centroid data association.
|
||||||
|
|
||||||
|
Iteratively matches the closest (track, cluster) pair, one-to-one, as
|
||||||
|
long as the distance is strictly less than `max_dist`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
matches : list of (track_idx, cluster_idx) pairs
|
||||||
|
unmatched_tracks : track indices with no assigned cluster
|
||||||
|
unmatched_clusters: cluster indices with no assigned track
|
||||||
|
"""
|
||||||
|
N = len(positions)
|
||||||
|
M = len(centroids)
|
||||||
|
|
||||||
|
if N == 0 or M == 0:
|
||||||
|
return [], list(range(N)), list(range(M))
|
||||||
|
|
||||||
|
# (N, M) pairwise Euclidean distance matrix
|
||||||
|
cost = np.linalg.norm(
|
||||||
|
positions[:, None, :] - centroids[None, :, :], axis=2
|
||||||
|
).astype(np.float64) # (N, M)
|
||||||
|
|
||||||
|
matched_t: set = set()
|
||||||
|
matched_c: set = set()
|
||||||
|
matches: List[Tuple[int, int]] = []
|
||||||
|
|
||||||
|
for _ in range(min(N, M)):
|
||||||
|
if cost.min() >= max_dist:
|
||||||
|
break
|
||||||
|
t_idx, c_idx = np.unravel_index(int(cost.argmin()), cost.shape)
|
||||||
|
matches.append((int(t_idx), int(c_idx)))
|
||||||
|
matched_t.add(t_idx)
|
||||||
|
matched_c.add(c_idx)
|
||||||
|
cost[t_idx, :] = np.inf
|
||||||
|
cost[:, c_idx] = np.inf
|
||||||
|
|
||||||
|
unmatched_t = [i for i in range(N) if i not in matched_t]
|
||||||
|
unmatched_c = [i for i in range(M) if i not in matched_c]
|
||||||
|
return matches, unmatched_t, unmatched_c
|
||||||
|
|
||||||
|
|
||||||
|
# ── ObstacleTracker ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class ObstacleTracker:
|
||||||
|
"""
|
||||||
|
Multi-obstacle Kalman tracker operating on LIDAR cluster centroids.
|
||||||
|
|
||||||
|
Call `update()` once per LIDAR scan with the list of cluster centroids
|
||||||
|
(and optional bbox / point_count metadata). Returns the current list
|
||||||
|
of alive KalmanTrack objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_association_dist_m: float = 0.50,
|
||||||
|
max_coasting_frames: int = 5,
|
||||||
|
n_init_frames: int = 3,
|
||||||
|
q_pos: float = 0.05,
|
||||||
|
q_vel: float = 0.50,
|
||||||
|
r_pos: float = 0.10,
|
||||||
|
) -> None:
|
||||||
|
self._max_dist = float(max_association_dist_m)
|
||||||
|
self._max_coast = int(max_coasting_frames)
|
||||||
|
self._n_init = int(n_init_frames)
|
||||||
|
self._q_pos = float(q_pos)
|
||||||
|
self._q_vel = float(q_vel)
|
||||||
|
self._r_pos = float(r_pos)
|
||||||
|
|
||||||
|
self._tracks: Dict[int, KalmanTrack] = {}
|
||||||
|
self._next_id: int = 1
|
||||||
|
self._last_t: Optional[float] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tracks(self) -> Dict[int, KalmanTrack]:
|
||||||
|
return self._tracks
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
centroids: List[np.ndarray],
|
||||||
|
timestamp: float,
|
||||||
|
widths: Optional[List[float]] = None,
|
||||||
|
depths: Optional[List[float]] = None,
|
||||||
|
point_counts: Optional[List[int]] = None,
|
||||||
|
) -> List[KalmanTrack]:
|
||||||
|
"""
|
||||||
|
Process one frame of cluster centroids.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
centroids : list of (2,) arrays — cluster centroid positions
|
||||||
|
timestamp : wall-clock time in seconds (time.time())
|
||||||
|
widths : optional per-cluster bbox widths (metres)
|
||||||
|
depths : optional per-cluster bbox depths (metres)
|
||||||
|
point_counts : optional per-cluster LIDAR point counts
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List of all alive KalmanTrack objects after this update.
|
||||||
|
"""
|
||||||
|
dt = max(0.0, timestamp - self._last_t) if self._last_t is not None else 0.0
|
||||||
|
self._last_t = timestamp
|
||||||
|
|
||||||
|
# 1. Predict all existing tracks forward by dt
|
||||||
|
for track in self._tracks.values():
|
||||||
|
track.predict(dt)
|
||||||
|
|
||||||
|
# 2. Build arrays for association (alive tracks only)
|
||||||
|
alive = [t for t in self._tracks.values() if t.alive]
|
||||||
|
if alive:
|
||||||
|
pred_pos = np.array([t.position for t in alive]) # (N, 2)
|
||||||
|
else:
|
||||||
|
pred_pos = np.empty((0, 2), dtype=np.float64)
|
||||||
|
|
||||||
|
if centroids:
|
||||||
|
cent_arr = np.array([c[:2] for c in centroids], dtype=np.float64) # (M, 2)
|
||||||
|
else:
|
||||||
|
cent_arr = np.empty((0, 2), dtype=np.float64)
|
||||||
|
|
||||||
|
# 3. Associate
|
||||||
|
matches, _, unmatched_c = associate(pred_pos, cent_arr, self._max_dist)
|
||||||
|
|
||||||
|
# 4. Update matched tracks
|
||||||
|
for ti, ci in matches:
|
||||||
|
track = alive[ti]
|
||||||
|
track.update(
|
||||||
|
cent_arr[ci],
|
||||||
|
width = widths[ci] if widths else 0.0,
|
||||||
|
depth = depths[ci] if depths else 0.0,
|
||||||
|
point_count = point_counts[ci] if point_counts else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Spawn new tracks for unmatched clusters
|
||||||
|
for ci in unmatched_c:
|
||||||
|
track = KalmanTrack(
|
||||||
|
track_id = self._next_id,
|
||||||
|
centroid = cent_arr[ci],
|
||||||
|
q_pos = self._q_pos,
|
||||||
|
q_vel = self._q_vel,
|
||||||
|
r_pos = self._r_pos,
|
||||||
|
n_init_frames = self._n_init,
|
||||||
|
max_coasting = self._max_coast,
|
||||||
|
)
|
||||||
|
track.update(
|
||||||
|
cent_arr[ci],
|
||||||
|
width = widths[ci] if widths else 0.0,
|
||||||
|
depth = depths[ci] if depths else 0.0,
|
||||||
|
point_count = point_counts[ci] if point_counts else 0,
|
||||||
|
)
|
||||||
|
self._tracks[self._next_id] = track
|
||||||
|
self._next_id += 1
|
||||||
|
|
||||||
|
# 6. Prune dead tracks
|
||||||
|
self._tracks = {tid: t for tid, t in self._tracks.items() if t.alive}
|
||||||
|
|
||||||
|
return list(self._tracks.values())
|
||||||
@ -0,0 +1,379 @@
|
|||||||
|
"""
|
||||||
|
_path_edges.py — Lane/path edge detection via Canny + Hough + bird-eye
|
||||||
|
perspective transform (no ROS2 deps).
|
||||||
|
|
||||||
|
Algorithm
|
||||||
|
---------
|
||||||
|
1. Crop the bottom `roi_frac` of the BGR image to an ROI.
|
||||||
|
2. Convert ROI to grayscale → Gaussian blur → Canny edge map.
|
||||||
|
3. Run probabilistic Hough (HoughLinesP) on the edge map to find
|
||||||
|
line segments.
|
||||||
|
4. Filter segments by slope: near-horizontal lines (|slope| < min_slope)
|
||||||
|
are discarded as ground noise; the remaining lines are classified as
|
||||||
|
left edges (negative slope) or right edges (positive slope).
|
||||||
|
5. Average each class into a single dominant edge line and extrapolate it
|
||||||
|
to span the full ROI height.
|
||||||
|
6. Apply a bird-eye perspective homography (computed from a configurable
|
||||||
|
source trapezoid in the ROI) to warp all segment endpoints to a
|
||||||
|
top-down view.
|
||||||
|
7. Return a PathEdgesResult with all data.
|
||||||
|
|
||||||
|
Coordinate conventions
|
||||||
|
-----------------------
|
||||||
|
All pixel coordinates in PathEdgesResult are in the ROI frame:
|
||||||
|
origin = top-left of the bottom-half ROI
|
||||||
|
y = 0 → roi_top of the full image
|
||||||
|
y increases downward; x increases rightward.
|
||||||
|
|
||||||
|
Bird-eye homography
|
||||||
|
-------------------
|
||||||
|
The homography H is computed via cv2.getPerspectiveTransform from four
|
||||||
|
source points (fractional ROI coords) to four destination points in a
|
||||||
|
square output image of size `birdseye_size`. Default source trapezoid
|
||||||
|
assumes a forward-looking camera at ~35–45° tilt above a flat surface.
|
||||||
|
|
||||||
|
Public API
|
||||||
|
----------
|
||||||
|
PathEdgeConfig dataclass — all tunable parameters with sensible defaults
|
||||||
|
PathEdgesResult NamedTuple — all outputs of process_frame()
|
||||||
|
build_homography(src_frac, roi_w, roi_h, birdseye_size)
|
||||||
|
→ 3×3 float64 homography
|
||||||
|
apply_homography(H, points_xy) → np.ndarray (N, 2) warped points
|
||||||
|
canny_edges(bgr_roi, low, high, ksize) → uint8 edge map
|
||||||
|
hough_lines(edge_map, threshold, min_len, max_gap) → List[(x1,y1,x2,y2)]
|
||||||
|
classify_lines(lines, min_slope) → (left, right)
|
||||||
|
average_line(lines, roi_height) → Optional[(x1,y1,x2,y2)]
|
||||||
|
warp_segments(lines, H) → List[(bx1,by1,bx2,by2)]
|
||||||
|
process_frame(bgr, cfg) → PathEdgesResult
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, NamedTuple, Optional, Tuple
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# ── Config ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PathEdgeConfig:
|
||||||
|
# ROI
|
||||||
|
roi_frac: float = 0.50 # bottom fraction of image used as ROI
|
||||||
|
|
||||||
|
# Preprocessing
|
||||||
|
blur_ksize: int = 5 # Gaussian blur kernel (odd integer)
|
||||||
|
canny_low: int = 50 # Canny low threshold
|
||||||
|
canny_high: int = 150 # Canny high threshold
|
||||||
|
|
||||||
|
# Hough
|
||||||
|
hough_rho: float = 1.0 # distance resolution (px)
|
||||||
|
hough_theta: float = math.pi / 180.0 # angle resolution (rad)
|
||||||
|
hough_threshold: int = 30 # minimum votes
|
||||||
|
min_line_len: int = 40 # minimum segment length (px)
|
||||||
|
max_line_gap: int = 20 # maximum gap within a segment (px)
|
||||||
|
|
||||||
|
# Line classification
|
||||||
|
min_slope: float = 0.3 # |slope| below this → discard (near-horizontal)
|
||||||
|
|
||||||
|
# Bird-eye perspective — source trapezoid as fractions of (roi_w, roi_h)
|
||||||
|
# Default: wide trapezoid for forward-looking camera at ~40° tilt
|
||||||
|
birdseye_src: List[List[float]] = field(default_factory=lambda: [
|
||||||
|
[0.40, 0.05], # top-left (near horizon)
|
||||||
|
[0.60, 0.05], # top-right (near horizon)
|
||||||
|
[0.95, 0.95], # bottom-right (near robot)
|
||||||
|
[0.05, 0.95], # bottom-left (near robot)
|
||||||
|
])
|
||||||
|
birdseye_size: int = 400 # square bird-eye output image side (px)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Result ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class PathEdgesResult(NamedTuple):
|
||||||
|
lines: List[Tuple[float, float, float, float]] # ROI coords
|
||||||
|
left_lines: List[Tuple[float, float, float, float]]
|
||||||
|
right_lines: List[Tuple[float, float, float, float]]
|
||||||
|
left_edge: Optional[Tuple[float, float, float, float]] # None if absent
|
||||||
|
right_edge: Optional[Tuple[float, float, float, float]]
|
||||||
|
birdseye_lines: List[Tuple[float, float, float, float]]
|
||||||
|
birdseye_left: Optional[Tuple[float, float, float, float]]
|
||||||
|
birdseye_right: Optional[Tuple[float, float, float, float]]
|
||||||
|
H: np.ndarray # 3×3 homography (ROI → bird-eye)
|
||||||
|
roi_top: int # y-offset of ROI in the full image
|
||||||
|
|
||||||
|
|
||||||
|
# ── Homography ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def build_homography(
|
||||||
|
src_frac: List[List[float]],
|
||||||
|
roi_w: int,
|
||||||
|
roi_h: int,
|
||||||
|
birdseye_size: int,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Compute a perspective homography from ROI pixel coords to a square
|
||||||
|
bird-eye image.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
src_frac : four [fx, fy] pairs as fractions of (roi_w, roi_h)
|
||||||
|
roi_w, roi_h : ROI dimensions in pixels
|
||||||
|
birdseye_size : side length of the square output image
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
H : (3, 3) float64 homography matrix; maps ROI px → bird-eye px.
|
||||||
|
"""
|
||||||
|
src = np.array(
|
||||||
|
[[fx * roi_w, fy * roi_h] for fx, fy in src_frac],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
S = float(birdseye_size)
|
||||||
|
dst = np.array([
|
||||||
|
[S * 0.25, 0.0],
|
||||||
|
[S * 0.75, 0.0],
|
||||||
|
[S * 0.75, S],
|
||||||
|
[S * 0.25, S],
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
H = cv2.getPerspectiveTransform(src, dst)
|
||||||
|
return H.astype(np.float64)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_homography(H: np.ndarray, points_xy: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Apply a 3×3 homography to an (N, 2) float array of 2-D points.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
(N, 2) float32 array of transformed points.
|
||||||
|
"""
|
||||||
|
if len(points_xy) == 0:
|
||||||
|
return np.empty((0, 2), dtype=np.float32)
|
||||||
|
|
||||||
|
pts = np.column_stack([points_xy, np.ones(len(points_xy))]) # (N, 3)
|
||||||
|
warped = (H @ pts.T).T # (N, 3)
|
||||||
|
warped /= warped[:, 2:3] # homogenise
|
||||||
|
return warped[:, :2].astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Edge detection ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def canny_edges(bgr_roi: np.ndarray, low: int, high: int, ksize: int) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Convert a BGR ROI to a Canny edge map.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
bgr_roi : uint8 BGR image (the bottom-half ROI)
|
||||||
|
low : Canny lower threshold
|
||||||
|
high : Canny upper threshold
|
||||||
|
ksize : Gaussian blur kernel size (must be odd ≥ 1)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
uint8 binary edge map (0 or 255), same spatial size as bgr_roi.
|
||||||
|
"""
|
||||||
|
grey = cv2.cvtColor(bgr_roi, cv2.COLOR_BGR2GRAY)
|
||||||
|
if ksize >= 3 and ksize % 2 == 1:
|
||||||
|
grey = cv2.GaussianBlur(grey, (ksize, ksize), 0)
|
||||||
|
return cv2.Canny(grey, low, high)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Hough lines ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def hough_lines(
|
||||||
|
edge_map: np.ndarray,
|
||||||
|
threshold: int = 30,
|
||||||
|
min_len: int = 40,
|
||||||
|
max_gap: int = 20,
|
||||||
|
rho: float = 1.0,
|
||||||
|
theta: float = math.pi / 180.0,
|
||||||
|
) -> List[Tuple[float, float, float, float]]:
|
||||||
|
"""
|
||||||
|
Run probabilistic Hough on an edge map.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List of (x1, y1, x2, y2) tuples in the edge_map pixel frame.
|
||||||
|
Empty list when no lines are found.
|
||||||
|
"""
|
||||||
|
raw = cv2.HoughLinesP(
|
||||||
|
edge_map,
|
||||||
|
rho=rho,
|
||||||
|
theta=theta,
|
||||||
|
threshold=threshold,
|
||||||
|
minLineLength=min_len,
|
||||||
|
maxLineGap=max_gap,
|
||||||
|
)
|
||||||
|
if raw is None:
|
||||||
|
return []
|
||||||
|
return [
|
||||||
|
(float(x1), float(y1), float(x2), float(y2))
|
||||||
|
for x1, y1, x2, y2 in raw[:, 0]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Line classification ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def classify_lines(
|
||||||
|
lines: List[Tuple[float, float, float, float]],
|
||||||
|
min_slope: float = 0.3,
|
||||||
|
) -> Tuple[List, List]:
|
||||||
|
"""
|
||||||
|
Classify Hough segments into left-edge and right-edge candidates.
|
||||||
|
|
||||||
|
In image coordinates (y increases downward):
|
||||||
|
- Left lane lines have NEGATIVE slope (upper-right → lower-left)
|
||||||
|
- Right lane lines have POSITIVE slope (upper-left → lower-right)
|
||||||
|
- Near-horizontal segments (|slope| < min_slope) are discarded
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
(left_lines, right_lines)
|
||||||
|
"""
|
||||||
|
left: List = []
|
||||||
|
right: List = []
|
||||||
|
for x1, y1, x2, y2 in lines:
|
||||||
|
dx = x2 - x1
|
||||||
|
if abs(dx) < 1e-6:
|
||||||
|
continue # vertical — skip to avoid division by zero
|
||||||
|
slope = (y2 - y1) / dx
|
||||||
|
if slope < -min_slope:
|
||||||
|
left.append((x1, y1, x2, y2))
|
||||||
|
elif slope > min_slope:
|
||||||
|
right.append((x1, y1, x2, y2))
|
||||||
|
return left, right
|
||||||
|
|
||||||
|
|
||||||
|
# ── Line averaging ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def average_line(
|
||||||
|
lines: List[Tuple[float, float, float, float]],
|
||||||
|
roi_height: int,
|
||||||
|
) -> Optional[Tuple[float, float, float, float]]:
|
||||||
|
"""
|
||||||
|
Average a list of collinear Hough segments into one representative line,
|
||||||
|
extrapolated to span the full ROI height (y=0 → y=roi_height-1).
|
||||||
|
|
||||||
|
Returns None if the list is empty or the averaged slope is near-zero.
|
||||||
|
"""
|
||||||
|
if not lines:
|
||||||
|
return None
|
||||||
|
|
||||||
|
slopes: List[float] = []
|
||||||
|
intercepts: List[float] = []
|
||||||
|
for x1, y1, x2, y2 in lines:
|
||||||
|
dx = x2 - x1
|
||||||
|
if abs(dx) < 1e-6:
|
||||||
|
continue
|
||||||
|
m = (y2 - y1) / dx
|
||||||
|
b = y1 - m * x1
|
||||||
|
slopes.append(m)
|
||||||
|
intercepts.append(b)
|
||||||
|
|
||||||
|
if not slopes:
|
||||||
|
return None
|
||||||
|
|
||||||
|
m_avg = float(np.mean(slopes))
|
||||||
|
b_avg = float(np.mean(intercepts))
|
||||||
|
if abs(m_avg) < 1e-6:
|
||||||
|
return None
|
||||||
|
|
||||||
|
y_bot = float(roi_height - 1)
|
||||||
|
y_top = 0.0
|
||||||
|
x_bot = (y_bot - b_avg) / m_avg
|
||||||
|
x_top = (y_top - b_avg) / m_avg
|
||||||
|
return (x_bot, y_bot, x_top, y_top)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Bird-eye segment warping ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def warp_segments(
|
||||||
|
lines: List[Tuple[float, float, float, float]],
|
||||||
|
H: np.ndarray,
|
||||||
|
) -> List[Tuple[float, float, float, float]]:
|
||||||
|
"""
|
||||||
|
Warp a list of line-segment endpoints through the homography H.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List of (bx1, by1, bx2, by2) in bird-eye pixel coordinates.
|
||||||
|
"""
|
||||||
|
if not lines:
|
||||||
|
return []
|
||||||
|
pts = np.array(
|
||||||
|
[[x1, y1] for x1, y1, _, _ in lines] +
|
||||||
|
[[x2, y2] for _, _, x2, y2 in lines],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
warped = apply_homography(H, pts)
|
||||||
|
n = len(lines)
|
||||||
|
return [
|
||||||
|
(float(warped[i, 0]), float(warped[i, 1]),
|
||||||
|
float(warped[i + n, 0]), float(warped[i + n, 1]))
|
||||||
|
for i in range(n)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Main entry point ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def process_frame(bgr: np.ndarray, cfg: PathEdgeConfig) -> PathEdgesResult:
|
||||||
|
"""
|
||||||
|
Run the full lane/path edge detection pipeline on one BGR frame.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
bgr : uint8 BGR image (H × W × 3)
|
||||||
|
cfg : PathEdgeConfig
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
PathEdgesResult with all detected edges in ROI + bird-eye coordinates.
|
||||||
|
"""
|
||||||
|
h, w = bgr.shape[:2]
|
||||||
|
roi_top = int(h * (1.0 - cfg.roi_frac))
|
||||||
|
roi = bgr[roi_top:, :]
|
||||||
|
roi_h, roi_w = roi.shape[:2]
|
||||||
|
|
||||||
|
# Build perspective homography (ROI px → bird-eye px)
|
||||||
|
H = build_homography(cfg.birdseye_src, roi_w, roi_h, cfg.birdseye_size)
|
||||||
|
|
||||||
|
# Edge detection
|
||||||
|
edges = canny_edges(roi, cfg.canny_low, cfg.canny_high, cfg.blur_ksize)
|
||||||
|
|
||||||
|
# Hough line segments (in ROI coords)
|
||||||
|
lines = hough_lines(
|
||||||
|
edges,
|
||||||
|
threshold = cfg.hough_threshold,
|
||||||
|
min_len = cfg.min_line_len,
|
||||||
|
max_gap = cfg.max_line_gap,
|
||||||
|
rho = cfg.hough_rho,
|
||||||
|
theta = cfg.hough_theta,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Classify and average
|
||||||
|
left_lines, right_lines = classify_lines(lines, cfg.min_slope)
|
||||||
|
left_edge = average_line(left_lines, roi_h)
|
||||||
|
right_edge = average_line(right_lines, roi_h)
|
||||||
|
|
||||||
|
# Warp all segments to bird-eye
|
||||||
|
birdseye_lines = warp_segments(lines, H)
|
||||||
|
birdseye_left = (warp_segments([left_edge], H)[0] if left_edge else None)
|
||||||
|
birdseye_right = (warp_segments([right_edge], H)[0] if right_edge else None)
|
||||||
|
|
||||||
|
return PathEdgesResult(
|
||||||
|
lines = lines,
|
||||||
|
left_lines = left_lines,
|
||||||
|
right_lines = right_lines,
|
||||||
|
left_edge = left_edge,
|
||||||
|
right_edge = right_edge,
|
||||||
|
birdseye_lines = birdseye_lines,
|
||||||
|
birdseye_left = birdseye_left,
|
||||||
|
birdseye_right = birdseye_right,
|
||||||
|
H = H,
|
||||||
|
roi_top = roi_top,
|
||||||
|
)
|
||||||
@ -0,0 +1,729 @@
|
|||||||
|
"""
|
||||||
|
_person_tracker.py — Multi-person tracker for follow-me mode (no ROS2 deps).
|
||||||
|
|
||||||
|
Pipeline (called once per colour frame)
|
||||||
|
-----------------------------------------
|
||||||
|
1. Person detections (bounding boxes + detector confidence) arrive from an
|
||||||
|
external detector (YOLOv8n, MobileNetSSD, etc.) — not this module's concern.
|
||||||
|
2. Active tracks are predicted one step forward via a constant-velocity Kalman.
|
||||||
|
3. Detections are matched to predicted tracks with greedy IoU ≥ iou_threshold.
|
||||||
|
4. Unmatched detections that survive re-ID histogram matching against LOST
|
||||||
|
tracks are reattached to those tracks (brief-occlusion recovery).
|
||||||
|
5. Remaining unmatched detections start new TENTATIVE tracks.
|
||||||
|
6. Tracks not updated for max_lost_frames are removed permanently.
|
||||||
|
7. Bearing and range to a designated follow target are derived from camera
|
||||||
|
intrinsics + an aligned depth image.
|
||||||
|
|
||||||
|
Re-identification
|
||||||
|
-----------------
|
||||||
|
Each track stores an HSV colour histogram of the person's torso region
|
||||||
|
(middle 50 % of bbox height, centre 80 % width). After occlusion, new
|
||||||
|
detections whose histogram Bhattacharyya similarity exceeds reid_threshold
|
||||||
|
*and* whose predicted position is within reid_max_dist pixels are candidates
|
||||||
|
for re-identification. Closest histogram match wins.
|
||||||
|
|
||||||
|
Kalman state (8-D, one per track)
|
||||||
|
----------------------------------
|
||||||
|
x = [cx, cy, w, h, vcx, vcy, vw, vh]
|
||||||
|
Measurement z = [cx, cy, w, h]
|
||||||
|
dt = 1 frame (constant velocity model)
|
||||||
|
|
||||||
|
Public API
|
||||||
|
----------
|
||||||
|
BBox NamedTuple (x, y, w, h) — pixel coordinates
|
||||||
|
Detection NamedTuple (bbox, confidence, frame_bgr)
|
||||||
|
TrackState Enum TENTATIVE / ACTIVE / LOST
|
||||||
|
PersonTrack dataclass per-track state snapshot
|
||||||
|
PersonTracker class .update() → list[PersonTrack]
|
||||||
|
|
||||||
|
iou(a, b) → float
|
||||||
|
bearing_from_pixel(u, cx_px, fx) → float (degrees)
|
||||||
|
depth_at_bbox(depth_u16, bbox, …) → (depth_m, quality)
|
||||||
|
extract_torso_hist(bgr, bbox, …) → ndarray | None
|
||||||
|
hist_similarity(h1, h2) → float (0 = different, 1 = same)
|
||||||
|
KalmanBoxFilter class
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import IntEnum
|
||||||
|
from typing import List, NamedTuple, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# ── Simple types ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class BBox(NamedTuple):
|
||||||
|
"""Axis-aligned bounding box in pixel coordinates."""
|
||||||
|
x: int # left edge
|
||||||
|
y: int # top edge
|
||||||
|
w: int # width (≥ 1)
|
||||||
|
h: int # height (≥ 1)
|
||||||
|
|
||||||
|
|
||||||
|
class Detection(NamedTuple):
|
||||||
|
"""One person detection from an external detector."""
|
||||||
|
bbox: BBox
|
||||||
|
confidence: float # 0–1 detector score
|
||||||
|
frame_bgr: Optional[np.ndarray] = None # colour frame (for histogram)
|
||||||
|
|
||||||
|
|
||||||
|
class TrackState(IntEnum):
|
||||||
|
TENTATIVE = 0 # seen < min_hits frames; not yet published
|
||||||
|
ACTIVE = 1 # confirmed; published to follow-me controller
|
||||||
|
LOST = 2 # missing; still kept for re-ID up to max_lost_frames
|
||||||
|
|
||||||
|
|
||||||
|
# ── Depth quality levels ──────────────────────────────────────────────────────
|
||||||
|
DEPTH_INVALID = 0
|
||||||
|
DEPTH_EXTRAPOLATED = 1
|
||||||
|
DEPTH_MARGINAL = 2
|
||||||
|
DEPTH_GOOD = 3
|
||||||
|
|
||||||
|
|
||||||
|
# ── IoU ───────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def iou(a: BBox, b: BBox) -> float:
|
||||||
|
"""Intersection-over-union of two bounding boxes."""
|
||||||
|
ax2, ay2 = a.x + a.w, a.y + a.h
|
||||||
|
bx2, by2 = b.x + b.w, b.y + b.h
|
||||||
|
ix1 = max(a.x, b.x); iy1 = max(a.y, b.y)
|
||||||
|
ix2 = min(ax2, bx2); iy2 = min(ay2, by2)
|
||||||
|
inter = max(0, ix2 - ix1) * max(0, iy2 - iy1)
|
||||||
|
union = a.w * a.h + b.w * b.h - inter
|
||||||
|
return float(inter) / max(float(union), 1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Kalman box filter ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class KalmanBoxFilter:
|
||||||
|
"""
|
||||||
|
Constant-velocity 8-state Kalman filter for bounding-box tracking.
|
||||||
|
|
||||||
|
State x = [cx, cy, w, h, vcx, vcy, vw, vh]
|
||||||
|
Meas z = [cx, cy, w, h]
|
||||||
|
|
||||||
|
Process noise Q: position uncertainty σ=0.1 px/frame²,
|
||||||
|
velocity uncertainty σ=10 px/frame
|
||||||
|
Measurement noise R: ±2 px std-dev on bbox edges
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, initial_bbox: BBox) -> None:
|
||||||
|
# Transition matrix (dt = 1 frame)
|
||||||
|
self._F = np.eye(8, dtype=np.float64)
|
||||||
|
self._F[:4, 4:] = np.eye(4)
|
||||||
|
|
||||||
|
# Measurement matrix
|
||||||
|
self._H = np.zeros((4, 8), dtype=np.float64)
|
||||||
|
self._H[:4, :4] = np.eye(4)
|
||||||
|
|
||||||
|
# Process noise
|
||||||
|
self._Q = np.diag([1., 1., 1., 1., 100., 100., 100., 100.]).astype(np.float64)
|
||||||
|
|
||||||
|
# Measurement noise (~2 px std-dev → var = 4)
|
||||||
|
self._R = np.eye(4, dtype=np.float64) * 4.0
|
||||||
|
|
||||||
|
# Initial state
|
||||||
|
cx = initial_bbox.x + initial_bbox.w * 0.5
|
||||||
|
cy = initial_bbox.y + initial_bbox.h * 0.5
|
||||||
|
self._x = np.array(
|
||||||
|
[cx, cy, float(initial_bbox.w), float(initial_bbox.h), 0., 0., 0., 0.],
|
||||||
|
dtype=np.float64,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initial covariance — small pos uncertainty, large velocity uncertainty
|
||||||
|
self._P = np.diag([10., 10., 10., 10., 1000., 1000., 1000., 1000.]).astype(np.float64)
|
||||||
|
|
||||||
|
# -- predict ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def predict(self) -> BBox:
|
||||||
|
"""Advance state one step forward; return predicted BBox."""
|
||||||
|
self._x = self._F @ self._x
|
||||||
|
self._P = self._F @ self._P @ self._F.T + self._Q
|
||||||
|
return self._to_bbox()
|
||||||
|
|
||||||
|
# -- update ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def update(self, bbox: BBox) -> BBox:
|
||||||
|
"""Correct state with observation; return corrected BBox."""
|
||||||
|
z = np.array(
|
||||||
|
[bbox.x + bbox.w * 0.5, bbox.y + bbox.h * 0.5,
|
||||||
|
float(bbox.w), float(bbox.h)],
|
||||||
|
dtype=np.float64,
|
||||||
|
)
|
||||||
|
y = z - self._H @ self._x
|
||||||
|
S = self._H @ self._P @ self._H.T + self._R
|
||||||
|
K = self._P @ self._H.T @ np.linalg.inv(S)
|
||||||
|
self._x = self._x + K @ y
|
||||||
|
self._P = (np.eye(8) - K @ self._H) @ self._P
|
||||||
|
return self._to_bbox()
|
||||||
|
|
||||||
|
# -- accessors ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@property
|
||||||
|
def velocity_px(self) -> Tuple[float, float]:
|
||||||
|
"""(vcx, vcy) in pixels / frame from Kalman state."""
|
||||||
|
return float(self._x[4]), float(self._x[5])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bbox(self) -> BBox:
|
||||||
|
return self._to_bbox()
|
||||||
|
|
||||||
|
def _to_bbox(self) -> BBox:
|
||||||
|
cx, cy, w, h = self._x[:4]
|
||||||
|
w = max(1.0, w); h = max(1.0, h)
|
||||||
|
return BBox(
|
||||||
|
int(round(cx - w * 0.5)),
|
||||||
|
int(round(cy - h * 0.5)),
|
||||||
|
int(round(w)),
|
||||||
|
int(round(h)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Colour histogram (HSV torso) ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
_HIST_H_BINS = 16
|
||||||
|
_HIST_S_BINS = 8
|
||||||
|
_HIST_SIZE = _HIST_H_BINS * _HIST_S_BINS # 128
|
||||||
|
|
||||||
|
|
||||||
|
def extract_torso_hist(
|
||||||
|
bgr: np.ndarray,
|
||||||
|
bbox: BBox,
|
||||||
|
h_bins: int = _HIST_H_BINS,
|
||||||
|
s_bins: int = _HIST_S_BINS,
|
||||||
|
) -> Optional[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Extract a normalised HSV colour histogram from the torso region of a
|
||||||
|
person bounding box.
|
||||||
|
|
||||||
|
Torso region: middle 50 % of bbox height, centre 80 % of bbox width.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
bgr : (H, W, 3) uint8 colour image
|
||||||
|
bbox : person bounding box
|
||||||
|
h_bins, s_bins : histogram bins for H and S channels
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Normalised 1-D histogram of length h_bins * s_bins, or None if the
|
||||||
|
crop is too small or bgr is None.
|
||||||
|
"""
|
||||||
|
if bgr is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
ih, iw = bgr.shape[:2]
|
||||||
|
|
||||||
|
# Torso crop
|
||||||
|
y0 = bbox.y + bbox.h // 4 # skip top 25 % (head)
|
||||||
|
y1 = bbox.y + bbox.h * 3 // 4 # skip bottom 25 % (legs)
|
||||||
|
x0 = bbox.x + bbox.w // 10
|
||||||
|
x1 = bbox.x + bbox.w * 9 // 10
|
||||||
|
|
||||||
|
y0 = max(0, y0); y1 = min(ih, y1)
|
||||||
|
x0 = max(0, x0); x1 = min(iw, x1)
|
||||||
|
|
||||||
|
if y1 - y0 < 4 or x1 - x0 < 4:
|
||||||
|
return None
|
||||||
|
|
||||||
|
crop = bgr[y0:y1, x0:x1]
|
||||||
|
|
||||||
|
# BGR → HSV (manual, no OpenCV dependency for tests)
|
||||||
|
crop_f = crop.astype(np.float32) / 255.0
|
||||||
|
r, g, b = crop_f[..., 2], crop_f[..., 1], crop_f[..., 0]
|
||||||
|
|
||||||
|
cmax = np.maximum(np.maximum(r, g), b)
|
||||||
|
cmin = np.minimum(np.minimum(r, g), b)
|
||||||
|
delta_h = cmax - cmin + 1e-7 # epsilon only for hue angle computation
|
||||||
|
|
||||||
|
# Hue in [0, 360)
|
||||||
|
h = np.where(
|
||||||
|
cmax == r, 60.0 * ((g - b) / delta_h % 6),
|
||||||
|
np.where(cmax == g, 60.0 * ((b - r) / delta_h + 2),
|
||||||
|
60.0 * ((r - g) / delta_h + 4)),
|
||||||
|
)
|
||||||
|
h = np.clip(h % 360.0, 0.0, 359.9999)
|
||||||
|
# Saturation in [0, 1] — no epsilon so pure-colour pixels stay ≤ 1.0
|
||||||
|
s = np.clip(np.where(cmax > 1e-6, (cmax - cmin) / cmax, 0.0), 0.0, 1.0)
|
||||||
|
|
||||||
|
h_flat = h.ravel()
|
||||||
|
s_flat = s.ravel()
|
||||||
|
|
||||||
|
hist, _, _ = np.histogram2d(
|
||||||
|
h_flat, s_flat,
|
||||||
|
bins=[h_bins, s_bins],
|
||||||
|
range=[[0, 360], [0, 1]],
|
||||||
|
)
|
||||||
|
hist = hist.ravel().astype(np.float32)
|
||||||
|
total = hist.sum()
|
||||||
|
if total > 0:
|
||||||
|
hist /= total
|
||||||
|
return hist
|
||||||
|
|
||||||
|
|
||||||
|
def hist_similarity(h1: np.ndarray, h2: np.ndarray) -> float:
|
||||||
|
"""
|
||||||
|
Bhattacharyya similarity between two normalised histograms.
|
||||||
|
|
||||||
|
Returns a value in [0, 1]: 1 = identical, 0 = completely different.
|
||||||
|
"""
|
||||||
|
bc = float(np.sum(np.sqrt(h1 * h2)))
|
||||||
|
return float(np.clip(bc, 0.0, 1.0))
|
||||||
|
|
||||||
|
|
||||||
|
# ── Camera geometry ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def bearing_from_pixel(
|
||||||
|
u: float,
|
||||||
|
cx_px: float,
|
||||||
|
fx: float,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Convert a horizontal pixel coordinate to a bearing angle.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
u : pixel column (horizontal image coordinate)
|
||||||
|
cx_px : principal point x (from CameraInfo.K[2])
|
||||||
|
fx : horizontal focal length in pixels (from CameraInfo.K[0])
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bearing_deg : signed degrees; positive = right of camera centre.
|
||||||
|
"""
|
||||||
|
return math.degrees(math.atan2(float(u - cx_px), float(fx)))
|
||||||
|
|
||||||
|
|
||||||
|
def depth_at_bbox(
|
||||||
|
depth_u16: np.ndarray,
|
||||||
|
bbox: BBox,
|
||||||
|
depth_scale: float = 0.001,
|
||||||
|
window_frac: float = 0.3,
|
||||||
|
) -> Tuple[float, int]:
|
||||||
|
"""
|
||||||
|
Sample median depth from the central torso region of a bounding box in a
|
||||||
|
uint16 depth image (D435i mm units by default).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
depth_u16 : (H, W) uint16 depth image
|
||||||
|
bbox : person bounding box (colour image coordinates, assumed aligned)
|
||||||
|
depth_scale : multiply raw value to get metres (D435i: 0.001)
|
||||||
|
window_frac : fraction of bbox dimensions to use as central sample window
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
(depth_m, quality)
|
||||||
|
depth_m : median depth in metres (0.0 when no valid pixels)
|
||||||
|
quality : DEPTH_GOOD / DEPTH_MARGINAL / DEPTH_EXTRAPOLATED / DEPTH_INVALID
|
||||||
|
"""
|
||||||
|
ih, iw = depth_u16.shape
|
||||||
|
|
||||||
|
# Central window
|
||||||
|
wf = max(0.1, min(1.0, window_frac))
|
||||||
|
cx = bbox.x + bbox.w * 0.5
|
||||||
|
cy = bbox.y + bbox.h * 0.5
|
||||||
|
hw = bbox.w * wf * 0.5
|
||||||
|
hh = bbox.h * wf * 0.5
|
||||||
|
|
||||||
|
r0 = int(max(0, cy - hh))
|
||||||
|
r1 = int(min(ih, cy + hh + 1))
|
||||||
|
c0 = int(max(0, cx - hw))
|
||||||
|
c1 = int(min(iw, cx + hw + 1))
|
||||||
|
|
||||||
|
if r1 <= r0 or c1 <= c0:
|
||||||
|
return 0.0, DEPTH_INVALID
|
||||||
|
|
||||||
|
patch = depth_u16[r0:r1, c0:c1]
|
||||||
|
valid = patch[patch > 0]
|
||||||
|
n_total = patch.size
|
||||||
|
|
||||||
|
if len(valid) == 0:
|
||||||
|
return 0.0, DEPTH_INVALID
|
||||||
|
|
||||||
|
fill_ratio = len(valid) / max(n_total, 1)
|
||||||
|
depth_m = float(np.median(valid)) * depth_scale
|
||||||
|
|
||||||
|
if fill_ratio > 0.6:
|
||||||
|
quality = DEPTH_GOOD
|
||||||
|
elif fill_ratio > 0.25:
|
||||||
|
quality = DEPTH_MARGINAL
|
||||||
|
else:
|
||||||
|
quality = DEPTH_EXTRAPOLATED
|
||||||
|
|
||||||
|
return depth_m, quality
|
||||||
|
|
||||||
|
|
||||||
|
# ── Per-track state ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PersonTrack:
|
||||||
|
"""Current state of one tracked person."""
|
||||||
|
track_id: int
|
||||||
|
state: TrackState
|
||||||
|
bbox: BBox # smoothed Kalman bbox (colour img px)
|
||||||
|
bearing: float = 0.0 # degrees; 0 until cam params set
|
||||||
|
distance: float = 0.0 # metres; 0 until depth available
|
||||||
|
depth_qual: int = DEPTH_INVALID
|
||||||
|
confidence: float = 0.0 # 0–1 combined score
|
||||||
|
vel_u: float = 0.0 # Kalman horizontal velocity (px/frame)
|
||||||
|
vel_v: float = 0.0 # Kalman vertical velocity (px/frame)
|
||||||
|
|
||||||
|
hits: int = 0 # consecutive matched frames
|
||||||
|
age: int = 0 # total frames since creation
|
||||||
|
lost_age: int = 0 # consecutive unmatched frames
|
||||||
|
|
||||||
|
color_hist: Optional[np.ndarray] = None # HSV torso histogram
|
||||||
|
|
||||||
|
# Internal — not serialised
|
||||||
|
_kalman: Optional[KalmanBoxFilter] = field(default=None, repr=False)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Camera parameters (minimal) ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CamParams:
|
||||||
|
"""Minimal camera intrinsics needed for bearing calculation."""
|
||||||
|
fx: float = 615.0 # D435i 640×480 depth defaults
|
||||||
|
fy: float = 615.0
|
||||||
|
cx: float = 320.0
|
||||||
|
cy: float = 240.0
|
||||||
|
fps: float = 30.0 # frame rate (for velocity conversion px→m/s)
|
||||||
|
|
||||||
|
|
||||||
|
# ── PersonTracker ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class PersonTracker:
|
||||||
|
"""
|
||||||
|
Multi-person tracker combining Kalman prediction, IoU data association,
|
||||||
|
and HSV colour histogram re-identification.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
iou_threshold : minimum IoU to consider a detection-track match
|
||||||
|
min_hits : frames before a track transitions TENTATIVE → ACTIVE
|
||||||
|
max_lost_frames : frames a track survives without a detection before removal
|
||||||
|
reid_threshold : minimum histogram Bhattacharyya similarity for re-ID
|
||||||
|
reid_max_dist : max predicted-to-detection centre distance for re-ID (px)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
iou_threshold: float = 0.30,
|
||||||
|
min_hits: int = 3,
|
||||||
|
max_lost_frames: int = 30,
|
||||||
|
reid_threshold: float = 0.55,
|
||||||
|
reid_max_dist: float = 150.0,
|
||||||
|
) -> None:
|
||||||
|
self.iou_threshold = iou_threshold
|
||||||
|
self.min_hits = min_hits
|
||||||
|
self.max_lost_frames = max_lost_frames
|
||||||
|
self.reid_threshold = reid_threshold
|
||||||
|
self.reid_max_dist = reid_max_dist
|
||||||
|
|
||||||
|
self._tracks: List[PersonTrack] = []
|
||||||
|
self._next_id: int = 0
|
||||||
|
|
||||||
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tracks(self) -> List[PersonTrack]:
|
||||||
|
return list(self._tracks)
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
detections: List[Detection],
|
||||||
|
cam: Optional[CamParams] = None,
|
||||||
|
depth_u16: Optional[np.ndarray] = None,
|
||||||
|
depth_scale: float = 0.001,
|
||||||
|
) -> List[PersonTrack]:
|
||||||
|
"""
|
||||||
|
Process one frame of detections.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
detections : list of Detection from an external detector
|
||||||
|
cam : camera intrinsics (for bearing computation); None = skip
|
||||||
|
depth_u16 : aligned uint16 depth image; None = depth unavailable
|
||||||
|
depth_scale : mm-to-metres scale factor (D435i default 0.001)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list of PersonTrack (ACTIVE state only)
|
||||||
|
"""
|
||||||
|
# ── Step 1: Predict all tracks ────────────────────────────────────────
|
||||||
|
for trk in self._tracks:
|
||||||
|
if trk._kalman is not None:
|
||||||
|
trk.bbox = trk._kalman.predict()
|
||||||
|
trk.age += 1
|
||||||
|
|
||||||
|
# ── Step 2: IoU matching ──────────────────────────────────────────────
|
||||||
|
active_tracks = [t for t in self._tracks if t.state != TrackState.LOST]
|
||||||
|
matched_t, matched_d, unmatched_t, unmatched_d = \
|
||||||
|
self._match_iou(active_tracks, detections)
|
||||||
|
|
||||||
|
# ── Step 3: Update matched tracks ────────────────────────────────────
|
||||||
|
for t_idx, d_idx in zip(matched_t, matched_d):
|
||||||
|
trk = active_tracks[t_idx]
|
||||||
|
det = detections[d_idx]
|
||||||
|
if trk._kalman is not None:
|
||||||
|
trk.bbox = trk._kalman.update(det.bbox)
|
||||||
|
trk.vel_u, trk.vel_v = trk._kalman.velocity_px
|
||||||
|
trk.hits += 1
|
||||||
|
trk.lost_age = 0
|
||||||
|
trk.confidence = float(det.confidence)
|
||||||
|
if trk.state == TrackState.TENTATIVE and trk.hits >= self.min_hits:
|
||||||
|
trk.state = TrackState.ACTIVE
|
||||||
|
self._update_hist(trk, det)
|
||||||
|
|
||||||
|
# ── Step 4: Re-ID for unmatched detections vs LOST tracks ────────────
|
||||||
|
lost_tracks = [t for t in self._tracks if t.state == TrackState.LOST]
|
||||||
|
still_unmatched_d = list(unmatched_d)
|
||||||
|
|
||||||
|
for d_idx in list(still_unmatched_d):
|
||||||
|
det = detections[d_idx]
|
||||||
|
best_trk, best_sim = self._reid_match(det, lost_tracks)
|
||||||
|
if best_trk is not None:
|
||||||
|
if best_trk._kalman is not None:
|
||||||
|
best_trk.bbox = best_trk._kalman.update(det.bbox)
|
||||||
|
else:
|
||||||
|
best_trk.bbox = det.bbox
|
||||||
|
best_trk.state = TrackState.ACTIVE
|
||||||
|
best_trk.hits += 1
|
||||||
|
best_trk.lost_age = 0
|
||||||
|
best_trk.confidence = float(det.confidence)
|
||||||
|
self._update_hist(best_trk, det)
|
||||||
|
still_unmatched_d.remove(d_idx)
|
||||||
|
|
||||||
|
# ── Step 5: Create new tracks for still-unmatched detections ─────────
|
||||||
|
for d_idx in still_unmatched_d:
|
||||||
|
det = detections[d_idx]
|
||||||
|
trk = PersonTrack(
|
||||||
|
track_id = self._next_id,
|
||||||
|
state = TrackState.TENTATIVE,
|
||||||
|
bbox = det.bbox,
|
||||||
|
hits = 1,
|
||||||
|
confidence= float(det.confidence),
|
||||||
|
_kalman = KalmanBoxFilter(det.bbox),
|
||||||
|
)
|
||||||
|
self._update_hist(trk, det)
|
||||||
|
self._tracks.append(trk)
|
||||||
|
self._next_id += 1
|
||||||
|
|
||||||
|
# ── Step 6: Age lost tracks, remove stale ────────────────────────────
|
||||||
|
# Mark newly-unmatched tracks as LOST (reset lost_age to 0)
|
||||||
|
for t_idx in unmatched_t:
|
||||||
|
trk = active_tracks[t_idx]
|
||||||
|
if trk.state != TrackState.LOST:
|
||||||
|
trk.lost_age = 0
|
||||||
|
trk.state = TrackState.LOST
|
||||||
|
|
||||||
|
# Increment lost_age for every LOST track (including previously LOST ones)
|
||||||
|
for trk in self._tracks:
|
||||||
|
if trk.state == TrackState.LOST:
|
||||||
|
trk.lost_age += 1
|
||||||
|
|
||||||
|
self._tracks = [
|
||||||
|
t for t in self._tracks
|
||||||
|
if t.lost_age < self.max_lost_frames
|
||||||
|
]
|
||||||
|
|
||||||
|
# ── Step 7: Update bearing / depth for all active tracks ─────────────
|
||||||
|
for trk in self._tracks:
|
||||||
|
if trk.state != TrackState.ACTIVE:
|
||||||
|
continue
|
||||||
|
u_centre = trk.bbox.x + trk.bbox.w * 0.5
|
||||||
|
if cam is not None:
|
||||||
|
trk.bearing = bearing_from_pixel(u_centre, cam.cx, cam.fx)
|
||||||
|
if depth_u16 is not None:
|
||||||
|
trk.distance, trk.depth_qual = depth_at_bbox(
|
||||||
|
depth_u16, trk.bbox, depth_scale=depth_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
return [t for t in self._tracks if t.state == TrackState.ACTIVE]
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Clear all tracks."""
|
||||||
|
self._tracks.clear()
|
||||||
|
self._next_id = 0
|
||||||
|
|
||||||
|
# ── Internal helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _match_iou(
|
||||||
|
self,
|
||||||
|
tracks: List[PersonTrack],
|
||||||
|
detections: List[Detection],
|
||||||
|
) -> Tuple[List[int], List[int], List[int], List[int]]:
|
||||||
|
"""
|
||||||
|
Greedy IoU matching between tracks and detections.
|
||||||
|
|
||||||
|
Returns (matched_t_idx, matched_d_idx, unmatched_t_idx, unmatched_d_idx).
|
||||||
|
"""
|
||||||
|
if not tracks or not detections:
|
||||||
|
return [], [], list(range(len(tracks))), list(range(len(detections)))
|
||||||
|
|
||||||
|
iou_mat = np.zeros((len(tracks), len(detections)), dtype=np.float32)
|
||||||
|
for ti, trk in enumerate(tracks):
|
||||||
|
for di, det in enumerate(detections):
|
||||||
|
iou_mat[ti, di] = iou(trk.bbox, det.bbox)
|
||||||
|
|
||||||
|
matched_t: List[int] = []
|
||||||
|
matched_d: List[int] = []
|
||||||
|
used_t = set()
|
||||||
|
used_d = set()
|
||||||
|
|
||||||
|
# Greedy: highest IoU first
|
||||||
|
flat_order = np.argsort(iou_mat.ravel())[::-1]
|
||||||
|
for flat_idx in flat_order:
|
||||||
|
ti, di = divmod(int(flat_idx), len(detections))
|
||||||
|
if iou_mat[ti, di] < self.iou_threshold:
|
||||||
|
break
|
||||||
|
if ti not in used_t and di not in used_d:
|
||||||
|
matched_t.append(ti)
|
||||||
|
matched_d.append(di)
|
||||||
|
used_t.add(ti)
|
||||||
|
used_d.add(di)
|
||||||
|
|
||||||
|
unmatched_t = [i for i in range(len(tracks)) if i not in used_t]
|
||||||
|
unmatched_d = [i for i in range(len(detections)) if i not in used_d]
|
||||||
|
return matched_t, matched_d, unmatched_t, unmatched_d
|
||||||
|
|
||||||
|
def _reid_match(
|
||||||
|
self,
|
||||||
|
det: Detection,
|
||||||
|
lost: List[PersonTrack],
|
||||||
|
) -> Tuple[Optional[PersonTrack], float]:
|
||||||
|
"""
|
||||||
|
Find the best re-identification match for a detection among lost tracks.
|
||||||
|
|
||||||
|
Returns (best_track, similarity) or (None, 0.0) if no match found.
|
||||||
|
"""
|
||||||
|
if not lost:
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
det_hist = None
|
||||||
|
if det.frame_bgr is not None:
|
||||||
|
det_hist = extract_torso_hist(det.frame_bgr, det.bbox)
|
||||||
|
|
||||||
|
best_trk: Optional[PersonTrack] = None
|
||||||
|
best_sim: float = 0.0
|
||||||
|
|
||||||
|
det_cx = det.bbox.x + det.bbox.w * 0.5
|
||||||
|
det_cy = det.bbox.y + det.bbox.h * 0.5
|
||||||
|
|
||||||
|
for trk in lost:
|
||||||
|
# Spatial gate: predicted centre must be close enough
|
||||||
|
trk_cx = trk.bbox.x + trk.bbox.w * 0.5
|
||||||
|
trk_cy = trk.bbox.y + trk.bbox.h * 0.5
|
||||||
|
dist = math.sqrt((det_cx - trk_cx) ** 2 + (det_cy - trk_cy) ** 2)
|
||||||
|
if dist > self.reid_max_dist:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Histogram similarity
|
||||||
|
if det_hist is not None and trk.color_hist is not None:
|
||||||
|
sim = hist_similarity(det_hist, trk.color_hist)
|
||||||
|
if sim > self.reid_threshold and sim > best_sim:
|
||||||
|
best_sim = sim
|
||||||
|
best_trk = trk
|
||||||
|
|
||||||
|
return best_trk, best_sim
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _update_hist(trk: PersonTrack, det: Detection) -> None:
|
||||||
|
"""Update track's colour histogram with exponential decay."""
|
||||||
|
if det.frame_bgr is None:
|
||||||
|
return
|
||||||
|
new_hist = extract_torso_hist(det.frame_bgr, det.bbox)
|
||||||
|
if new_hist is None:
|
||||||
|
return
|
||||||
|
if trk.color_hist is None:
|
||||||
|
trk.color_hist = new_hist
|
||||||
|
else:
|
||||||
|
# Running average (α = 0.3 — new frame contributes 30 %)
|
||||||
|
trk.color_hist = 0.7 * trk.color_hist + 0.3 * new_hist
|
||||||
|
|
||||||
|
|
||||||
|
# ── Follow-target selector ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class FollowTargetSelector:
|
||||||
|
"""
|
||||||
|
Selects and locks onto a single PersonTrack as the follow target.
|
||||||
|
|
||||||
|
Strategy
|
||||||
|
--------
|
||||||
|
• On start() or when no target is locked: choose the nearest active track
|
||||||
|
(by depth distance, or by image-centre proximity when depth unavailable).
|
||||||
|
• Re-lock onto the same track_id each frame (continuous tracking).
|
||||||
|
• If the locked track disappears: hold the last known state for
|
||||||
|
`hold_frames` frames, then go inactive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hold_frames: int = 15) -> None:
|
||||||
|
self.hold_frames = hold_frames
|
||||||
|
self._target_id: Optional[int] = None
|
||||||
|
self._last_target: Optional[PersonTrack] = None
|
||||||
|
self._held_frames: int = 0
|
||||||
|
self._active: bool = False
|
||||||
|
|
||||||
|
# -- control ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
"""(Re-)enable follow mode; re-select target on next update."""
|
||||||
|
self._active = True
|
||||||
|
self._target_id = None
|
||||||
|
self._held_frames = 0
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Disable follow mode."""
|
||||||
|
self._active = False
|
||||||
|
self._target_id = None
|
||||||
|
self._last_target = None
|
||||||
|
|
||||||
|
# -- update ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
active_tracks: List[PersonTrack],
|
||||||
|
img_cx: float = 320.0, # image centre x (px)
|
||||||
|
) -> Optional[PersonTrack]:
|
||||||
|
"""
|
||||||
|
Select the follow target from a list of active tracks.
|
||||||
|
|
||||||
|
Returns the locked PersonTrack, or None if follow mode is inactive or
|
||||||
|
no candidate found.
|
||||||
|
"""
|
||||||
|
if not self._active:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not active_tracks:
|
||||||
|
if self._last_target is not None and self._held_frames < self.hold_frames:
|
||||||
|
self._held_frames += 1
|
||||||
|
return self._last_target
|
||||||
|
self._last_target = None
|
||||||
|
return None
|
||||||
|
|
||||||
|
self._held_frames = 0
|
||||||
|
|
||||||
|
# Re-find locked track
|
||||||
|
if self._target_id is not None:
|
||||||
|
for t in active_tracks:
|
||||||
|
if t.track_id == self._target_id:
|
||||||
|
self._last_target = t
|
||||||
|
return t
|
||||||
|
# Locked track lost — re-select
|
||||||
|
self._target_id = None
|
||||||
|
|
||||||
|
# Select: prefer by smallest distance, then by image-centre proximity
|
||||||
|
def _score(t: PersonTrack) -> float:
|
||||||
|
if t.distance > 0:
|
||||||
|
return t.distance
|
||||||
|
return abs((t.bbox.x + t.bbox.w * 0.5) - img_cx)
|
||||||
|
|
||||||
|
chosen = min(active_tracks, key=_score)
|
||||||
|
self._target_id = chosen.track_id
|
||||||
|
self._last_target = chosen
|
||||||
|
return chosen
|
||||||
@ -0,0 +1,411 @@
|
|||||||
|
"""
|
||||||
|
_uwb_tracker.py — UWB DW3000 anchor/tag ranging + bearing estimation (Issue #365).
|
||||||
|
|
||||||
|
Pure-Python library (no ROS2 / hardware dependencies) so it can be fully unit-tested
|
||||||
|
on a development machine without the physical ESP32-UWB-Pro anchors attached.
|
||||||
|
|
||||||
|
Hardware layout
|
||||||
|
---------------
|
||||||
|
Two DW3000 anchors are mounted on the SaltyBot chassis, separated by a ~25 cm
|
||||||
|
baseline, both facing forward:
|
||||||
|
|
||||||
|
anchor0 (left, x = −baseline/2) anchor1 (right, x = +baseline/2)
|
||||||
|
│←──────── baseline ────────→│
|
||||||
|
↑
|
||||||
|
robot forward
|
||||||
|
|
||||||
|
The wearable tag is carried by the Tee (person being followed).
|
||||||
|
|
||||||
|
Bearing estimation
|
||||||
|
------------------
|
||||||
|
Given TWR (two-way ranging) distances d0 and d1 from the two anchors to the tag,
|
||||||
|
and the known baseline B between the anchors, we compute bearing θ using the
|
||||||
|
law of cosines:
|
||||||
|
|
||||||
|
cos α = (d0² - d1² + B²) / (2 · B · d1) [angle at anchor1]
|
||||||
|
|
||||||
|
Bearing is measured from the perpendicular bisector of the baseline (i.e. the
|
||||||
|
robot's forward axis):
|
||||||
|
|
||||||
|
θ = 90° − α
|
||||||
|
|
||||||
|
Positive θ = tag is to the right, negative = tag is to the left.
|
||||||
|
|
||||||
|
The formula degrades when the tag is very close to the baseline or at extreme
|
||||||
|
angles. We report a confidence value that penalises these conditions.
|
||||||
|
|
||||||
|
Serial protocol (ESP32 → Orin via USB-serial)
|
||||||
|
---------------------------------------------
|
||||||
|
Each anchor sends newline-terminated ASCII frames at ≥10 Hz:
|
||||||
|
|
||||||
|
RANGE,<anchor_id>,<tag_id>,<distance_mm>\n
|
||||||
|
|
||||||
|
Example:
|
||||||
|
RANGE,0,T0,1532\n → anchor 0, tag T0, distance 1532 mm
|
||||||
|
RANGE,1,T0,1748\n → anchor 1, tag T0, distance 1748 mm
|
||||||
|
|
||||||
|
A STATUS frame is sent once on connection:
|
||||||
|
STATUS,<anchor_id>,OK\n
|
||||||
|
|
||||||
|
The node opens two serial ports (one per anchor). A background thread reads
|
||||||
|
each port and updates a shared RangingState.
|
||||||
|
|
||||||
|
Kalman filter
|
||||||
|
-------------
|
||||||
|
A simple 1-D constant-velocity Kalman filter smooths the bearing output.
|
||||||
|
State: [bearing_deg, bearing_rate_dps].
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# ── Constants ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Fix quality codes (written to UwbTarget.fix_quality)
|
||||||
|
FIX_NONE = 0
|
||||||
|
FIX_SINGLE = 1 # only one anchor responding
|
||||||
|
FIX_DUAL = 2 # both anchors responding — full bearing estimate
|
||||||
|
|
||||||
|
# Maximum age of a ranging measurement before it is considered stale (seconds)
|
||||||
|
_STALE_S = 0.5
|
||||||
|
|
||||||
|
# ── Serial protocol parsing ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_RANGE_RE = re.compile(r'^RANGE,(\d+),(\w+),([\d.]+)\s*$')
|
||||||
|
_STATUS_RE = re.compile(r'^STATUS,(\d+),(\w+)\s*$')
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RangeFrame:
|
||||||
|
"""Decoded RANGE frame from a DW3000 anchor."""
|
||||||
|
anchor_id: int
|
||||||
|
tag_id: str
|
||||||
|
distance_m: float
|
||||||
|
timestamp: float = field(default_factory=time.monotonic)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_frame(line: str) -> Optional[RangeFrame]:
|
||||||
|
"""
|
||||||
|
Parse one ASCII line from the anchor serial stream.
|
||||||
|
|
||||||
|
Returns a RangeFrame on success, None for STATUS/unknown/malformed lines.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
line : raw ASCII line (may include trailing whitespace / CR)
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> f = parse_frame("RANGE,0,T0,1532")
|
||||||
|
>>> f.anchor_id, f.tag_id, f.distance_m
|
||||||
|
(0, 'T0', 1.532)
|
||||||
|
>>> parse_frame("STATUS,0,OK") is None
|
||||||
|
True
|
||||||
|
>>> parse_frame("GARBAGE") is None
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
m = _RANGE_RE.match(line.strip())
|
||||||
|
if m is None:
|
||||||
|
return None
|
||||||
|
anchor_id = int(m.group(1))
|
||||||
|
tag_id = m.group(2)
|
||||||
|
distance_m = float(m.group(3)) / 1000.0 # mm → m
|
||||||
|
return RangeFrame(anchor_id=anchor_id, tag_id=tag_id, distance_m=distance_m)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Two-anchor bearing geometry ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def bearing_from_ranges(
|
||||||
|
d0: float,
|
||||||
|
d1: float,
|
||||||
|
baseline_m: float,
|
||||||
|
) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Compute bearing to tag from two anchor distances.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
d0 : distance from anchor-0 (left, at x = −baseline/2) to tag (m)
|
||||||
|
d1 : distance from anchor-1 (right, at x = +baseline/2) to tag (m)
|
||||||
|
baseline_m : separation between the two anchors (m)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
(bearing_deg, confidence)
|
||||||
|
bearing_deg : signed bearing in degrees (positive = right)
|
||||||
|
confidence : 0.0–1.0 quality estimate
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
Derived from law of cosines applied to the triangle
|
||||||
|
(anchor0, anchor1, tag):
|
||||||
|
cos(α) = (d0² − d1² + B²) / (2·B·d0) where α is angle at anchor0
|
||||||
|
Forward axis is the perpendicular bisector of the baseline, so:
|
||||||
|
bearing = 90° − α_at_anchor1
|
||||||
|
We use the symmetric formula through the midpoint:
|
||||||
|
x_tag = (d0² - d1²) / (2·B) [lateral offset from midpoint]
|
||||||
|
y_tag = sqrt(d0² - (x_tag + B/2)²) [forward distance]
|
||||||
|
bearing = atan2(x_tag, y_tag) * 180/π
|
||||||
|
"""
|
||||||
|
if baseline_m <= 0.0:
|
||||||
|
raise ValueError(f'baseline_m must be positive, got {baseline_m}')
|
||||||
|
if d0 <= 0.0 or d1 <= 0.0:
|
||||||
|
raise ValueError(f'distances must be positive, got d0={d0} d1={d1}')
|
||||||
|
|
||||||
|
B = baseline_m
|
||||||
|
# Lateral offset of tag from midpoint of baseline (positive = towards anchor1 = right)
|
||||||
|
x = (d0 * d0 - d1 * d1) / (2.0 * B)
|
||||||
|
|
||||||
|
# Forward distance (Pythagorean)
|
||||||
|
# anchor0 is at x_coord = -B/2 relative to midpoint
|
||||||
|
y_sq = d0 * d0 - (x + B / 2.0) ** 2
|
||||||
|
if y_sq < 0.0:
|
||||||
|
# Triangle inequality violated — noisy ranging; clamp to zero
|
||||||
|
y_sq = 0.0
|
||||||
|
y = math.sqrt(y_sq)
|
||||||
|
|
||||||
|
bearing_deg = math.degrees(math.atan2(x, max(y, 1e-3)))
|
||||||
|
|
||||||
|
# ── Confidence ───────────────────────────────────────────────────────────
|
||||||
|
# 1. Range agreement penalty — if |d0-d1| > sqrt(d0²+d1²) * factor, tag
|
||||||
|
# is almost directly on the baseline extension (extreme angle, poor geometry)
|
||||||
|
mean_d = (d0 + d1) / 2.0
|
||||||
|
diff = abs(d0 - d1)
|
||||||
|
# Geometric dilution: confidence falls as |bearing| approaches 90°
|
||||||
|
bearing_penalty = math.cos(math.radians(bearing_deg)) # 1.0 at 0°, 0.0 at 90°
|
||||||
|
bearing_penalty = max(0.0, bearing_penalty)
|
||||||
|
|
||||||
|
# 2. Distance sanity: if tag is unreasonably close (< B) confidence drops
|
||||||
|
dist_penalty = min(1.0, mean_d / max(B, 0.1))
|
||||||
|
|
||||||
|
confidence = float(np.clip(bearing_penalty * dist_penalty, 0.0, 1.0))
|
||||||
|
|
||||||
|
return bearing_deg, confidence
|
||||||
|
|
||||||
|
|
||||||
|
def bearing_single_anchor(
|
||||||
|
d0: float,
|
||||||
|
baseline_m: float = 0.25,
|
||||||
|
) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Fallback bearing when only one anchor is responding.
|
||||||
|
|
||||||
|
With a single anchor we cannot determine lateral position, so we return
|
||||||
|
bearing=0 (directly ahead) with a low confidence proportional to 1/distance.
|
||||||
|
This keeps the follow-me controller moving toward the target even in
|
||||||
|
degraded mode.
|
||||||
|
|
||||||
|
Returns (0.0, confidence) where confidence ≤ 0.3.
|
||||||
|
"""
|
||||||
|
# Single-anchor confidence: ≤ 0.3, decreases with distance
|
||||||
|
confidence = float(np.clip(0.3 * (2.0 / max(d0, 0.5)), 0.0, 0.3))
|
||||||
|
return 0.0, confidence
|
||||||
|
|
||||||
|
|
||||||
|
# ── Kalman filter for bearing smoothing ───────────────────────────────────────
|
||||||
|
|
||||||
|
class BearingKalman:
|
||||||
|
"""
|
||||||
|
1-D constant-velocity Kalman filter for bearing smoothing.
|
||||||
|
|
||||||
|
State: [bearing_deg, bearing_rate_dps]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
process_noise_deg2: float = 10.0, # bearing process noise (deg²/frame)
|
||||||
|
process_noise_rate2: float = 100.0, # rate process noise
|
||||||
|
meas_noise_deg2: float = 4.0, # bearing measurement noise (deg²)
|
||||||
|
) -> None:
|
||||||
|
self._x = np.zeros(2, dtype=np.float64) # [bearing, rate]
|
||||||
|
self._P = np.diag([100.0, 1000.0]) # initial covariance
|
||||||
|
self._Q = np.diag([process_noise_deg2, process_noise_rate2])
|
||||||
|
self._R = np.array([[meas_noise_deg2]])
|
||||||
|
self._F = np.array([[1.0, 1.0], # state transition (dt=1 frame)
|
||||||
|
[0.0, 1.0]])
|
||||||
|
self._H = np.array([[1.0, 0.0]]) # observation: bearing only
|
||||||
|
self._initialised = False
|
||||||
|
|
||||||
|
def predict(self) -> float:
|
||||||
|
"""Predict next bearing; returns predicted bearing_deg."""
|
||||||
|
self._x = self._F @ self._x
|
||||||
|
self._P = self._F @ self._P @ self._F.T + self._Q
|
||||||
|
return float(self._x[0])
|
||||||
|
|
||||||
|
def update(self, bearing_deg: float) -> float:
|
||||||
|
"""
|
||||||
|
Update with a new bearing measurement.
|
||||||
|
|
||||||
|
On first call, seeds the filter state.
|
||||||
|
Returns smoothed bearing_deg.
|
||||||
|
"""
|
||||||
|
if not self._initialised:
|
||||||
|
self._x[0] = bearing_deg
|
||||||
|
self._initialised = True
|
||||||
|
return bearing_deg
|
||||||
|
|
||||||
|
self.predict()
|
||||||
|
|
||||||
|
y = np.array([[bearing_deg]]) - self._H @ self._x.reshape(-1, 1)
|
||||||
|
S = self._H @ self._P @ self._H.T + self._R
|
||||||
|
K = self._P @ self._H.T @ np.linalg.inv(S)
|
||||||
|
self._x = self._x + (K @ y).ravel()
|
||||||
|
self._P = (np.eye(2) - K @ self._H) @ self._P
|
||||||
|
return float(self._x[0])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bearing_rate_dps(self) -> float:
|
||||||
|
"""Current bearing rate estimate (degrees/second at nominal 10 Hz)."""
|
||||||
|
return float(self._x[1]) * 10.0 # per-frame rate → per-second
|
||||||
|
|
||||||
|
|
||||||
|
# ── Shared ranging state (thread-safe) ────────────────────────────────────────
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _AnchorState:
|
||||||
|
distance_m: float = 0.0
|
||||||
|
timestamp: float = 0.0
|
||||||
|
valid: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class UwbRangingState:
|
||||||
|
"""
|
||||||
|
Thread-safe store for the most recent ranging measurement from each anchor.
|
||||||
|
|
||||||
|
Updated by the serial reader threads, consumed by the ROS2 publish timer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
baseline_m: float = 0.25,
|
||||||
|
stale_timeout: float = _STALE_S,
|
||||||
|
) -> None:
|
||||||
|
self.baseline_m = baseline_m
|
||||||
|
self.stale_timeout = stale_timeout
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._anchors = [_AnchorState(), _AnchorState()]
|
||||||
|
self._kalman = BearingKalman()
|
||||||
|
|
||||||
|
def update_anchor(self, anchor_id: int, distance_m: float) -> None:
|
||||||
|
"""Record a new ranging measurement from anchor anchor_id (0 or 1)."""
|
||||||
|
if anchor_id not in (0, 1):
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
s = self._anchors[anchor_id]
|
||||||
|
s.distance_m = distance_m
|
||||||
|
s.timestamp = time.monotonic()
|
||||||
|
s.valid = True
|
||||||
|
|
||||||
|
def compute(self) -> 'UwbResult':
|
||||||
|
"""
|
||||||
|
Derive bearing, distance, confidence, and fix quality from latest ranges.
|
||||||
|
|
||||||
|
Called at the publish rate (≥10 Hz). Applies Kalman smoothing to bearing.
|
||||||
|
"""
|
||||||
|
now = time.monotonic()
|
||||||
|
with self._lock:
|
||||||
|
a0 = self._anchors[0]
|
||||||
|
a1 = self._anchors[1]
|
||||||
|
v0 = a0.valid and (now - a0.timestamp) < self.stale_timeout
|
||||||
|
v1 = a1.valid and (now - a1.timestamp) < self.stale_timeout
|
||||||
|
d0 = a0.distance_m
|
||||||
|
d1 = a1.distance_m
|
||||||
|
|
||||||
|
if not v0 and not v1:
|
||||||
|
return UwbResult(valid=False)
|
||||||
|
|
||||||
|
if v0 and v1:
|
||||||
|
bearing_raw, conf = bearing_from_ranges(d0, d1, self.baseline_m)
|
||||||
|
fix_quality = FIX_DUAL
|
||||||
|
distance_m = (d0 + d1) / 2.0
|
||||||
|
elif v0:
|
||||||
|
bearing_raw, conf = bearing_single_anchor(d0, self.baseline_m)
|
||||||
|
fix_quality = FIX_SINGLE
|
||||||
|
distance_m = d0
|
||||||
|
d1 = 0.0
|
||||||
|
else:
|
||||||
|
bearing_raw, conf = bearing_single_anchor(d1, self.baseline_m)
|
||||||
|
fix_quality = FIX_SINGLE
|
||||||
|
distance_m = d1
|
||||||
|
d0 = 0.0
|
||||||
|
|
||||||
|
bearing_smooth = self._kalman.update(bearing_raw)
|
||||||
|
|
||||||
|
return UwbResult(
|
||||||
|
valid = True,
|
||||||
|
bearing_deg = bearing_smooth,
|
||||||
|
distance_m = distance_m,
|
||||||
|
confidence = conf,
|
||||||
|
anchor0_dist = d0,
|
||||||
|
anchor1_dist = d1,
|
||||||
|
baseline_m = self.baseline_m,
|
||||||
|
fix_quality = fix_quality,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UwbResult:
|
||||||
|
"""Computed UWB fix — mirrors UwbTarget.msg fields."""
|
||||||
|
valid: bool = False
|
||||||
|
bearing_deg: float = 0.0
|
||||||
|
distance_m: float = 0.0
|
||||||
|
confidence: float = 0.0
|
||||||
|
anchor0_dist: float = 0.0
|
||||||
|
anchor1_dist: float = 0.0
|
||||||
|
baseline_m: float = 0.25
|
||||||
|
fix_quality: int = FIX_NONE
|
||||||
|
|
||||||
|
|
||||||
|
# ── Serial reader (runs in background thread) ──────────────────────────────────
|
||||||
|
|
||||||
|
class AnchorSerialReader:
|
||||||
|
"""
|
||||||
|
Background thread that reads from one anchor's serial port and calls
|
||||||
|
state.update_anchor() on each valid RANGE frame.
|
||||||
|
|
||||||
|
Designed to be used with a real serial.Serial object in production, or
|
||||||
|
any file-like object with a readline() method for testing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
anchor_id: int,
|
||||||
|
port, # serial.Serial or file-like (readline() interface)
|
||||||
|
state: UwbRangingState,
|
||||||
|
logger=None,
|
||||||
|
) -> None:
|
||||||
|
self._anchor_id = anchor_id
|
||||||
|
self._port = port
|
||||||
|
self._state = state
|
||||||
|
self._log = logger
|
||||||
|
self._running = False
|
||||||
|
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
self._running = True
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
def _run(self) -> None:
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
raw = self._port.readline()
|
||||||
|
if isinstance(raw, bytes):
|
||||||
|
raw = raw.decode('ascii', errors='replace')
|
||||||
|
frame = parse_frame(raw)
|
||||||
|
if frame is not None:
|
||||||
|
self._state.update_anchor(frame.anchor_id, frame.distance_m)
|
||||||
|
except Exception as e:
|
||||||
|
if self._log:
|
||||||
|
self._log.warn(f'UWB anchor {self._anchor_id} read error: {e}')
|
||||||
|
time.sleep(0.05)
|
||||||
@ -0,0 +1,194 @@
|
|||||||
|
"""
|
||||||
|
_velocity_ramp.py — Acceleration/deceleration limiter for cmd_vel (Issue #350).
|
||||||
|
|
||||||
|
Pure-Python library (no ROS2 dependencies) for full unit-test coverage.
|
||||||
|
|
||||||
|
Behaviour
|
||||||
|
---------
|
||||||
|
The VelocityRamp class applies independent rate limits to the linear-x and
|
||||||
|
angular-z components of a 2D velocity command:
|
||||||
|
|
||||||
|
- Linear acceleration limit : max_lin_accel (m/s²)
|
||||||
|
- Linear deceleration limit : max_lin_decel (m/s²) — may differ from accel
|
||||||
|
- Angular acceleration limit : max_ang_accel (rad/s²)
|
||||||
|
- Angular deceleration limit : max_ang_decel (rad/s²)
|
||||||
|
|
||||||
|
"Deceleration" applies when the magnitude of the velocity is *decreasing*
|
||||||
|
(including moving toward zero from either sign direction), while "acceleration"
|
||||||
|
applies when the magnitude is increasing.
|
||||||
|
|
||||||
|
Emergency stop
|
||||||
|
--------------
|
||||||
|
When both linear and angular targets are exactly 0.0 the ramp is bypassed and
|
||||||
|
the output is forced to (0.0, 0.0) immediately. This allows a watchdog or
|
||||||
|
safety node to halt the robot instantly without waiting for the deceleration
|
||||||
|
ramp.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
ramp = VelocityRamp(dt=0.02) # 50 Hz
|
||||||
|
out_lin, out_ang = ramp.step(1.0, 0.0) # target linear=1 m/s, angular=0
|
||||||
|
out_lin, out_ang = ramp.step(1.0, 0.0) # ramp climbs toward 1.0 m/s
|
||||||
|
out_lin, out_ang = ramp.step(0.0, 0.0) # emergency stop → (0.0, 0.0)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RampParams:
|
||||||
|
"""Acceleration / deceleration limits for one velocity axis."""
|
||||||
|
max_accel: float # magnitude / second — applied when |v| increasing
|
||||||
|
max_decel: float # magnitude / second — applied when |v| decreasing
|
||||||
|
|
||||||
|
|
||||||
|
def _ramp_axis(
|
||||||
|
current: float,
|
||||||
|
target: float,
|
||||||
|
params: RampParams,
|
||||||
|
dt: float,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Advance *current* one timestep toward *target* with rate limiting.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
current : current velocity on this axis
|
||||||
|
target : desired velocity on this axis
|
||||||
|
params : accel / decel limits
|
||||||
|
dt : timestep in seconds
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
New velocity, clamped so the change per step never exceeds the limits.
|
||||||
|
|
||||||
|
Notes on accel vs decel selection
|
||||||
|
----------------------------------
|
||||||
|
We treat the motion as decelerating when the target is closer to zero
|
||||||
|
than the current value, i.e. |target| < |current| or they have opposite
|
||||||
|
signs. In all other cases we use the acceleration limit.
|
||||||
|
"""
|
||||||
|
delta = target - current
|
||||||
|
|
||||||
|
# Choose limit: decel if velocity magnitude is falling, else accel.
|
||||||
|
is_decelerating = (
|
||||||
|
abs(target) < abs(current)
|
||||||
|
or (current > 0 and target < 0)
|
||||||
|
or (current < 0 and target > 0)
|
||||||
|
)
|
||||||
|
limit = params.max_decel if is_decelerating else params.max_accel
|
||||||
|
|
||||||
|
max_change = limit * dt
|
||||||
|
if abs(delta) <= max_change:
|
||||||
|
return target
|
||||||
|
return current + max_change * (1.0 if delta > 0 else -1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class VelocityRamp:
|
||||||
|
"""
|
||||||
|
Smooths a stream of (linear_x, angular_z) velocity commands by applying
|
||||||
|
configurable acceleration and deceleration limits.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
dt : timestep in seconds (must match the control loop rate)
|
||||||
|
max_lin_accel : maximum linear acceleration (m/s²)
|
||||||
|
max_lin_decel : maximum linear deceleration (m/s²); defaults to max_lin_accel
|
||||||
|
max_ang_accel : maximum angular acceleration (rad/s²)
|
||||||
|
max_ang_decel : maximum angular deceleration (rad/s²); defaults to max_ang_accel
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dt: float = 0.02, # 50 Hz
|
||||||
|
max_lin_accel: float = 0.5, # m/s²
|
||||||
|
max_lin_decel: float | None = None,
|
||||||
|
max_ang_accel: float = 1.0, # rad/s²
|
||||||
|
max_ang_decel: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
if dt <= 0:
|
||||||
|
raise ValueError(f'dt must be positive, got {dt}')
|
||||||
|
if max_lin_accel <= 0:
|
||||||
|
raise ValueError(f'max_lin_accel must be positive, got {max_lin_accel}')
|
||||||
|
if max_ang_accel <= 0:
|
||||||
|
raise ValueError(f'max_ang_accel must be positive, got {max_ang_accel}')
|
||||||
|
|
||||||
|
self._dt = dt
|
||||||
|
self._lin = RampParams(
|
||||||
|
max_accel=max_lin_accel,
|
||||||
|
max_decel=max_lin_decel if max_lin_decel is not None else max_lin_accel,
|
||||||
|
)
|
||||||
|
self._ang = RampParams(
|
||||||
|
max_accel=max_ang_accel,
|
||||||
|
max_decel=max_ang_decel if max_ang_decel is not None else max_ang_accel,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._cur_lin: float = 0.0
|
||||||
|
self._cur_ang: float = 0.0
|
||||||
|
|
||||||
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def step(self, target_lin: float, target_ang: float) -> tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Advance the ramp one timestep.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
target_lin : desired linear velocity (m/s)
|
||||||
|
target_ang : desired angular velocity (rad/s)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
(smoothed_linear, smoothed_angular) tuple.
|
||||||
|
|
||||||
|
If both target_lin and target_ang are exactly 0.0, an emergency stop
|
||||||
|
is applied and (0.0, 0.0) is returned immediately.
|
||||||
|
"""
|
||||||
|
# Emergency stop: bypass ramp entirely
|
||||||
|
if target_lin == 0.0 and target_ang == 0.0:
|
||||||
|
self._cur_lin = 0.0
|
||||||
|
self._cur_ang = 0.0
|
||||||
|
return 0.0, 0.0
|
||||||
|
|
||||||
|
self._cur_lin = _ramp_axis(self._cur_lin, target_lin, self._lin, self._dt)
|
||||||
|
self._cur_ang = _ramp_axis(self._cur_ang, target_ang, self._ang, self._dt)
|
||||||
|
return self._cur_lin, self._cur_ang
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset internal state to zero (e.g. on node restart or re-enable)."""
|
||||||
|
self._cur_lin = 0.0
|
||||||
|
self._cur_ang = 0.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_linear(self) -> float:
|
||||||
|
"""Current smoothed linear velocity (m/s)."""
|
||||||
|
return self._cur_lin
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_angular(self) -> float:
|
||||||
|
"""Current smoothed angular velocity (rad/s)."""
|
||||||
|
return self._cur_ang
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dt(self) -> float:
|
||||||
|
return self._dt
|
||||||
|
|
||||||
|
# ── Derived ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def steps_to_reach(self, target_lin: float, target_ang: float) -> int:
|
||||||
|
"""
|
||||||
|
Estimate how many steps() are needed to reach the target from the
|
||||||
|
current state (useful for test assertions).
|
||||||
|
|
||||||
|
This is an approximation based on the worst-case axis; it does not
|
||||||
|
account for the emergency-stop shortcut.
|
||||||
|
"""
|
||||||
|
lin_steps = 0
|
||||||
|
ang_steps = 0
|
||||||
|
if self._lin.max_accel > 0:
|
||||||
|
lin_steps = int(abs(target_lin - self._cur_lin) / (self._lin.max_accel * self._dt)) + 1
|
||||||
|
if self._ang.max_accel > 0:
|
||||||
|
ang_steps = int(abs(target_ang - self._cur_ang) / (self._ang.max_accel * self._dt)) + 1
|
||||||
|
return max(lin_steps, ang_steps)
|
||||||
@ -0,0 +1,215 @@
|
|||||||
|
"""
|
||||||
|
audio_scene_node.py — Audio scene classifier node (Issue #353).
|
||||||
|
|
||||||
|
Buffers raw PCM audio from the audio_common stack, runs a nearest-centroid
|
||||||
|
MFCC classifier at 1 Hz, and publishes the estimated environment label.
|
||||||
|
|
||||||
|
Subscribes
|
||||||
|
----------
|
||||||
|
/audio/audio audio_common_msgs/AudioData (BEST_EFFORT)
|
||||||
|
/audio/audio_info audio_common_msgs/AudioInfo (RELIABLE, latched)
|
||||||
|
|
||||||
|
Publishes
|
||||||
|
---------
|
||||||
|
/saltybot/audio_scene saltybot_scene_msgs/AudioScene (1 Hz)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sample_rate int 16000 Expected audio sample rate (Hz)
|
||||||
|
channels int 1 Expected number of channels (mono=1)
|
||||||
|
clip_duration_s float 1.0 Duration of each classification window (s)
|
||||||
|
publish_rate_hz float 1.0 Classification + publish rate (Hz)
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
* Audio data is expected as signed 16-bit little-endian PCM ('S16LE').
|
||||||
|
If AudioInfo arrives first with a different format, a warning is logged.
|
||||||
|
* If the audio buffer does not contain enough samples when the timer fires,
|
||||||
|
the existing buffer (padded with zeros to clip_duration_s) is used.
|
||||||
|
* The publish rate is independent of the audio callback rate; bursts and
|
||||||
|
gaps are handled gracefully.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import struct
|
||||||
|
from collections import deque
|
||||||
|
from typing import Deque
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import (
|
||||||
|
QoSProfile,
|
||||||
|
ReliabilityPolicy,
|
||||||
|
HistoryPolicy,
|
||||||
|
DurabilityPolicy,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from audio_common_msgs.msg import AudioData, AudioInfo
|
||||||
|
_HAVE_AUDIO_MSGS = True
|
||||||
|
except ImportError:
|
||||||
|
_HAVE_AUDIO_MSGS = False
|
||||||
|
|
||||||
|
from saltybot_scene_msgs.msg import AudioScene
|
||||||
|
from ._audio_scene import classify_audio, _N_FEATURES
|
||||||
|
|
||||||
|
|
||||||
|
_SENSOR_QOS = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=10,
|
||||||
|
)
|
||||||
|
_LATCHED_QOS = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.RELIABLE,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=1,
|
||||||
|
durability=DurabilityPolicy.TRANSIENT_LOCAL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioSceneNode(Node):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__('audio_scene_node')
|
||||||
|
|
||||||
|
# ── Parameters ──────────────────────────────────────────────────────
|
||||||
|
self.declare_parameter('sample_rate', 16_000)
|
||||||
|
self.declare_parameter('channels', 1)
|
||||||
|
self.declare_parameter('clip_duration_s', 1.0)
|
||||||
|
self.declare_parameter('publish_rate_hz', 1.0)
|
||||||
|
|
||||||
|
p = self.get_parameter
|
||||||
|
self._sr = int(p('sample_rate').value)
|
||||||
|
self._channels = int(p('channels').value)
|
||||||
|
self._clip_dur = float(p('clip_duration_s').value)
|
||||||
|
self._pub_rate = float(p('publish_rate_hz').value)
|
||||||
|
|
||||||
|
self._clip_samples = int(self._sr * self._clip_dur)
|
||||||
|
|
||||||
|
# ── Audio sample ring buffer ─────────────────────────────────────────
|
||||||
|
# Store float32 mono samples; deque with max size = 2× clip
|
||||||
|
self._buffer: Deque[float] = deque(
|
||||||
|
maxlen=self._clip_samples * 2
|
||||||
|
)
|
||||||
|
self._audio_format = 'S16LE' # updated from AudioInfo if available
|
||||||
|
|
||||||
|
# ── Subscribers ──────────────────────────────────────────────────────
|
||||||
|
if _HAVE_AUDIO_MSGS:
|
||||||
|
self._audio_sub = self.create_subscription(
|
||||||
|
AudioData, '/audio/audio',
|
||||||
|
self._on_audio, _SENSOR_QOS,
|
||||||
|
)
|
||||||
|
self._info_sub = self.create_subscription(
|
||||||
|
AudioInfo, '/audio/audio_info',
|
||||||
|
self._on_audio_info, _LATCHED_QOS,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.get_logger().warn(
|
||||||
|
'audio_common_msgs not available — no audio will be received. '
|
||||||
|
'Install ros-humble-audio-common-msgs to enable audio input.'
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Publisher + timer ────────────────────────────────────────────────
|
||||||
|
self._pub = self.create_publisher(
|
||||||
|
AudioScene, '/saltybot/audio_scene', 10
|
||||||
|
)
|
||||||
|
period = 1.0 / max(self._pub_rate, 0.01)
|
||||||
|
self._timer = self.create_timer(period, self._on_timer)
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f'audio_scene_node ready — '
|
||||||
|
f'sr={self._sr} Hz, clip={self._clip_dur:.1f}s, '
|
||||||
|
f'rate={self._pub_rate:.1f} Hz'
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Callbacks ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _on_audio_info(self, msg) -> None:
|
||||||
|
"""Latch audio format from AudioInfo."""
|
||||||
|
fmt = getattr(msg, 'sample_format', 'S16LE') or 'S16LE'
|
||||||
|
if fmt != self._audio_format:
|
||||||
|
self.get_logger().info(f'Audio format updated: {fmt!r}')
|
||||||
|
self._audio_format = fmt
|
||||||
|
reported_sr = getattr(msg, 'sample_rate', self._sr)
|
||||||
|
if reported_sr and reported_sr != self._sr:
|
||||||
|
self.get_logger().warn(
|
||||||
|
f'AudioInfo sample_rate={reported_sr} != '
|
||||||
|
f'param sample_rate={self._sr} — using param value',
|
||||||
|
throttle_duration_sec=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_audio(self, msg: 'AudioData') -> None:
|
||||||
|
"""Decode raw PCM bytes and push float32 mono samples to buffer."""
|
||||||
|
raw = bytes(msg.data)
|
||||||
|
if len(raw) < 2:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
n_samples = len(raw) // 2
|
||||||
|
samples_i16 = struct.unpack(f'<{n_samples}h', raw[:n_samples * 2])
|
||||||
|
except struct.error as exc:
|
||||||
|
self.get_logger().warn(
|
||||||
|
f'Failed to decode audio frame: {exc}',
|
||||||
|
throttle_duration_sec=5.0,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Convert to float32 mono
|
||||||
|
arr = np.array(samples_i16, dtype=np.float32) / 32768.0
|
||||||
|
if self._channels > 1:
|
||||||
|
# Downmix to mono: take channel 0
|
||||||
|
arr = arr[::self._channels]
|
||||||
|
|
||||||
|
self._buffer.extend(arr.tolist())
|
||||||
|
|
||||||
|
def _on_timer(self) -> None:
|
||||||
|
"""Classify buffered audio and publish result."""
|
||||||
|
n_buf = len(self._buffer)
|
||||||
|
if n_buf >= self._clip_samples:
|
||||||
|
# Take the most recent clip_samples
|
||||||
|
samples = np.array(
|
||||||
|
list(self._buffer)[-self._clip_samples:],
|
||||||
|
dtype=np.float64,
|
||||||
|
)
|
||||||
|
elif n_buf > 0:
|
||||||
|
# Pad with zeros at the start
|
||||||
|
samples = np.zeros(self._clip_samples, dtype=np.float64)
|
||||||
|
buf_arr = np.array(list(self._buffer), dtype=np.float64)
|
||||||
|
samples[-n_buf:] = buf_arr
|
||||||
|
else:
|
||||||
|
# No audio received — publish silence (indoor default)
|
||||||
|
samples = np.zeros(self._clip_samples, dtype=np.float64)
|
||||||
|
|
||||||
|
result = classify_audio(samples, sr=self._sr)
|
||||||
|
|
||||||
|
msg = AudioScene()
|
||||||
|
msg.header.stamp = self.get_clock().now().to_msg()
|
||||||
|
msg.header.frame_id = 'audio'
|
||||||
|
msg.label = result.label
|
||||||
|
msg.confidence = float(result.confidence)
|
||||||
|
feat_padded = np.zeros(_N_FEATURES, dtype=np.float32)
|
||||||
|
n = min(len(result.features), _N_FEATURES)
|
||||||
|
feat_padded[:n] = result.features[:n].astype(np.float32)
|
||||||
|
msg.features = feat_padded.tolist()
|
||||||
|
|
||||||
|
self._pub.publish(msg)
|
||||||
|
self.get_logger().debug(
|
||||||
|
f'audio_scene: {result.label!r} conf={result.confidence:.2f}'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = AudioSceneNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -0,0 +1,215 @@
|
|||||||
|
"""
|
||||||
|
camera_power_node.py — Adaptive camera power mode manager (Issue #375).
|
||||||
|
|
||||||
|
Subscribes to speed, scenario, and battery level inputs, runs the
|
||||||
|
CameraPowerFSM, and publishes camera enable/disable commands at 2 Hz.
|
||||||
|
|
||||||
|
Subscribes
|
||||||
|
----------
|
||||||
|
/saltybot/speed std_msgs/Float32 robot speed in m/s
|
||||||
|
/saltybot/scenario std_msgs/String operating scenario label
|
||||||
|
/saltybot/battery_pct std_msgs/Float32 battery charge level 0–100 %
|
||||||
|
|
||||||
|
Publishes
|
||||||
|
---------
|
||||||
|
/saltybot/camera_mode saltybot_scene_msgs/CameraPowerMode (2 Hz)
|
||||||
|
/saltybot/camera_cmd/front std_msgs/Bool — front CSI enable
|
||||||
|
/saltybot/camera_cmd/rear std_msgs/Bool — rear CSI enable
|
||||||
|
/saltybot/camera_cmd/left std_msgs/Bool — left CSI enable
|
||||||
|
/saltybot/camera_cmd/right std_msgs/Bool — right CSI enable
|
||||||
|
/saltybot/camera_cmd/realsense std_msgs/Bool — D435i enable
|
||||||
|
/saltybot/camera_cmd/lidar std_msgs/Bool — LIDAR enable
|
||||||
|
/saltybot/camera_cmd/uwb std_msgs/Bool — UWB enable
|
||||||
|
/saltybot/camera_cmd/webcam std_msgs/Bool — C920 webcam enable
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
rate_hz float 2.0 FSM evaluation and publish rate
|
||||||
|
downgrade_hold_s float 5.0 Seconds before a downgrade is applied
|
||||||
|
idle_to_social_s float 30.0 Idle time before AWARE→SOCIAL
|
||||||
|
battery_low_pct float 20.0 Battery threshold for AWARE cap
|
||||||
|
initial_mode int 0 Starting mode (0=SLEEP … 4=FULL)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy, DurabilityPolicy
|
||||||
|
|
||||||
|
from std_msgs.msg import Bool, Float32, String
|
||||||
|
from saltybot_scene_msgs.msg import CameraPowerMode
|
||||||
|
|
||||||
|
from ._camera_power_manager import (
|
||||||
|
CameraMode,
|
||||||
|
CameraPowerFSM,
|
||||||
|
Scenario,
|
||||||
|
MODE_SENSORS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_RELIABLE_LATCHED = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.RELIABLE,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=1,
|
||||||
|
durability=DurabilityPolicy.TRANSIENT_LOCAL,
|
||||||
|
)
|
||||||
|
_RELIABLE = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.RELIABLE,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=10,
|
||||||
|
)
|
||||||
|
_SENSOR_QOS = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Camera command topic suffixes and their ActiveSensors field name
|
||||||
|
_CAM_TOPICS = [
|
||||||
|
('front', 'csi_front'),
|
||||||
|
('rear', 'csi_rear'),
|
||||||
|
('left', 'csi_left'),
|
||||||
|
('right', 'csi_right'),
|
||||||
|
('realsense', 'realsense'),
|
||||||
|
('lidar', 'lidar'),
|
||||||
|
('uwb', 'uwb'),
|
||||||
|
('webcam', 'webcam'),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class CameraPowerNode(Node):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__('camera_power_node')
|
||||||
|
|
||||||
|
# ── Parameters ──────────────────────────────────────────────────────
|
||||||
|
self.declare_parameter('rate_hz', 2.0)
|
||||||
|
self.declare_parameter('downgrade_hold_s', 5.0)
|
||||||
|
self.declare_parameter('idle_to_social_s', 30.0)
|
||||||
|
self.declare_parameter('battery_low_pct', 20.0)
|
||||||
|
self.declare_parameter('initial_mode', 0)
|
||||||
|
|
||||||
|
p = self.get_parameter
|
||||||
|
rate_hz = max(float(p('rate_hz').value), 0.1)
|
||||||
|
downgrade_hold_s = max(float(p('downgrade_hold_s').value), 0.0)
|
||||||
|
idle_to_social_s = max(float(p('idle_to_social_s').value), 0.0)
|
||||||
|
battery_low_pct = float(p('battery_low_pct').value)
|
||||||
|
initial_mode = int(p('initial_mode').value)
|
||||||
|
|
||||||
|
# ── FSM ──────────────────────────────────────────────────────────────
|
||||||
|
self._fsm = CameraPowerFSM(
|
||||||
|
downgrade_hold_s = downgrade_hold_s,
|
||||||
|
idle_to_social_s = idle_to_social_s,
|
||||||
|
battery_low_pct = battery_low_pct,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
self._fsm.reset(CameraMode(initial_mode))
|
||||||
|
except ValueError:
|
||||||
|
self._fsm.reset(CameraMode.SLEEP)
|
||||||
|
|
||||||
|
# ── Input state ──────────────────────────────────────────────────────
|
||||||
|
self._speed_mps: float = 0.0
|
||||||
|
self._scenario: str = Scenario.UNKNOWN
|
||||||
|
self._battery_pct: float = 100.0
|
||||||
|
|
||||||
|
# ── Subscribers ──────────────────────────────────────────────────────
|
||||||
|
self._speed_sub = self.create_subscription(
|
||||||
|
Float32, '/saltybot/speed',
|
||||||
|
lambda m: setattr(self, '_speed_mps', max(0.0, float(m.data))),
|
||||||
|
_SENSOR_QOS,
|
||||||
|
)
|
||||||
|
self._scenario_sub = self.create_subscription(
|
||||||
|
String, '/saltybot/scenario',
|
||||||
|
lambda m: setattr(self, '_scenario', str(m.data)),
|
||||||
|
_RELIABLE,
|
||||||
|
)
|
||||||
|
self._battery_sub = self.create_subscription(
|
||||||
|
Float32, '/saltybot/battery_pct',
|
||||||
|
lambda m: setattr(self, '_battery_pct',
|
||||||
|
max(0.0, min(100.0, float(m.data)))),
|
||||||
|
_RELIABLE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Publishers ───────────────────────────────────────────────────────
|
||||||
|
self._mode_pub = self.create_publisher(
|
||||||
|
CameraPowerMode, '/saltybot/camera_mode', _RELIABLE_LATCHED,
|
||||||
|
)
|
||||||
|
self._cam_pubs: dict[str, rclpy.publisher.Publisher] = {}
|
||||||
|
for suffix, _ in _CAM_TOPICS:
|
||||||
|
self._cam_pubs[suffix] = self.create_publisher(
|
||||||
|
Bool, f'/saltybot/camera_cmd/{suffix}', _RELIABLE_LATCHED,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Timer ────────────────────────────────────────────────────────────
|
||||||
|
self._timer = self.create_timer(1.0 / rate_hz, self._step)
|
||||||
|
self._prev_mode: CameraMode | None = None
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f'camera_power_node ready — '
|
||||||
|
f'rate={rate_hz:.1f}Hz, '
|
||||||
|
f'downgrade_hold={downgrade_hold_s}s, '
|
||||||
|
f'idle_to_social={idle_to_social_s}s, '
|
||||||
|
f'battery_low={battery_low_pct}%'
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Step callback ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _step(self) -> None:
|
||||||
|
decision = self._fsm.update(
|
||||||
|
speed_mps = self._speed_mps,
|
||||||
|
scenario = self._scenario,
|
||||||
|
battery_pct = self._battery_pct,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log transitions
|
||||||
|
if decision.mode != self._prev_mode:
|
||||||
|
prev_name = self._prev_mode.label if self._prev_mode else 'INIT'
|
||||||
|
self.get_logger().info(
|
||||||
|
f'Camera mode: {prev_name} → {decision.mode.label} '
|
||||||
|
f'(speed={self._speed_mps:.2f}m/s '
|
||||||
|
f'scenario={self._scenario} '
|
||||||
|
f'battery={self._battery_pct:.0f}%)'
|
||||||
|
)
|
||||||
|
self._prev_mode = decision.mode
|
||||||
|
|
||||||
|
# Publish CameraPowerMode message
|
||||||
|
msg = CameraPowerMode()
|
||||||
|
msg.header.stamp = self.get_clock().now().to_msg()
|
||||||
|
msg.header.frame_id = ''
|
||||||
|
msg.mode = int(decision.mode)
|
||||||
|
msg.mode_name = decision.mode.label
|
||||||
|
s = decision.sensors
|
||||||
|
msg.csi_front = s.csi_front
|
||||||
|
msg.csi_rear = s.csi_rear
|
||||||
|
msg.csi_left = s.csi_left
|
||||||
|
msg.csi_right = s.csi_right
|
||||||
|
msg.realsense = s.realsense
|
||||||
|
msg.lidar = s.lidar
|
||||||
|
msg.uwb = s.uwb
|
||||||
|
msg.webcam = s.webcam
|
||||||
|
msg.trigger_speed_mps = float(decision.trigger_speed_mps)
|
||||||
|
msg.trigger_scenario = decision.trigger_scenario
|
||||||
|
msg.scenario_override = decision.scenario_override
|
||||||
|
self._mode_pub.publish(msg)
|
||||||
|
|
||||||
|
# Publish per-camera Bool commands
|
||||||
|
for suffix, field in _CAM_TOPICS:
|
||||||
|
enabled = bool(getattr(s, field))
|
||||||
|
b = Bool()
|
||||||
|
b.data = enabled
|
||||||
|
self._cam_pubs[suffix].publish(b)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = CameraPowerNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -0,0 +1,207 @@
|
|||||||
|
"""
|
||||||
|
face_emotion_node.py — Facial emotion classifier node (Issue #359).
|
||||||
|
|
||||||
|
Subscribes to the raw colour camera stream, runs MediaPipe FaceMesh to
|
||||||
|
detect facial landmarks, applies geometric emotion rules, and publishes
|
||||||
|
the result at the camera frame rate (≤ param max_fps).
|
||||||
|
|
||||||
|
Subscribes
|
||||||
|
----------
|
||||||
|
/camera/color/image_raw sensor_msgs/Image (BEST_EFFORT)
|
||||||
|
|
||||||
|
Publishes
|
||||||
|
---------
|
||||||
|
/saltybot/face_emotions saltybot_scene_msgs/FaceEmotionArray
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
max_fps float 15.0 Maximum classification rate (Hz)
|
||||||
|
max_faces int 4 MediaPipe max_num_faces
|
||||||
|
min_detection_conf float 0.5 MediaPipe min detection confidence
|
||||||
|
min_tracking_conf float 0.5 MediaPipe min tracking confidence
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
* MediaPipe FaceMesh (mediapipe.solutions.face_mesh) is initialised lazily
|
||||||
|
in a background thread to avoid blocking the ROS2 executor.
|
||||||
|
* If mediapipe is not installed the node still starts but logs a warning
|
||||||
|
and publishes empty arrays.
|
||||||
|
* Landmark coordinates are always normalised to [0..1] image space.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||||
|
|
||||||
|
from sensor_msgs.msg import Image
|
||||||
|
from saltybot_scene_msgs.msg import FaceEmotion, FaceEmotionArray
|
||||||
|
from ._face_emotion import FaceLandmarks, detect_emotion, from_mediapipe
|
||||||
|
|
||||||
|
|
||||||
|
_SENSOR_QOS = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _MPFaceMesh:
|
||||||
|
"""Lazy wrapper around mediapipe.solutions.face_mesh.FaceMesh."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_faces: int = 4,
|
||||||
|
min_detection_conf: float = 0.5,
|
||||||
|
min_tracking_conf: float = 0.5,
|
||||||
|
) -> None:
|
||||||
|
self._max_faces = max_faces
|
||||||
|
self._min_det_conf = min_detection_conf
|
||||||
|
self._min_track_conf = min_tracking_conf
|
||||||
|
self._mesh = None
|
||||||
|
self._available = False
|
||||||
|
self._ready = threading.Event()
|
||||||
|
threading.Thread(target=self._init, daemon=True).start()
|
||||||
|
|
||||||
|
def _init(self) -> None:
|
||||||
|
try:
|
||||||
|
import mediapipe as mp
|
||||||
|
self._mesh = mp.solutions.face_mesh.FaceMesh(
|
||||||
|
static_image_mode = False,
|
||||||
|
max_num_faces = self._max_faces,
|
||||||
|
refine_landmarks = False,
|
||||||
|
min_detection_confidence = self._min_det_conf,
|
||||||
|
min_tracking_confidence = self._min_track_conf,
|
||||||
|
)
|
||||||
|
self._available = True
|
||||||
|
except Exception:
|
||||||
|
self._available = False
|
||||||
|
finally:
|
||||||
|
self._ready.set()
|
||||||
|
|
||||||
|
def process(self, bgr: np.ndarray):
|
||||||
|
"""Run FaceMesh on a BGR uint8 frame. Returns mp results or None."""
|
||||||
|
if not self._ready.wait(timeout=5.0) or not self._available:
|
||||||
|
return None
|
||||||
|
import cv2
|
||||||
|
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
||||||
|
return self._mesh.process(rgb)
|
||||||
|
|
||||||
|
|
||||||
|
class FaceEmotionNode(Node):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__('face_emotion_node')
|
||||||
|
|
||||||
|
# ── Parameters ──────────────────────────────────────────────────────
|
||||||
|
self.declare_parameter('max_fps', 15.0)
|
||||||
|
self.declare_parameter('max_faces', 4)
|
||||||
|
self.declare_parameter('min_detection_conf', 0.5)
|
||||||
|
self.declare_parameter('min_tracking_conf', 0.5)
|
||||||
|
|
||||||
|
p = self.get_parameter
|
||||||
|
self._min_period = 1.0 / max(float(p('max_fps').value), 0.1)
|
||||||
|
max_faces = int(p('max_faces').value)
|
||||||
|
min_det = float(p('min_detection_conf').value)
|
||||||
|
min_trk = float(p('min_tracking_conf').value)
|
||||||
|
|
||||||
|
# Lazy-initialised MediaPipe
|
||||||
|
self._mp = _MPFaceMesh(
|
||||||
|
max_faces = max_faces,
|
||||||
|
min_detection_conf = min_det,
|
||||||
|
min_tracking_conf = min_trk,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._last_proc: float = 0.0
|
||||||
|
|
||||||
|
# ── Subscriber / publisher ───────────────────────────────────────────
|
||||||
|
self._sub = self.create_subscription(
|
||||||
|
Image, '/camera/color/image_raw',
|
||||||
|
self._on_image, _SENSOR_QOS,
|
||||||
|
)
|
||||||
|
self._pub = self.create_publisher(
|
||||||
|
FaceEmotionArray, '/saltybot/face_emotions', 10
|
||||||
|
)
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f'face_emotion_node ready — '
|
||||||
|
f'max_fps={1.0/self._min_period:.1f}, max_faces={max_faces}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Callback ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _on_image(self, msg: Image) -> None:
|
||||||
|
now = time.monotonic()
|
||||||
|
if now - self._last_proc < self._min_period:
|
||||||
|
return
|
||||||
|
self._last_proc = now
|
||||||
|
|
||||||
|
enc = msg.encoding.lower()
|
||||||
|
if enc in ('bgr8', 'rgb8'):
|
||||||
|
data = np.frombuffer(msg.data, dtype=np.uint8)
|
||||||
|
try:
|
||||||
|
frame = data.reshape((msg.height, msg.width, 3))
|
||||||
|
except ValueError:
|
||||||
|
return
|
||||||
|
if enc == 'rgb8':
|
||||||
|
import cv2
|
||||||
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||||||
|
else:
|
||||||
|
self.get_logger().warn(
|
||||||
|
f'Unsupported image encoding {msg.encoding!r}',
|
||||||
|
throttle_duration_sec=10.0,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
results = self._mp.process(frame)
|
||||||
|
|
||||||
|
out = FaceEmotionArray()
|
||||||
|
out.header.stamp = msg.header.stamp
|
||||||
|
out.header.frame_id = msg.header.frame_id or 'camera_color_optical_frame'
|
||||||
|
|
||||||
|
if results and results.multi_face_landmarks:
|
||||||
|
for face_id, lms in enumerate(results.multi_face_landmarks):
|
||||||
|
try:
|
||||||
|
fl = from_mediapipe(lms)
|
||||||
|
res = detect_emotion(fl)
|
||||||
|
except Exception as exc:
|
||||||
|
self.get_logger().warn(
|
||||||
|
f'Emotion detection failed for face {face_id}: {exc}',
|
||||||
|
throttle_duration_sec=5.0,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
fe = FaceEmotion()
|
||||||
|
fe.header = out.header
|
||||||
|
fe.face_id = face_id
|
||||||
|
fe.emotion = res.emotion
|
||||||
|
fe.confidence = float(res.confidence)
|
||||||
|
fe.mouth_open = float(res.features.mouth_open)
|
||||||
|
fe.smile = float(res.features.smile)
|
||||||
|
fe.brow_raise = float(res.features.brow_raise)
|
||||||
|
fe.eye_open = float(res.features.eye_open)
|
||||||
|
out.faces.append(fe)
|
||||||
|
|
||||||
|
out.face_count = len(out.faces)
|
||||||
|
self._pub.publish(out)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = FaceEmotionNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -0,0 +1,262 @@
|
|||||||
|
"""
|
||||||
|
obstacle_size_node.py — Depth-based obstacle size estimator (Issue #348).
|
||||||
|
|
||||||
|
Fuses 2-D LIDAR clusters with the D435i depth image to estimate the full
|
||||||
|
3-D width and height (metres) of each obstacle.
|
||||||
|
|
||||||
|
Subscribes
|
||||||
|
----------
|
||||||
|
/scan sensor_msgs/LaserScan (BEST_EFFORT, ~10 Hz)
|
||||||
|
/camera/depth/image_rect_raw sensor_msgs/Image (BEST_EFFORT, ~15 Hz)
|
||||||
|
/camera/depth/camera_info sensor_msgs/CameraInfo (RELIABLE)
|
||||||
|
|
||||||
|
Publishes
|
||||||
|
---------
|
||||||
|
/saltybot/obstacle_sizes saltybot_scene_msgs/ObstacleSizeArray
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
Clustering
|
||||||
|
distance_threshold_m float 0.20 Max gap between consecutive scan pts (m)
|
||||||
|
min_points int 3 Min LIDAR points per cluster
|
||||||
|
range_min_m float 0.05 Discard ranges below this (m)
|
||||||
|
range_max_m float 12.0 Discard ranges above this (m)
|
||||||
|
|
||||||
|
Depth sampling
|
||||||
|
depth_window_px int 5 Half-side of median sampling window
|
||||||
|
search_rows_px int 120 Half-height of height search strip
|
||||||
|
col_hw_px int 10 Half-width of height search strip
|
||||||
|
z_tol_m float 0.30 Depth tolerance for height collection
|
||||||
|
|
||||||
|
Extrinsics (LIDAR origin in camera frame — metres)
|
||||||
|
lidar_ex float 0.0 Camera-frame X (right) offset
|
||||||
|
lidar_ey float 0.05 Camera-frame Y (down) offset
|
||||||
|
lidar_ez float 0.0 Camera-frame Z (forward) offset
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
Camera intrinsics are read from /camera/depth/camera_info and override the
|
||||||
|
defaults in CameraParams once the first CameraInfo message arrives.
|
||||||
|
|
||||||
|
If the depth image is unavailable, the node publishes with depth_z=0 and
|
||||||
|
height_m=0, confidence=0 (LIDAR-only width estimate).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||||
|
|
||||||
|
from sensor_msgs.msg import CameraInfo, Image, LaserScan
|
||||||
|
from saltybot_scene_msgs.msg import ObstacleSize, ObstacleSizeArray
|
||||||
|
|
||||||
|
from ._lidar_clustering import scan_to_cartesian, cluster_points
|
||||||
|
from ._obstacle_size import CameraParams, estimate_cluster_size
|
||||||
|
|
||||||
|
|
||||||
|
_SENSOR_QOS = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=4,
|
||||||
|
)
|
||||||
|
_RELIABLE_QOS = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.RELIABLE,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ObstacleSizeNode(Node):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__('obstacle_size_node')
|
||||||
|
|
||||||
|
# ── Parameters ──────────────────────────────────────────────────────
|
||||||
|
# Clustering
|
||||||
|
self.declare_parameter('distance_threshold_m', 0.20)
|
||||||
|
self.declare_parameter('min_points', 3)
|
||||||
|
self.declare_parameter('range_min_m', 0.05)
|
||||||
|
self.declare_parameter('range_max_m', 12.0)
|
||||||
|
# Depth sampling
|
||||||
|
self.declare_parameter('depth_window_px', 5)
|
||||||
|
self.declare_parameter('search_rows_px', 120)
|
||||||
|
self.declare_parameter('col_hw_px', 10)
|
||||||
|
self.declare_parameter('z_tol_m', 0.30)
|
||||||
|
# Extrinsics
|
||||||
|
self.declare_parameter('lidar_ex', 0.0)
|
||||||
|
self.declare_parameter('lidar_ey', 0.05)
|
||||||
|
self.declare_parameter('lidar_ez', 0.0)
|
||||||
|
|
||||||
|
p = self.get_parameter
|
||||||
|
self._dist_thresh = float(p('distance_threshold_m').value)
|
||||||
|
self._min_pts = int(p('min_points').value)
|
||||||
|
self._range_min = float(p('range_min_m').value)
|
||||||
|
self._range_max = float(p('range_max_m').value)
|
||||||
|
self._win_px = int(p('depth_window_px').value)
|
||||||
|
self._rows_px = int(p('search_rows_px').value)
|
||||||
|
self._col_hw = int(p('col_hw_px').value)
|
||||||
|
self._z_tol = float(p('z_tol_m').value)
|
||||||
|
|
||||||
|
# Build default CameraParams (overridden when CameraInfo arrives)
|
||||||
|
self._cam = CameraParams(
|
||||||
|
ex=float(p('lidar_ex').value),
|
||||||
|
ey=float(p('lidar_ey').value),
|
||||||
|
ez=float(p('lidar_ez').value),
|
||||||
|
)
|
||||||
|
self._cam_info_received = False
|
||||||
|
|
||||||
|
# Latest depth frame
|
||||||
|
self._depth_image: Optional[np.ndarray] = None
|
||||||
|
|
||||||
|
# ── Subscribers ──────────────────────────────────────────────────────
|
||||||
|
self._scan_sub = self.create_subscription(
|
||||||
|
LaserScan, '/scan', self._on_scan, _SENSOR_QOS
|
||||||
|
)
|
||||||
|
self._depth_sub = self.create_subscription(
|
||||||
|
Image, '/camera/depth/image_rect_raw',
|
||||||
|
self._on_depth, _SENSOR_QOS
|
||||||
|
)
|
||||||
|
self._info_sub = self.create_subscription(
|
||||||
|
CameraInfo, '/camera/depth/camera_info',
|
||||||
|
self._on_camera_info, _RELIABLE_QOS
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Publisher ────────────────────────────────────────────────────────
|
||||||
|
self._pub = self.create_publisher(
|
||||||
|
ObstacleSizeArray, '/saltybot/obstacle_sizes', 10
|
||||||
|
)
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f'obstacle_size_node ready — '
|
||||||
|
f'dist_thresh={self._dist_thresh}m, z_tol={self._z_tol}m'
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Callbacks ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _on_camera_info(self, msg: CameraInfo) -> None:
|
||||||
|
"""Latch camera intrinsics on first message."""
|
||||||
|
if self._cam_info_received:
|
||||||
|
return
|
||||||
|
ex = self._cam.ex
|
||||||
|
ey = self._cam.ey
|
||||||
|
ez = self._cam.ez
|
||||||
|
self._cam = CameraParams(
|
||||||
|
fx=float(msg.k[0]),
|
||||||
|
fy=float(msg.k[4]),
|
||||||
|
cx=float(msg.k[2]),
|
||||||
|
cy=float(msg.k[5]),
|
||||||
|
width=int(msg.width),
|
||||||
|
height=int(msg.height),
|
||||||
|
depth_scale=self._cam.depth_scale,
|
||||||
|
ex=ex, ey=ey, ez=ez,
|
||||||
|
)
|
||||||
|
self._cam_info_received = True
|
||||||
|
self.get_logger().info(
|
||||||
|
f'Camera intrinsics loaded: '
|
||||||
|
f'fx={self._cam.fx:.1f} fy={self._cam.fy:.1f} '
|
||||||
|
f'cx={self._cam.cx:.1f} cy={self._cam.cy:.1f} '
|
||||||
|
f'{self._cam.width}×{self._cam.height}'
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_depth(self, msg: Image) -> None:
|
||||||
|
"""Cache the latest depth frame as a numpy uint16 array."""
|
||||||
|
enc = msg.encoding.lower()
|
||||||
|
if enc not in ('16uc1', 'mono16'):
|
||||||
|
self.get_logger().warn(
|
||||||
|
f'Unexpected depth encoding {msg.encoding!r} — expected 16UC1',
|
||||||
|
throttle_duration_sec=10.0,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
data = np.frombuffer(msg.data, dtype=np.uint16)
|
||||||
|
try:
|
||||||
|
self._depth_image = data.reshape((msg.height, msg.width))
|
||||||
|
except ValueError as exc:
|
||||||
|
self.get_logger().warn(f'Depth reshape failed: {exc}')
|
||||||
|
|
||||||
|
def _on_scan(self, msg: LaserScan) -> None:
|
||||||
|
"""Cluster LIDAR scan, project clusters to depth image, publish sizes."""
|
||||||
|
points = scan_to_cartesian(
|
||||||
|
list(msg.ranges),
|
||||||
|
msg.angle_min,
|
||||||
|
msg.angle_increment,
|
||||||
|
range_min=self._range_min,
|
||||||
|
range_max=self._range_max,
|
||||||
|
)
|
||||||
|
clusters = cluster_points(
|
||||||
|
points,
|
||||||
|
distance_threshold_m=self._dist_thresh,
|
||||||
|
min_points=self._min_pts,
|
||||||
|
)
|
||||||
|
|
||||||
|
depth = self._depth_image
|
||||||
|
|
||||||
|
# Build output message
|
||||||
|
out = ObstacleSizeArray()
|
||||||
|
out.header.stamp = msg.header.stamp
|
||||||
|
out.header.frame_id = msg.header.frame_id or 'laser'
|
||||||
|
|
||||||
|
for idx, cluster in enumerate(clusters):
|
||||||
|
if depth is not None:
|
||||||
|
est = estimate_cluster_size(
|
||||||
|
cluster,
|
||||||
|
depth,
|
||||||
|
self._cam,
|
||||||
|
depth_window = self._win_px,
|
||||||
|
search_rows = self._rows_px,
|
||||||
|
col_hw = self._col_hw,
|
||||||
|
z_tol = self._z_tol,
|
||||||
|
obstacle_id = idx,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No depth image available — publish LIDAR-only width
|
||||||
|
from ._obstacle_size import ObstacleSizeEstimate
|
||||||
|
import math as _math
|
||||||
|
cx = float(cluster.centroid[0])
|
||||||
|
cy = float(cluster.centroid[1])
|
||||||
|
est = ObstacleSizeEstimate(
|
||||||
|
obstacle_id = idx,
|
||||||
|
centroid_x = cx,
|
||||||
|
centroid_y = cy,
|
||||||
|
depth_z = 0.0,
|
||||||
|
width_m = float(cluster.width_m),
|
||||||
|
height_m = 0.0,
|
||||||
|
pixel_u = -1,
|
||||||
|
pixel_v = -1,
|
||||||
|
lidar_range = float(_math.sqrt(cx * cx + cy * cy)),
|
||||||
|
confidence = 0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
obs = ObstacleSize()
|
||||||
|
obs.header = out.header
|
||||||
|
obs.obstacle_id = idx
|
||||||
|
obs.centroid_x = float(est.centroid_x)
|
||||||
|
obs.centroid_y = float(est.centroid_y)
|
||||||
|
obs.depth_z = float(est.depth_z)
|
||||||
|
obs.width_m = float(est.width_m)
|
||||||
|
obs.height_m = float(est.height_m)
|
||||||
|
obs.pixel_u = int(est.pixel_u)
|
||||||
|
obs.pixel_v = int(est.pixel_v)
|
||||||
|
obs.lidar_range = float(est.lidar_range)
|
||||||
|
obs.confidence = float(est.confidence)
|
||||||
|
out.obstacles.append(obs)
|
||||||
|
|
||||||
|
self._pub.publish(out)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = ObstacleSizeNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -0,0 +1,172 @@
|
|||||||
|
"""
|
||||||
|
obstacle_velocity_node.py — Dynamic obstacle velocity estimator (Issue #326).
|
||||||
|
|
||||||
|
Subscribes to the 2-D LIDAR scan, clusters points using the existing
|
||||||
|
gap-segmentation helper, tracks cluster centroids with per-obstacle
|
||||||
|
constant-velocity Kalman filters, and publishes estimated velocity
|
||||||
|
vectors on /saltybot/obstacle_velocities.
|
||||||
|
|
||||||
|
Subscribes:
|
||||||
|
/scan sensor_msgs/LaserScan (BEST_EFFORT)
|
||||||
|
|
||||||
|
Publishes:
|
||||||
|
/saltybot/obstacle_velocities saltybot_scene_msgs/ObstacleVelocityArray
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
distance_threshold_m float 0.20 Max gap between consecutive scan points (m)
|
||||||
|
min_points int 3 Minimum LIDAR points per valid cluster
|
||||||
|
range_min_m float 0.05 Discard ranges below this (m)
|
||||||
|
range_max_m float 12.0 Discard ranges above this (m)
|
||||||
|
max_association_dist_m float 0.50 Max centroid distance for track-cluster match
|
||||||
|
max_coasting_frames int 5 Missed frames before a track is deleted
|
||||||
|
n_init_frames int 3 Updates required to reach confidence=1.0
|
||||||
|
q_pos float 0.05 Process noise — position (m²/s²)
|
||||||
|
q_vel float 0.50 Process noise — velocity (m²/s³)
|
||||||
|
r_pos float 0.10 Measurement noise std-dev (m)
|
||||||
|
static_speed_threshold float 0.10 Speed (m/s) below which obstacle is static
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||||
|
|
||||||
|
from sensor_msgs.msg import LaserScan
|
||||||
|
from saltybot_scene_msgs.msg import ObstacleVelocity, ObstacleVelocityArray
|
||||||
|
|
||||||
|
from ._lidar_clustering import scan_to_cartesian, cluster_points
|
||||||
|
from ._obstacle_velocity import ObstacleTracker
|
||||||
|
|
||||||
|
|
||||||
|
_SENSOR_QOS = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ObstacleVelocityNode(Node):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__('obstacle_velocity_node')
|
||||||
|
|
||||||
|
self.declare_parameter('distance_threshold_m', 0.20)
|
||||||
|
self.declare_parameter('min_points', 3)
|
||||||
|
self.declare_parameter('range_min_m', 0.05)
|
||||||
|
self.declare_parameter('range_max_m', 12.0)
|
||||||
|
self.declare_parameter('max_association_dist_m', 0.50)
|
||||||
|
self.declare_parameter('max_coasting_frames', 5)
|
||||||
|
self.declare_parameter('n_init_frames', 3)
|
||||||
|
self.declare_parameter('q_pos', 0.05)
|
||||||
|
self.declare_parameter('q_vel', 0.50)
|
||||||
|
self.declare_parameter('r_pos', 0.10)
|
||||||
|
self.declare_parameter('static_speed_threshold', 0.10)
|
||||||
|
|
||||||
|
self._dist_thresh = float(self.get_parameter('distance_threshold_m').value)
|
||||||
|
self._min_pts = int(self.get_parameter('min_points').value)
|
||||||
|
self._range_min = float(self.get_parameter('range_min_m').value)
|
||||||
|
self._range_max = float(self.get_parameter('range_max_m').value)
|
||||||
|
self._static_thresh = float(self.get_parameter('static_speed_threshold').value)
|
||||||
|
|
||||||
|
self._tracker = ObstacleTracker(
|
||||||
|
max_association_dist_m = float(self.get_parameter('max_association_dist_m').value),
|
||||||
|
max_coasting_frames = int(self.get_parameter('max_coasting_frames').value),
|
||||||
|
n_init_frames = int(self.get_parameter('n_init_frames').value),
|
||||||
|
q_pos = float(self.get_parameter('q_pos').value),
|
||||||
|
q_vel = float(self.get_parameter('q_vel').value),
|
||||||
|
r_pos = float(self.get_parameter('r_pos').value),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._sub = self.create_subscription(
|
||||||
|
LaserScan, '/scan', self._on_scan, _SENSOR_QOS)
|
||||||
|
self._pub = self.create_publisher(
|
||||||
|
ObstacleVelocityArray, '/saltybot/obstacle_velocities', 10)
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f'obstacle_velocity_node ready — '
|
||||||
|
f'dist_thresh={self._dist_thresh}m '
|
||||||
|
f'static_thresh={self._static_thresh}m/s'
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Scan callback ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _on_scan(self, msg: LaserScan) -> None:
|
||||||
|
points = scan_to_cartesian(
|
||||||
|
list(msg.ranges),
|
||||||
|
msg.angle_min,
|
||||||
|
msg.angle_increment,
|
||||||
|
range_min=self._range_min,
|
||||||
|
range_max=self._range_max,
|
||||||
|
)
|
||||||
|
clusters = cluster_points(
|
||||||
|
points,
|
||||||
|
distance_threshold_m=self._dist_thresh,
|
||||||
|
min_points=self._min_pts,
|
||||||
|
)
|
||||||
|
|
||||||
|
centroids = [c.centroid for c in clusters]
|
||||||
|
widths = [c.width_m for c in clusters]
|
||||||
|
depths = [c.depth_m for c in clusters]
|
||||||
|
point_counts = [len(c.points) for c in clusters]
|
||||||
|
|
||||||
|
# Use scan timestamp if available, fall back to wall clock
|
||||||
|
stamp_secs = (
|
||||||
|
msg.header.stamp.sec + msg.header.stamp.nanosec * 1e-9
|
||||||
|
if msg.header.stamp.sec > 0
|
||||||
|
else time.time()
|
||||||
|
)
|
||||||
|
|
||||||
|
alive_tracks = self._tracker.update(
|
||||||
|
centroids,
|
||||||
|
timestamp = stamp_secs,
|
||||||
|
widths = widths,
|
||||||
|
depths = depths,
|
||||||
|
point_counts = point_counts,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Build and publish message ─────────────────────────────────────────
|
||||||
|
out = ObstacleVelocityArray()
|
||||||
|
out.header.stamp = msg.header.stamp
|
||||||
|
out.header.frame_id = msg.header.frame_id or 'laser'
|
||||||
|
|
||||||
|
for track in alive_tracks:
|
||||||
|
pos = track.position
|
||||||
|
vel = track.velocity
|
||||||
|
|
||||||
|
obs = ObstacleVelocity()
|
||||||
|
obs.header = out.header
|
||||||
|
obs.obstacle_id = track.track_id
|
||||||
|
obs.centroid.x = float(pos[0])
|
||||||
|
obs.centroid.y = float(pos[1])
|
||||||
|
obs.centroid.z = 0.0
|
||||||
|
obs.velocity.x = float(vel[0])
|
||||||
|
obs.velocity.y = float(vel[1])
|
||||||
|
obs.velocity.z = 0.0
|
||||||
|
obs.speed_mps = float(track.speed)
|
||||||
|
obs.width_m = float(track.last_width)
|
||||||
|
obs.depth_m = float(track.last_depth)
|
||||||
|
obs.point_count = int(track.last_point_count)
|
||||||
|
obs.confidence = float(track.confidence)
|
||||||
|
obs.is_static = track.speed < self._static_thresh
|
||||||
|
out.obstacles.append(obs)
|
||||||
|
|
||||||
|
self._pub.publish(out)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = ObstacleVelocityNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -0,0 +1,164 @@
|
|||||||
|
"""
|
||||||
|
path_edges_node.py — ROS2 node for lane/path edge detection (Issue #339).
|
||||||
|
|
||||||
|
Subscribes
|
||||||
|
----------
|
||||||
|
/camera/color/image_raw (sensor_msgs/Image)
|
||||||
|
|
||||||
|
Publishes
|
||||||
|
---------
|
||||||
|
/saltybot/path_edges (saltybot_scene_msgs/PathEdges)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
roi_frac (float, default 0.50) bottom fraction of image as ROI
|
||||||
|
blur_ksize (int, default 5) Gaussian blur kernel size (odd)
|
||||||
|
canny_low (int, default 50) Canny lower threshold
|
||||||
|
canny_high (int, default 150) Canny upper threshold
|
||||||
|
hough_threshold (int, default 30) minimum Hough votes
|
||||||
|
min_line_len (int, default 40) minimum Hough segment length (px)
|
||||||
|
max_line_gap (int, default 20) maximum Hough gap (px)
|
||||||
|
min_slope (float, default 0.3) |slope| below this → discard
|
||||||
|
birdseye_size (int, default 400) bird-eye square image side (px)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from sensor_msgs.msg import Image
|
||||||
|
from std_msgs.msg import Header
|
||||||
|
|
||||||
|
from saltybot_scene_msgs.msg import PathEdges
|
||||||
|
|
||||||
|
from ._path_edges import PathEdgeConfig, process_frame
|
||||||
|
|
||||||
|
|
||||||
|
class PathEdgesNode(Node):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__('path_edges_node')
|
||||||
|
|
||||||
|
# Declare parameters
|
||||||
|
self.declare_parameter('roi_frac', 0.50)
|
||||||
|
self.declare_parameter('blur_ksize', 5)
|
||||||
|
self.declare_parameter('canny_low', 50)
|
||||||
|
self.declare_parameter('canny_high', 150)
|
||||||
|
self.declare_parameter('hough_threshold', 30)
|
||||||
|
self.declare_parameter('min_line_len', 40)
|
||||||
|
self.declare_parameter('max_line_gap', 20)
|
||||||
|
self.declare_parameter('min_slope', 0.3)
|
||||||
|
self.declare_parameter('birdseye_size', 400)
|
||||||
|
|
||||||
|
self._sub = self.create_subscription(
|
||||||
|
Image,
|
||||||
|
'/camera/color/image_raw',
|
||||||
|
self._image_cb,
|
||||||
|
10,
|
||||||
|
)
|
||||||
|
self._pub = self.create_publisher(PathEdges, '/saltybot/path_edges', 10)
|
||||||
|
|
||||||
|
self.get_logger().info('PathEdgesNode ready — subscribing /camera/color/image_raw')
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _build_config(self) -> PathEdgeConfig:
|
||||||
|
p = self.get_parameter
|
||||||
|
return PathEdgeConfig(
|
||||||
|
roi_frac = p('roi_frac').value,
|
||||||
|
blur_ksize = p('blur_ksize').value,
|
||||||
|
canny_low = p('canny_low').value,
|
||||||
|
canny_high = p('canny_high').value,
|
||||||
|
hough_threshold = p('hough_threshold').value,
|
||||||
|
min_line_len = p('min_line_len').value,
|
||||||
|
max_line_gap = p('max_line_gap').value,
|
||||||
|
min_slope = p('min_slope').value,
|
||||||
|
birdseye_size = p('birdseye_size').value,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _image_cb(self, msg: Image) -> None:
|
||||||
|
# Convert ROS Image → BGR numpy array
|
||||||
|
bgr = self._ros_image_to_bgr(msg)
|
||||||
|
if bgr is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
cfg = self._build_config()
|
||||||
|
result = process_frame(bgr, cfg)
|
||||||
|
|
||||||
|
out = PathEdges()
|
||||||
|
out.header = msg.header
|
||||||
|
out.line_count = len(result.lines)
|
||||||
|
out.roi_top = result.roi_top
|
||||||
|
|
||||||
|
# Flat float32 arrays: [x1,y1,x2,y2, x1,y1,x2,y2, ...]
|
||||||
|
out.segments_px = [v for seg in result.lines for v in seg]
|
||||||
|
out.segments_birdseye_px = [v for seg in result.birdseye_lines for v in seg]
|
||||||
|
|
||||||
|
# Left edge
|
||||||
|
if result.left_edge is not None:
|
||||||
|
x1, y1, x2, y2 = result.left_edge
|
||||||
|
out.left_x1, out.left_y1 = float(x1), float(y1)
|
||||||
|
out.left_x2, out.left_y2 = float(x2), float(y2)
|
||||||
|
out.left_detected = True
|
||||||
|
else:
|
||||||
|
out.left_detected = False
|
||||||
|
|
||||||
|
if result.birdseye_left is not None:
|
||||||
|
bx1, by1, bx2, by2 = result.birdseye_left
|
||||||
|
out.left_birdseye_x1, out.left_birdseye_y1 = float(bx1), float(by1)
|
||||||
|
out.left_birdseye_x2, out.left_birdseye_y2 = float(bx2), float(by2)
|
||||||
|
|
||||||
|
# Right edge
|
||||||
|
if result.right_edge is not None:
|
||||||
|
x1, y1, x2, y2 = result.right_edge
|
||||||
|
out.right_x1, out.right_y1 = float(x1), float(y1)
|
||||||
|
out.right_x2, out.right_y2 = float(x2), float(y2)
|
||||||
|
out.right_detected = True
|
||||||
|
else:
|
||||||
|
out.right_detected = False
|
||||||
|
|
||||||
|
if result.birdseye_right is not None:
|
||||||
|
bx1, by1, bx2, by2 = result.birdseye_right
|
||||||
|
out.right_birdseye_x1, out.right_birdseye_y1 = float(bx1), float(by1)
|
||||||
|
out.right_birdseye_x2, out.right_birdseye_y2 = float(bx2), float(by2)
|
||||||
|
|
||||||
|
self._pub.publish(out)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ros_image_to_bgr(msg: Image) -> np.ndarray | None:
|
||||||
|
"""Convert a sensor_msgs/Image to a uint8 BGR numpy array."""
|
||||||
|
enc = msg.encoding.lower()
|
||||||
|
data = np.frombuffer(msg.data, dtype=np.uint8)
|
||||||
|
|
||||||
|
if enc in ('rgb8', 'bgr8', 'mono8'):
|
||||||
|
channels = 1 if enc == 'mono8' else 3
|
||||||
|
try:
|
||||||
|
img = data.reshape((msg.height, msg.width, channels))
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
if enc == 'rgb8':
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||||||
|
elif enc == 'mono8':
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
|
return img
|
||||||
|
|
||||||
|
# Unsupported encoding
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = PathEdgesNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.try_shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -0,0 +1,371 @@
|
|||||||
|
"""
|
||||||
|
person_tracking_node.py — P0 follow-me person tracking node (Issue #363).
|
||||||
|
|
||||||
|
Runs a real-time person detection + tracking pipeline on the D435i colour and
|
||||||
|
depth streams and publishes a single TargetTrack message for the follow-me
|
||||||
|
motion controller.
|
||||||
|
|
||||||
|
Detection backend
|
||||||
|
-----------------
|
||||||
|
Attempts YOLOv8n via ultralytics (auto-converted to TensorRT FP16 on first
|
||||||
|
run for ≥ 15 fps on Orin Nano). Falls back to a simple HOG+SVM detector
|
||||||
|
when ultralytics is unavailable.
|
||||||
|
|
||||||
|
Subscribes
|
||||||
|
----------
|
||||||
|
/camera/color/image_raw sensor_msgs/Image (BEST_EFFORT)
|
||||||
|
/camera/depth/image_rect_raw sensor_msgs/Image (BEST_EFFORT)
|
||||||
|
/camera/depth/camera_info sensor_msgs/CameraInfo (RELIABLE)
|
||||||
|
/saltybot/follow_start std_msgs/Empty — start/resume following
|
||||||
|
/saltybot/follow_stop std_msgs/Empty — stop following
|
||||||
|
|
||||||
|
Publishes
|
||||||
|
---------
|
||||||
|
/saltybot/target_track saltybot_scene_msgs/TargetTrack
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
detector_model str 'yolov8n.pt' Ultralytics model file or TRT engine
|
||||||
|
use_tensorrt bool True Convert to TensorRT FP16 on first run
|
||||||
|
max_fps float 30.0 Maximum processing rate (Hz)
|
||||||
|
iou_threshold float 0.30 Tracker IoU matching threshold
|
||||||
|
min_hits int 3 Frames before TENTATIVE → ACTIVE
|
||||||
|
max_lost_frames int 30 Frames a track survives without det
|
||||||
|
reid_threshold float 0.55 HSV histogram re-ID similarity cutoff
|
||||||
|
depth_scale float 0.001 D435i raw-to-metres scale
|
||||||
|
depth_max_m float 5.0 Range beyond which depth degrades
|
||||||
|
auto_follow bool True Auto-select nearest person on start
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import (
|
||||||
|
QoSProfile, ReliabilityPolicy, HistoryPolicy, DurabilityPolicy,
|
||||||
|
)
|
||||||
|
|
||||||
|
from sensor_msgs.msg import CameraInfo, Image
|
||||||
|
from std_msgs.msg import Empty
|
||||||
|
from saltybot_scene_msgs.msg import TargetTrack
|
||||||
|
from ._person_tracker import (
|
||||||
|
BBox, CamParams, Detection,
|
||||||
|
PersonTracker, FollowTargetSelector,
|
||||||
|
DEPTH_GOOD, DEPTH_MARGINAL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_SENSOR_QOS = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=2,
|
||||||
|
)
|
||||||
|
_RELIABLE_QOS = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.RELIABLE,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=1,
|
||||||
|
durability=DurabilityPolicy.TRANSIENT_LOCAL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Detector wrappers ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _YoloDetector:
|
||||||
|
"""Lazy-initialised YOLOv8n person detector (ultralytics / TensorRT)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path: str = 'yolov8n.pt',
|
||||||
|
use_tensorrt: bool = True,
|
||||||
|
logger=None,
|
||||||
|
) -> None:
|
||||||
|
self._model_path = model_path
|
||||||
|
self._use_trt = use_tensorrt
|
||||||
|
self._log = logger
|
||||||
|
self._model = None
|
||||||
|
self._available = False
|
||||||
|
self._ready = threading.Event()
|
||||||
|
threading.Thread(target=self._load, daemon=True).start()
|
||||||
|
|
||||||
|
def _load(self) -> None:
|
||||||
|
try:
|
||||||
|
from ultralytics import YOLO
|
||||||
|
model = YOLO(self._model_path)
|
||||||
|
if self._use_trt:
|
||||||
|
try:
|
||||||
|
model = YOLO(model.export(format='engine', half=True, device=0))
|
||||||
|
except Exception as e:
|
||||||
|
if self._log:
|
||||||
|
self._log.warn(f'TRT export failed ({e}); using PyTorch')
|
||||||
|
self._model = model
|
||||||
|
self._available = True
|
||||||
|
if self._log:
|
||||||
|
self._log.info(f'YOLO detector loaded: {self._model_path}')
|
||||||
|
except Exception as e:
|
||||||
|
self._available = False
|
||||||
|
if self._log:
|
||||||
|
self._log.warn(f'ultralytics not available ({e}); using HOG fallback')
|
||||||
|
finally:
|
||||||
|
self._ready.set()
|
||||||
|
|
||||||
|
def detect(self, bgr: np.ndarray, conf_thresh: float = 0.40) -> List[Detection]:
|
||||||
|
if not self._ready.wait(timeout=0.01) or not self._available:
|
||||||
|
return []
|
||||||
|
results = self._model(bgr, classes=[0], conf=conf_thresh, verbose=False)
|
||||||
|
dets: List[Detection] = []
|
||||||
|
for r in results:
|
||||||
|
for box in r.boxes:
|
||||||
|
x1, y1, x2, y2 = (int(v) for v in box.xyxy[0].cpu().numpy())
|
||||||
|
w, h = max(1, x2 - x1), max(1, y2 - y1)
|
||||||
|
dets.append(Detection(
|
||||||
|
bbox = BBox(x1, y1, w, h),
|
||||||
|
confidence = float(box.conf[0]),
|
||||||
|
frame_bgr = bgr,
|
||||||
|
))
|
||||||
|
return dets
|
||||||
|
|
||||||
|
|
||||||
|
class _HogDetector:
|
||||||
|
"""OpenCV HOG+SVM person detector — CPU fallback, ~5–10 fps."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
import cv2
|
||||||
|
self._hog = cv2.HOGDescriptor()
|
||||||
|
self._hog.setSVMDetector(cv2.HOGDescriptor_getDefaultPeopleDetector())
|
||||||
|
|
||||||
|
def detect(self, bgr: np.ndarray, conf_thresh: float = 0.40) -> List[Detection]:
|
||||||
|
import cv2
|
||||||
|
small = cv2.resize(bgr, (320, 240)) if bgr.shape[1] > 320 else bgr
|
||||||
|
scale = bgr.shape[1] / small.shape[1]
|
||||||
|
rects, weights = self._hog.detectMultiScale(
|
||||||
|
small, winStride=(8, 8), padding=(4, 4), scale=1.05,
|
||||||
|
)
|
||||||
|
dets: List[Detection] = []
|
||||||
|
for rect, w in zip(rects, weights):
|
||||||
|
conf = float(np.clip(w, 0.0, 1.0))
|
||||||
|
if conf < conf_thresh:
|
||||||
|
continue
|
||||||
|
x, y, rw, rh = rect
|
||||||
|
dets.append(Detection(
|
||||||
|
bbox = BBox(
|
||||||
|
int(x * scale), int(y * scale),
|
||||||
|
int(rw * scale), int(rh * scale),
|
||||||
|
),
|
||||||
|
confidence = conf,
|
||||||
|
frame_bgr = bgr,
|
||||||
|
))
|
||||||
|
return dets
|
||||||
|
|
||||||
|
|
||||||
|
# ── ROS2 Node ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class PersonTrackingNode(Node):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__('person_tracking_node')
|
||||||
|
|
||||||
|
# ── Parameters ──────────────────────────────────────────────────────
|
||||||
|
self.declare_parameter('detector_model', 'yolov8n.pt')
|
||||||
|
self.declare_parameter('use_tensorrt', True)
|
||||||
|
self.declare_parameter('max_fps', 30.0)
|
||||||
|
self.declare_parameter('iou_threshold', 0.30)
|
||||||
|
self.declare_parameter('min_hits', 3)
|
||||||
|
self.declare_parameter('max_lost_frames', 30)
|
||||||
|
self.declare_parameter('reid_threshold', 0.55)
|
||||||
|
self.declare_parameter('depth_scale', 0.001)
|
||||||
|
self.declare_parameter('depth_max_m', 5.0)
|
||||||
|
self.declare_parameter('auto_follow', True)
|
||||||
|
|
||||||
|
p = self.get_parameter
|
||||||
|
self._min_period = 1.0 / max(float(p('max_fps').value), 1.0)
|
||||||
|
self._depth_scale = float(p('depth_scale').value)
|
||||||
|
self._depth_max_m = float(p('depth_max_m').value)
|
||||||
|
self._auto_follow = bool(p('auto_follow').value)
|
||||||
|
|
||||||
|
# ── Tracker + selector ───────────────────────────────────────────────
|
||||||
|
self._tracker = PersonTracker(
|
||||||
|
iou_threshold = float(p('iou_threshold').value),
|
||||||
|
min_hits = int(p('min_hits').value),
|
||||||
|
max_lost_frames = int(p('max_lost_frames').value),
|
||||||
|
reid_threshold = float(p('reid_threshold').value),
|
||||||
|
)
|
||||||
|
self._selector = FollowTargetSelector(hold_frames=15)
|
||||||
|
if self._auto_follow:
|
||||||
|
self._selector.start()
|
||||||
|
|
||||||
|
# ── Detector (lazy) ──────────────────────────────────────────────────
|
||||||
|
self._detector: Optional[_YoloDetector] = _YoloDetector(
|
||||||
|
model_path = str(p('detector_model').value),
|
||||||
|
use_tensorrt = bool(p('use_tensorrt').value),
|
||||||
|
logger = self.get_logger(),
|
||||||
|
)
|
||||||
|
self._hog_fallback: Optional[_HogDetector] = None
|
||||||
|
|
||||||
|
# ── Camera state ─────────────────────────────────────────────────────
|
||||||
|
self._cam = CamParams()
|
||||||
|
self._cam_received = False
|
||||||
|
self._depth_image: Optional[np.ndarray] = None
|
||||||
|
self._last_proc = 0.0
|
||||||
|
|
||||||
|
# ── Subscribers ──────────────────────────────────────────────────────
|
||||||
|
self._color_sub = self.create_subscription(
|
||||||
|
Image, '/camera/color/image_raw',
|
||||||
|
self._on_color, _SENSOR_QOS,
|
||||||
|
)
|
||||||
|
self._depth_sub = self.create_subscription(
|
||||||
|
Image, '/camera/depth/image_rect_raw',
|
||||||
|
self._on_depth, _SENSOR_QOS,
|
||||||
|
)
|
||||||
|
self._info_sub = self.create_subscription(
|
||||||
|
CameraInfo, '/camera/depth/camera_info',
|
||||||
|
self._on_cam_info, _RELIABLE_QOS,
|
||||||
|
)
|
||||||
|
self._start_sub = self.create_subscription(
|
||||||
|
Empty, '/saltybot/follow_start',
|
||||||
|
lambda _: self._selector.start(), 10,
|
||||||
|
)
|
||||||
|
self._stop_sub = self.create_subscription(
|
||||||
|
Empty, '/saltybot/follow_stop',
|
||||||
|
lambda _: self._selector.stop(), 10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Publisher ────────────────────────────────────────────────────────
|
||||||
|
self._pub = self.create_publisher(
|
||||||
|
TargetTrack, '/saltybot/target_track', 10
|
||||||
|
)
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
'person_tracking_node ready — '
|
||||||
|
f'auto_follow={self._auto_follow}, '
|
||||||
|
f'max_fps={1.0/self._min_period:.0f}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Callbacks ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _on_cam_info(self, msg: CameraInfo) -> None:
|
||||||
|
if self._cam_received:
|
||||||
|
return
|
||||||
|
self._cam = CamParams(
|
||||||
|
fx=float(msg.k[0]), fy=float(msg.k[4]),
|
||||||
|
cx=float(msg.k[2]), cy=float(msg.k[5]),
|
||||||
|
)
|
||||||
|
self._cam_received = True
|
||||||
|
self.get_logger().info(
|
||||||
|
f'Camera intrinsics: fx={self._cam.fx:.1f} cx={self._cam.cx:.1f}'
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_depth(self, msg: Image) -> None:
|
||||||
|
if msg.encoding.lower() not in ('16uc1', 'mono16'):
|
||||||
|
return
|
||||||
|
data = np.frombuffer(msg.data, dtype=np.uint16)
|
||||||
|
try:
|
||||||
|
self._depth_image = data.reshape((msg.height, msg.width))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _on_color(self, msg: Image) -> None:
|
||||||
|
now = time.monotonic()
|
||||||
|
if now - self._last_proc < self._min_period:
|
||||||
|
return
|
||||||
|
self._last_proc = now
|
||||||
|
|
||||||
|
enc = msg.encoding.lower()
|
||||||
|
if enc not in ('bgr8', 'rgb8'):
|
||||||
|
return
|
||||||
|
data = np.frombuffer(msg.data, dtype=np.uint8)
|
||||||
|
try:
|
||||||
|
frame = data.reshape((msg.height, msg.width, 3))
|
||||||
|
except ValueError:
|
||||||
|
return
|
||||||
|
if enc == 'rgb8':
|
||||||
|
import cv2
|
||||||
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
# ── Detect ────────────────────────────────────────────────────────
|
||||||
|
dets: List[Detection] = []
|
||||||
|
if self._detector and self._detector._ready.is_set():
|
||||||
|
if self._detector._available:
|
||||||
|
dets = self._detector.detect(frame)
|
||||||
|
else:
|
||||||
|
# Init HOG fallback on first YOLO failure
|
||||||
|
if self._hog_fallback is None:
|
||||||
|
try:
|
||||||
|
self._hog_fallback = _HogDetector()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if self._hog_fallback:
|
||||||
|
dets = self._hog_fallback.detect(frame)
|
||||||
|
else:
|
||||||
|
# YOLO still loading — run HOG if available
|
||||||
|
if self._hog_fallback is None:
|
||||||
|
try:
|
||||||
|
self._hog_fallback = _HogDetector()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if self._hog_fallback:
|
||||||
|
dets = self._hog_fallback.detect(frame)
|
||||||
|
|
||||||
|
# ── Track ─────────────────────────────────────────────────────────
|
||||||
|
active = self._tracker.update(
|
||||||
|
dets,
|
||||||
|
cam = self._cam if self._cam_received else None,
|
||||||
|
depth_u16 = self._depth_image,
|
||||||
|
depth_scale = self._depth_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Select target ─────────────────────────────────────────────────
|
||||||
|
target = self._selector.update(
|
||||||
|
active, img_cx=self._cam.cx
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Publish ───────────────────────────────────────────────────────
|
||||||
|
out = TargetTrack()
|
||||||
|
out.header.stamp = msg.header.stamp
|
||||||
|
out.header.frame_id = msg.header.frame_id or 'camera_color_optical_frame'
|
||||||
|
|
||||||
|
if target is not None:
|
||||||
|
out.tracking_active = True
|
||||||
|
out.track_id = target.track_id
|
||||||
|
out.bearing_deg = float(target.bearing)
|
||||||
|
out.distance_m = float(target.distance)
|
||||||
|
out.confidence = float(target.confidence)
|
||||||
|
out.bbox_x = int(target.bbox.x)
|
||||||
|
out.bbox_y = int(target.bbox.y)
|
||||||
|
out.bbox_w = int(target.bbox.w)
|
||||||
|
out.bbox_h = int(target.bbox.h)
|
||||||
|
out.depth_quality = int(target.depth_qual)
|
||||||
|
|
||||||
|
# Convert Kalman pixel velocity → bearing rate
|
||||||
|
if self._cam_received and target.vel_u != 0.0:
|
||||||
|
u_c = target.bbox.x + target.bbox.w * 0.5
|
||||||
|
# d(bearing)/du ≈ fx / (fx² + (u-cx)²) * (180/π)
|
||||||
|
denom = self._cam.fx ** 2 + (u_c - self._cam.cx) ** 2
|
||||||
|
d_bear_du = (self._cam.fx / denom) * (180.0 / 3.14159)
|
||||||
|
out.vel_bearing_dps = float(d_bear_du * target.vel_u * self._cam.fps)
|
||||||
|
|
||||||
|
# Distance velocity from depth (placeholder: not computed per-frame here)
|
||||||
|
out.vel_dist_mps = 0.0
|
||||||
|
else:
|
||||||
|
out.tracking_active = False
|
||||||
|
|
||||||
|
self._pub.publish(out)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = PersonTrackingNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
170
jetson/ros2_ws/src/saltybot_bringup/saltybot_bringup/uwb_node.py
Normal file
170
jetson/ros2_ws/src/saltybot_bringup/saltybot_bringup/uwb_node.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
"""
|
||||||
|
uwb_node.py — UWB DW3000 anchor/tag ranging node (Issue #365).
|
||||||
|
|
||||||
|
Reads TWR distances from two ESP32-UWB-Pro anchors via USB serial and publishes
|
||||||
|
a fused bearing + distance estimate for the follow-me controller.
|
||||||
|
|
||||||
|
Hardware
|
||||||
|
--------
|
||||||
|
anchor0 (left) : USB serial, default /dev/ttyUSB0
|
||||||
|
anchor1 (right) : USB serial, default /dev/ttyUSB1
|
||||||
|
Baseline : ~25 cm (configurable)
|
||||||
|
|
||||||
|
Publishes
|
||||||
|
---------
|
||||||
|
/saltybot/uwb_target saltybot_scene_msgs/UwbTarget (10 Hz)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
port_anchor0 str '/dev/ttyUSB0' Serial device for left anchor
|
||||||
|
port_anchor1 str '/dev/ttyUSB1' Serial device for right anchor
|
||||||
|
baud_rate int 115200 Serial baud rate
|
||||||
|
baseline_m float 0.25 Anchor separation (m)
|
||||||
|
publish_rate_hz float 10.0 Output publish rate
|
||||||
|
stale_timeout_s float 0.5 Age beyond which a range is discarded
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||||
|
|
||||||
|
from std_msgs.msg import Header
|
||||||
|
from saltybot_scene_msgs.msg import UwbTarget
|
||||||
|
|
||||||
|
from ._uwb_tracker import (
|
||||||
|
AnchorSerialReader,
|
||||||
|
UwbRangingState,
|
||||||
|
FIX_NONE, FIX_SINGLE, FIX_DUAL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_PUB_QOS = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.RELIABLE,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UwbNode(Node):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__('uwb_node')
|
||||||
|
|
||||||
|
# ── Parameters ──────────────────────────────────────────────────────
|
||||||
|
self.declare_parameter('port_anchor0', '/dev/ttyUSB0')
|
||||||
|
self.declare_parameter('port_anchor1', '/dev/ttyUSB1')
|
||||||
|
self.declare_parameter('baud_rate', 115200)
|
||||||
|
self.declare_parameter('baseline_m', 0.25)
|
||||||
|
self.declare_parameter('publish_rate_hz', 10.0)
|
||||||
|
self.declare_parameter('stale_timeout_s', 0.5)
|
||||||
|
|
||||||
|
p = self.get_parameter
|
||||||
|
|
||||||
|
self._port0 = str(p('port_anchor0').value)
|
||||||
|
self._port1 = str(p('port_anchor1').value)
|
||||||
|
self._baud = int(p('baud_rate').value)
|
||||||
|
baseline = float(p('baseline_m').value)
|
||||||
|
rate_hz = float(p('publish_rate_hz').value)
|
||||||
|
stale_s = float(p('stale_timeout_s').value)
|
||||||
|
|
||||||
|
# ── State + readers ──────────────────────────────────────────────────
|
||||||
|
self._state = UwbRangingState(
|
||||||
|
baseline_m=baseline,
|
||||||
|
stale_timeout=stale_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._readers: list[AnchorSerialReader] = []
|
||||||
|
self._open_serial_readers()
|
||||||
|
|
||||||
|
# ── Publisher + timer ────────────────────────────────────────────────
|
||||||
|
self._pub = self.create_publisher(UwbTarget, '/saltybot/uwb_target', _PUB_QOS)
|
||||||
|
self._timer = self.create_timer(1.0 / max(rate_hz, 1.0), self._publish)
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f'uwb_node ready — baseline={baseline:.3f}m, '
|
||||||
|
f'rate={rate_hz:.0f}Hz, '
|
||||||
|
f'ports=[{self._port0}, {self._port1}]'
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Serial open ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _open_serial_readers(self) -> None:
|
||||||
|
"""
|
||||||
|
Attempt to open both serial ports. Failures are non-fatal — the node
|
||||||
|
will publish FIX_NONE until a port becomes available (e.g. anchor
|
||||||
|
hardware plugged in after startup).
|
||||||
|
"""
|
||||||
|
for anchor_id, port_name in enumerate([self._port0, self._port1]):
|
||||||
|
port = self._try_open_port(anchor_id, port_name)
|
||||||
|
if port is not None:
|
||||||
|
reader = AnchorSerialReader(
|
||||||
|
anchor_id=anchor_id,
|
||||||
|
port=port,
|
||||||
|
state=self._state,
|
||||||
|
logger=self.get_logger(),
|
||||||
|
)
|
||||||
|
reader.start()
|
||||||
|
self._readers.append(reader)
|
||||||
|
self.get_logger().info(f'Anchor {anchor_id} opened on {port_name}')
|
||||||
|
else:
|
||||||
|
self.get_logger().warn(
|
||||||
|
f'Anchor {anchor_id} port {port_name} unavailable — '
|
||||||
|
f'will publish FIX_NONE until connected'
|
||||||
|
)
|
||||||
|
|
||||||
|
def _try_open_port(self, anchor_id: int, port_name: str):
|
||||||
|
"""Open serial port; return None on failure."""
|
||||||
|
try:
|
||||||
|
import serial
|
||||||
|
return serial.Serial(port_name, baudrate=self._baud, timeout=0.1)
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().warn(
|
||||||
|
f'Cannot open anchor {anchor_id} serial port {port_name}: {e}'
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ── Publish callback ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _publish(self) -> None:
|
||||||
|
result = self._state.compute()
|
||||||
|
msg = UwbTarget()
|
||||||
|
msg.header.stamp = self.get_clock().now().to_msg()
|
||||||
|
msg.header.frame_id = 'base_link'
|
||||||
|
|
||||||
|
msg.valid = result.valid
|
||||||
|
msg.bearing_deg = float(result.bearing_deg)
|
||||||
|
msg.distance_m = float(result.distance_m)
|
||||||
|
msg.confidence = float(result.confidence)
|
||||||
|
msg.anchor0_dist_m = float(result.anchor0_dist)
|
||||||
|
msg.anchor1_dist_m = float(result.anchor1_dist)
|
||||||
|
msg.baseline_m = float(result.baseline_m)
|
||||||
|
msg.fix_quality = int(result.fix_quality)
|
||||||
|
|
||||||
|
self._pub.publish(msg)
|
||||||
|
|
||||||
|
# ── Cleanup ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def destroy_node(self) -> None:
|
||||||
|
for r in self._readers:
|
||||||
|
r.stop()
|
||||||
|
super().destroy_node()
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = UwbNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -0,0 +1,125 @@
|
|||||||
|
"""
|
||||||
|
velocity_ramp_node.py — Smooth velocity ramp controller (Issue #350).
|
||||||
|
|
||||||
|
Subscribes to raw /cmd_vel commands and republishes them on /cmd_vel_smooth
|
||||||
|
after applying independent acceleration and deceleration limits to the linear-x
|
||||||
|
and angular-z components.
|
||||||
|
|
||||||
|
An emergency stop (both linear and angular targets exactly 0.0) bypasses the
|
||||||
|
ramp and forces the output to zero immediately.
|
||||||
|
|
||||||
|
Subscribes
|
||||||
|
----------
|
||||||
|
/cmd_vel geometry_msgs/Twist raw velocity commands
|
||||||
|
|
||||||
|
Publishes
|
||||||
|
---------
|
||||||
|
/cmd_vel_smooth geometry_msgs/Twist rate-limited velocity commands
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
max_lin_accel float 0.5 Maximum linear acceleration (m/s²)
|
||||||
|
max_lin_decel float 0.5 Maximum linear deceleration (m/s²)
|
||||||
|
max_ang_accel float 1.0 Maximum angular acceleration (rad/s²)
|
||||||
|
max_ang_decel float 1.0 Maximum angular deceleration (rad/s²)
|
||||||
|
rate_hz float 50.0 Control loop rate (Hz)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||||
|
|
||||||
|
from geometry_msgs.msg import Twist
|
||||||
|
|
||||||
|
from ._velocity_ramp import VelocityRamp
|
||||||
|
|
||||||
|
|
||||||
|
_SUB_QOS = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.RELIABLE,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=10,
|
||||||
|
)
|
||||||
|
_PUB_QOS = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.RELIABLE,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VelocityRampNode(Node):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__('velocity_ramp_node')
|
||||||
|
|
||||||
|
# ── Parameters ──────────────────────────────────────────────────────
|
||||||
|
self.declare_parameter('max_lin_accel', 0.5)
|
||||||
|
self.declare_parameter('max_lin_decel', 0.5)
|
||||||
|
self.declare_parameter('max_ang_accel', 1.0)
|
||||||
|
self.declare_parameter('max_ang_decel', 1.0)
|
||||||
|
self.declare_parameter('rate_hz', 50.0)
|
||||||
|
|
||||||
|
p = self.get_parameter
|
||||||
|
rate_hz = max(float(p('rate_hz').value), 1.0)
|
||||||
|
max_lin_acc = max(float(p('max_lin_accel').value), 1e-3)
|
||||||
|
max_lin_dec = max(float(p('max_lin_decel').value), 1e-3)
|
||||||
|
max_ang_acc = max(float(p('max_ang_accel').value), 1e-3)
|
||||||
|
max_ang_dec = max(float(p('max_ang_decel').value), 1e-3)
|
||||||
|
|
||||||
|
dt = 1.0 / rate_hz
|
||||||
|
|
||||||
|
# ── Ramp state ───────────────────────────────────────────────────────
|
||||||
|
self._ramp = VelocityRamp(
|
||||||
|
dt = dt,
|
||||||
|
max_lin_accel = max_lin_acc,
|
||||||
|
max_lin_decel = max_lin_dec,
|
||||||
|
max_ang_accel = max_ang_acc,
|
||||||
|
max_ang_decel = max_ang_dec,
|
||||||
|
)
|
||||||
|
self._target_lin: float = 0.0
|
||||||
|
self._target_ang: float = 0.0
|
||||||
|
|
||||||
|
# ── Subscriber ───────────────────────────────────────────────────────
|
||||||
|
self._sub = self.create_subscription(
|
||||||
|
Twist, '/cmd_vel', self._on_cmd_vel, _SUB_QOS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Publisher + timer ────────────────────────────────────────────────
|
||||||
|
self._pub = self.create_publisher(Twist, '/cmd_vel_smooth', _PUB_QOS)
|
||||||
|
self._timer = self.create_timer(dt, self._step)
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f'velocity_ramp_node ready — '
|
||||||
|
f'rate={rate_hz:.0f}Hz '
|
||||||
|
f'lin_accel={max_lin_acc}m/s² lin_decel={max_lin_dec}m/s² '
|
||||||
|
f'ang_accel={max_ang_acc}rad/s² ang_decel={max_ang_dec}rad/s²'
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Callbacks ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _on_cmd_vel(self, msg: Twist) -> None:
|
||||||
|
self._target_lin = float(msg.linear.x)
|
||||||
|
self._target_ang = float(msg.angular.z)
|
||||||
|
|
||||||
|
def _step(self) -> None:
|
||||||
|
lin, ang = self._ramp.step(self._target_lin, self._target_ang)
|
||||||
|
|
||||||
|
out = Twist()
|
||||||
|
out.linear.x = lin
|
||||||
|
out.angular.z = ang
|
||||||
|
self._pub.publish(out)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = VelocityRampNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
97
jetson/ros2_ws/src/saltybot_bringup/scripts/audio_router.py
Normal file
97
jetson/ros2_ws/src/saltybot_bringup/scripts/audio_router.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
MageDok Audio Router
|
||||||
|
Routes HDMI audio from DisplayPort adapter to internal speakers via PulseAudio
|
||||||
|
"""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from std_msgs.msg import String
|
||||||
|
|
||||||
|
|
||||||
|
class AudioRouter(Node):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('audio_router')
|
||||||
|
|
||||||
|
self.declare_parameter('hdmi_sink', 'alsa_output.pci-0000_00_1d.0.hdmi-stereo')
|
||||||
|
self.declare_parameter('default_sink', True)
|
||||||
|
|
||||||
|
self.hdmi_sink = self.get_parameter('hdmi_sink').value
|
||||||
|
self.set_default = self.get_parameter('default_sink').value
|
||||||
|
|
||||||
|
self.audio_status_pub = self.create_publisher(String, '/magedok/audio_status', 10)
|
||||||
|
|
||||||
|
self.get_logger().info('Audio Router: Configuring HDMI audio routing...')
|
||||||
|
self.setup_pulseaudio()
|
||||||
|
|
||||||
|
# Check status every 5 seconds
|
||||||
|
self.create_timer(5.0, self.check_audio_status)
|
||||||
|
|
||||||
|
def setup_pulseaudio(self):
|
||||||
|
"""Configure PulseAudio to route HDMI audio"""
|
||||||
|
try:
|
||||||
|
# List available sinks
|
||||||
|
result = subprocess.run(['pactl', 'list', 'sinks'], capture_output=True, text=True, timeout=5)
|
||||||
|
sinks = self._parse_pa_sinks(result.stdout)
|
||||||
|
|
||||||
|
if not sinks:
|
||||||
|
self.get_logger().warn('No PulseAudio sinks detected')
|
||||||
|
return
|
||||||
|
|
||||||
|
self.get_logger().info(f'Available sinks: {", ".join(sinks.keys())}')
|
||||||
|
|
||||||
|
# Find HDMI or use first available
|
||||||
|
hdmi_sink = None
|
||||||
|
for name in sinks.keys():
|
||||||
|
if 'hdmi' in name.lower() or 'HDMI' in name:
|
||||||
|
hdmi_sink = name
|
||||||
|
break
|
||||||
|
|
||||||
|
if not hdmi_sink:
|
||||||
|
hdmi_sink = list(sinks.keys())[0] # Fallback to first sink
|
||||||
|
self.get_logger().warn(f'HDMI sink not found, using: {hdmi_sink}')
|
||||||
|
else:
|
||||||
|
self.get_logger().info(f'✓ HDMI sink identified: {hdmi_sink}')
|
||||||
|
|
||||||
|
# Set as default if requested
|
||||||
|
if self.set_default:
|
||||||
|
subprocess.run(['pactl', 'set-default-sink', hdmi_sink], timeout=5)
|
||||||
|
self.get_logger().info(f'✓ Audio routed to: {hdmi_sink}')
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f'PulseAudio setup failed: {e}')
|
||||||
|
|
||||||
|
def _parse_pa_sinks(self, pactl_output):
|
||||||
|
"""Parse 'pactl list sinks' output"""
|
||||||
|
sinks = {}
|
||||||
|
current_sink = None
|
||||||
|
for line in pactl_output.split('\n'):
|
||||||
|
if line.startswith('Sink #'):
|
||||||
|
current_sink = line.split('#')[1].strip()
|
||||||
|
elif '\tName: ' in line and current_sink:
|
||||||
|
name = line.split('Name: ')[1].strip()
|
||||||
|
sinks[name] = current_sink
|
||||||
|
return sinks
|
||||||
|
|
||||||
|
def check_audio_status(self):
|
||||||
|
"""Verify audio is properly routed"""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(['pactl', 'get-default-sink'], capture_output=True, text=True, timeout=5)
|
||||||
|
status = String()
|
||||||
|
status.data = result.stdout.strip()
|
||||||
|
self.audio_status_pub.publish(status)
|
||||||
|
self.get_logger().debug(f'Current audio sink: {status.data}')
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().warn(f'Audio status check failed: {e}')
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
rclpy.init(args=args)
|
||||||
|
router = AudioRouter()
|
||||||
|
rclpy.spin(router)
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
88
jetson/ros2_ws/src/saltybot_bringup/scripts/touch_monitor.py
Normal file
88
jetson/ros2_ws/src/saltybot_bringup/scripts/touch_monitor.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
MageDok Touch Input Monitor
|
||||||
|
Verifies USB touch device is recognized and functional
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from std_msgs.msg import String, Bool
|
||||||
|
|
||||||
|
|
||||||
|
class TouchMonitor(Node):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('touch_monitor')
|
||||||
|
|
||||||
|
self.declare_parameter('device_name', 'MageDok Touch')
|
||||||
|
self.declare_parameter('poll_interval', 0.1)
|
||||||
|
|
||||||
|
self.device_name = self.get_parameter('device_name').value
|
||||||
|
self.poll_interval = self.get_parameter('poll_interval').value
|
||||||
|
|
||||||
|
self.touch_status_pub = self.create_publisher(Bool, '/magedok/touch_status', 10)
|
||||||
|
self.device_info_pub = self.create_publisher(String, '/magedok/device_info', 10)
|
||||||
|
|
||||||
|
self.get_logger().info(f'Touch Monitor: Scanning for {self.device_name}...')
|
||||||
|
self.detect_touch_device()
|
||||||
|
|
||||||
|
# Publish status every 2 seconds
|
||||||
|
self.create_timer(2.0, self.publish_status)
|
||||||
|
|
||||||
|
def detect_touch_device(self):
|
||||||
|
"""Detect MageDok touch device via USB"""
|
||||||
|
try:
|
||||||
|
# Check lsusb for MageDok or eGTouch device
|
||||||
|
result = subprocess.run(['lsusb'], capture_output=True, text=True, timeout=5)
|
||||||
|
lines = result.stdout.split('\n')
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if 'eGTouch' in line or 'EETI' in line or 'MageDok' in line or 'touch' in line.lower():
|
||||||
|
self.get_logger().info(f'✓ Touch device found: {line.strip()}')
|
||||||
|
msg = String()
|
||||||
|
msg.data = line.strip()
|
||||||
|
self.device_info_pub.publish(msg)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Fallback: check input devices
|
||||||
|
result = subprocess.run(['grep', '-l', 'eGTouch\|EETI\|MageDok', '/proc/bus/input/devices'],
|
||||||
|
capture_output=True, text=True, timeout=5)
|
||||||
|
if result.returncode == 0:
|
||||||
|
self.get_logger().info('✓ Touch device registered in /proc/bus/input/devices')
|
||||||
|
return True
|
||||||
|
|
||||||
|
self.get_logger().warn('⚠ Touch device not detected — ensure USB connection is secure')
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f'Device detection failed: {e}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
def publish_status(self):
|
||||||
|
"""Publish current touch device status"""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(['ls', '/dev/magedok-touch'], capture_output=True, timeout=2)
|
||||||
|
status = Bool()
|
||||||
|
status.data = (result.returncode == 0)
|
||||||
|
self.touch_status_pub.publish(status)
|
||||||
|
|
||||||
|
if status.data:
|
||||||
|
self.get_logger().debug('Touch device: ACTIVE')
|
||||||
|
else:
|
||||||
|
self.get_logger().warn('Touch device: NOT DETECTED')
|
||||||
|
except Exception as e:
|
||||||
|
status = Bool()
|
||||||
|
status.data = False
|
||||||
|
self.touch_status_pub.publish(status)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
rclpy.init(args=args)
|
||||||
|
monitor = TouchMonitor()
|
||||||
|
rclpy.spin(monitor)
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -0,0 +1,98 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
MageDok Display Verifier
|
||||||
|
Validates that the 7" display is running at 1024×600 resolution
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
|
||||||
|
|
||||||
|
class DisplayVerifier(Node):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('display_verifier')
|
||||||
|
|
||||||
|
self.declare_parameter('target_width', 1024)
|
||||||
|
self.declare_parameter('target_height', 600)
|
||||||
|
self.declare_parameter('target_refresh', 60)
|
||||||
|
|
||||||
|
self.target_w = self.get_parameter('target_width').value
|
||||||
|
self.target_h = self.get_parameter('target_height').value
|
||||||
|
self.target_f = self.get_parameter('target_refresh').value
|
||||||
|
|
||||||
|
self.get_logger().info(f'Display Verifier: Target {self.target_w}×{self.target_h} @ {self.target_f}Hz')
|
||||||
|
self.verify_display()
|
||||||
|
|
||||||
|
def verify_display(self):
|
||||||
|
"""Check current display resolution via xdotool or xrandr"""
|
||||||
|
try:
|
||||||
|
# Try xrandr first
|
||||||
|
result = subprocess.run(['xrandr'], capture_output=True, text=True, timeout=5)
|
||||||
|
if result.returncode == 0:
|
||||||
|
self.parse_xrandr(result.stdout)
|
||||||
|
else:
|
||||||
|
self.get_logger().warn('xrandr not available, checking edid-decode')
|
||||||
|
self.check_edid()
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f'Display verification failed: {e}')
|
||||||
|
|
||||||
|
def parse_xrandr(self, output):
|
||||||
|
"""Parse xrandr output to find active display resolution"""
|
||||||
|
lines = output.split('\n')
|
||||||
|
for line in lines:
|
||||||
|
# Look for connected display with resolution
|
||||||
|
if 'connected' in line and 'primary' in line:
|
||||||
|
# Example: "HDMI-1 connected primary 1024x600+0+0 (normal left inverted right)"
|
||||||
|
match = re.search(r'(\d+)x(\d+)', line)
|
||||||
|
if match:
|
||||||
|
width, height = int(match.group(1)), int(match.group(2))
|
||||||
|
self.verify_resolution(width, height)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.get_logger().warn('Could not determine active display from xrandr')
|
||||||
|
|
||||||
|
def verify_resolution(self, current_w, current_h):
|
||||||
|
"""Validate resolution matches target"""
|
||||||
|
if current_w == self.target_w and current_h == self.target_h:
|
||||||
|
self.get_logger().info(f'✓ Display verified: {current_w}×{current_h} [OK]')
|
||||||
|
else:
|
||||||
|
self.get_logger().warn(f'⚠ Display mismatch: Expected {self.target_w}×{self.target_h}, got {current_w}×{current_h}')
|
||||||
|
self.attempt_set_resolution()
|
||||||
|
|
||||||
|
def attempt_set_resolution(self):
|
||||||
|
"""Try to set resolution via xrandr"""
|
||||||
|
try:
|
||||||
|
# Find HDMI output
|
||||||
|
result = subprocess.run(
|
||||||
|
['xrandr', '--output', 'HDMI-1', '--mode', f'{self.target_w}x{self.target_h}', '--rate', str(self.target_f)],
|
||||||
|
capture_output=True, text=True, timeout=5
|
||||||
|
)
|
||||||
|
if result.returncode == 0:
|
||||||
|
self.get_logger().info(f'✓ Resolution set to {self.target_w}×{self.target_h} @ {self.target_f}Hz')
|
||||||
|
else:
|
||||||
|
self.get_logger().warn(f'Resolution change failed: {result.stderr}')
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f'Could not set resolution: {e}')
|
||||||
|
|
||||||
|
def check_edid(self):
|
||||||
|
"""Fallback: check EDID (Extended Display ID) data"""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(['edid-decode', '/sys/class/drm/card0-HDMI-A-1/edid'],
|
||||||
|
capture_output=True, text=True, timeout=5)
|
||||||
|
if 'Established timings' in result.stdout:
|
||||||
|
self.get_logger().info('Display EDID detected (MageDok 1024×600 display)')
|
||||||
|
except:
|
||||||
|
self.get_logger().warn('EDID check unavailable')
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
rclpy.init(args=args)
|
||||||
|
verifier = DisplayVerifier()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -51,6 +51,22 @@ setup(
|
|||||||
'wheel_odom = saltybot_bringup.wheel_odom_node:main',
|
'wheel_odom = saltybot_bringup.wheel_odom_node:main',
|
||||||
# Appearance-based person re-identification (Issue #322)
|
# Appearance-based person re-identification (Issue #322)
|
||||||
'person_reid = saltybot_bringup.person_reid_node:main',
|
'person_reid = saltybot_bringup.person_reid_node:main',
|
||||||
|
# Dynamic obstacle velocity estimator (Issue #326)
|
||||||
|
'obstacle_velocity = saltybot_bringup.obstacle_velocity_node:main',
|
||||||
|
# Lane/path edge detector (Issue #339)
|
||||||
|
'path_edges = saltybot_bringup.path_edges_node:main',
|
||||||
|
# Depth-based obstacle size estimator (Issue #348)
|
||||||
|
'obstacle_size = saltybot_bringup.obstacle_size_node:main',
|
||||||
|
# Audio scene classifier (Issue #353)
|
||||||
|
'audio_scene = saltybot_bringup.audio_scene_node:main',
|
||||||
|
# Face emotion classifier (Issue #359)
|
||||||
|
'face_emotion = saltybot_bringup.face_emotion_node:main',
|
||||||
|
# Person tracking for follow-me mode (Issue #363)
|
||||||
|
'person_tracking = saltybot_bringup.person_tracking_node:main',
|
||||||
|
# UWB DW3000 anchor/tag ranging (Issue #365)
|
||||||
|
'uwb_node = saltybot_bringup.uwb_node:main',
|
||||||
|
# Smooth velocity ramp controller (Issue #350)
|
||||||
|
'velocity_ramp = saltybot_bringup.velocity_ramp_node:main',
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@ -0,0 +1,26 @@
|
|||||||
|
[Unit]
|
||||||
|
Description=MageDok 7" Display Setup and Auto-Launch
|
||||||
|
Documentation=https://gitea.vayrette.com/seb/saltylab-firmware/issues/369
|
||||||
|
After=network-online.target
|
||||||
|
Wants=network-online.target
|
||||||
|
ConditionPathExists=/dev/pts/0
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=oneshot
|
||||||
|
ExecStartPre=/bin/sleep 2
|
||||||
|
ExecStart=/usr/bin/env bash -c 'source /opt/ros/jazzy/setup.bash && ros2 launch saltybot_bringup magedok_display.launch.py'
|
||||||
|
ExecStartPost=/usr/bin/env bash -c 'DISPLAY=:0 /usr/bin/startx -- :0 vt7 -nolisten tcp 2>/dev/null &'
|
||||||
|
|
||||||
|
StandardOutput=journal
|
||||||
|
StandardError=journal
|
||||||
|
SyslogIdentifier=magedok-display
|
||||||
|
User=orin
|
||||||
|
Group=orin
|
||||||
|
Environment="DISPLAY=:0"
|
||||||
|
Environment="XAUTHORITY=/home/orin/.Xauthority"
|
||||||
|
|
||||||
|
Restart=on-failure
|
||||||
|
RestartSec=5
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
444
jetson/ros2_ws/src/saltybot_bringup/test/test_audio_scene.py
Normal file
444
jetson/ros2_ws/src/saltybot_bringup/test/test_audio_scene.py
Normal file
@ -0,0 +1,444 @@
|
|||||||
|
"""
|
||||||
|
test_audio_scene.py — Unit tests for the audio scene classifier (Issue #353).
|
||||||
|
|
||||||
|
All tests use synthetic audio signals generated with numpy — no microphone,
|
||||||
|
no audio files, no ROS2 runtime needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from saltybot_bringup._audio_scene import (
|
||||||
|
SCENE_LABELS,
|
||||||
|
AudioSceneResult,
|
||||||
|
NearestCentroidClassifier,
|
||||||
|
_build_centroids,
|
||||||
|
_make_prototype,
|
||||||
|
classify_audio,
|
||||||
|
extract_features,
|
||||||
|
_frame,
|
||||||
|
_mel_filterbank,
|
||||||
|
_dct2,
|
||||||
|
_N_FEATURES,
|
||||||
|
_SR_DEFAULT,
|
||||||
|
_CLASSIFIER,
|
||||||
|
_CENTROIDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Helpers
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _sine(freq_hz: float, duration_s: float = 1.0, sr: int = _SR_DEFAULT,
|
||||||
|
amp: float = 0.5) -> np.ndarray:
|
||||||
|
"""Generate a pure sine wave."""
|
||||||
|
n = int(sr * duration_s)
|
||||||
|
t = np.linspace(0.0, duration_s, n, endpoint=False)
|
||||||
|
return (amp * np.sin(2.0 * math.pi * freq_hz * t)).astype(np.float64)
|
||||||
|
|
||||||
|
|
||||||
|
def _white_noise(duration_s: float = 1.0, sr: int = _SR_DEFAULT,
|
||||||
|
amp: float = 0.1, seed: int = 0) -> np.ndarray:
|
||||||
|
"""White noise."""
|
||||||
|
rng = np.random.RandomState(seed)
|
||||||
|
n = int(sr * duration_s)
|
||||||
|
return (amp * rng.randn(n)).astype(np.float64)
|
||||||
|
|
||||||
|
|
||||||
|
def _silence(duration_s: float = 1.0, sr: int = _SR_DEFAULT) -> np.ndarray:
|
||||||
|
return np.zeros(int(sr * duration_s), dtype=np.float64)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# SCENE_LABELS
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_scene_labels_tuple():
|
||||||
|
assert isinstance(SCENE_LABELS, tuple)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scene_labels_count():
|
||||||
|
assert len(SCENE_LABELS) == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_scene_labels_values():
|
||||||
|
assert set(SCENE_LABELS) == {'indoor', 'outdoor', 'traffic', 'park'}
|
||||||
|
|
||||||
|
|
||||||
|
def test_scene_labels_order():
|
||||||
|
assert SCENE_LABELS[0] == 'indoor'
|
||||||
|
assert SCENE_LABELS[1] == 'outdoor'
|
||||||
|
assert SCENE_LABELS[2] == 'traffic'
|
||||||
|
assert SCENE_LABELS[3] == 'park'
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# _frame
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_frame_shape():
|
||||||
|
x = np.arange(1000.0)
|
||||||
|
frames = _frame(x, 400, 160)
|
||||||
|
assert frames.ndim == 2
|
||||||
|
assert frames.shape[1] == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_frame_n_frames():
|
||||||
|
x = np.arange(1000.0)
|
||||||
|
frames = _frame(x, 400, 160)
|
||||||
|
expected = 1 + (1000 - 400) // 160
|
||||||
|
assert frames.shape[0] == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_frame_first_frame():
|
||||||
|
x = np.arange(100.0)
|
||||||
|
frames = _frame(x, 10, 5)
|
||||||
|
np.testing.assert_array_equal(frames[0], np.arange(10.0))
|
||||||
|
|
||||||
|
|
||||||
|
def test_frame_second_frame():
|
||||||
|
x = np.arange(100.0)
|
||||||
|
frames = _frame(x, 10, 5)
|
||||||
|
np.testing.assert_array_equal(frames[1], np.arange(5.0, 15.0))
|
||||||
|
|
||||||
|
|
||||||
|
def test_frame_short_signal():
|
||||||
|
"""Signal shorter than one frame should still return one frame."""
|
||||||
|
x = np.ones(10)
|
||||||
|
frames = _frame(x, 400, 160)
|
||||||
|
assert frames.shape[0] == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# _mel_filterbank
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_mel_filterbank_shape():
|
||||||
|
fbank = _mel_filterbank(16000, 512, 26)
|
||||||
|
assert fbank.shape == (26, 512 // 2 + 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mel_filterbank_non_negative():
|
||||||
|
fbank = _mel_filterbank(16000, 512, 26)
|
||||||
|
assert (fbank >= 0).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_mel_filterbank_max_one():
|
||||||
|
fbank = _mel_filterbank(16000, 512, 26)
|
||||||
|
assert fbank.max() <= 1.0 + 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def test_mel_filterbank_non_zero():
|
||||||
|
fbank = _mel_filterbank(16000, 512, 26)
|
||||||
|
assert fbank.sum() > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# _dct2
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_dct2_shape():
|
||||||
|
x = np.random.randn(10, 26)
|
||||||
|
out = _dct2(x)
|
||||||
|
assert out.shape == x.shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_dct2_dc():
|
||||||
|
"""DC (constant) row should produce non-zero first column only."""
|
||||||
|
x = np.ones((1, 8))
|
||||||
|
out = _dct2(x)
|
||||||
|
# First coefficient = sum; rest should be near zero for constant input
|
||||||
|
assert abs(out[0, 0]) > abs(out[0, 1]) * 5
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# extract_features
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_extract_features_length():
|
||||||
|
sig = _sine(440.0)
|
||||||
|
feat = extract_features(sig)
|
||||||
|
assert len(feat) == _N_FEATURES
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_features_dtype():
|
||||||
|
sig = _sine(440.0)
|
||||||
|
feat = extract_features(sig)
|
||||||
|
assert feat.dtype in (np.float32, np.float64)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_features_silence_finite():
|
||||||
|
"""Features from silence should be finite (no NaN/Inf)."""
|
||||||
|
feat = extract_features(_silence())
|
||||||
|
assert np.all(np.isfinite(feat))
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_features_sine_finite():
|
||||||
|
feat = extract_features(_sine(440.0))
|
||||||
|
assert np.all(np.isfinite(feat))
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_features_noise_finite():
|
||||||
|
feat = extract_features(_white_noise())
|
||||||
|
assert np.all(np.isfinite(feat))
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_features_centroid_low_freq():
|
||||||
|
"""Low-frequency sine → low spectral centroid (feature index 13)."""
|
||||||
|
feat_lo = extract_features(_sine(80.0))
|
||||||
|
feat_hi = extract_features(_sine(4000.0))
|
||||||
|
assert feat_lo[13] < feat_hi[13]
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_features_centroid_high_freq():
|
||||||
|
"""High-frequency sine → high spectral centroid."""
|
||||||
|
feat_4k = extract_features(_sine(4000.0))
|
||||||
|
feat_200 = extract_features(_sine(200.0))
|
||||||
|
assert feat_4k[13] > feat_200[13]
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_features_rolloff_low_freq():
|
||||||
|
"""Low-frequency sine → low spectral rolloff (feature index 14)."""
|
||||||
|
feat_lo = extract_features(_sine(100.0))
|
||||||
|
feat_hi = extract_features(_sine(5000.0))
|
||||||
|
assert feat_lo[14] < feat_hi[14]
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_features_zcr_increases_with_freq():
|
||||||
|
"""Higher-frequency sine → higher ZCR (feature index 15)."""
|
||||||
|
feat_lo = extract_features(_sine(100.0))
|
||||||
|
feat_hi = extract_features(_sine(3000.0))
|
||||||
|
assert feat_lo[15] < feat_hi[15]
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_features_short_clip():
|
||||||
|
"""Clips shorter than one frame should not crash."""
|
||||||
|
sig = np.zeros(100)
|
||||||
|
feat = extract_features(sig)
|
||||||
|
assert len(feat) == _N_FEATURES
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_features_empty_clip():
|
||||||
|
"""Empty signal should return zero vector."""
|
||||||
|
feat = extract_features(np.array([]))
|
||||||
|
assert len(feat) == _N_FEATURES
|
||||||
|
assert np.all(feat == 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# _make_prototype
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_prototype_length():
|
||||||
|
for label in SCENE_LABELS:
|
||||||
|
sig = _make_prototype(label)
|
||||||
|
assert len(sig) == _SR_DEFAULT
|
||||||
|
|
||||||
|
|
||||||
|
def test_prototype_finite():
|
||||||
|
for label in SCENE_LABELS:
|
||||||
|
sig = _make_prototype(label)
|
||||||
|
assert np.all(np.isfinite(sig))
|
||||||
|
|
||||||
|
|
||||||
|
def test_prototype_unknown_label():
|
||||||
|
with pytest.raises(ValueError, match='Unknown scene label'):
|
||||||
|
_make_prototype('underwater')
|
||||||
|
|
||||||
|
|
||||||
|
def test_prototype_deterministic():
|
||||||
|
"""Same seed → same prototype."""
|
||||||
|
rng1 = np.random.RandomState(42)
|
||||||
|
rng2 = np.random.RandomState(42)
|
||||||
|
sig1 = _make_prototype('indoor', rng=rng1)
|
||||||
|
sig2 = _make_prototype('indoor', rng=rng2)
|
||||||
|
np.testing.assert_array_equal(sig1, sig2)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# _build_centroids
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_build_centroids_shape():
|
||||||
|
C = _build_centroids()
|
||||||
|
assert C.shape == (4, _N_FEATURES)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_centroids_finite():
|
||||||
|
C = _build_centroids()
|
||||||
|
assert np.all(np.isfinite(C))
|
||||||
|
|
||||||
|
|
||||||
|
def test_centroids_traffic_low_centroid():
|
||||||
|
"""Traffic prototype centroid (col 13) should be the lowest."""
|
||||||
|
C = _build_centroids()
|
||||||
|
idx_traffic = list(SCENE_LABELS).index('traffic')
|
||||||
|
assert C[idx_traffic, 13] == C[:, 13].min()
|
||||||
|
|
||||||
|
|
||||||
|
def test_centroids_park_high_centroid():
|
||||||
|
"""Park prototype centroid should have the highest spectral centroid."""
|
||||||
|
C = _build_centroids()
|
||||||
|
idx_park = list(SCENE_LABELS).index('park')
|
||||||
|
assert C[idx_park, 13] == C[:, 13].max()
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# NearestCentroidClassifier
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_classifier_predict_returns_known_label():
|
||||||
|
feat = extract_features(_sine(440.0))
|
||||||
|
label, conf = _CLASSIFIER.predict(feat)
|
||||||
|
assert label in SCENE_LABELS
|
||||||
|
|
||||||
|
|
||||||
|
def test_classifier_confidence_range():
|
||||||
|
feat = extract_features(_sine(440.0))
|
||||||
|
_, conf = _CLASSIFIER.predict(feat)
|
||||||
|
assert 0.0 < conf <= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_classifier_prototype_self_predict():
|
||||||
|
"""Each prototype signal should classify as its own class."""
|
||||||
|
rng = np.random.RandomState(42)
|
||||||
|
for label in SCENE_LABELS:
|
||||||
|
sig = _make_prototype(label, rng=rng)
|
||||||
|
feat = extract_features(sig)
|
||||||
|
pred, _ = _CLASSIFIER.predict(feat)
|
||||||
|
assert pred == label, (
|
||||||
|
f'Prototype {label!r} classified as {pred!r}'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_classifier_custom_centroids():
|
||||||
|
"""Simple 1-D centroid classifier sanity check."""
|
||||||
|
# Two classes: 'low' centroid at 0, 'high' centroid at 10
|
||||||
|
C = np.array([[0.0], [10.0]])
|
||||||
|
clf = NearestCentroidClassifier(C, ('low', 'high'))
|
||||||
|
assert clf.predict(np.array([1.0]))[0] == 'low'
|
||||||
|
assert clf.predict(np.array([9.0]))[0] == 'high'
|
||||||
|
|
||||||
|
|
||||||
|
def test_classifier_confidence_max_at_centroid():
|
||||||
|
"""Feature equal to a centroid → maximum confidence for that class."""
|
||||||
|
C = _CENTROIDS
|
||||||
|
clf = NearestCentroidClassifier(C, SCENE_LABELS)
|
||||||
|
for i, label in enumerate(SCENE_LABELS):
|
||||||
|
pred, conf = clf.predict(C[i])
|
||||||
|
assert pred == label
|
||||||
|
# dist = 0 → conf = 1/(1+0) = 1.0
|
||||||
|
assert abs(conf - 1.0) < 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# classify_audio — end-to-end
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_classify_audio_returns_result():
|
||||||
|
res = classify_audio(_sine(440.0))
|
||||||
|
assert isinstance(res, AudioSceneResult)
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_audio_label_valid():
|
||||||
|
res = classify_audio(_sine(440.0))
|
||||||
|
assert res.label in SCENE_LABELS
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_audio_confidence_range():
|
||||||
|
res = classify_audio(_sine(440.0))
|
||||||
|
assert 0.0 < res.confidence <= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_audio_features_length():
|
||||||
|
res = classify_audio(_sine(440.0))
|
||||||
|
assert len(res.features) == _N_FEATURES
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_traffic_low_freq():
|
||||||
|
"""80 Hz sine → traffic."""
|
||||||
|
res = classify_audio(_sine(80.0))
|
||||||
|
assert res.label == 'traffic', (
|
||||||
|
f'Expected traffic, got {res.label!r} (conf={res.confidence:.2f})'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_indoor_440():
|
||||||
|
"""440 Hz sine → indoor."""
|
||||||
|
res = classify_audio(_sine(440.0))
|
||||||
|
assert res.label == 'indoor', (
|
||||||
|
f'Expected indoor, got {res.label!r} (conf={res.confidence:.2f})'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_outdoor_mid_freq():
|
||||||
|
"""Mixed 1000+2000 Hz → outdoor."""
|
||||||
|
sr = _SR_DEFAULT
|
||||||
|
t = np.linspace(0.0, 1.0, sr, endpoint=False)
|
||||||
|
sig = (
|
||||||
|
0.4 * np.sin(2 * math.pi * 1000.0 * t)
|
||||||
|
+ 0.4 * np.sin(2 * math.pi * 2000.0 * t)
|
||||||
|
)
|
||||||
|
res = classify_audio(sig)
|
||||||
|
assert res.label == 'outdoor', (
|
||||||
|
f'Expected outdoor, got {res.label!r} (conf={res.confidence:.2f})'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_park_high_freq():
|
||||||
|
"""3200+4800 Hz tones → park."""
|
||||||
|
sr = _SR_DEFAULT
|
||||||
|
t = np.linspace(0.0, 1.0, sr, endpoint=False)
|
||||||
|
sig = (
|
||||||
|
0.4 * np.sin(2 * math.pi * 3200.0 * t)
|
||||||
|
+ 0.3 * np.sin(2 * math.pi * 4800.0 * t)
|
||||||
|
)
|
||||||
|
res = classify_audio(sig)
|
||||||
|
assert res.label == 'park', (
|
||||||
|
f'Expected park, got {res.label!r} (conf={res.confidence:.2f})'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_silence_no_crash():
|
||||||
|
"""Silence should return a valid (if low-confidence) result."""
|
||||||
|
res = classify_audio(_silence())
|
||||||
|
assert res.label in SCENE_LABELS
|
||||||
|
assert np.isfinite(res.confidence)
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_audio_short_clip():
|
||||||
|
"""Short clip (200 ms) should not crash."""
|
||||||
|
sig = _sine(440.0, duration_s=0.2)
|
||||||
|
res = classify_audio(sig)
|
||||||
|
assert res.label in SCENE_LABELS
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_prototype_indoor():
|
||||||
|
rng = np.random.RandomState(42)
|
||||||
|
sig = _make_prototype('indoor', rng=rng)
|
||||||
|
res = classify_audio(sig)
|
||||||
|
assert res.label == 'indoor'
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_prototype_outdoor():
|
||||||
|
rng = np.random.RandomState(42)
|
||||||
|
sig = _make_prototype('outdoor', rng=rng)
|
||||||
|
res = classify_audio(sig)
|
||||||
|
assert res.label == 'outdoor'
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_prototype_traffic():
|
||||||
|
rng = np.random.RandomState(42)
|
||||||
|
sig = _make_prototype('traffic', rng=rng)
|
||||||
|
res = classify_audio(sig)
|
||||||
|
assert res.label == 'traffic'
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_prototype_park():
|
||||||
|
rng = np.random.RandomState(42)
|
||||||
|
sig = _make_prototype('park', rng=rng)
|
||||||
|
res = classify_audio(sig)
|
||||||
|
assert res.label == 'park'
|
||||||
@ -0,0 +1,575 @@
|
|||||||
|
"""
|
||||||
|
test_camera_power_manager.py — Unit tests for _camera_power_manager.py (Issue #375).
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- Mode sensor configurations
|
||||||
|
- Speed-driven upgrade transitions
|
||||||
|
- Speed-driven downgrade with hysteresis
|
||||||
|
- Scenario overrides (CROSSING, EMERGENCY, PARKED, INDOOR)
|
||||||
|
- Battery low cap
|
||||||
|
- Idle → SOCIAL transition
|
||||||
|
- Safety invariants (rear CSI in ACTIVE/FULL, CROSSING cannot downgrade)
|
||||||
|
- Reset
|
||||||
|
- ModeDecision fields
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from saltybot_bringup._camera_power_manager import (
|
||||||
|
ActiveSensors,
|
||||||
|
CameraMode,
|
||||||
|
CameraPowerFSM,
|
||||||
|
MODE_SENSORS,
|
||||||
|
ModeDecision,
|
||||||
|
Scenario,
|
||||||
|
_SPD_ACTIVE_DOWN,
|
||||||
|
_SPD_ACTIVE_UP,
|
||||||
|
_SPD_FULL_DOWN,
|
||||||
|
_SPD_FULL_UP,
|
||||||
|
_SPD_MOTION,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Helpers
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _FakeClock:
|
||||||
|
"""Injectable monotonic clock for deterministic tests."""
|
||||||
|
def __init__(self, t: float = 0.0) -> None:
|
||||||
|
self.t = t
|
||||||
|
def __call__(self) -> float:
|
||||||
|
return self.t
|
||||||
|
def advance(self, dt: float) -> None:
|
||||||
|
self.t += dt
|
||||||
|
|
||||||
|
|
||||||
|
def _fsm(hold: float = 5.0, idle: float = 30.0, bat_low: float = 20.0,
|
||||||
|
clock=None) -> CameraPowerFSM:
|
||||||
|
c = clock or _FakeClock()
|
||||||
|
return CameraPowerFSM(
|
||||||
|
downgrade_hold_s=hold,
|
||||||
|
idle_to_social_s=idle,
|
||||||
|
battery_low_pct=bat_low,
|
||||||
|
clock=c,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Mode sensor configurations
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestModeSensors:
|
||||||
|
|
||||||
|
def test_sleep_no_sensors(self):
|
||||||
|
s = MODE_SENSORS[CameraMode.SLEEP]
|
||||||
|
assert s.active_count == 0
|
||||||
|
|
||||||
|
def test_social_webcam_only(self):
|
||||||
|
s = MODE_SENSORS[CameraMode.SOCIAL]
|
||||||
|
assert s.webcam is True
|
||||||
|
assert s.csi_front is False
|
||||||
|
assert s.realsense is False
|
||||||
|
assert s.lidar is False
|
||||||
|
|
||||||
|
def test_aware_front_realsense_lidar(self):
|
||||||
|
s = MODE_SENSORS[CameraMode.AWARE]
|
||||||
|
assert s.csi_front is True
|
||||||
|
assert s.realsense is True
|
||||||
|
assert s.lidar is True
|
||||||
|
assert s.csi_rear is False
|
||||||
|
assert s.csi_left is False
|
||||||
|
assert s.csi_right is False
|
||||||
|
assert s.uwb is False
|
||||||
|
|
||||||
|
def test_active_front_rear_realsense_lidar_uwb(self):
|
||||||
|
s = MODE_SENSORS[CameraMode.ACTIVE]
|
||||||
|
assert s.csi_front is True
|
||||||
|
assert s.csi_rear is True
|
||||||
|
assert s.realsense is True
|
||||||
|
assert s.lidar is True
|
||||||
|
assert s.uwb is True
|
||||||
|
assert s.csi_left is False
|
||||||
|
assert s.csi_right is False
|
||||||
|
|
||||||
|
def test_full_all_sensors(self):
|
||||||
|
s = MODE_SENSORS[CameraMode.FULL]
|
||||||
|
assert s.csi_front is True
|
||||||
|
assert s.csi_rear is True
|
||||||
|
assert s.csi_left is True
|
||||||
|
assert s.csi_right is True
|
||||||
|
assert s.realsense is True
|
||||||
|
assert s.lidar is True
|
||||||
|
assert s.uwb is True
|
||||||
|
assert s.webcam is False # webcam not needed at speed
|
||||||
|
|
||||||
|
def test_mode_sensor_counts_increase(self):
|
||||||
|
counts = [MODE_SENSORS[m].active_count for m in CameraMode]
|
||||||
|
assert counts == sorted(counts), "Higher modes should have more sensors"
|
||||||
|
|
||||||
|
def test_safety_rear_csi_in_active(self):
|
||||||
|
assert MODE_SENSORS[CameraMode.ACTIVE].csi_rear is True
|
||||||
|
|
||||||
|
def test_safety_rear_csi_in_full(self):
|
||||||
|
assert MODE_SENSORS[CameraMode.FULL].csi_rear is True
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Speed-driven upgrades (instant)
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestSpeedUpgrades:
|
||||||
|
|
||||||
|
def test_no_motion_stays_social(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=0.0)
|
||||||
|
assert d.mode == CameraMode.SOCIAL
|
||||||
|
|
||||||
|
def test_slow_motion_upgrades_to_aware(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=_SPD_MOTION + 0.1)
|
||||||
|
assert d.mode == CameraMode.AWARE
|
||||||
|
|
||||||
|
def test_5kmh_upgrades_to_active(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=_SPD_ACTIVE_UP + 0.1)
|
||||||
|
assert d.mode == CameraMode.ACTIVE
|
||||||
|
|
||||||
|
def test_15kmh_upgrades_to_full(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=_SPD_FULL_UP + 0.1)
|
||||||
|
assert d.mode == CameraMode.FULL
|
||||||
|
|
||||||
|
def test_upgrades_skip_intermediate_modes(self):
|
||||||
|
"""At 20 km/h from rest, jumps directly to FULL."""
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=20.0 / 3.6)
|
||||||
|
assert d.mode == CameraMode.FULL
|
||||||
|
|
||||||
|
def test_upgrade_is_immediate_no_hold(self):
|
||||||
|
"""Upgrades do NOT require hold time."""
|
||||||
|
clock = _FakeClock(0.0)
|
||||||
|
f = _fsm(hold=10.0, clock=clock)
|
||||||
|
# Still at t=0
|
||||||
|
d = f.update(speed_mps=_SPD_FULL_UP + 0.5)
|
||||||
|
assert d.mode == CameraMode.FULL
|
||||||
|
|
||||||
|
def test_exactly_at_motion_threshold(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=_SPD_MOTION)
|
||||||
|
assert d.mode == CameraMode.AWARE
|
||||||
|
|
||||||
|
def test_exactly_at_active_threshold(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=_SPD_ACTIVE_UP)
|
||||||
|
assert d.mode == CameraMode.ACTIVE
|
||||||
|
|
||||||
|
def test_exactly_at_full_threshold(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=_SPD_FULL_UP)
|
||||||
|
assert d.mode == CameraMode.FULL
|
||||||
|
|
||||||
|
def test_negative_speed_clamped_to_zero(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=-1.0)
|
||||||
|
assert d.mode == CameraMode.SOCIAL
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Downgrade hysteresis
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestDowngradeHysteresis:
|
||||||
|
|
||||||
|
def _reach_full(self, f: CameraPowerFSM, clock: _FakeClock) -> None:
|
||||||
|
f.update(speed_mps=_SPD_FULL_UP + 0.5)
|
||||||
|
assert f.mode == CameraMode.FULL
|
||||||
|
|
||||||
|
def test_downgrade_not_immediate(self):
|
||||||
|
clock = _FakeClock(0.0)
|
||||||
|
f = _fsm(hold=5.0, clock=clock)
|
||||||
|
self._reach_full(f, clock)
|
||||||
|
# Drop to below FULL threshold but don't advance clock
|
||||||
|
d = f.update(speed_mps=0.0)
|
||||||
|
assert d.mode == CameraMode.FULL # still held
|
||||||
|
|
||||||
|
def test_downgrade_after_hold_expires(self):
|
||||||
|
clock = _FakeClock(0.0)
|
||||||
|
f = _fsm(hold=5.0, clock=clock)
|
||||||
|
self._reach_full(f, clock)
|
||||||
|
f.update(speed_mps=0.0) # first low-speed call starts the hold timer
|
||||||
|
clock.advance(5.1)
|
||||||
|
d = f.update(speed_mps=0.0) # hold expired — downgrade to AWARE (not SOCIAL;
|
||||||
|
# SOCIAL requires an additional idle timer from AWARE, see TestIdleToSocial)
|
||||||
|
assert d.mode == CameraMode.AWARE
|
||||||
|
|
||||||
|
def test_downgrade_cancelled_by_speed_spike(self):
|
||||||
|
"""If speed spikes back up during hold, downgrade is cancelled."""
|
||||||
|
clock = _FakeClock(0.0)
|
||||||
|
f = _fsm(hold=5.0, clock=clock)
|
||||||
|
self._reach_full(f, clock)
|
||||||
|
clock.advance(3.0)
|
||||||
|
f.update(speed_mps=0.0) # start downgrade timer
|
||||||
|
clock.advance(1.0)
|
||||||
|
f.update(speed_mps=_SPD_FULL_UP + 1.0) # back to full speed
|
||||||
|
clock.advance(3.0) # would have expired hold if not cancelled
|
||||||
|
d = f.update(speed_mps=0.0)
|
||||||
|
# Hold restarted; only 3s elapsed since the cancellation reset
|
||||||
|
assert d.mode == CameraMode.FULL
|
||||||
|
|
||||||
|
def test_full_to_active_hysteresis_band(self):
|
||||||
|
"""Speed in [_SPD_FULL_DOWN, _SPD_FULL_UP) while in FULL → stays FULL (hold pending)."""
|
||||||
|
clock = _FakeClock(0.0)
|
||||||
|
f = _fsm(hold=5.0, clock=clock)
|
||||||
|
self._reach_full(f, clock)
|
||||||
|
# Speed in hysteresis band (between down and up thresholds)
|
||||||
|
mid = (_SPD_FULL_DOWN + _SPD_FULL_UP) / 2.0
|
||||||
|
d = f.update(speed_mps=mid)
|
||||||
|
assert d.mode == CameraMode.FULL # pending downgrade, not yet applied
|
||||||
|
|
||||||
|
def test_hold_zero_downgrades_immediately(self):
|
||||||
|
clock = _FakeClock(0.0)
|
||||||
|
f = _fsm(hold=0.0, clock=clock)
|
||||||
|
f.update(speed_mps=_SPD_FULL_UP + 0.5)
|
||||||
|
# hold=0: downgrade applies on the very first low-speed call.
|
||||||
|
# From FULL at speed=0 → AWARE (SOCIAL requires a separate idle timer).
|
||||||
|
d = f.update(speed_mps=0.0)
|
||||||
|
assert d.mode == CameraMode.AWARE
|
||||||
|
|
||||||
|
def test_downgrade_full_to_active_at_13kmh(self):
|
||||||
|
clock = _FakeClock(0.0)
|
||||||
|
f = _fsm(hold=0.0, clock=clock)
|
||||||
|
f.update(speed_mps=_SPD_FULL_UP + 0.5) # → FULL
|
||||||
|
d = f.update(speed_mps=_SPD_FULL_DOWN - 0.05) # just below 13 km/h
|
||||||
|
assert d.mode == CameraMode.ACTIVE
|
||||||
|
|
||||||
|
def test_downgrade_active_to_aware(self):
|
||||||
|
clock = _FakeClock(0.0)
|
||||||
|
f = _fsm(hold=0.0, clock=clock)
|
||||||
|
f.update(speed_mps=_SPD_ACTIVE_UP + 0.5) # → ACTIVE
|
||||||
|
d = f.update(speed_mps=_SPD_ACTIVE_DOWN - 0.05)
|
||||||
|
assert d.mode == CameraMode.AWARE
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Scenario overrides
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestScenarioOverrides:
|
||||||
|
|
||||||
|
def test_crossing_forces_full_from_rest(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=0.0, scenario=Scenario.CROSSING)
|
||||||
|
assert d.mode == CameraMode.FULL
|
||||||
|
assert d.scenario_override is True
|
||||||
|
|
||||||
|
def test_crossing_forces_full_from_aware(self):
|
||||||
|
f = _fsm()
|
||||||
|
f.update(speed_mps=0.5) # AWARE
|
||||||
|
d = f.update(speed_mps=0.5, scenario=Scenario.CROSSING)
|
||||||
|
assert d.mode == CameraMode.FULL
|
||||||
|
|
||||||
|
def test_emergency_forces_full(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=0.0, scenario=Scenario.EMERGENCY)
|
||||||
|
assert d.mode == CameraMode.FULL
|
||||||
|
assert d.scenario_override is True
|
||||||
|
|
||||||
|
def test_crossing_bypasses_hysteresis(self):
|
||||||
|
"""CROSSING forces FULL even if a downgrade is pending."""
|
||||||
|
clock = _FakeClock(0.0)
|
||||||
|
f = _fsm(hold=5.0, clock=clock)
|
||||||
|
f.update(speed_mps=_SPD_FULL_UP + 1.0) # FULL
|
||||||
|
clock.advance(4.0)
|
||||||
|
f.update(speed_mps=0.0) # pending downgrade started
|
||||||
|
d = f.update(speed_mps=0.0, scenario=Scenario.CROSSING)
|
||||||
|
assert d.mode == CameraMode.FULL
|
||||||
|
|
||||||
|
def test_parked_forces_social(self):
|
||||||
|
f = _fsm()
|
||||||
|
f.update(speed_mps=_SPD_FULL_UP + 1.0) # FULL
|
||||||
|
d = f.update(speed_mps=0.0, scenario=Scenario.PARKED)
|
||||||
|
assert d.mode == CameraMode.SOCIAL
|
||||||
|
assert d.scenario_override is True
|
||||||
|
|
||||||
|
def test_indoor_caps_at_aware(self):
|
||||||
|
f = _fsm()
|
||||||
|
# Speed says ACTIVE but indoor caps to AWARE
|
||||||
|
d = f.update(speed_mps=_SPD_ACTIVE_UP + 0.5, scenario=Scenario.INDOOR)
|
||||||
|
assert d.mode == CameraMode.AWARE
|
||||||
|
|
||||||
|
def test_indoor_cannot_reach_full(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=_SPD_FULL_UP + 2.0, scenario=Scenario.INDOOR)
|
||||||
|
assert d.mode == CameraMode.AWARE
|
||||||
|
|
||||||
|
def test_outdoor_uses_speed_logic(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=_SPD_FULL_UP + 0.5, scenario=Scenario.OUTDOOR)
|
||||||
|
assert d.mode == CameraMode.FULL
|
||||||
|
|
||||||
|
def test_unknown_scenario_uses_speed_logic(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=_SPD_ACTIVE_UP + 0.5, scenario=Scenario.UNKNOWN)
|
||||||
|
assert d.mode == CameraMode.ACTIVE
|
||||||
|
|
||||||
|
def test_crossing_sets_all_csi_active(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=0.0, scenario=Scenario.CROSSING)
|
||||||
|
s = d.sensors
|
||||||
|
assert s.csi_front and s.csi_rear and s.csi_left and s.csi_right
|
||||||
|
|
||||||
|
def test_scenario_override_false_for_speed_transition(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=_SPD_ACTIVE_UP + 0.5, scenario=Scenario.OUTDOOR)
|
||||||
|
assert d.scenario_override is False
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Battery low cap
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestBatteryLow:
|
||||||
|
|
||||||
|
def test_battery_low_caps_at_aware(self):
|
||||||
|
f = _fsm(bat_low=20.0)
|
||||||
|
d = f.update(speed_mps=_SPD_FULL_UP + 1.0, battery_pct=15.0)
|
||||||
|
assert d.mode == CameraMode.AWARE
|
||||||
|
|
||||||
|
def test_battery_low_prevents_active(self):
|
||||||
|
f = _fsm(bat_low=20.0)
|
||||||
|
d = f.update(speed_mps=_SPD_ACTIVE_UP + 0.5, battery_pct=19.9)
|
||||||
|
assert d.mode == CameraMode.AWARE
|
||||||
|
|
||||||
|
def test_battery_full_no_cap(self):
|
||||||
|
f = _fsm(bat_low=20.0)
|
||||||
|
d = f.update(speed_mps=_SPD_FULL_UP + 1.0, battery_pct=100.0)
|
||||||
|
assert d.mode == CameraMode.FULL
|
||||||
|
|
||||||
|
def test_battery_at_threshold_not_capped(self):
|
||||||
|
f = _fsm(bat_low=20.0)
|
||||||
|
d = f.update(speed_mps=_SPD_FULL_UP + 1.0, battery_pct=20.0)
|
||||||
|
# At exactly threshold — not below — so no cap
|
||||||
|
assert d.mode == CameraMode.FULL
|
||||||
|
|
||||||
|
def test_battery_low_allows_aware(self):
|
||||||
|
f = _fsm(bat_low=20.0)
|
||||||
|
d = f.update(speed_mps=_SPD_MOTION + 0.1, battery_pct=10.0)
|
||||||
|
assert d.mode == CameraMode.AWARE
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Idle → SOCIAL transition
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestIdleToSocial:
|
||||||
|
|
||||||
|
def test_idle_transitions_to_social_after_timeout(self):
|
||||||
|
clock = _FakeClock(0.0)
|
||||||
|
f = _fsm(idle=10.0, hold=0.0, clock=clock)
|
||||||
|
# Bring to AWARE
|
||||||
|
f.update(speed_mps=_SPD_MOTION + 0.1)
|
||||||
|
# Reduce to near-zero
|
||||||
|
clock.advance(5.0)
|
||||||
|
f.update(speed_mps=0.05) # below 0.1, idle timer starts
|
||||||
|
clock.advance(10.1)
|
||||||
|
d = f.update(speed_mps=0.05)
|
||||||
|
assert d.mode == CameraMode.SOCIAL
|
||||||
|
|
||||||
|
def test_motion_resets_idle_timer(self):
|
||||||
|
clock = _FakeClock(0.0)
|
||||||
|
f = _fsm(idle=10.0, hold=0.0, clock=clock)
|
||||||
|
f.update(speed_mps=_SPD_MOTION + 0.1) # AWARE
|
||||||
|
clock.advance(5.0)
|
||||||
|
f.update(speed_mps=0.05) # idle timer starts
|
||||||
|
clock.advance(5.0)
|
||||||
|
f.update(speed_mps=1.0) # motion resets timer
|
||||||
|
clock.advance(6.0)
|
||||||
|
d = f.update(speed_mps=0.05) # timer restarted, only 6s elapsed
|
||||||
|
assert d.mode == CameraMode.AWARE
|
||||||
|
|
||||||
|
def test_not_idle_when_moving(self):
|
||||||
|
clock = _FakeClock(0.0)
|
||||||
|
f = _fsm(idle=5.0, clock=clock)
|
||||||
|
f.update(speed_mps=_SPD_MOTION + 0.1)
|
||||||
|
clock.advance(100.0)
|
||||||
|
d = f.update(speed_mps=_SPD_MOTION + 0.1) # still moving
|
||||||
|
assert d.mode == CameraMode.AWARE
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Reset
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestReset:
|
||||||
|
|
||||||
|
def test_reset_to_sleep(self):
|
||||||
|
f = _fsm()
|
||||||
|
f.update(speed_mps=_SPD_FULL_UP + 1.0)
|
||||||
|
f.reset(CameraMode.SLEEP)
|
||||||
|
assert f.mode == CameraMode.SLEEP
|
||||||
|
|
||||||
|
def test_reset_clears_downgrade_timer(self):
|
||||||
|
clock = _FakeClock(0.0)
|
||||||
|
f = _fsm(hold=5.0, clock=clock)
|
||||||
|
f.update(speed_mps=_SPD_FULL_UP + 1.0)
|
||||||
|
f.update(speed_mps=0.0) # pending downgrade started
|
||||||
|
f.reset(CameraMode.AWARE)
|
||||||
|
# After reset, pending downgrade should be cleared
|
||||||
|
clock.advance(10.0)
|
||||||
|
d = f.update(speed_mps=0.0)
|
||||||
|
# From AWARE at speed 0 → idle timer not yet expired → stays AWARE (not SLEEP)
|
||||||
|
# Actually: speed 0 < _SPD_MOTION → desired=SOCIAL, hold=5.0
|
||||||
|
# With hold=5.0 and clock just advanced, pending since = now → no downgrade yet
|
||||||
|
assert d.mode in (CameraMode.AWARE, CameraMode.SOCIAL)
|
||||||
|
|
||||||
|
def test_reset_to_full(self):
|
||||||
|
f = _fsm()
|
||||||
|
f.reset(CameraMode.FULL)
|
||||||
|
assert f.mode == CameraMode.FULL
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# ModeDecision fields
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestModeDecisionFields:
|
||||||
|
|
||||||
|
def test_decision_has_mode(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=0.0)
|
||||||
|
assert isinstance(d.mode, CameraMode)
|
||||||
|
|
||||||
|
def test_decision_has_sensors(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=0.0)
|
||||||
|
assert isinstance(d.sensors, ActiveSensors)
|
||||||
|
|
||||||
|
def test_decision_sensors_match_mode(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=_SPD_FULL_UP + 1.0)
|
||||||
|
assert d.sensors == MODE_SENSORS[CameraMode.FULL]
|
||||||
|
|
||||||
|
def test_decision_trigger_speed_recorded(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=2.5)
|
||||||
|
assert abs(d.trigger_speed_mps - 2.5) < 1e-6
|
||||||
|
|
||||||
|
def test_decision_trigger_scenario_recorded(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=0.0, scenario='indoor')
|
||||||
|
assert d.trigger_scenario == 'indoor'
|
||||||
|
|
||||||
|
def test_scenario_override_false_for_normal(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=1.0)
|
||||||
|
assert d.scenario_override is False
|
||||||
|
|
||||||
|
def test_scenario_override_true_for_crossing(self):
|
||||||
|
f = _fsm()
|
||||||
|
d = f.update(speed_mps=0.0, scenario=Scenario.CROSSING)
|
||||||
|
assert d.scenario_override is True
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Active sensor counts (RAM budget)
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestActiveSensors:
|
||||||
|
|
||||||
|
def test_active_sensor_count_non_negative(self):
|
||||||
|
for m, s in MODE_SENSORS.items():
|
||||||
|
assert s.active_count >= 0
|
||||||
|
|
||||||
|
def test_sleep_zero_sensors(self):
|
||||||
|
assert MODE_SENSORS[CameraMode.SLEEP].active_count == 0
|
||||||
|
|
||||||
|
def test_full_seven_sensors(self):
|
||||||
|
# csi x4 + realsense + lidar + uwb = 7
|
||||||
|
assert MODE_SENSORS[CameraMode.FULL].active_count == 7
|
||||||
|
|
||||||
|
def test_active_sensors_correct(self):
|
||||||
|
# csi_front + csi_rear + realsense + lidar + uwb = 5
|
||||||
|
assert MODE_SENSORS[CameraMode.ACTIVE].active_count == 5
|
||||||
|
|
||||||
|
def test_aware_sensors_correct(self):
|
||||||
|
# csi_front + realsense + lidar = 3
|
||||||
|
assert MODE_SENSORS[CameraMode.AWARE].active_count == 3
|
||||||
|
|
||||||
|
def test_social_sensors_correct(self):
|
||||||
|
# webcam = 1
|
||||||
|
assert MODE_SENSORS[CameraMode.SOCIAL].active_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Mode labels
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestModeLabels:
|
||||||
|
|
||||||
|
def test_all_modes_have_labels(self):
|
||||||
|
for m in CameraMode:
|
||||||
|
assert isinstance(m.label, str)
|
||||||
|
assert len(m.label) > 0
|
||||||
|
|
||||||
|
def test_sleep_label(self):
|
||||||
|
assert CameraMode.SLEEP.label == 'SLEEP'
|
||||||
|
|
||||||
|
def test_full_label(self):
|
||||||
|
assert CameraMode.FULL.label == 'FULL'
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Integration: typical ride scenario
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestIntegrationRideScenario:
|
||||||
|
"""Simulates a typical follow-me trip: start→walk→jog→sprint→crossing→indoor."""
|
||||||
|
|
||||||
|
def test_full_ride(self):
|
||||||
|
clock = _FakeClock(0.0)
|
||||||
|
f = _fsm(hold=2.0, idle=5.0, clock=clock)
|
||||||
|
|
||||||
|
# Starting up
|
||||||
|
d = f.update(0.0, Scenario.OUTDOOR, 100.0)
|
||||||
|
assert d.mode == CameraMode.SOCIAL
|
||||||
|
|
||||||
|
# Walking pace (~3 km/h)
|
||||||
|
d = f.update(0.8, Scenario.OUTDOOR, 95.0)
|
||||||
|
assert d.mode == CameraMode.AWARE
|
||||||
|
|
||||||
|
# Jogging (~7 km/h)
|
||||||
|
d = f.update(2.0, Scenario.OUTDOOR, 90.0)
|
||||||
|
assert d.mode == CameraMode.ACTIVE
|
||||||
|
|
||||||
|
# High-speed following (~20 km/h)
|
||||||
|
d = f.update(5.6, Scenario.OUTDOOR, 85.0)
|
||||||
|
assert d.mode == CameraMode.FULL
|
||||||
|
|
||||||
|
# Street crossing — even at slow speed, stays FULL
|
||||||
|
d = f.update(1.0, Scenario.CROSSING, 85.0)
|
||||||
|
assert d.mode == CameraMode.FULL
|
||||||
|
assert d.scenario_override is True
|
||||||
|
|
||||||
|
# Back to outdoor walk
|
||||||
|
clock.advance(3.0)
|
||||||
|
d = f.update(1.0, Scenario.OUTDOOR, 80.0)
|
||||||
|
assert d.mode == CameraMode.FULL # hold not expired yet
|
||||||
|
|
||||||
|
clock.advance(2.1)
|
||||||
|
d = f.update(0.8, Scenario.OUTDOOR, 80.0)
|
||||||
|
assert d.mode == CameraMode.AWARE # hold expired, down to walk speed
|
||||||
|
|
||||||
|
# Enter supermarket (indoor cap)
|
||||||
|
d = f.update(0.8, Scenario.INDOOR, 78.0)
|
||||||
|
assert d.mode == CameraMode.AWARE
|
||||||
|
|
||||||
|
# Park and wait
|
||||||
|
d = f.update(0.0, Scenario.PARKED, 75.0)
|
||||||
|
assert d.mode == CameraMode.SOCIAL
|
||||||
|
|
||||||
|
# Low battery during fast follow
|
||||||
|
d = f.update(5.6, Scenario.OUTDOOR, 15.0)
|
||||||
|
assert d.mode == CameraMode.AWARE # battery cap
|
||||||
504
jetson/ros2_ws/src/saltybot_bringup/test/test_face_emotion.py
Normal file
504
jetson/ros2_ws/src/saltybot_bringup/test/test_face_emotion.py
Normal file
@ -0,0 +1,504 @@
|
|||||||
|
"""
|
||||||
|
test_face_emotion.py — Unit tests for the geometric face emotion classifier.
|
||||||
|
|
||||||
|
All tests use synthetic FaceLandmarks constructed with known geometry —
|
||||||
|
no camera, no MediaPipe, no ROS2 runtime needed.
|
||||||
|
|
||||||
|
Coordinate convention: x ∈ [0,1] (left→right), y ∈ [0,1] (top→bottom).
|
||||||
|
|
||||||
|
Test face dimensions
|
||||||
|
--------------------
|
||||||
|
forehead : (0.50, 0.20) ─┐
|
||||||
|
chin : (0.50, 0.70) ─┘ face_height = 0.50
|
||||||
|
|
||||||
|
Neutral eye/brow/mouth positions chosen so all emotion scores are below
|
||||||
|
their trigger thresholds (verified analytically below each fixture).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from saltybot_bringup._face_emotion import (
|
||||||
|
EMOTION_LABELS,
|
||||||
|
FaceLandmarks,
|
||||||
|
EmotionFeatures,
|
||||||
|
EmotionResult,
|
||||||
|
compute_features,
|
||||||
|
classify_emotion,
|
||||||
|
detect_emotion,
|
||||||
|
MOUTH_UPPER, MOUTH_LOWER, MOUTH_LEFT, MOUTH_RIGHT,
|
||||||
|
L_EYE_TOP, L_EYE_BOT, R_EYE_TOP, R_EYE_BOT,
|
||||||
|
L_BROW_INNER, L_BROW_OUTER, R_BROW_INNER, R_BROW_OUTER,
|
||||||
|
CHIN, FOREHEAD,
|
||||||
|
_T_SURPRISED_BROW, _T_SURPRISED_EYE, _T_SURPRISED_MOUTH,
|
||||||
|
_T_HAPPY_SMILE, _T_ANGRY_FURL, _T_SAD_FROWN,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Fixture helpers
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# face_height = 0.70 - 0.20 = 0.50
|
||||||
|
_FOREHEAD = (0.50, 0.20)
|
||||||
|
_CHIN = (0.50, 0.70)
|
||||||
|
_FH = 0.50 # face_height
|
||||||
|
|
||||||
|
|
||||||
|
def _neutral_face() -> FaceLandmarks:
|
||||||
|
"""
|
||||||
|
Neutral expression:
|
||||||
|
mouth_open = (0.57-0.55)/0.50 = 0.04 < 0.07 (not open)
|
||||||
|
smile = (0.56-0.56)/0.50 = 0.00 in [-0.025, 0.025]
|
||||||
|
brow_raise = (0.38-0.33)/0.50 = 0.10 < 0.12 (not raised enough)
|
||||||
|
eye_open = (0.42-0.38)/0.50 = 0.08 > 0.07 — but surprised needs ALL 3
|
||||||
|
brow_furl = (0.33-0.34)/0.50 = -0.02 < 0.02 (not furrowed)
|
||||||
|
→ neutral
|
||||||
|
"""
|
||||||
|
return FaceLandmarks(
|
||||||
|
mouth_upper = (0.50, 0.55),
|
||||||
|
mouth_lower = (0.50, 0.57),
|
||||||
|
mouth_left = (0.46, 0.56),
|
||||||
|
mouth_right = (0.54, 0.56),
|
||||||
|
l_eye_top = (0.44, 0.38),
|
||||||
|
l_eye_bot = (0.44, 0.42),
|
||||||
|
r_eye_top = (0.56, 0.38),
|
||||||
|
r_eye_bot = (0.56, 0.42),
|
||||||
|
l_eye_left = (0.41, 0.40),
|
||||||
|
l_eye_right = (0.47, 0.40),
|
||||||
|
r_eye_left = (0.53, 0.40),
|
||||||
|
r_eye_right = (0.59, 0.40),
|
||||||
|
l_brow_inner = (0.46, 0.33),
|
||||||
|
l_brow_outer = (0.41, 0.34),
|
||||||
|
r_brow_inner = (0.54, 0.33),
|
||||||
|
r_brow_outer = (0.59, 0.34),
|
||||||
|
chin = _CHIN,
|
||||||
|
forehead = _FOREHEAD,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _happy_face() -> FaceLandmarks:
|
||||||
|
"""
|
||||||
|
Happy: corners raised to y=0.54, mid_y=0.56 → smile=(0.56-0.54)/0.50=0.04 > 0.025
|
||||||
|
All other signals stay below their surprised/angry/sad thresholds.
|
||||||
|
"""
|
||||||
|
fl = _neutral_face()
|
||||||
|
fl.mouth_left = (0.46, 0.54) # corners above midpoint
|
||||||
|
fl.mouth_right = (0.54, 0.54)
|
||||||
|
return fl
|
||||||
|
|
||||||
|
|
||||||
|
def _surprised_face() -> FaceLandmarks:
|
||||||
|
"""
|
||||||
|
Surprised:
|
||||||
|
brow_raise = (0.38-0.26)/0.50 = 0.24 > 0.12
|
||||||
|
eye_open = (0.45-0.38)/0.50 = 0.14 > 0.07
|
||||||
|
mouth_open = (0.65-0.55)/0.50 = 0.20 > 0.07
|
||||||
|
"""
|
||||||
|
fl = _neutral_face()
|
||||||
|
fl.l_brow_inner = (0.46, 0.26) # brows raised high
|
||||||
|
fl.r_brow_inner = (0.54, 0.26)
|
||||||
|
fl.l_brow_outer = (0.41, 0.27)
|
||||||
|
fl.r_brow_outer = (0.59, 0.27)
|
||||||
|
fl.l_eye_bot = (0.44, 0.45) # eyes wide open
|
||||||
|
fl.r_eye_bot = (0.56, 0.45)
|
||||||
|
fl.mouth_lower = (0.50, 0.65) # mouth open
|
||||||
|
return fl
|
||||||
|
|
||||||
|
|
||||||
|
def _angry_face() -> FaceLandmarks:
|
||||||
|
"""
|
||||||
|
Angry:
|
||||||
|
brow_furl = (0.37-0.30)/0.50 = 0.14 > 0.02
|
||||||
|
smile = (0.56-0.56)/0.50 = 0.00 < 0.01 (neutral mouth)
|
||||||
|
brow_raise = (0.38-0.37)/0.50 = 0.02 < 0.12 (not surprised)
|
||||||
|
"""
|
||||||
|
fl = _neutral_face()
|
||||||
|
fl.l_brow_inner = (0.46, 0.37) # inner brow lowered (toward eye)
|
||||||
|
fl.r_brow_inner = (0.54, 0.37)
|
||||||
|
fl.l_brow_outer = (0.41, 0.30) # outer brow raised
|
||||||
|
fl.r_brow_outer = (0.59, 0.30)
|
||||||
|
return fl
|
||||||
|
|
||||||
|
|
||||||
|
def _sad_face() -> FaceLandmarks:
|
||||||
|
"""
|
||||||
|
Sad:
|
||||||
|
smile = (0.56-0.59)/0.50 = -0.06 < -0.025
|
||||||
|
brow_furl = (0.33-0.34)/0.50 = -0.02 < 0.015 (not furrowed → not angry)
|
||||||
|
"""
|
||||||
|
fl = _neutral_face()
|
||||||
|
fl.mouth_left = (0.46, 0.59) # corners below midpoint (frown)
|
||||||
|
fl.mouth_right = (0.54, 0.59)
|
||||||
|
return fl
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# EMOTION_LABELS
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_emotion_labels_tuple():
|
||||||
|
assert isinstance(EMOTION_LABELS, tuple)
|
||||||
|
|
||||||
|
|
||||||
|
def test_emotion_labels_count():
|
||||||
|
assert len(EMOTION_LABELS) == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_emotion_labels_values():
|
||||||
|
assert set(EMOTION_LABELS) == {'neutral', 'happy', 'surprised', 'angry', 'sad'}
|
||||||
|
|
||||||
|
|
||||||
|
def test_emotion_labels_neutral_first():
|
||||||
|
assert EMOTION_LABELS[0] == 'neutral'
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Landmark index constants
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_landmark_indices_in_range():
|
||||||
|
indices = [
|
||||||
|
MOUTH_UPPER, MOUTH_LOWER, MOUTH_LEFT, MOUTH_RIGHT,
|
||||||
|
L_EYE_TOP, L_EYE_BOT, R_EYE_TOP, R_EYE_BOT,
|
||||||
|
L_BROW_INNER, L_BROW_OUTER, R_BROW_INNER, R_BROW_OUTER,
|
||||||
|
CHIN, FOREHEAD,
|
||||||
|
]
|
||||||
|
for idx in indices:
|
||||||
|
assert 0 <= idx < 468, f'Landmark index {idx} out of range'
|
||||||
|
|
||||||
|
|
||||||
|
def test_landmark_indices_unique():
|
||||||
|
indices = [
|
||||||
|
MOUTH_UPPER, MOUTH_LOWER, MOUTH_LEFT, MOUTH_RIGHT,
|
||||||
|
L_EYE_TOP, L_EYE_BOT, R_EYE_TOP, R_EYE_BOT,
|
||||||
|
L_BROW_INNER, L_BROW_OUTER, R_BROW_INNER, R_BROW_OUTER,
|
||||||
|
CHIN, FOREHEAD,
|
||||||
|
]
|
||||||
|
assert len(set(indices)) == len(indices)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# compute_features — neutral face
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_features_face_height():
|
||||||
|
f = compute_features(_neutral_face())
|
||||||
|
assert abs(f.face_height - 0.50) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_features_mouth_open_neutral():
|
||||||
|
f = compute_features(_neutral_face())
|
||||||
|
# (0.57 - 0.55) / 0.50 = 0.04
|
||||||
|
assert abs(f.mouth_open - 0.04) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_features_smile_neutral():
|
||||||
|
f = compute_features(_neutral_face())
|
||||||
|
# mid_y=0.56, corner_y=0.56 → smile=0.0
|
||||||
|
assert abs(f.smile) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_features_brow_raise_neutral():
|
||||||
|
f = compute_features(_neutral_face())
|
||||||
|
# (0.38 - 0.33) / 0.50 = 0.10
|
||||||
|
assert abs(f.brow_raise - 0.10) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_features_brow_furl_neutral():
|
||||||
|
f = compute_features(_neutral_face())
|
||||||
|
# l_furl = (0.33 - 0.34) / 0.50 = -0.02 per side → mean = -0.02
|
||||||
|
assert abs(f.brow_furl - (-0.02)) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_features_eye_open_neutral():
|
||||||
|
f = compute_features(_neutral_face())
|
||||||
|
# (0.42 - 0.38) / 0.50 = 0.08
|
||||||
|
assert abs(f.eye_open - 0.08) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_features_mouth_open_non_negative():
|
||||||
|
"""Mouth open must always be ≥ 0."""
|
||||||
|
# Invert lips (impossible geometry)
|
||||||
|
fl = _neutral_face()
|
||||||
|
fl.mouth_upper = (0.50, 0.60)
|
||||||
|
fl.mouth_lower = (0.50, 0.55)
|
||||||
|
f = compute_features(fl)
|
||||||
|
assert f.mouth_open == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_features_eye_open_non_negative():
|
||||||
|
fl = _neutral_face()
|
||||||
|
fl.l_eye_top = (0.44, 0.42)
|
||||||
|
fl.l_eye_bot = (0.44, 0.38)
|
||||||
|
f = compute_features(fl)
|
||||||
|
assert f.eye_open >= 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# compute_features — happy face
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_features_smile_happy():
|
||||||
|
f = compute_features(_happy_face())
|
||||||
|
# mid_y=0.56, corner_y=0.54 → smile=(0.56-0.54)/0.50=0.04
|
||||||
|
assert abs(f.smile - 0.04) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_features_smile_positive_happy():
|
||||||
|
f = compute_features(_happy_face())
|
||||||
|
assert f.smile > _T_HAPPY_SMILE
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# compute_features — surprised face
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_features_brow_raise_surprised():
|
||||||
|
f = compute_features(_surprised_face())
|
||||||
|
# (0.38 - 0.26) / 0.50 = 0.24
|
||||||
|
assert abs(f.brow_raise - 0.24) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_features_brow_raise_above_threshold():
|
||||||
|
f = compute_features(_surprised_face())
|
||||||
|
assert f.brow_raise > _T_SURPRISED_BROW
|
||||||
|
|
||||||
|
|
||||||
|
def test_features_eye_open_surprised():
|
||||||
|
f = compute_features(_surprised_face())
|
||||||
|
# (0.45 - 0.38) / 0.50 = 0.14
|
||||||
|
assert abs(f.eye_open - 0.14) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_features_mouth_open_surprised():
|
||||||
|
f = compute_features(_surprised_face())
|
||||||
|
# (0.65 - 0.55) / 0.50 = 0.20
|
||||||
|
assert abs(f.mouth_open - 0.20) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# compute_features — angry face
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_features_brow_furl_angry():
|
||||||
|
f = compute_features(_angry_face())
|
||||||
|
# l_furl = (0.37 - 0.30) / 0.50 = 0.14 per side → mean = 0.14
|
||||||
|
assert abs(f.brow_furl - 0.14) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_features_brow_furl_above_threshold():
|
||||||
|
f = compute_features(_angry_face())
|
||||||
|
assert f.brow_furl > _T_ANGRY_FURL
|
||||||
|
|
||||||
|
|
||||||
|
def test_features_smile_near_zero_angry():
|
||||||
|
f = compute_features(_angry_face())
|
||||||
|
assert abs(f.smile) < 0.005
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# compute_features — sad face
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_features_smile_sad():
|
||||||
|
f = compute_features(_sad_face())
|
||||||
|
# mid_y=0.56, corner_y=0.59 → smile=(0.56-0.59)/0.50=-0.06
|
||||||
|
assert abs(f.smile - (-0.06)) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_features_smile_negative_sad():
|
||||||
|
f = compute_features(_sad_face())
|
||||||
|
assert f.smile < -_T_SAD_FROWN
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# classify_emotion
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_classify_returns_known_label():
|
||||||
|
feat = compute_features(_neutral_face())
|
||||||
|
label, _ = classify_emotion(feat)
|
||||||
|
assert label in EMOTION_LABELS
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_confidence_range():
|
||||||
|
for make in [_neutral_face, _happy_face, _surprised_face,
|
||||||
|
_angry_face, _sad_face]:
|
||||||
|
feat = compute_features(make())
|
||||||
|
_, conf = classify_emotion(feat)
|
||||||
|
assert 0.0 < conf <= 1.0, f'{make.__name__}: confidence={conf}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_neutral():
|
||||||
|
feat = compute_features(_neutral_face())
|
||||||
|
label, conf = classify_emotion(feat)
|
||||||
|
assert label == 'neutral', f'Expected neutral, got {label!r}'
|
||||||
|
assert conf > 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_happy():
|
||||||
|
feat = compute_features(_happy_face())
|
||||||
|
label, conf = classify_emotion(feat)
|
||||||
|
assert label == 'happy', f'Expected happy, got {label!r}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_surprised():
|
||||||
|
feat = compute_features(_surprised_face())
|
||||||
|
label, conf = classify_emotion(feat)
|
||||||
|
assert label == 'surprised', f'Expected surprised, got {label!r}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_angry():
|
||||||
|
feat = compute_features(_angry_face())
|
||||||
|
label, conf = classify_emotion(feat)
|
||||||
|
assert label == 'angry', f'Expected angry, got {label!r}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_sad():
|
||||||
|
feat = compute_features(_sad_face())
|
||||||
|
label, conf = classify_emotion(feat)
|
||||||
|
assert label == 'sad', f'Expected sad, got {label!r}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_surprised_priority_over_happy():
|
||||||
|
"""Surprised should win even if smile is also positive."""
|
||||||
|
fl = _surprised_face()
|
||||||
|
fl.mouth_left = (0.46, 0.54) # add smile too
|
||||||
|
fl.mouth_right = (0.54, 0.54)
|
||||||
|
feat = compute_features(fl)
|
||||||
|
label, _ = classify_emotion(feat)
|
||||||
|
assert label == 'surprised'
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# detect_emotion — end-to-end
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_detect_returns_result():
|
||||||
|
res = detect_emotion(_neutral_face())
|
||||||
|
assert isinstance(res, EmotionResult)
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_neutral():
|
||||||
|
res = detect_emotion(_neutral_face())
|
||||||
|
assert res.emotion == 'neutral'
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_happy():
|
||||||
|
res = detect_emotion(_happy_face())
|
||||||
|
assert res.emotion == 'happy'
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_surprised():
|
||||||
|
res = detect_emotion(_surprised_face())
|
||||||
|
assert res.emotion == 'surprised'
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_angry():
|
||||||
|
res = detect_emotion(_angry_face())
|
||||||
|
assert res.emotion == 'angry'
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_sad():
|
||||||
|
res = detect_emotion(_sad_face())
|
||||||
|
assert res.emotion == 'sad'
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_confidence_range():
|
||||||
|
for make in [_neutral_face, _happy_face, _surprised_face,
|
||||||
|
_angry_face, _sad_face]:
|
||||||
|
res = detect_emotion(make())
|
||||||
|
assert 0.0 < res.confidence <= 1.0, (
|
||||||
|
f'{make.__name__}: confidence={res.confidence}'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_features_attached():
|
||||||
|
res = detect_emotion(_neutral_face())
|
||||||
|
assert isinstance(res.features, EmotionFeatures)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Edge cases
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_zero_face_height_no_crash():
|
||||||
|
"""Degenerate face (chin == forehead) must not crash."""
|
||||||
|
fl = _neutral_face()
|
||||||
|
fl.forehead = fl.chin # face_height → ~1e-4 clamp
|
||||||
|
res = detect_emotion(fl)
|
||||||
|
assert res.emotion in EMOTION_LABELS
|
||||||
|
|
||||||
|
|
||||||
|
def test_extreme_smile_caps_confidence():
|
||||||
|
"""Wildly exaggerated smile should not push confidence above 1.0."""
|
||||||
|
fl = _neutral_face()
|
||||||
|
fl.mouth_left = (0.46, 0.20) # corners pulled very high
|
||||||
|
fl.mouth_right = (0.54, 0.20)
|
||||||
|
res = detect_emotion(fl)
|
||||||
|
assert res.confidence <= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_extreme_brow_raise_caps_confidence():
|
||||||
|
fl = _surprised_face()
|
||||||
|
fl.l_brow_inner = (0.46, 0.01) # brows at top of image
|
||||||
|
fl.r_brow_inner = (0.54, 0.01)
|
||||||
|
res = detect_emotion(fl)
|
||||||
|
assert res.confidence <= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_gradual_smile_crosses_threshold():
|
||||||
|
"""Smile just above threshold → happy; just below → neutral."""
|
||||||
|
fl = _neutral_face()
|
||||||
|
# Just below threshold: smile < _T_HAPPY_SMILE
|
||||||
|
offset = _T_HAPPY_SMILE * _FH * 0.5 # half the threshold distance
|
||||||
|
mid_y = 0.56
|
||||||
|
fl.mouth_left = (0.46, mid_y - offset)
|
||||||
|
fl.mouth_right = (0.54, mid_y - offset)
|
||||||
|
res_below = detect_emotion(fl)
|
||||||
|
|
||||||
|
# Just above threshold: smile > _T_HAPPY_SMILE
|
||||||
|
offset_above = _T_HAPPY_SMILE * _FH * 1.5
|
||||||
|
fl.mouth_left = (0.46, mid_y - offset_above)
|
||||||
|
fl.mouth_right = (0.54, mid_y - offset_above)
|
||||||
|
res_above = detect_emotion(fl)
|
||||||
|
|
||||||
|
assert res_below.emotion == 'neutral'
|
||||||
|
assert res_above.emotion == 'happy'
|
||||||
|
|
||||||
|
|
||||||
|
def test_gradual_frown_crosses_threshold():
|
||||||
|
"""Frown just above threshold → sad; just below → neutral."""
|
||||||
|
fl = _neutral_face()
|
||||||
|
mid_y = 0.56
|
||||||
|
offset_below = _T_SAD_FROWN * _FH * 0.5
|
||||||
|
fl.mouth_left = (0.46, mid_y + offset_below)
|
||||||
|
fl.mouth_right = (0.54, mid_y + offset_below)
|
||||||
|
res_below = detect_emotion(fl)
|
||||||
|
|
||||||
|
offset_above = _T_SAD_FROWN * _FH * 1.5
|
||||||
|
fl.mouth_left = (0.46, mid_y + offset_above)
|
||||||
|
fl.mouth_right = (0.54, mid_y + offset_above)
|
||||||
|
res_above = detect_emotion(fl)
|
||||||
|
|
||||||
|
assert res_below.emotion == 'neutral'
|
||||||
|
assert res_above.emotion == 'sad'
|
||||||
|
|
||||||
|
|
||||||
|
def test_angry_requires_no_smile():
|
||||||
|
"""If brows are furrowed but face is smiling, should not be angry."""
|
||||||
|
fl = _angry_face()
|
||||||
|
fl.mouth_left = (0.46, 0.54) # add a clear smile
|
||||||
|
fl.mouth_right = (0.54, 0.54)
|
||||||
|
res = detect_emotion(fl)
|
||||||
|
assert res.emotion != 'angry'
|
||||||
|
|
||||||
|
|
||||||
|
def test_symmetry_independent():
|
||||||
|
"""One-sided brow raise still contributes to brow_raise metric."""
|
||||||
|
fl = _neutral_face()
|
||||||
|
fl.l_brow_inner = (0.46, 0.26) # left brow raised
|
||||||
|
# right brow unchanged at 0.33
|
||||||
|
f = compute_features(fl)
|
||||||
|
# brow_raise = avg of left (0.24) and right (0.10) = 0.17 > 0.12
|
||||||
|
assert f.brow_raise > _T_SURPRISED_BROW
|
||||||
364
jetson/ros2_ws/src/saltybot_bringup/test/test_obstacle_size.py
Normal file
364
jetson/ros2_ws/src/saltybot_bringup/test/test_obstacle_size.py
Normal file
@ -0,0 +1,364 @@
|
|||||||
|
"""
|
||||||
|
test_obstacle_size.py — pytest tests for _obstacle_size.py (no ROS2 required).
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
CameraParams — defaults and construction
|
||||||
|
lidar_to_camera — coordinate transform
|
||||||
|
project_to_pixel — pinhole projection + bounds checks
|
||||||
|
sample_depth_median — empty/uniform/sparse depth images
|
||||||
|
estimate_height — flat/obstacle/edge cases
|
||||||
|
estimate_cluster_size — full pipeline with synthetic cluster + depth
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from saltybot_bringup._obstacle_size import (
|
||||||
|
CameraParams,
|
||||||
|
ObstacleSizeEstimate,
|
||||||
|
estimate_cluster_size,
|
||||||
|
estimate_height,
|
||||||
|
lidar_to_camera,
|
||||||
|
project_to_pixel,
|
||||||
|
sample_depth_median,
|
||||||
|
)
|
||||||
|
# Import Cluster helper from lidar_clustering for building test fixtures
|
||||||
|
from saltybot_bringup._lidar_clustering import Cluster
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _default_params(**kw) -> CameraParams:
|
||||||
|
p = CameraParams()
|
||||||
|
for k, v in kw.items():
|
||||||
|
object.__setattr__(p, k, v)
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
def _blank_depth(h: int = 480, w: int = 640, val: int = 0) -> np.ndarray:
|
||||||
|
return np.full((h, w), val, dtype=np.uint16)
|
||||||
|
|
||||||
|
|
||||||
|
def _obstacle_depth(
|
||||||
|
h: int = 480, w: int = 640,
|
||||||
|
depth_mm: int = 2000,
|
||||||
|
u_c: int = 320, v_c: int = 240,
|
||||||
|
half_w: int = 30, half_h: int = 60,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Create a synthetic depth image with a rectangular obstacle."""
|
||||||
|
img = _blank_depth(h, w)
|
||||||
|
r0, r1 = max(0, v_c - half_h), min(h, v_c + half_h + 1)
|
||||||
|
c0, c1 = max(0, u_c - half_w), min(w, u_c + half_w + 1)
|
||||||
|
img[r0:r1, c0:c1] = depth_mm
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def _make_cluster(
|
||||||
|
cx: float = 2.0, cy: float = 0.0,
|
||||||
|
width_m: float = 0.4, depth_m: float = 0.3,
|
||||||
|
n_pts: int = 10,
|
||||||
|
) -> Cluster:
|
||||||
|
"""Construct a minimal Cluster NamedTuple for testing."""
|
||||||
|
pts = np.column_stack([
|
||||||
|
np.full(n_pts, cx),
|
||||||
|
np.full(n_pts, cy),
|
||||||
|
]).astype(np.float64)
|
||||||
|
centroid = np.array([cx, cy], dtype=np.float64)
|
||||||
|
bbox_min = centroid - np.array([width_m / 2, depth_m / 2])
|
||||||
|
bbox_max = centroid + np.array([width_m / 2, depth_m / 2])
|
||||||
|
return Cluster(
|
||||||
|
points=pts, centroid=centroid,
|
||||||
|
bbox_min=bbox_min, bbox_max=bbox_max,
|
||||||
|
width_m=width_m, depth_m=depth_m,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── CameraParams ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestCameraParams:
|
||||||
|
def test_defaults(self):
|
||||||
|
p = CameraParams()
|
||||||
|
assert p.fx == pytest.approx(383.0)
|
||||||
|
assert p.fy == pytest.approx(383.0)
|
||||||
|
assert p.cx == pytest.approx(320.0)
|
||||||
|
assert p.cy == pytest.approx(240.0)
|
||||||
|
assert p.width == 640
|
||||||
|
assert p.height == 480
|
||||||
|
assert p.depth_scale == pytest.approx(0.001)
|
||||||
|
assert p.ey == pytest.approx(0.05)
|
||||||
|
|
||||||
|
def test_custom(self):
|
||||||
|
p = CameraParams(fx=400.0, fy=400.0, cx=330.0, cy=250.0)
|
||||||
|
assert p.fx == pytest.approx(400.0)
|
||||||
|
|
||||||
|
|
||||||
|
# ── lidar_to_camera ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestLidarToCamera:
|
||||||
|
def test_forward_no_offset(self):
|
||||||
|
"""A point directly in front of the LIDAR (y_lidar=0) → x_cam=0."""
|
||||||
|
p = CameraParams(ex=0.0, ey=0.0, ez=0.0)
|
||||||
|
x_cam, y_cam, z_cam = lidar_to_camera(3.0, 0.0, p)
|
||||||
|
assert x_cam == pytest.approx(0.0)
|
||||||
|
assert y_cam == pytest.approx(0.0)
|
||||||
|
assert z_cam == pytest.approx(3.0)
|
||||||
|
|
||||||
|
def test_lateral_maps_to_x(self):
|
||||||
|
"""y_lidar positive (left) → x_cam negative (right convention)."""
|
||||||
|
p = CameraParams(ex=0.0, ey=0.0, ez=0.0)
|
||||||
|
x_cam, y_cam, z_cam = lidar_to_camera(0.0, 1.0, p)
|
||||||
|
assert x_cam == pytest.approx(-1.0)
|
||||||
|
assert z_cam == pytest.approx(0.0)
|
||||||
|
|
||||||
|
def test_extrinsic_offset_applied(self):
|
||||||
|
"""Extrinsic translation is added to camera-frame coords."""
|
||||||
|
p = CameraParams(ex=0.1, ey=0.05, ez=-0.02)
|
||||||
|
x_cam, y_cam, z_cam = lidar_to_camera(2.0, 0.0, p)
|
||||||
|
assert x_cam == pytest.approx(0.1) # -y_lidar + ex = 0 + 0.1
|
||||||
|
assert y_cam == pytest.approx(0.05) # 0 + ey
|
||||||
|
assert z_cam == pytest.approx(1.98) # 2.0 + (-0.02)
|
||||||
|
|
||||||
|
def test_zero_point(self):
|
||||||
|
p = CameraParams(ex=0.0, ey=0.0, ez=0.0)
|
||||||
|
x_cam, y_cam, z_cam = lidar_to_camera(0.0, 0.0, p)
|
||||||
|
assert z_cam == pytest.approx(0.0)
|
||||||
|
|
||||||
|
|
||||||
|
# ── project_to_pixel ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestProjectToPixel:
|
||||||
|
def test_centre_projects_to_principal(self):
|
||||||
|
"""A point exactly on the optical axis projects to (cx, cy)."""
|
||||||
|
p = CameraParams(fx=400.0, fy=400.0, cx=320.0, cy=240.0)
|
||||||
|
px = project_to_pixel(0.0, 0.0, 2.0, p)
|
||||||
|
assert px is not None
|
||||||
|
u, v = px
|
||||||
|
assert u == 320
|
||||||
|
assert v == 240
|
||||||
|
|
||||||
|
def test_negative_z_returns_none(self):
|
||||||
|
p = CameraParams()
|
||||||
|
assert project_to_pixel(0.0, 0.0, -1.0, p) is None
|
||||||
|
|
||||||
|
def test_zero_z_returns_none(self):
|
||||||
|
p = CameraParams()
|
||||||
|
assert project_to_pixel(0.0, 0.0, 0.0, p) is None
|
||||||
|
|
||||||
|
def test_off_to_right(self):
|
||||||
|
"""A point to the right of axis should land to the right of cx."""
|
||||||
|
p = CameraParams(fx=400.0, cx=320.0, cy=240.0)
|
||||||
|
px = project_to_pixel(0.5, 0.0, 2.0, p)
|
||||||
|
assert px is not None
|
||||||
|
u, _ = px
|
||||||
|
assert u > 320
|
||||||
|
|
||||||
|
def test_out_of_image_returns_none(self):
|
||||||
|
p = CameraParams(fx=100.0, fy=100.0, cx=50.0, cy=50.0,
|
||||||
|
width=100, height=100)
|
||||||
|
# Very far off-axis → outside image
|
||||||
|
px = project_to_pixel(10.0, 0.0, 1.0, p)
|
||||||
|
assert px is None
|
||||||
|
|
||||||
|
def test_known_projection(self):
|
||||||
|
"""Verify exact pixel for known 3D point."""
|
||||||
|
p = CameraParams(fx=400.0, fy=400.0, cx=320.0, cy=240.0,
|
||||||
|
width=640, height=480)
|
||||||
|
# x_cam=1.0, z_cam=4.0 → u = 400*(1/4)+320 = 420
|
||||||
|
# y_cam=0.0 → v = 400*(0/4)+240 = 240
|
||||||
|
px = project_to_pixel(1.0, 0.0, 4.0, p)
|
||||||
|
assert px == (420, 240)
|
||||||
|
|
||||||
|
|
||||||
|
# ── sample_depth_median ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestSampleDepthMedian:
|
||||||
|
def test_all_zeros_returns_zero(self):
|
||||||
|
img = _blank_depth()
|
||||||
|
d, n = sample_depth_median(img, 320, 240, window_px=5)
|
||||||
|
assert d == pytest.approx(0.0)
|
||||||
|
assert n == 0
|
||||||
|
|
||||||
|
def test_uniform_image(self):
|
||||||
|
"""All pixels set to 2000 mm → median = 2.0 m."""
|
||||||
|
img = _blank_depth(val=2000)
|
||||||
|
d, n = sample_depth_median(img, 320, 240, window_px=5, depth_scale=0.001)
|
||||||
|
assert d == pytest.approx(2.0)
|
||||||
|
assert n > 0
|
||||||
|
|
||||||
|
def test_scale_applied(self):
|
||||||
|
img = _blank_depth(val=3000)
|
||||||
|
d, _ = sample_depth_median(img, 320, 240, window_px=3, depth_scale=0.001)
|
||||||
|
assert d == pytest.approx(3.0)
|
||||||
|
|
||||||
|
def test_window_clips_at_image_edge(self):
|
||||||
|
"""Sampling near edge should not crash."""
|
||||||
|
img = _blank_depth(val=1500)
|
||||||
|
d, n = sample_depth_median(img, 0, 0, window_px=10)
|
||||||
|
assert d > 0.0
|
||||||
|
assert n > 0
|
||||||
|
|
||||||
|
def test_sparse_window(self):
|
||||||
|
"""Only some pixels are valid — median should reflect valid ones."""
|
||||||
|
img = _blank_depth()
|
||||||
|
img[240, 320] = 1000 # single valid pixel
|
||||||
|
d, n = sample_depth_median(img, 320, 240, window_px=5, depth_scale=0.001)
|
||||||
|
assert d == pytest.approx(1.0)
|
||||||
|
assert n == 1
|
||||||
|
|
||||||
|
def test_mixed_window_median(self):
|
||||||
|
"""Median of [1000, 2000, 3000] mm = 2000 mm = 2.0 m."""
|
||||||
|
img = _blank_depth()
|
||||||
|
img[240, 319] = 1000
|
||||||
|
img[240, 320] = 2000
|
||||||
|
img[240, 321] = 3000
|
||||||
|
d, n = sample_depth_median(img, 320, 240, window_px=1, depth_scale=0.001)
|
||||||
|
assert d == pytest.approx(2.0)
|
||||||
|
assert n == 3
|
||||||
|
|
||||||
|
|
||||||
|
# ── estimate_height ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestEstimateHeight:
|
||||||
|
def test_blank_image_returns_zero(self):
|
||||||
|
p = CameraParams(fy=383.0)
|
||||||
|
img = _blank_depth()
|
||||||
|
h = estimate_height(img, 320, 240, z_ref=2.0, params=p)
|
||||||
|
assert h == pytest.approx(0.0)
|
||||||
|
|
||||||
|
def test_zero_z_ref_returns_zero(self):
|
||||||
|
p = CameraParams()
|
||||||
|
img = _blank_depth(val=2000)
|
||||||
|
h = estimate_height(img, 320, 240, z_ref=0.0, params=p)
|
||||||
|
assert h == pytest.approx(0.0)
|
||||||
|
|
||||||
|
def test_known_height(self):
|
||||||
|
"""
|
||||||
|
Obstacle occupies rows 180–300 (120 row-pixel span) at depth 2.0m.
|
||||||
|
Expected height ≈ 120 * 2.0 / 383.0 ≈ 0.627 m.
|
||||||
|
"""
|
||||||
|
p = CameraParams(fy=383.0, depth_scale=0.001)
|
||||||
|
img = _blank_depth()
|
||||||
|
# Obstacle at z=2000 mm, rows 180–300, cols 300–340
|
||||||
|
img[180:301, 300:341] = 2000
|
||||||
|
h = estimate_height(img, 320, 240, z_ref=2.0, params=p,
|
||||||
|
search_rows=200, col_hw=30, z_tol=0.3)
|
||||||
|
expected = 120 * 2.0 / 383.0
|
||||||
|
assert h == pytest.approx(expected, rel=0.05)
|
||||||
|
|
||||||
|
def test_single_row_returns_zero(self):
|
||||||
|
"""Only one valid row → span=0 → height=0."""
|
||||||
|
p = CameraParams(fy=383.0, depth_scale=0.001)
|
||||||
|
img = _blank_depth()
|
||||||
|
img[240, 315:326] = 2000 # single row
|
||||||
|
h = estimate_height(img, 320, 240, z_ref=2.0, params=p,
|
||||||
|
search_rows=20, col_hw=10, z_tol=0.3)
|
||||||
|
assert h == pytest.approx(0.0)
|
||||||
|
|
||||||
|
def test_depth_outside_tolerance_excluded(self):
|
||||||
|
"""Pixels far from z_ref should not contribute to height."""
|
||||||
|
p = CameraParams(fy=383.0, depth_scale=0.001)
|
||||||
|
img = _blank_depth()
|
||||||
|
# Two rows at very different depths — only one within z_tol
|
||||||
|
img[220, 315:326] = 2000 # z = 2.0 m (within tolerance)
|
||||||
|
img[260, 315:326] = 5000 # z = 5.0 m (outside tolerance)
|
||||||
|
h = estimate_height(img, 320, 240, z_ref=2.0, params=p,
|
||||||
|
search_rows=100, col_hw=10, z_tol=0.3)
|
||||||
|
assert h == pytest.approx(0.0) # only 1 valid row → 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── estimate_cluster_size ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestEstimateClusterSize:
|
||||||
|
def test_returns_named_tuple(self):
|
||||||
|
cluster = _make_cluster()
|
||||||
|
img = _blank_depth(val=2000)
|
||||||
|
est = estimate_cluster_size(cluster, img, CameraParams())
|
||||||
|
assert isinstance(est, ObstacleSizeEstimate)
|
||||||
|
|
||||||
|
def test_width_comes_from_lidar(self):
|
||||||
|
"""Width should equal the LIDAR cluster width_m."""
|
||||||
|
cluster = _make_cluster(cx=3.0, cy=0.0, width_m=0.5)
|
||||||
|
img = _blank_depth(val=3000)
|
||||||
|
est = estimate_cluster_size(cluster, img, CameraParams())
|
||||||
|
assert est.width_m == pytest.approx(0.5)
|
||||||
|
|
||||||
|
def test_centroid_preserved(self):
|
||||||
|
cluster = _make_cluster(cx=2.0, cy=0.3)
|
||||||
|
img = _blank_depth(val=2000)
|
||||||
|
est = estimate_cluster_size(cluster, img, CameraParams())
|
||||||
|
assert est.centroid_x == pytest.approx(2.0)
|
||||||
|
assert est.centroid_y == pytest.approx(0.3)
|
||||||
|
|
||||||
|
def test_lidar_range_correct(self):
|
||||||
|
cluster = _make_cluster(cx=3.0, cy=4.0) # range = 5.0
|
||||||
|
img = _blank_depth(val=5000)
|
||||||
|
est = estimate_cluster_size(cluster, img, CameraParams())
|
||||||
|
assert est.lidar_range == pytest.approx(5.0)
|
||||||
|
|
||||||
|
def test_blank_depth_gives_zero_confidence(self):
|
||||||
|
"""No valid depth pixels → confidence=0, depth_z falls back to z_cam."""
|
||||||
|
cluster = _make_cluster(cx=2.0, cy=0.0)
|
||||||
|
img = _blank_depth()
|
||||||
|
p = CameraParams(ex=0.0, ey=0.0, ez=0.0)
|
||||||
|
est = estimate_cluster_size(cluster, img, p)
|
||||||
|
assert est.confidence == pytest.approx(0.0)
|
||||||
|
assert est.depth_z > 0.0 # LIDAR-fallback z_cam
|
||||||
|
|
||||||
|
def test_obstacle_in_depth_image(self):
|
||||||
|
"""With a real depth patch, confidence > 0 and pixel is in-image."""
|
||||||
|
# Cluster at x=2m forward, y=0 → z_cam=2m, x_cam=0, y_cam=ey
|
||||||
|
p = CameraParams(fy=383.0, fx=383.0, cx=320.0, cy=240.0,
|
||||||
|
depth_scale=0.001, ex=0.0, ey=0.05, ez=0.0)
|
||||||
|
cluster = _make_cluster(cx=2.0, cy=0.0)
|
||||||
|
# z_cam=2m, x_cam=0, y_cam=0.05
|
||||||
|
# u = 383*0/2+320 = 320, v = 383*0.05/2+240 ≈ 249
|
||||||
|
img = _obstacle_depth(depth_mm=2000, u_c=320, v_c=249,
|
||||||
|
half_w=30, half_h=80)
|
||||||
|
est = estimate_cluster_size(cluster, img, p,
|
||||||
|
depth_window=5, search_rows=120,
|
||||||
|
col_hw=10, z_tol=0.3, obstacle_id=1)
|
||||||
|
assert est.confidence > 0.0
|
||||||
|
assert 0 <= est.pixel_u < 640
|
||||||
|
assert 0 <= est.pixel_v < 480
|
||||||
|
assert est.depth_z == pytest.approx(2.0, abs=0.1)
|
||||||
|
assert est.obstacle_id == 1
|
||||||
|
|
||||||
|
def test_behind_camera_returns_zero_confidence(self):
|
||||||
|
"""Cluster behind camera (x_lidar < 0 with no ez) → not projected."""
|
||||||
|
p = CameraParams(ex=0.0, ey=0.0, ez=0.0)
|
||||||
|
# x_lidar = -1.0 → z_cam = -1 + 0 = -1 → behind camera
|
||||||
|
cluster = _make_cluster(cx=-1.0, cy=0.0)
|
||||||
|
img = _blank_depth(val=1000)
|
||||||
|
est = estimate_cluster_size(cluster, img, p)
|
||||||
|
assert est.confidence == pytest.approx(0.0)
|
||||||
|
assert est.pixel_u == -1
|
||||||
|
assert est.pixel_v == -1
|
||||||
|
|
||||||
|
def test_height_estimated_from_depth(self):
|
||||||
|
"""When obstacle spans rows, height_m > 0."""
|
||||||
|
p = CameraParams(fy=383.0, fx=383.0, cx=320.0, cy=240.0,
|
||||||
|
depth_scale=0.001, ex=0.0, ey=0.05, ez=0.0)
|
||||||
|
cluster = _make_cluster(cx=2.0, cy=0.0)
|
||||||
|
# Expected pixel: u≈320, v≈249; give obstacle a 100-row span
|
||||||
|
img = _obstacle_depth(depth_mm=2000, u_c=320, v_c=249,
|
||||||
|
half_w=20, half_h=50)
|
||||||
|
est = estimate_cluster_size(cluster, img, p,
|
||||||
|
search_rows=150, col_hw=15, z_tol=0.3)
|
||||||
|
assert est.height_m > 0.0
|
||||||
|
|
||||||
|
def test_obstacle_id_passed_through(self):
|
||||||
|
cluster = _make_cluster()
|
||||||
|
img = _blank_depth(val=2000)
|
||||||
|
est = estimate_cluster_size(cluster, img, CameraParams(), obstacle_id=42)
|
||||||
|
assert est.obstacle_id == 42
|
||||||
|
|
||||||
|
def test_confidence_bounded(self):
|
||||||
|
cluster = _make_cluster(cx=2.0)
|
||||||
|
img = _blank_depth(val=2000)
|
||||||
|
est = estimate_cluster_size(cluster, img, CameraParams())
|
||||||
|
assert 0.0 <= est.confidence <= 1.0
|
||||||
@ -0,0 +1,467 @@
|
|||||||
|
"""
|
||||||
|
test_obstacle_velocity.py — Unit tests for obstacle velocity estimation (no ROS2).
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
KalmanTrack — construction:
|
||||||
|
- initial position matches centroid
|
||||||
|
- initial velocity is zero
|
||||||
|
- initial confidence is 0.0
|
||||||
|
- initial alive=True, coasting=0, age=0
|
||||||
|
|
||||||
|
KalmanTrack — predict():
|
||||||
|
- predict(dt=0) leaves position unchanged
|
||||||
|
- predict(dt=1) with zero velocity leaves position unchanged
|
||||||
|
- predict() increments coasting by 1
|
||||||
|
- alive stays True until coasting > max_coasting
|
||||||
|
- alive becomes False exactly when coasting > max_coasting
|
||||||
|
|
||||||
|
KalmanTrack — update():
|
||||||
|
- update() resets coasting to 0
|
||||||
|
- update() increments age
|
||||||
|
- update() shifts position toward measurement
|
||||||
|
- update() reduces position uncertainty (trace of P decreases)
|
||||||
|
- confidence increases with age, caps at 1.0
|
||||||
|
- confidence never exceeds 1.0
|
||||||
|
- metadata stored (width, depth, point_count)
|
||||||
|
|
||||||
|
KalmanTrack — velocity convergence:
|
||||||
|
- constant-velocity target: vx converges toward true velocity after N steps
|
||||||
|
- stationary target: speed stays near zero
|
||||||
|
|
||||||
|
associate():
|
||||||
|
- empty tracks → all clusters unmatched
|
||||||
|
- empty clusters → all tracks unmatched
|
||||||
|
- both empty → empty matches
|
||||||
|
- single perfect match below max_dist
|
||||||
|
- single pair above max_dist → no match
|
||||||
|
- two tracks two clusters: diagonal nearest wins
|
||||||
|
- more tracks than clusters: extra tracks unmatched
|
||||||
|
- more clusters than tracks: extra clusters unmatched
|
||||||
|
|
||||||
|
ObstacleTracker:
|
||||||
|
- empty centroids → no tracks created
|
||||||
|
- single cluster → one track, confidence=0 initially
|
||||||
|
- same position twice → track maintained, confidence increases
|
||||||
|
- cluster disappears: coasts then deleted after max_coasting+1 frames
|
||||||
|
- new cluster after deletion gets new track_id
|
||||||
|
- two clusters → two tracks, correct IDs
|
||||||
|
- moving cluster: velocity direction correct after convergence
|
||||||
|
- track_id is monotonically increasing
|
||||||
|
- metadata (width, depth, point_count) stored on track
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||||
|
|
||||||
|
from saltybot_bringup._obstacle_velocity import (
|
||||||
|
KalmanTrack,
|
||||||
|
ObstacleTracker,
|
||||||
|
associate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _c(x: float, y: float) -> np.ndarray:
|
||||||
|
return np.array([x, y], dtype=np.float64)
|
||||||
|
|
||||||
|
|
||||||
|
def _track(x: float = 0.0, y: float = 0.0, **kw) -> KalmanTrack:
|
||||||
|
return KalmanTrack(1, _c(x, y), **kw)
|
||||||
|
|
||||||
|
|
||||||
|
# ── KalmanTrack — construction ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestKalmanTrackInit:
|
||||||
|
|
||||||
|
def test_initial_position(self):
|
||||||
|
t = _track(3.0, -1.5)
|
||||||
|
assert t.position == pytest.approx([3.0, -1.5], abs=1e-9)
|
||||||
|
|
||||||
|
def test_initial_velocity_zero(self):
|
||||||
|
t = _track(1.0, 2.0)
|
||||||
|
assert t.velocity == pytest.approx([0.0, 0.0], abs=1e-9)
|
||||||
|
|
||||||
|
def test_initial_speed_zero(self):
|
||||||
|
t = _track(1.0, 0.0)
|
||||||
|
assert t.speed == pytest.approx(0.0, abs=1e-9)
|
||||||
|
|
||||||
|
def test_initial_confidence_zero(self):
|
||||||
|
t = KalmanTrack(1, _c(0, 0), n_init_frames=3)
|
||||||
|
assert t.confidence == pytest.approx(0.0, abs=1e-9)
|
||||||
|
|
||||||
|
def test_initial_alive(self):
|
||||||
|
assert _track().alive is True
|
||||||
|
|
||||||
|
def test_initial_coasting_zero(self):
|
||||||
|
assert _track().coasting == 0
|
||||||
|
|
||||||
|
def test_initial_age_zero(self):
|
||||||
|
assert _track().age == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── KalmanTrack — predict ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestKalmanTrackPredict:
|
||||||
|
|
||||||
|
def test_predict_dt_zero_no_position_change(self):
|
||||||
|
t = _track(2.0, 3.0)
|
||||||
|
t.predict(0.0)
|
||||||
|
assert t.position == pytest.approx([2.0, 3.0], abs=1e-9)
|
||||||
|
|
||||||
|
def test_predict_dt_positive_zero_velocity_no_position_change(self):
|
||||||
|
t = _track(1.0, 1.0)
|
||||||
|
t.predict(1.0)
|
||||||
|
assert t.position == pytest.approx([1.0, 1.0], abs=1e-6)
|
||||||
|
|
||||||
|
def test_predict_negative_dt_clamped_to_zero(self):
|
||||||
|
t = _track(1.0, 0.0)
|
||||||
|
t.predict(-5.0)
|
||||||
|
assert t.position == pytest.approx([1.0, 0.0], abs=1e-9)
|
||||||
|
|
||||||
|
def test_predict_increments_coasting(self):
|
||||||
|
t = _track()
|
||||||
|
assert t.coasting == 0
|
||||||
|
t.predict(0.1)
|
||||||
|
assert t.coasting == 1
|
||||||
|
t.predict(0.1)
|
||||||
|
assert t.coasting == 2
|
||||||
|
|
||||||
|
def test_alive_before_max_coasting(self):
|
||||||
|
t = KalmanTrack(1, _c(0, 0), max_coasting=3)
|
||||||
|
for _ in range(3):
|
||||||
|
t.predict(0.1)
|
||||||
|
assert t.alive is True # coast=3, 3 > 3 is False → still alive
|
||||||
|
|
||||||
|
def test_alive_false_after_exceeding_max_coasting(self):
|
||||||
|
t = KalmanTrack(1, _c(0, 0), max_coasting=3)
|
||||||
|
for _ in range(4):
|
||||||
|
t.predict(0.1)
|
||||||
|
assert t.alive is False # coast=4, 4 > 3 → dead
|
||||||
|
|
||||||
|
def test_predict_advances_position_when_velocity_set(self):
|
||||||
|
"""After seeding velocity via update, predict advances x."""
|
||||||
|
t = KalmanTrack(1, _c(0.0, 0.0), r_pos=0.001, n_init_frames=1)
|
||||||
|
# Drive velocity to ~(1,0) by updating with advancing centroids
|
||||||
|
for i in range(1, 8):
|
||||||
|
t.predict(1.0)
|
||||||
|
t.update(_c(float(i), 0.0))
|
||||||
|
# Now predict one more step — position should advance in +x
|
||||||
|
x0 = t.position[0]
|
||||||
|
t.predict(1.0)
|
||||||
|
assert t.position[0] > x0
|
||||||
|
|
||||||
|
|
||||||
|
# ── KalmanTrack — update ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestKalmanTrackUpdate:
|
||||||
|
|
||||||
|
def test_update_resets_coasting(self):
|
||||||
|
t = _track()
|
||||||
|
t.predict(0.1) # coast=1
|
||||||
|
t.predict(0.1) # coast=2
|
||||||
|
t.update(_c(0, 0))
|
||||||
|
assert t.coasting == 0
|
||||||
|
|
||||||
|
def test_update_increments_age(self):
|
||||||
|
t = _track()
|
||||||
|
t.update(_c(1, 0))
|
||||||
|
assert t.age == 1
|
||||||
|
t.update(_c(2, 0))
|
||||||
|
assert t.age == 2
|
||||||
|
|
||||||
|
def test_update_shifts_position_toward_measurement(self):
|
||||||
|
t = _track(0.0, 0.0)
|
||||||
|
t.update(_c(5.0, 0.0))
|
||||||
|
# Position should have moved in +x direction
|
||||||
|
assert t.position[0] > 0.0
|
||||||
|
|
||||||
|
def test_update_reduces_position_covariance(self):
|
||||||
|
t = KalmanTrack(1, _c(0, 0))
|
||||||
|
p_before = float(t._P[0, 0])
|
||||||
|
t.update(_c(0, 0))
|
||||||
|
p_after = float(t._P[0, 0])
|
||||||
|
assert p_after < p_before
|
||||||
|
|
||||||
|
def test_confidence_increases_with_age(self):
|
||||||
|
t = KalmanTrack(1, _c(0, 0), n_init_frames=4)
|
||||||
|
assert t.confidence == pytest.approx(0.0, abs=1e-9)
|
||||||
|
t.update(_c(0, 0))
|
||||||
|
assert t.confidence == pytest.approx(0.25)
|
||||||
|
t.update(_c(0, 0))
|
||||||
|
assert t.confidence == pytest.approx(0.50)
|
||||||
|
t.update(_c(0, 0))
|
||||||
|
assert t.confidence == pytest.approx(0.75)
|
||||||
|
t.update(_c(0, 0))
|
||||||
|
assert t.confidence == pytest.approx(1.0)
|
||||||
|
|
||||||
|
def test_confidence_caps_at_one(self):
|
||||||
|
t = KalmanTrack(1, _c(0, 0), n_init_frames=2)
|
||||||
|
for _ in range(10):
|
||||||
|
t.update(_c(0, 0))
|
||||||
|
assert t.confidence == pytest.approx(1.0)
|
||||||
|
|
||||||
|
def test_update_stores_metadata(self):
|
||||||
|
t = _track()
|
||||||
|
t.update(_c(1, 1), width=0.3, depth=0.5, point_count=7)
|
||||||
|
assert t.last_width == pytest.approx(0.3)
|
||||||
|
assert t.last_depth == pytest.approx(0.5)
|
||||||
|
assert t.last_point_count == 7
|
||||||
|
|
||||||
|
|
||||||
|
# ── KalmanTrack — velocity convergence ───────────────────────────────────────
|
||||||
|
|
||||||
|
class TestKalmanTrackVelocityConvergence:
|
||||||
|
|
||||||
|
def test_constant_velocity_converges(self):
|
||||||
|
"""
|
||||||
|
Target at vx=1 m/s: observations at x=0,1,2,...
|
||||||
|
After 15 predict+update cycles with low noise, vx should be near 1.0.
|
||||||
|
"""
|
||||||
|
t = KalmanTrack(1, _c(0.0, 0.0), r_pos=0.01, q_pos=0.001, q_vel=0.1)
|
||||||
|
for i in range(1, 16):
|
||||||
|
t.predict(1.0)
|
||||||
|
t.update(_c(float(i), 0.0))
|
||||||
|
assert t.velocity[0] == pytest.approx(1.0, abs=0.15)
|
||||||
|
assert t.velocity[1] == pytest.approx(0.0, abs=0.10)
|
||||||
|
|
||||||
|
def test_diagonal_velocity_converges(self):
|
||||||
|
"""Target moving at (vx=1, vy=1) m/s."""
|
||||||
|
t = KalmanTrack(1, _c(0.0, 0.0), r_pos=0.01, q_pos=0.001, q_vel=0.1)
|
||||||
|
for i in range(1, 16):
|
||||||
|
t.predict(1.0)
|
||||||
|
t.update(_c(float(i), float(i)))
|
||||||
|
assert t.velocity[0] == pytest.approx(1.0, abs=0.15)
|
||||||
|
assert t.velocity[1] == pytest.approx(1.0, abs=0.15)
|
||||||
|
|
||||||
|
def test_stationary_target_speed_near_zero(self):
|
||||||
|
"""Target at fixed position: estimated speed should stay small."""
|
||||||
|
t = KalmanTrack(1, _c(2.0, 3.0), r_pos=0.001)
|
||||||
|
for _ in range(15):
|
||||||
|
t.predict(0.1)
|
||||||
|
t.update(_c(2.0, 3.0))
|
||||||
|
assert t.speed < 0.05
|
||||||
|
|
||||||
|
|
||||||
|
# ── associate ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestAssociate:
|
||||||
|
|
||||||
|
def test_empty_tracks(self):
|
||||||
|
pos = np.empty((0, 2))
|
||||||
|
cent = np.array([[1.0, 0.0]])
|
||||||
|
m, ut, uc = associate(pos, cent, 1.0)
|
||||||
|
assert m == []
|
||||||
|
assert ut == []
|
||||||
|
assert uc == [0]
|
||||||
|
|
||||||
|
def test_empty_clusters(self):
|
||||||
|
pos = np.array([[1.0, 0.0]])
|
||||||
|
cent = np.empty((0, 2))
|
||||||
|
m, ut, uc = associate(pos, cent, 1.0)
|
||||||
|
assert m == []
|
||||||
|
assert ut == [0]
|
||||||
|
assert uc == []
|
||||||
|
|
||||||
|
def test_both_empty(self):
|
||||||
|
m, ut, uc = associate(np.empty((0, 2)), np.empty((0, 2)), 1.0)
|
||||||
|
assert m == []
|
||||||
|
assert ut == []
|
||||||
|
assert uc == []
|
||||||
|
|
||||||
|
def test_single_match_below_threshold(self):
|
||||||
|
pos = np.array([[0.0, 0.0]])
|
||||||
|
cent = np.array([[0.1, 0.0]])
|
||||||
|
m, ut, uc = associate(pos, cent, 0.5)
|
||||||
|
assert m == [(0, 0)]
|
||||||
|
assert ut == []
|
||||||
|
assert uc == []
|
||||||
|
|
||||||
|
def test_single_pair_above_threshold_no_match(self):
|
||||||
|
pos = np.array([[0.0, 0.0]])
|
||||||
|
cent = np.array([[2.0, 0.0]])
|
||||||
|
m, ut, uc = associate(pos, cent, 0.5)
|
||||||
|
assert m == []
|
||||||
|
assert ut == [0]
|
||||||
|
assert uc == [0]
|
||||||
|
|
||||||
|
def test_two_tracks_two_clusters_diagonal(self):
|
||||||
|
pos = np.array([[0.0, 0.0], [3.0, 0.0]])
|
||||||
|
cent = np.array([[0.1, 0.0], [3.1, 0.0]])
|
||||||
|
m, ut, uc = associate(pos, cent, 0.5)
|
||||||
|
assert (0, 0) in m
|
||||||
|
assert (1, 1) in m
|
||||||
|
assert ut == []
|
||||||
|
assert uc == []
|
||||||
|
|
||||||
|
def test_more_tracks_than_clusters(self):
|
||||||
|
pos = np.array([[0.0, 0.0], [5.0, 0.0]])
|
||||||
|
cent = np.array([[0.1, 0.0]])
|
||||||
|
m, ut, uc = associate(pos, cent, 0.5)
|
||||||
|
assert len(m) == 1
|
||||||
|
assert len(ut) == 1 # one track unmatched
|
||||||
|
assert uc == []
|
||||||
|
|
||||||
|
def test_more_clusters_than_tracks(self):
|
||||||
|
pos = np.array([[0.0, 0.0]])
|
||||||
|
cent = np.array([[0.1, 0.0], [5.0, 0.0]])
|
||||||
|
m, ut, uc = associate(pos, cent, 0.5)
|
||||||
|
assert len(m) == 1
|
||||||
|
assert ut == []
|
||||||
|
assert len(uc) == 1 # one cluster unmatched
|
||||||
|
|
||||||
|
def test_nearest_wins(self):
|
||||||
|
"""Track 0 is closest to cluster 0; ensure it's matched to cluster 0."""
|
||||||
|
pos = np.array([[0.0, 0.0], [10.0, 0.0]])
|
||||||
|
cent = np.array([[0.05, 0.0], [10.2, 0.0]])
|
||||||
|
m, ut, uc = associate(pos, cent, 1.0)
|
||||||
|
assert (0, 0) in m
|
||||||
|
assert (1, 1) in m
|
||||||
|
|
||||||
|
def test_threshold_strictly_less_than(self):
|
||||||
|
"""Distance exactly equal to max_dist should NOT match."""
|
||||||
|
pos = np.array([[0.0, 0.0]])
|
||||||
|
cent = np.array([[0.5, 0.0]])
|
||||||
|
m, ut, uc = associate(pos, cent, 0.5) # dist=0.5, max_dist=0.5
|
||||||
|
assert m == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── ObstacleTracker ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestObstacleTracker:
|
||||||
|
|
||||||
|
def test_empty_input_no_tracks(self):
|
||||||
|
tracker = ObstacleTracker()
|
||||||
|
result = tracker.update([], timestamp=0.0)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_single_cluster_creates_track(self):
|
||||||
|
tracker = ObstacleTracker()
|
||||||
|
tracks = tracker.update([_c(1, 0)], timestamp=0.0)
|
||||||
|
assert len(tracks) == 1
|
||||||
|
|
||||||
|
def test_track_id_starts_at_one(self):
|
||||||
|
tracker = ObstacleTracker()
|
||||||
|
tracks = tracker.update([_c(0, 0)], timestamp=0.0)
|
||||||
|
assert tracks[0].track_id == 1
|
||||||
|
|
||||||
|
def test_same_position_twice_single_track(self):
|
||||||
|
tracker = ObstacleTracker()
|
||||||
|
t1 = tracker.update([_c(1, 0)], timestamp=0.0)
|
||||||
|
t2 = tracker.update([_c(1, 0)], timestamp=0.1)
|
||||||
|
assert len(t2) == 1
|
||||||
|
assert t2[0].track_id == t1[0].track_id
|
||||||
|
|
||||||
|
def test_confidence_increases_with_updates(self):
|
||||||
|
tracker = ObstacleTracker(n_init_frames=3)
|
||||||
|
tracks = tracker.update([_c(0, 0)], timestamp=0.0)
|
||||||
|
c0 = tracks[0].confidence
|
||||||
|
tracks = tracker.update([_c(0, 0)], timestamp=0.1)
|
||||||
|
assert tracks[0].confidence > c0
|
||||||
|
|
||||||
|
def test_track_coasts_then_dies(self):
|
||||||
|
tracker = ObstacleTracker(max_coasting_frames=2)
|
||||||
|
tracker.update([_c(0, 0)], timestamp=0.0) # create
|
||||||
|
t1 = tracker.update([], timestamp=0.1) # coast 1
|
||||||
|
assert len(t1) == 1
|
||||||
|
t2 = tracker.update([], timestamp=0.2) # coast 2
|
||||||
|
assert len(t2) == 1
|
||||||
|
t3 = tracker.update([], timestamp=0.3) # coast 3 > 2 → dead
|
||||||
|
assert len(t3) == 0
|
||||||
|
|
||||||
|
def test_new_cluster_after_deletion_new_id(self):
|
||||||
|
tracker = ObstacleTracker(max_coasting_frames=1)
|
||||||
|
t0 = tracker.update([_c(0, 0)], timestamp=0.0)
|
||||||
|
old_id = t0[0].track_id
|
||||||
|
tracker.update([], timestamp=0.1) # coast 1
|
||||||
|
tracker.update([], timestamp=0.2) # coast 2 > 1 → dead
|
||||||
|
t1 = tracker.update([_c(0, 0)], timestamp=0.3)
|
||||||
|
assert len(t1) == 1
|
||||||
|
assert t1[0].track_id != old_id
|
||||||
|
|
||||||
|
def test_two_clusters_two_tracks(self):
|
||||||
|
tracker = ObstacleTracker()
|
||||||
|
tracks = tracker.update([_c(0, 0), _c(5, 0)], timestamp=0.0)
|
||||||
|
assert len(tracks) == 2
|
||||||
|
ids = {t.track_id for t in tracks}
|
||||||
|
assert len(ids) == 2
|
||||||
|
|
||||||
|
def test_track_ids_monotonically_increasing(self):
|
||||||
|
tracker = ObstacleTracker()
|
||||||
|
tracker.update([_c(0, 0)], timestamp=0.0)
|
||||||
|
tracker.update([_c(10, 10)], timestamp=0.1) # far → new track
|
||||||
|
all_ids = [t.track_id for t in tracker.tracks.values()]
|
||||||
|
assert all_ids == sorted(all_ids)
|
||||||
|
|
||||||
|
def test_moving_cluster_velocity_direction(self):
|
||||||
|
"""
|
||||||
|
Cluster moves in +x at 1 m/s; after convergence vx should be positive.
|
||||||
|
Use low noise and many steps for reliable convergence.
|
||||||
|
"""
|
||||||
|
tracker = ObstacleTracker(
|
||||||
|
n_init_frames=1,
|
||||||
|
r_pos=0.01,
|
||||||
|
q_pos=0.001,
|
||||||
|
q_vel=0.1,
|
||||||
|
)
|
||||||
|
for i in range(20):
|
||||||
|
tracker.update([_c(float(i) * 0.1, 0.0)], timestamp=float(i) * 0.1)
|
||||||
|
tracks = list(tracker.tracks.values())
|
||||||
|
assert len(tracks) == 1
|
||||||
|
assert tracks[0].velocity[0] > 0.05
|
||||||
|
|
||||||
|
def test_metadata_stored_on_track(self):
|
||||||
|
tracker = ObstacleTracker()
|
||||||
|
tracker.update(
|
||||||
|
[_c(1, 1)],
|
||||||
|
timestamp = 0.0,
|
||||||
|
widths = [0.4],
|
||||||
|
depths = [0.6],
|
||||||
|
point_counts = [9],
|
||||||
|
)
|
||||||
|
t = list(tracker.tracks.values())[0]
|
||||||
|
assert t.last_width == pytest.approx(0.4)
|
||||||
|
assert t.last_depth == pytest.approx(0.6)
|
||||||
|
assert t.last_point_count == 9
|
||||||
|
|
||||||
|
def test_far_cluster_creates_new_track(self):
|
||||||
|
"""Cluster beyond max_association_dist creates a second track."""
|
||||||
|
tracker = ObstacleTracker(max_association_dist_m=0.5)
|
||||||
|
tracker.update([_c(0, 0)], timestamp=0.0)
|
||||||
|
tracks = tracker.update([_c(10, 0)], timestamp=0.1)
|
||||||
|
# Original track coasts, new track spawned for far cluster
|
||||||
|
assert len(tracks) == 2
|
||||||
|
|
||||||
|
def test_empty_to_single_and_back(self):
|
||||||
|
tracker = ObstacleTracker(max_coasting_frames=0)
|
||||||
|
t1 = tracker.update([_c(1, 0)], timestamp=0.0)
|
||||||
|
assert len(t1) == 1
|
||||||
|
t2 = tracker.update([], timestamp=0.1) # coast=1 > 0 → dead
|
||||||
|
assert len(t2) == 0
|
||||||
|
t3 = tracker.update([_c(1, 0)], timestamp=0.2)
|
||||||
|
assert len(t3) == 1
|
||||||
|
|
||||||
|
def test_constant_velocity_estimate(self):
|
||||||
|
"""
|
||||||
|
Target moves at vx=0.1 m/s (0.1 m per 1-second step).
|
||||||
|
Each step is well within the default max_association_dist_m=0.5 m,
|
||||||
|
so the track is continuously matched. After many updates, estimated
|
||||||
|
speed should be close to 0.1 m/s.
|
||||||
|
"""
|
||||||
|
tracker = ObstacleTracker(
|
||||||
|
n_init_frames=1, r_pos=0.001, q_pos=0.0001, q_vel=0.05
|
||||||
|
)
|
||||||
|
for i in range(30):
|
||||||
|
tracker.update([_c(float(i) * 0.1, 0.0)], timestamp=float(i))
|
||||||
|
t = list(tracker.tracks.values())[0]
|
||||||
|
assert t.speed == pytest.approx(0.1, abs=0.03)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main([__file__, '-v'])
|
||||||
364
jetson/ros2_ws/src/saltybot_bringup/test/test_path_edges.py
Normal file
364
jetson/ros2_ws/src/saltybot_bringup/test/test_path_edges.py
Normal file
@ -0,0 +1,364 @@
|
|||||||
|
"""
|
||||||
|
test_path_edges.py — pytest tests for _path_edges.py (no ROS2 required).
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- build_homography: output shape, identity-like mapping
|
||||||
|
- apply_homography: empty input, single point, batch
|
||||||
|
- canny_edges: output shape, dtype, uniform image produces no edges
|
||||||
|
- hough_lines: empty edge map returns []
|
||||||
|
- classify_lines: slope filtering, left/right split
|
||||||
|
- average_line: empty → None, single line, multi-line average
|
||||||
|
- warp_segments: empty list, segment endpoint ordering
|
||||||
|
- process_frame: smoke test on synthetic image
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from saltybot_bringup._path_edges import (
|
||||||
|
PathEdgeConfig,
|
||||||
|
PathEdgesResult,
|
||||||
|
apply_homography,
|
||||||
|
average_line,
|
||||||
|
build_homography,
|
||||||
|
canny_edges,
|
||||||
|
classify_lines,
|
||||||
|
hough_lines,
|
||||||
|
process_frame,
|
||||||
|
warp_segments,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _solid_bgr(h: int = 100, w: int = 200, color=(128, 128, 128)) -> np.ndarray:
|
||||||
|
img = np.zeros((h, w, 3), dtype=np.uint8)
|
||||||
|
img[:] = color
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def _default_H(roi_w: int = 200, roi_h: int = 100) -> np.ndarray:
|
||||||
|
cfg = PathEdgeConfig()
|
||||||
|
return build_homography(cfg.birdseye_src, roi_w, roi_h, cfg.birdseye_size)
|
||||||
|
|
||||||
|
|
||||||
|
# ── build_homography ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestBuildHomography:
|
||||||
|
def test_shape(self):
|
||||||
|
H = _default_H()
|
||||||
|
assert H.shape == (3, 3)
|
||||||
|
|
||||||
|
def test_dtype(self):
|
||||||
|
H = _default_H()
|
||||||
|
assert H.dtype == np.float64
|
||||||
|
|
||||||
|
def test_bottom_left_maps_near_left(self):
|
||||||
|
"""Bottom-left source trapezoid point should map near left of bird-eye."""
|
||||||
|
cfg = PathEdgeConfig()
|
||||||
|
H = build_homography(cfg.birdseye_src, 200, 100, cfg.birdseye_size)
|
||||||
|
# src[3] = [0.05, 0.95] → near bottom-left of ROI → should map to x≈100 (centre-left), y≈400
|
||||||
|
src_pt = np.array([[0.05 * 200, 0.95 * 100]], dtype=np.float32)
|
||||||
|
dst = apply_homography(H, src_pt)
|
||||||
|
assert dst[0, 0] == pytest.approx(cfg.birdseye_size * 0.25, abs=5)
|
||||||
|
assert dst[0, 1] == pytest.approx(cfg.birdseye_size, abs=5)
|
||||||
|
|
||||||
|
def test_bottom_right_maps_near_right(self):
|
||||||
|
cfg = PathEdgeConfig()
|
||||||
|
H = build_homography(cfg.birdseye_src, 200, 100, cfg.birdseye_size)
|
||||||
|
# src[2] = [0.95, 0.95] → bottom-right → maps to x≈300, y≈400
|
||||||
|
src_pt = np.array([[0.95 * 200, 0.95 * 100]], dtype=np.float32)
|
||||||
|
dst = apply_homography(H, src_pt)
|
||||||
|
assert dst[0, 0] == pytest.approx(cfg.birdseye_size * 0.75, abs=5)
|
||||||
|
assert dst[0, 1] == pytest.approx(cfg.birdseye_size, abs=5)
|
||||||
|
|
||||||
|
|
||||||
|
# ── apply_homography ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestApplyHomography:
|
||||||
|
def test_empty_input(self):
|
||||||
|
H = _default_H()
|
||||||
|
out = apply_homography(H, np.empty((0, 2), dtype=np.float32))
|
||||||
|
assert out.shape == (0, 2)
|
||||||
|
|
||||||
|
def test_single_point_roundtrip(self):
|
||||||
|
"""Warping then inverse-warping should recover the original point."""
|
||||||
|
H = _default_H(200, 100)
|
||||||
|
H_inv = np.linalg.inv(H)
|
||||||
|
pt = np.array([[50.0, 40.0]], dtype=np.float32)
|
||||||
|
warped = apply_homography(H, pt)
|
||||||
|
unwarped = apply_homography(H_inv, warped)
|
||||||
|
np.testing.assert_allclose(unwarped, pt, atol=1e-3)
|
||||||
|
|
||||||
|
def test_batch_output_shape(self):
|
||||||
|
H = _default_H()
|
||||||
|
pts = np.random.rand(5, 2).astype(np.float32) * 100
|
||||||
|
out = apply_homography(H, pts)
|
||||||
|
assert out.shape == (5, 2)
|
||||||
|
assert out.dtype == np.float32
|
||||||
|
|
||||||
|
|
||||||
|
# ── canny_edges ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestCannyEdges:
|
||||||
|
def test_output_shape(self):
|
||||||
|
roi = _solid_bgr(80, 160)
|
||||||
|
edges = canny_edges(roi, low=50, high=150, ksize=5)
|
||||||
|
assert edges.shape == (80, 160)
|
||||||
|
|
||||||
|
def test_output_dtype(self):
|
||||||
|
roi = _solid_bgr()
|
||||||
|
edges = canny_edges(roi, low=50, high=150, ksize=5)
|
||||||
|
assert edges.dtype == np.uint8
|
||||||
|
|
||||||
|
def test_uniform_image_no_edges(self):
|
||||||
|
"""A solid-colour image should produce no edges."""
|
||||||
|
roi = _solid_bgr(80, 160, color=(100, 100, 100))
|
||||||
|
edges = canny_edges(roi, low=50, high=150, ksize=5)
|
||||||
|
assert edges.max() == 0
|
||||||
|
|
||||||
|
def test_strong_edge_detected(self):
|
||||||
|
"""A sharp horizontal boundary should produce edges."""
|
||||||
|
roi = np.zeros((100, 200, 3), dtype=np.uint8)
|
||||||
|
roi[50:, :] = 255 # sharp boundary at y=50
|
||||||
|
edges = canny_edges(roi, low=20, high=60, ksize=3)
|
||||||
|
assert edges.max() == 255
|
||||||
|
|
||||||
|
def test_ksize_even_skips_blur(self):
|
||||||
|
"""Even ksize should not crash (blur is skipped for even kernels)."""
|
||||||
|
roi = _solid_bgr()
|
||||||
|
edges = canny_edges(roi, low=50, high=150, ksize=4)
|
||||||
|
assert edges.shape == (100, 200)
|
||||||
|
|
||||||
|
|
||||||
|
# ── hough_lines ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestHoughLines:
|
||||||
|
def test_empty_edge_map(self):
|
||||||
|
edge_map = np.zeros((100, 200), dtype=np.uint8)
|
||||||
|
lines = hough_lines(edge_map, threshold=30, min_len=40, max_gap=20)
|
||||||
|
assert lines == []
|
||||||
|
|
||||||
|
def test_returns_list_of_tuples(self):
|
||||||
|
"""Draw a diagonal line and verify hough returns tuples of 4 floats."""
|
||||||
|
edge_map = np.zeros((100, 200), dtype=np.uint8)
|
||||||
|
cv2.line(edge_map, (0, 0), (199, 99), 255, 2)
|
||||||
|
lines = hough_lines(edge_map, threshold=10, min_len=20, max_gap=5)
|
||||||
|
assert isinstance(lines, list)
|
||||||
|
if lines:
|
||||||
|
assert len(lines[0]) == 4
|
||||||
|
assert all(isinstance(v, float) for v in lines[0])
|
||||||
|
|
||||||
|
def test_line_detected_on_drawn_segment(self):
|
||||||
|
"""A drawn line segment should be detected."""
|
||||||
|
edge_map = np.zeros((200, 400), dtype=np.uint8)
|
||||||
|
cv2.line(edge_map, (50, 100), (350, 150), 255, 2)
|
||||||
|
lines = hough_lines(edge_map, threshold=10, min_len=30, max_gap=10)
|
||||||
|
assert len(lines) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
# ── classify_lines ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestClassifyLines:
|
||||||
|
def test_empty_returns_two_empty_lists(self):
|
||||||
|
left, right = classify_lines([], min_slope=0.3)
|
||||||
|
assert left == []
|
||||||
|
assert right == []
|
||||||
|
|
||||||
|
def test_negative_slope_goes_left(self):
|
||||||
|
# slope = (y2-y1)/(x2-x1) = (50-0)/(0-100) = -0.5 → left
|
||||||
|
lines = [(100.0, 0.0, 0.0, 50.0)]
|
||||||
|
left, right = classify_lines(lines, min_slope=0.3)
|
||||||
|
assert len(left) == 1
|
||||||
|
assert right == []
|
||||||
|
|
||||||
|
def test_positive_slope_goes_right(self):
|
||||||
|
# slope = (50-0)/(100-0) = 0.5 → right
|
||||||
|
lines = [(0.0, 0.0, 100.0, 50.0)]
|
||||||
|
left, right = classify_lines(lines, min_slope=0.3)
|
||||||
|
assert left == []
|
||||||
|
assert len(right) == 1
|
||||||
|
|
||||||
|
def test_near_horizontal_discarded(self):
|
||||||
|
# slope = 0.1 → |slope| < 0.3 → discard
|
||||||
|
lines = [(0.0, 0.0, 100.0, 10.0)]
|
||||||
|
left, right = classify_lines(lines, min_slope=0.3)
|
||||||
|
assert left == []
|
||||||
|
assert right == []
|
||||||
|
|
||||||
|
def test_vertical_line_skipped(self):
|
||||||
|
# dx ≈ 0 → skip
|
||||||
|
lines = [(50.0, 0.0, 50.0, 100.0)]
|
||||||
|
left, right = classify_lines(lines, min_slope=0.3)
|
||||||
|
assert left == []
|
||||||
|
assert right == []
|
||||||
|
|
||||||
|
def test_mixed_lines(self):
|
||||||
|
lines = [
|
||||||
|
(100.0, 0.0, 0.0, 50.0), # slope -0.5 → left
|
||||||
|
(0.0, 0.0, 100.0, 50.0), # slope +0.5 → right
|
||||||
|
(0.0, 0.0, 100.0, 5.0), # slope +0.05 → discard
|
||||||
|
]
|
||||||
|
left, right = classify_lines(lines, min_slope=0.3)
|
||||||
|
assert len(left) == 1
|
||||||
|
assert len(right) == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ── average_line ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestAverageLine:
|
||||||
|
def test_empty_returns_none(self):
|
||||||
|
assert average_line([], roi_height=100) is None
|
||||||
|
|
||||||
|
def test_single_line_extrapolated(self):
|
||||||
|
# slope=0.5, intercept=0: x = y/0.5 = 2y
|
||||||
|
# At y=99: x=198; at y=0: x=0
|
||||||
|
lines = [(0.0, 0.0, 200.0, 100.0)] # slope = 100/200 = 0.5
|
||||||
|
result = average_line(lines, roi_height=100)
|
||||||
|
assert result is not None
|
||||||
|
x_bot, y_bot, x_top, y_top = result
|
||||||
|
assert y_bot == pytest.approx(99.0, abs=1)
|
||||||
|
assert y_top == pytest.approx(0.0, abs=1)
|
||||||
|
|
||||||
|
def test_two_parallel_lines_averaged(self):
|
||||||
|
# Both have slope=1.0, intercepts -50 and +50 → avg intercept=0
|
||||||
|
# x = (y - 0) / 1.0 = y
|
||||||
|
lines = [
|
||||||
|
(50.0, 0.0, 100.0, 50.0), # slope=1, intercept=-50
|
||||||
|
(0.0, 50.0, 50.0, 100.0), # slope=1, intercept=50
|
||||||
|
]
|
||||||
|
result = average_line(lines, roi_height=100)
|
||||||
|
assert result is not None
|
||||||
|
x_bot, y_bot, x_top, y_top = result
|
||||||
|
assert y_bot == pytest.approx(99.0, abs=1)
|
||||||
|
# avg intercept = 0, m_avg=1 → x_bot = 99
|
||||||
|
assert x_bot == pytest.approx(99.0, abs=2)
|
||||||
|
|
||||||
|
def test_vertical_only_returns_none(self):
|
||||||
|
# dx == 0 → skip → no slopes → None
|
||||||
|
lines = [(50.0, 0.0, 50.0, 100.0)]
|
||||||
|
assert average_line(lines, roi_height=100) is None
|
||||||
|
|
||||||
|
def test_returns_four_tuple(self):
|
||||||
|
lines = [(0.0, 0.0, 100.0, 50.0)]
|
||||||
|
result = average_line(lines, roi_height=100)
|
||||||
|
assert result is not None
|
||||||
|
assert len(result) == 4
|
||||||
|
|
||||||
|
|
||||||
|
# ── warp_segments ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestWarpSegments:
|
||||||
|
def test_empty_returns_empty(self):
|
||||||
|
H = _default_H()
|
||||||
|
result = warp_segments([], H)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_single_segment_returns_one_tuple(self):
|
||||||
|
H = _default_H(200, 100)
|
||||||
|
lines = [(10.0, 10.0, 90.0, 80.0)]
|
||||||
|
result = warp_segments(lines, H)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert len(result[0]) == 4
|
||||||
|
|
||||||
|
def test_start_and_end_distinct(self):
|
||||||
|
"""Warped segment endpoints should be different from each other."""
|
||||||
|
H = _default_H(200, 100)
|
||||||
|
lines = [(10.0, 10.0, 190.0, 90.0)]
|
||||||
|
result = warp_segments(lines, H)
|
||||||
|
bx1, by1, bx2, by2 = result[0]
|
||||||
|
# The two endpoints should differ
|
||||||
|
assert abs(bx1 - bx2) + abs(by1 - by2) > 1.0
|
||||||
|
|
||||||
|
def test_batch_preserves_count(self):
|
||||||
|
H = _default_H(200, 100)
|
||||||
|
lines = [(0.0, 0.0, 10.0, 10.0), (100.0, 0.0, 90.0, 50.0)]
|
||||||
|
result = warp_segments(lines, H)
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ── process_frame ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestProcessFrame:
|
||||||
|
def _lane_image(self, h: int = 480, w: int = 640) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Create a synthetic BGR image with two diagonal lines representing
|
||||||
|
left and right lane edges in the bottom half.
|
||||||
|
"""
|
||||||
|
img = np.zeros((h, w, 3), dtype=np.uint8)
|
||||||
|
roi_top = h // 2
|
||||||
|
# Left edge: from bottom-left area upward to the right (negative slope in image coords)
|
||||||
|
cv2.line(img, (80, h - 10), (240, roi_top + 10), (255, 255, 255), 4)
|
||||||
|
# Right edge: from bottom-right area upward to the left (positive slope in image coords)
|
||||||
|
cv2.line(img, (w - 80, h - 10), (w - 240, roi_top + 10), (255, 255, 255), 4)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def test_returns_named_tuple(self):
|
||||||
|
bgr = _solid_bgr(120, 240)
|
||||||
|
cfg = PathEdgeConfig()
|
||||||
|
result = process_frame(bgr, cfg)
|
||||||
|
assert isinstance(result, PathEdgesResult)
|
||||||
|
|
||||||
|
def test_roi_top_correct(self):
|
||||||
|
bgr = np.zeros((200, 400, 3), dtype=np.uint8)
|
||||||
|
cfg = PathEdgeConfig(roi_frac=0.5)
|
||||||
|
result = process_frame(bgr, cfg)
|
||||||
|
assert result.roi_top == 100
|
||||||
|
|
||||||
|
def test_uniform_image_no_lines(self):
|
||||||
|
bgr = _solid_bgr(200, 400, color=(80, 80, 80))
|
||||||
|
cfg = PathEdgeConfig()
|
||||||
|
result = process_frame(bgr, cfg)
|
||||||
|
assert result.lines == []
|
||||||
|
assert result.left_edge is None
|
||||||
|
assert result.right_edge is None
|
||||||
|
assert result.left_lines == []
|
||||||
|
assert result.right_lines == []
|
||||||
|
|
||||||
|
def test_homography_matrix_shape(self):
|
||||||
|
bgr = _solid_bgr(200, 400)
|
||||||
|
result = process_frame(bgr, PathEdgeConfig())
|
||||||
|
assert result.H.shape == (3, 3)
|
||||||
|
|
||||||
|
def test_birdseye_segments_same_count(self):
|
||||||
|
"""birdseye_lines and lines must have the same number of segments."""
|
||||||
|
bgr = self._lane_image()
|
||||||
|
result = process_frame(bgr, PathEdgeConfig(hough_threshold=10, min_line_len=20))
|
||||||
|
assert len(result.birdseye_lines) == len(result.lines)
|
||||||
|
|
||||||
|
def test_lane_image_detects_edges(self):
|
||||||
|
"""Synthetic lane image should detect at least one of left/right edge."""
|
||||||
|
bgr = self._lane_image()
|
||||||
|
cfg = PathEdgeConfig(
|
||||||
|
roi_frac=0.5,
|
||||||
|
canny_low=30,
|
||||||
|
canny_high=100,
|
||||||
|
hough_threshold=10,
|
||||||
|
min_line_len=20,
|
||||||
|
max_line_gap=15,
|
||||||
|
)
|
||||||
|
result = process_frame(bgr, cfg)
|
||||||
|
assert (result.left_edge is not None) or (result.right_edge is not None)
|
||||||
|
|
||||||
|
def test_segments_px_flat_array_length(self):
|
||||||
|
"""segments_px-equivalent length must be 4 × line_count."""
|
||||||
|
bgr = self._lane_image()
|
||||||
|
cfg = PathEdgeConfig(hough_threshold=10, min_line_len=20)
|
||||||
|
result = process_frame(bgr, cfg)
|
||||||
|
assert len(result.lines) * 4 == sum(4 for _ in result.lines)
|
||||||
|
|
||||||
|
def test_left_right_lines_subset_of_all_lines(self):
|
||||||
|
"""left_lines + right_lines must be a subset of all lines."""
|
||||||
|
bgr = self._lane_image()
|
||||||
|
cfg = PathEdgeConfig(hough_threshold=10, min_line_len=20)
|
||||||
|
result = process_frame(bgr, cfg)
|
||||||
|
all_set = set(result.lines)
|
||||||
|
for seg in result.left_lines:
|
||||||
|
assert seg in all_set
|
||||||
|
for seg in result.right_lines:
|
||||||
|
assert seg in all_set
|
||||||
599
jetson/ros2_ws/src/saltybot_bringup/test/test_person_tracker.py
Normal file
599
jetson/ros2_ws/src/saltybot_bringup/test/test_person_tracker.py
Normal file
@ -0,0 +1,599 @@
|
|||||||
|
"""
|
||||||
|
test_person_tracker.py — Unit tests for the P0 person tracking pipeline.
|
||||||
|
|
||||||
|
Tests cover: IoU, Kalman filter, colour histogram, bearing geometry,
|
||||||
|
depth sampling, tracker state machine, and follow-target selection.
|
||||||
|
No camera, no detector, no ROS2 needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from saltybot_bringup._person_tracker import (
|
||||||
|
BBox,
|
||||||
|
CamParams,
|
||||||
|
Detection,
|
||||||
|
FollowTargetSelector,
|
||||||
|
KalmanBoxFilter,
|
||||||
|
PersonTrack,
|
||||||
|
PersonTracker,
|
||||||
|
TrackState,
|
||||||
|
DEPTH_GOOD, DEPTH_MARGINAL, DEPTH_EXTRAPOLATED, DEPTH_INVALID,
|
||||||
|
bearing_from_pixel,
|
||||||
|
depth_at_bbox,
|
||||||
|
extract_torso_hist,
|
||||||
|
hist_similarity,
|
||||||
|
iou,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Helpers
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _det(x=10, y=20, w=60, h=150, conf=0.85, frame=None) -> Detection:
|
||||||
|
return Detection(BBox(x, y, w, h), conf, frame)
|
||||||
|
|
||||||
|
|
||||||
|
def _solid_bgr(h=200, w=100, b=128, g=64, r=32) -> np.ndarray:
|
||||||
|
"""Uniform colour BGR image."""
|
||||||
|
return np.full((h, w, 3), (b, g, r), dtype=np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
def _depth_image(h=480, w=640, val_mm=2000) -> np.ndarray:
|
||||||
|
"""Uniform uint16 depth image at val_mm mm."""
|
||||||
|
return np.full((h, w), val_mm, dtype=np.uint16)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_tracker_n(tracker, bbox, n=5, cam=None):
|
||||||
|
"""Feed the same detection to a tracker n times."""
|
||||||
|
for _ in range(n):
|
||||||
|
tracker.update([_det(*bbox)], cam=cam)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# BBox
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_bbox_fields():
|
||||||
|
b = BBox(10, 20, 60, 150)
|
||||||
|
assert b.x == 10 and b.y == 20 and b.w == 60 and b.h == 150
|
||||||
|
|
||||||
|
|
||||||
|
def test_bbox_is_named_tuple():
|
||||||
|
b = BBox(0, 0, 1, 1)
|
||||||
|
assert isinstance(b, tuple)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# IoU
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_iou_identical():
|
||||||
|
b = BBox(0, 0, 100, 100)
|
||||||
|
assert abs(iou(b, b) - 1.0) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_iou_no_overlap():
|
||||||
|
a = BBox(0, 0, 50, 50)
|
||||||
|
b = BBox(100, 0, 50, 50)
|
||||||
|
assert iou(a, b) == pytest.approx(0.0, abs=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def test_iou_partial_overlap():
|
||||||
|
a = BBox(0, 0, 100, 100)
|
||||||
|
b = BBox(50, 0, 100, 100) # 50 % horizontal overlap
|
||||||
|
result = iou(a, b)
|
||||||
|
assert 0.0 < result < 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_iou_exact_overlap_value():
|
||||||
|
a = BBox(0, 0, 100, 100) # area 10000
|
||||||
|
b = BBox(0, 0, 50, 100) # area 5000, inter=5000, union=10000
|
||||||
|
assert abs(iou(a, b) - 0.5) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_iou_contained():
|
||||||
|
a = BBox(0, 0, 100, 100)
|
||||||
|
b = BBox(25, 25, 50, 50) # b inside a
|
||||||
|
result = iou(a, b)
|
||||||
|
# inter = 2500, union = 10000+2500-2500 = 10000
|
||||||
|
assert abs(result - 0.25) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_iou_symmetric():
|
||||||
|
a = BBox(10, 10, 80, 80)
|
||||||
|
b = BBox(50, 50, 80, 80)
|
||||||
|
assert abs(iou(a, b) - iou(b, a)) < 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def test_iou_touching_edges():
|
||||||
|
a = BBox(0, 0, 50, 50)
|
||||||
|
b = BBox(50, 0, 50, 50) # touch at x=50 → zero overlap
|
||||||
|
assert iou(a, b) == pytest.approx(0.0, abs=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# KalmanBoxFilter
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_kalman_predict_returns_bbox():
|
||||||
|
kf = KalmanBoxFilter(BBox(10, 20, 60, 150))
|
||||||
|
pred = kf.predict()
|
||||||
|
assert isinstance(pred, BBox)
|
||||||
|
assert pred.w >= 1 and pred.h >= 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_kalman_update_returns_bbox():
|
||||||
|
kf = KalmanBoxFilter(BBox(10, 20, 60, 150))
|
||||||
|
updated = kf.update(BBox(10, 20, 60, 150))
|
||||||
|
assert isinstance(updated, BBox)
|
||||||
|
|
||||||
|
|
||||||
|
def test_kalman_converges_to_stationary():
|
||||||
|
"""After many identical measurements, Kalman should converge near them."""
|
||||||
|
init = BBox(100, 100, 80, 160)
|
||||||
|
meas = BBox(102, 98, 80, 160)
|
||||||
|
kf = KalmanBoxFilter(init)
|
||||||
|
for _ in range(20):
|
||||||
|
kf.predict()
|
||||||
|
result = kf.update(meas)
|
||||||
|
# Should be within ~10 px of measurement
|
||||||
|
assert abs(result.x - meas.x) < 10
|
||||||
|
assert abs(result.y - meas.y) < 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_kalman_predict_advances_with_velocity():
|
||||||
|
"""Give the filter a few frames of rightward motion; prediction overshoots."""
|
||||||
|
kf = KalmanBoxFilter(BBox(100, 100, 60, 150))
|
||||||
|
for i in range(5):
|
||||||
|
kf.predict()
|
||||||
|
kf.update(BBox(100 + i * 10, 100, 60, 150))
|
||||||
|
pred = kf.predict()
|
||||||
|
# After motion right, predicted cx should be further right
|
||||||
|
assert pred.x > 100 + 4 * 10 - 5 # at least near last position
|
||||||
|
|
||||||
|
|
||||||
|
def test_kalman_velocity_zero_initially():
|
||||||
|
kf = KalmanBoxFilter(BBox(100, 100, 60, 150))
|
||||||
|
assert kf.velocity_px == (0.0, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_kalman_bbox_property():
|
||||||
|
b = BBox(50, 50, 80, 120)
|
||||||
|
kf = KalmanBoxFilter(b)
|
||||||
|
out = kf.bbox
|
||||||
|
assert isinstance(out, BBox)
|
||||||
|
# Should be near initial
|
||||||
|
assert abs(out.x - b.x) <= 2
|
||||||
|
assert abs(out.y - b.y) <= 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_kalman_width_height_stay_positive():
|
||||||
|
kf = KalmanBoxFilter(BBox(10, 10, 5, 5))
|
||||||
|
for _ in range(30):
|
||||||
|
kf.predict()
|
||||||
|
assert kf.bbox.w >= 1 and kf.bbox.h >= 1
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# extract_torso_hist
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_torso_hist_shape():
|
||||||
|
bgr = _solid_bgr()
|
||||||
|
h = extract_torso_hist(bgr, BBox(0, 0, 100, 200))
|
||||||
|
assert h is not None
|
||||||
|
assert h.shape == (128,) # 16 * 8
|
||||||
|
|
||||||
|
|
||||||
|
def test_torso_hist_normalised():
|
||||||
|
bgr = _solid_bgr()
|
||||||
|
h = extract_torso_hist(bgr, BBox(0, 0, 100, 200))
|
||||||
|
assert h is not None
|
||||||
|
assert abs(h.sum() - 1.0) < 1e-5
|
||||||
|
|
||||||
|
|
||||||
|
def test_torso_hist_none_for_none_frame():
|
||||||
|
assert extract_torso_hist(None, BBox(0, 0, 100, 200)) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_torso_hist_none_for_tiny_bbox():
|
||||||
|
bgr = _solid_bgr()
|
||||||
|
assert extract_torso_hist(bgr, BBox(0, 0, 2, 2)) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_torso_hist_different_colours():
|
||||||
|
"""Two differently coloured crops should produce different histograms."""
|
||||||
|
bgr_red = np.full((200, 100, 3), (0, 0, 255), dtype=np.uint8)
|
||||||
|
bgr_blue = np.full((200, 100, 3), (255, 0, 0), dtype=np.uint8)
|
||||||
|
h_red = extract_torso_hist(bgr_red, BBox(0, 0, 100, 200))
|
||||||
|
h_blue = extract_torso_hist(bgr_blue, BBox(0, 0, 100, 200))
|
||||||
|
assert h_red is not None and h_blue is not None
|
||||||
|
assert not np.allclose(h_red, h_blue)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# hist_similarity
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_hist_similarity_identical():
|
||||||
|
h = np.ones(128, dtype=np.float32) / 128.0
|
||||||
|
assert abs(hist_similarity(h, h) - 1.0) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_hist_similarity_orthogonal():
|
||||||
|
"""Non-overlapping histograms → 0 similarity."""
|
||||||
|
h1 = np.zeros(128, dtype=np.float32)
|
||||||
|
h2 = np.zeros(128, dtype=np.float32)
|
||||||
|
h1[:64] = 1.0 / 64
|
||||||
|
h2[64:] = 1.0 / 64
|
||||||
|
assert abs(hist_similarity(h1, h2)) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_hist_similarity_range():
|
||||||
|
rng = np.random.RandomState(0)
|
||||||
|
for _ in range(20):
|
||||||
|
h1 = rng.rand(128).astype(np.float32)
|
||||||
|
h2 = rng.rand(128).astype(np.float32)
|
||||||
|
h1 /= h1.sum(); h2 /= h2.sum()
|
||||||
|
s = hist_similarity(h1, h2)
|
||||||
|
assert 0.0 <= s <= 1.0 + 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_hist_similarity_symmetric():
|
||||||
|
h1 = np.random.RandomState(1).rand(128).astype(np.float32)
|
||||||
|
h2 = np.random.RandomState(2).rand(128).astype(np.float32)
|
||||||
|
h1 /= h1.sum(); h2 /= h2.sum()
|
||||||
|
assert abs(hist_similarity(h1, h2) - hist_similarity(h2, h1)) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# bearing_from_pixel
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_bearing_centre_is_zero():
|
||||||
|
"""Pixel at principal point → bearing = 0°."""
|
||||||
|
assert abs(bearing_from_pixel(320.0, 320.0, 615.0)) < 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def test_bearing_right_is_positive():
|
||||||
|
b = bearing_from_pixel(400.0, 320.0, 615.0)
|
||||||
|
assert b > 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_bearing_left_is_negative():
|
||||||
|
b = bearing_from_pixel(200.0, 320.0, 615.0)
|
||||||
|
assert b < 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_bearing_symmetric():
|
||||||
|
b_right = bearing_from_pixel(400.0, 320.0, 615.0)
|
||||||
|
b_left = bearing_from_pixel(240.0, 320.0, 615.0)
|
||||||
|
assert abs(b_right + b_left) < 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def test_bearing_approx_at_45_deg():
|
||||||
|
"""u - cx = fx → atan(1) = 45°."""
|
||||||
|
bearing = bearing_from_pixel(935.0, 320.0, 615.0)
|
||||||
|
assert abs(bearing - 45.0) < 0.1
|
||||||
|
|
||||||
|
|
||||||
|
def test_bearing_degrees_not_radians():
|
||||||
|
"""Result must be in degrees (much larger than a radian value)."""
|
||||||
|
b = bearing_from_pixel(400.0, 320.0, 615.0)
|
||||||
|
assert abs(b) > 0.5 # atan2(80/615) ≈ 7.4°
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# depth_at_bbox
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_depth_uniform_field():
|
||||||
|
d = _depth_image(val_mm=2000)
|
||||||
|
depth_m, quality = depth_at_bbox(d, BBox(200, 150, 80, 200))
|
||||||
|
assert abs(depth_m - 2.0) < 0.01
|
||||||
|
assert quality == DEPTH_GOOD
|
||||||
|
|
||||||
|
|
||||||
|
def test_depth_zero_image_invalid():
|
||||||
|
d = np.zeros((480, 640), dtype=np.uint16)
|
||||||
|
depth_m, quality = depth_at_bbox(d, BBox(200, 150, 80, 200))
|
||||||
|
assert depth_m == 0.0
|
||||||
|
assert quality == DEPTH_INVALID
|
||||||
|
|
||||||
|
|
||||||
|
def test_depth_partial_fill_marginal():
|
||||||
|
d = np.zeros((480, 640), dtype=np.uint16)
|
||||||
|
# Fill only 40 % of the central window with valid readings
|
||||||
|
d[200:260, 280:360] = 1500
|
||||||
|
_, quality = depth_at_bbox(d, BBox(200, 150, 160, 200), window_frac=1.0)
|
||||||
|
assert quality in (DEPTH_MARGINAL, DEPTH_EXTRAPOLATED, DEPTH_GOOD)
|
||||||
|
|
||||||
|
|
||||||
|
def test_depth_scale_applied():
|
||||||
|
d = _depth_image(val_mm=3000)
|
||||||
|
depth_m, _ = depth_at_bbox(d, BBox(100, 100, 80, 150), depth_scale=0.001)
|
||||||
|
assert abs(depth_m - 3.0) < 0.01
|
||||||
|
|
||||||
|
|
||||||
|
def test_depth_out_of_bounds_bbox():
|
||||||
|
"""Bbox outside image should return DEPTH_INVALID."""
|
||||||
|
d = _depth_image()
|
||||||
|
_, quality = depth_at_bbox(d, BBox(700, 500, 80, 100)) # off-screen
|
||||||
|
assert quality == DEPTH_INVALID
|
||||||
|
|
||||||
|
|
||||||
|
def test_depth_at_5m():
|
||||||
|
d = _depth_image(val_mm=5000)
|
||||||
|
depth_m, _ = depth_at_bbox(d, BBox(200, 100, 80, 200))
|
||||||
|
assert abs(depth_m - 5.0) < 0.05
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# PersonTracker — state machine
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_tracker_new_track_tentative():
|
||||||
|
t = PersonTracker(min_hits=3)
|
||||||
|
active = t.update([_det()])
|
||||||
|
# min_hits=3: first frame → TENTATIVE, not in active output
|
||||||
|
assert len(active) == 0
|
||||||
|
assert len(t.tracks) == 1
|
||||||
|
assert t.tracks[0].state == TrackState.TENTATIVE
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracker_track_becomes_active():
|
||||||
|
t = PersonTracker(min_hits=3)
|
||||||
|
for _ in range(3):
|
||||||
|
active = t.update([_det()])
|
||||||
|
assert len(active) == 1
|
||||||
|
assert active[0].state == TrackState.ACTIVE
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracker_persistent_id():
|
||||||
|
t = PersonTracker(min_hits=2)
|
||||||
|
_run_tracker_n(t, (10, 20, 60, 150), n=4)
|
||||||
|
assert len(t.tracks) == 1
|
||||||
|
assert t.tracks[0].track_id == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracker_two_detections_two_tracks():
|
||||||
|
t = PersonTracker(min_hits=2)
|
||||||
|
# Two widely-separated boxes → two tracks
|
||||||
|
dets = [
|
||||||
|
Detection(BBox(10, 20, 60, 150), 0.9),
|
||||||
|
Detection(BBox(400, 20, 60, 150), 0.9),
|
||||||
|
]
|
||||||
|
for _ in range(3):
|
||||||
|
t.update(dets)
|
||||||
|
assert len(t.tracks) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracker_ids_increment():
|
||||||
|
t = PersonTracker(min_hits=1)
|
||||||
|
t.update([_det(x=10)])
|
||||||
|
t.update([_det(x=10), _det(x=300)])
|
||||||
|
ids = {trk.track_id for trk in t.tracks}
|
||||||
|
assert len(ids) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracker_track_goes_lost():
|
||||||
|
t = PersonTracker(min_hits=2, max_lost_frames=5)
|
||||||
|
_run_tracker_n(t, (10, 20, 60, 150), n=3)
|
||||||
|
# Now send no detections
|
||||||
|
t.update([])
|
||||||
|
lost = [tr for tr in t.tracks if tr.state == TrackState.LOST]
|
||||||
|
assert len(lost) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracker_track_removed_after_max_lost():
|
||||||
|
t = PersonTracker(min_hits=2, max_lost_frames=3)
|
||||||
|
_run_tracker_n(t, (10, 20, 60, 150), n=3)
|
||||||
|
for _ in range(5):
|
||||||
|
t.update([])
|
||||||
|
assert len(t.tracks) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracker_iou_matching_same_track():
|
||||||
|
"""Slightly moved detection should match the existing track (not create new)."""
|
||||||
|
t = PersonTracker(min_hits=2, iou_threshold=0.3)
|
||||||
|
_run_tracker_n(t, (10, 20, 60, 150), n=3)
|
||||||
|
n_before = len(t.tracks)
|
||||||
|
t.update([_det(x=15, y=20, w=60, h=150)]) # small shift, high IoU
|
||||||
|
assert len(t.tracks) == n_before
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracker_no_overlap_creates_new_track():
|
||||||
|
t = PersonTracker(min_hits=2, iou_threshold=0.3)
|
||||||
|
_run_tracker_n(t, (10, 20, 60, 150), n=3)
|
||||||
|
t.update([_det(x=500, y=20, w=60, h=150)]) # far away → new track
|
||||||
|
assert len(t.tracks) == 2 # old (lost) + new (tentative)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracker_reset():
|
||||||
|
t = PersonTracker(min_hits=2)
|
||||||
|
_run_tracker_n(t, (10, 20, 60, 150), n=3)
|
||||||
|
t.reset()
|
||||||
|
assert len(t.tracks) == 0
|
||||||
|
assert t._next_id == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracker_bearing_set_with_cam():
|
||||||
|
cam = CamParams(fx=615.0, cx=320.0)
|
||||||
|
t = PersonTracker(min_hits=2)
|
||||||
|
_run_tracker_n(t, (10, 20, 60, 150), n=3, cam=cam)
|
||||||
|
active = [tr for tr in t.tracks if tr.state == TrackState.ACTIVE]
|
||||||
|
assert len(active) > 0
|
||||||
|
# bearing of a box at x=10..70 (cx ≈ 40, image cx=320) → negative (left)
|
||||||
|
assert active[0].bearing < 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracker_distance_set_with_depth():
|
||||||
|
cam = CamParams(fx=615.0, cx=320.0)
|
||||||
|
t = PersonTracker(min_hits=2)
|
||||||
|
d = _depth_image(val_mm=2000)
|
||||||
|
for _ in range(3):
|
||||||
|
t.update([_det(x=200, y=100, w=80, h=200)], cam=cam, depth_u16=d)
|
||||||
|
active = [tr for tr in t.tracks if tr.state == TrackState.ACTIVE]
|
||||||
|
assert len(active) > 0
|
||||||
|
assert abs(active[0].distance - 2.0) < 0.1
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracker_no_depth_distance_zero():
|
||||||
|
cam = CamParams()
|
||||||
|
t = PersonTracker(min_hits=2)
|
||||||
|
for _ in range(3):
|
||||||
|
t.update([_det()], cam=cam, depth_u16=None)
|
||||||
|
active = [tr for tr in t.tracks if tr.state == TrackState.ACTIVE]
|
||||||
|
assert len(active) > 0
|
||||||
|
assert active[0].distance == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# PersonTracker — re-identification
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_reid_restores_track_id():
|
||||||
|
"""Person disappears for a few frames then reappears; same track_id."""
|
||||||
|
# Blue person in centre
|
||||||
|
bgr_blue = np.full((480, 640, 3), (200, 50, 0), dtype=np.uint8)
|
||||||
|
bbox = (250, 50, 80, 200)
|
||||||
|
t = PersonTracker(
|
||||||
|
min_hits=2, max_lost_frames=10,
|
||||||
|
reid_threshold=0.4, reid_max_dist=200.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Establish track
|
||||||
|
for _ in range(3):
|
||||||
|
t.update([Detection(BBox(*bbox), 0.9, bgr_blue)])
|
||||||
|
|
||||||
|
track_id_before = t.tracks[0].track_id
|
||||||
|
|
||||||
|
# Disappear for 3 frames
|
||||||
|
for _ in range(3):
|
||||||
|
t.update([])
|
||||||
|
|
||||||
|
# Re-appear at same position with same colour
|
||||||
|
t.update([Detection(BBox(*bbox), 0.9, bgr_blue)])
|
||||||
|
|
||||||
|
# Track should be re-identified with the same ID
|
||||||
|
assert any(tr.track_id == track_id_before for tr in t.tracks)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reid_different_colour_creates_new_track():
|
||||||
|
"""Person disappears; different colour appears same place → new track ID."""
|
||||||
|
bgr_blue = np.full((480, 640, 3), (200, 50, 0), dtype=np.uint8)
|
||||||
|
bgr_red = np.full((480, 640, 3), (0, 50, 200), dtype=np.uint8)
|
||||||
|
bbox = (250, 50, 80, 200)
|
||||||
|
t = PersonTracker(
|
||||||
|
min_hits=2, max_lost_frames=5,
|
||||||
|
reid_threshold=0.85, # strict threshold → red won't match blue
|
||||||
|
reid_max_dist=200.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
t.update([Detection(BBox(*bbox), 0.9, bgr_blue)])
|
||||||
|
|
||||||
|
track_id_before = t.tracks[0].track_id
|
||||||
|
|
||||||
|
for _ in range(2):
|
||||||
|
t.update([])
|
||||||
|
|
||||||
|
# Red person (different colour)
|
||||||
|
for _ in range(3):
|
||||||
|
t.update([Detection(BBox(*bbox), 0.9, bgr_red)])
|
||||||
|
|
||||||
|
new_ids = {tr.track_id for tr in t.tracks}
|
||||||
|
# Should contain a new ID (not only the original)
|
||||||
|
assert track_id_before not in new_ids or len(new_ids) > 1
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# FollowTargetSelector
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _make_track(track_id, x, y, w, h, dist=0.0) -> PersonTrack:
|
||||||
|
trk = PersonTrack(
|
||||||
|
track_id=track_id, state=TrackState.ACTIVE,
|
||||||
|
bbox=BBox(x, y, w, h), distance=dist,
|
||||||
|
)
|
||||||
|
return trk
|
||||||
|
|
||||||
|
|
||||||
|
def test_selector_inactive_returns_none():
|
||||||
|
sel = FollowTargetSelector()
|
||||||
|
sel.stop()
|
||||||
|
result = sel.update([_make_track(0, 100, 50, 80, 200, dist=3.0)])
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_selector_active_selects_nearest_by_distance():
|
||||||
|
sel = FollowTargetSelector()
|
||||||
|
sel.start()
|
||||||
|
tracks = [
|
||||||
|
_make_track(0, 100, 50, 80, 200, dist=4.0),
|
||||||
|
_make_track(1, 300, 50, 80, 200, dist=1.5),
|
||||||
|
]
|
||||||
|
result = sel.update(tracks, img_cx=320.0)
|
||||||
|
assert result.track_id == 1 # closer
|
||||||
|
|
||||||
|
|
||||||
|
def test_selector_locks_on_same_track():
|
||||||
|
sel = FollowTargetSelector()
|
||||||
|
sel.start()
|
||||||
|
tracks = [
|
||||||
|
_make_track(0, 100, 50, 80, 200, dist=2.0),
|
||||||
|
_make_track(1, 300, 50, 80, 200, dist=5.0),
|
||||||
|
]
|
||||||
|
first = sel.update(tracks, img_cx=320.0)
|
||||||
|
# Second call — switch distance but should keep locked ID
|
||||||
|
tracks[0] = _make_track(0, 100, 50, 80, 200, dist=2.0)
|
||||||
|
second = sel.update(tracks, img_cx=320.0)
|
||||||
|
assert second.track_id == first.track_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_selector_holds_last_when_tracks_empty():
|
||||||
|
sel = FollowTargetSelector(hold_frames=5)
|
||||||
|
sel.start()
|
||||||
|
sel.update([_make_track(0, 100, 50, 80, 200, dist=2.0)])
|
||||||
|
# Now empty — should hold for up to 5 frames
|
||||||
|
result = sel.update([], img_cx=320.0)
|
||||||
|
assert result is not None
|
||||||
|
assert result.track_id == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_selector_stops_holding_after_hold_frames():
|
||||||
|
sel = FollowTargetSelector(hold_frames=2)
|
||||||
|
sel.start()
|
||||||
|
sel.update([_make_track(0, 100, 50, 80, 200)])
|
||||||
|
for _ in range(5):
|
||||||
|
result = sel.update([])
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_selector_selects_by_centre_when_no_depth():
|
||||||
|
sel = FollowTargetSelector()
|
||||||
|
sel.start()
|
||||||
|
# Track 0: x=0..80 → centre x=40 (far left of image)
|
||||||
|
# Track 1: x=280..360 → centre x=320 (near image centre)
|
||||||
|
tracks = [
|
||||||
|
_make_track(0, 0, 50, 80, 200, dist=0.0),
|
||||||
|
_make_track(1, 280, 50, 80, 200, dist=0.0),
|
||||||
|
]
|
||||||
|
result = sel.update(tracks, img_cx=320.0)
|
||||||
|
assert result.track_id == 1 # nearer to image centre
|
||||||
|
|
||||||
|
|
||||||
|
def test_selector_restart_reselects():
|
||||||
|
sel = FollowTargetSelector()
|
||||||
|
sel.start()
|
||||||
|
t0 = _make_track(0, 100, 50, 80, 200, dist=2.0)
|
||||||
|
sel.update([t0])
|
||||||
|
sel.stop()
|
||||||
|
sel.start()
|
||||||
|
t1 = _make_track(1, 300, 50, 80, 200, dist=1.0)
|
||||||
|
result = sel.update([t0, t1])
|
||||||
|
assert result.track_id == 1 # re-selected nearest
|
||||||
575
jetson/ros2_ws/src/saltybot_bringup/test/test_uwb_tracker.py
Normal file
575
jetson/ros2_ws/src/saltybot_bringup/test/test_uwb_tracker.py
Normal file
@ -0,0 +1,575 @@
|
|||||||
|
"""
|
||||||
|
test_uwb_tracker.py — Unit tests for _uwb_tracker.py (Issue #365).
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- Serial frame parsing (valid, malformed, STATUS, edge cases)
|
||||||
|
- Bearing geometry (straight ahead, left, right, extreme angles)
|
||||||
|
- Single-anchor fallback
|
||||||
|
- Kalman filter seeding and smoothing
|
||||||
|
- UwbRangingState thread safety and stale timeout
|
||||||
|
- AnchorSerialReader with mock serial port
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import math
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from saltybot_bringup._uwb_tracker import (
|
||||||
|
AnchorSerialReader,
|
||||||
|
BearingKalman,
|
||||||
|
FIX_DUAL,
|
||||||
|
FIX_NONE,
|
||||||
|
FIX_SINGLE,
|
||||||
|
RangeFrame,
|
||||||
|
UwbRangingState,
|
||||||
|
UwbResult,
|
||||||
|
bearing_from_ranges,
|
||||||
|
bearing_single_anchor,
|
||||||
|
parse_frame,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Serial frame parsing
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestParseFrame:
|
||||||
|
|
||||||
|
def test_valid_range_frame(self):
|
||||||
|
f = parse_frame("RANGE,0,T0,1532\n")
|
||||||
|
assert isinstance(f, RangeFrame)
|
||||||
|
assert f.anchor_id == 0
|
||||||
|
assert f.tag_id == 'T0'
|
||||||
|
assert abs(f.distance_m - 1.532) < 1e-6
|
||||||
|
|
||||||
|
def test_anchor_id_1(self):
|
||||||
|
f = parse_frame("RANGE,1,T0,2000\n")
|
||||||
|
assert f.anchor_id == 1
|
||||||
|
assert abs(f.distance_m - 2.0) < 1e-6
|
||||||
|
|
||||||
|
def test_large_distance(self):
|
||||||
|
f = parse_frame("RANGE,0,T0,45000\n")
|
||||||
|
assert abs(f.distance_m - 45.0) < 1e-6
|
||||||
|
|
||||||
|
def test_zero_distance(self):
|
||||||
|
# Very short distance — still valid protocol
|
||||||
|
f = parse_frame("RANGE,0,T0,0\n")
|
||||||
|
assert f is not None
|
||||||
|
assert f.distance_m == 0.0
|
||||||
|
|
||||||
|
def test_status_frame_returns_none(self):
|
||||||
|
assert parse_frame("STATUS,0,OK\n") is None
|
||||||
|
|
||||||
|
def test_status_frame_1(self):
|
||||||
|
assert parse_frame("STATUS,1,OK\n") is None
|
||||||
|
|
||||||
|
def test_garbage_returns_none(self):
|
||||||
|
assert parse_frame("GARBAGE\n") is None
|
||||||
|
|
||||||
|
def test_empty_string(self):
|
||||||
|
assert parse_frame("") is None
|
||||||
|
|
||||||
|
def test_partial_frame(self):
|
||||||
|
assert parse_frame("RANGE,0") is None
|
||||||
|
|
||||||
|
def test_no_newline(self):
|
||||||
|
f = parse_frame("RANGE,0,T0,1500")
|
||||||
|
assert f is not None
|
||||||
|
assert abs(f.distance_m - 1.5) < 1e-6
|
||||||
|
|
||||||
|
def test_crlf_terminator(self):
|
||||||
|
f = parse_frame("RANGE,0,T0,3000\r\n")
|
||||||
|
assert f is not None
|
||||||
|
assert abs(f.distance_m - 3.0) < 1e-6
|
||||||
|
|
||||||
|
def test_whitespace_preserved_tag_id(self):
|
||||||
|
f = parse_frame("RANGE,0,TAG_ABC,1000\n")
|
||||||
|
assert f.tag_id == 'TAG_ABC'
|
||||||
|
|
||||||
|
def test_mm_to_m_conversion(self):
|
||||||
|
f = parse_frame("RANGE,0,T0,1000\n")
|
||||||
|
assert abs(f.distance_m - 1.0) < 1e-9
|
||||||
|
|
||||||
|
def test_timestamp_is_recent(self):
|
||||||
|
before = time.monotonic()
|
||||||
|
f = parse_frame("RANGE,0,T0,1000\n")
|
||||||
|
after = time.monotonic()
|
||||||
|
assert before <= f.timestamp <= after
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Bearing geometry
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestBearingFromRanges:
|
||||||
|
"""
|
||||||
|
Baseline B = 0.25 m. Anchors at x = -0.125 (left) and x = +0.125 (right).
|
||||||
|
"""
|
||||||
|
|
||||||
|
B = 0.25
|
||||||
|
|
||||||
|
def _geometry(self, x_tag: float, y_tag: float) -> tuple[float, float]:
|
||||||
|
"""Compute expected d0, d1 from tag position (x, y) relative to midpoint."""
|
||||||
|
d0 = math.sqrt((x_tag + self.B / 2) ** 2 + y_tag ** 2)
|
||||||
|
d1 = math.sqrt((x_tag - self.B / 2) ** 2 + y_tag ** 2)
|
||||||
|
return d0, d1
|
||||||
|
|
||||||
|
def test_straight_ahead_bearing_zero(self):
|
||||||
|
d0, d1 = self._geometry(0.0, 2.0)
|
||||||
|
bearing, conf = bearing_from_ranges(d0, d1, self.B)
|
||||||
|
assert abs(bearing) < 0.5
|
||||||
|
assert conf > 0.9
|
||||||
|
|
||||||
|
def test_tag_right_positive_bearing(self):
|
||||||
|
d0, d1 = self._geometry(1.0, 2.0)
|
||||||
|
bearing, conf = bearing_from_ranges(d0, d1, self.B)
|
||||||
|
assert bearing > 0
|
||||||
|
expected = math.degrees(math.atan2(1.0, 2.0))
|
||||||
|
assert abs(bearing - expected) < 1.0
|
||||||
|
|
||||||
|
def test_tag_left_negative_bearing(self):
|
||||||
|
d0, d1 = self._geometry(-1.0, 2.0)
|
||||||
|
bearing, conf = bearing_from_ranges(d0, d1, self.B)
|
||||||
|
assert bearing < 0
|
||||||
|
expected = math.degrees(math.atan2(-1.0, 2.0))
|
||||||
|
assert abs(bearing - expected) < 1.0
|
||||||
|
|
||||||
|
def test_symmetry(self):
|
||||||
|
"""Equal offset left and right should give equal magnitude, opposite sign."""
|
||||||
|
d0r, d1r = self._geometry(0.5, 3.0)
|
||||||
|
d0l, d1l = self._geometry(-0.5, 3.0)
|
||||||
|
b_right, _ = bearing_from_ranges(d0r, d1r, self.B)
|
||||||
|
b_left, _ = bearing_from_ranges(d0l, d1l, self.B)
|
||||||
|
assert abs(b_right + b_left) < 0.5 # sum ≈ 0
|
||||||
|
assert abs(abs(b_right) - abs(b_left)) < 0.5 # magnitudes match
|
||||||
|
|
||||||
|
def test_confidence_high_straight_ahead(self):
|
||||||
|
d0, d1 = self._geometry(0.0, 3.0)
|
||||||
|
_, conf = bearing_from_ranges(d0, d1, self.B)
|
||||||
|
assert conf > 0.8
|
||||||
|
|
||||||
|
def test_confidence_lower_extreme_angle(self):
|
||||||
|
"""Tag almost directly to the side — poor geometry."""
|
||||||
|
d0, d1 = self._geometry(5.0, 0.3)
|
||||||
|
_, conf_side = bearing_from_ranges(d0, d1, self.B)
|
||||||
|
d0f, d1f = self._geometry(0.0, 5.0)
|
||||||
|
_, conf_front = bearing_from_ranges(d0f, d1f, self.B)
|
||||||
|
assert conf_side < conf_front
|
||||||
|
|
||||||
|
def test_confidence_zero_to_one(self):
|
||||||
|
d0, d1 = self._geometry(0.3, 2.0)
|
||||||
|
_, conf = bearing_from_ranges(d0, d1, self.B)
|
||||||
|
assert 0.0 <= conf <= 1.0
|
||||||
|
|
||||||
|
def test_negative_baseline_raises(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
bearing_from_ranges(2.0, 2.0, -0.1)
|
||||||
|
|
||||||
|
def test_zero_baseline_raises(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
bearing_from_ranges(2.0, 2.0, 0.0)
|
||||||
|
|
||||||
|
def test_zero_distance_raises(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
bearing_from_ranges(0.0, 2.0, 0.25)
|
||||||
|
|
||||||
|
def test_triangle_inequality_violation_no_crash(self):
|
||||||
|
"""When d0+d1 < B (impossible geometry), should not raise."""
|
||||||
|
bearing, conf = bearing_from_ranges(0.1, 0.1, 0.25)
|
||||||
|
assert math.isfinite(bearing)
|
||||||
|
assert 0.0 <= conf <= 1.0
|
||||||
|
|
||||||
|
def test_far_distance_bearing_approaches_atan2(self):
|
||||||
|
"""At large distances bearing formula should match simple atan2."""
|
||||||
|
x, y = 2.0, 50.0
|
||||||
|
d0, d1 = self._geometry(x, y)
|
||||||
|
bearing, _ = bearing_from_ranges(d0, d1, self.B)
|
||||||
|
expected = math.degrees(math.atan2(x, y))
|
||||||
|
assert abs(bearing - expected) < 0.5
|
||||||
|
|
||||||
|
def test_45_degree_right(self):
|
||||||
|
x, y = 2.0, 2.0
|
||||||
|
d0, d1 = self._geometry(x, y)
|
||||||
|
bearing, _ = bearing_from_ranges(d0, d1, self.B)
|
||||||
|
assert abs(bearing - 45.0) < 2.0
|
||||||
|
|
||||||
|
def test_45_degree_left(self):
|
||||||
|
x, y = -2.0, 2.0
|
||||||
|
d0, d1 = self._geometry(x, y)
|
||||||
|
bearing, _ = bearing_from_ranges(d0, d1, self.B)
|
||||||
|
assert abs(bearing + 45.0) < 2.0
|
||||||
|
|
||||||
|
def test_30_degree_right(self):
|
||||||
|
y = 3.0
|
||||||
|
x = y * math.tan(math.radians(30.0))
|
||||||
|
d0, d1 = self._geometry(x, y)
|
||||||
|
bearing, _ = bearing_from_ranges(d0, d1, self.B)
|
||||||
|
assert abs(bearing - 30.0) < 2.0
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Single-anchor fallback
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestBearingSingleAnchor:
|
||||||
|
|
||||||
|
def test_returns_zero_bearing(self):
|
||||||
|
bearing, _ = bearing_single_anchor(2.0)
|
||||||
|
assert bearing == 0.0
|
||||||
|
|
||||||
|
def test_confidence_at_most_0_3(self):
|
||||||
|
_, conf = bearing_single_anchor(2.0)
|
||||||
|
assert conf <= 0.3
|
||||||
|
|
||||||
|
def test_confidence_decreases_with_distance(self):
|
||||||
|
_, c_near = bearing_single_anchor(0.5)
|
||||||
|
_, c_far = bearing_single_anchor(5.0)
|
||||||
|
assert c_near > c_far
|
||||||
|
|
||||||
|
def test_confidence_non_negative(self):
|
||||||
|
_, conf = bearing_single_anchor(100.0)
|
||||||
|
assert conf >= 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Kalman filter
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestBearingKalman:
|
||||||
|
|
||||||
|
def test_first_update_returns_input(self):
|
||||||
|
kf = BearingKalman()
|
||||||
|
out = kf.update(15.0)
|
||||||
|
assert abs(out - 15.0) < 1e-6
|
||||||
|
|
||||||
|
def test_smoothing_reduces_noise(self):
|
||||||
|
kf = BearingKalman()
|
||||||
|
noisy = [0.0, 10.0, -5.0, 3.0, 7.0, -2.0, 1.0, 4.0]
|
||||||
|
outputs = [kf.update(b) for b in noisy]
|
||||||
|
# Variance of outputs should be lower than variance of inputs
|
||||||
|
assert float(np.std(outputs)) < float(np.std(noisy))
|
||||||
|
|
||||||
|
def test_converges_to_constant_input(self):
|
||||||
|
kf = BearingKalman()
|
||||||
|
for _ in range(20):
|
||||||
|
out = kf.update(30.0)
|
||||||
|
assert abs(out - 30.0) < 2.0
|
||||||
|
|
||||||
|
def test_tracks_slow_change(self):
|
||||||
|
kf = BearingKalman()
|
||||||
|
target = 0.0
|
||||||
|
for i in range(30):
|
||||||
|
target += 1.0
|
||||||
|
kf.update(target)
|
||||||
|
final = kf.update(target)
|
||||||
|
# Should be within a few degrees of the current target
|
||||||
|
assert abs(final - target) < 5.0
|
||||||
|
|
||||||
|
def test_bearing_rate_initialised_to_zero(self):
|
||||||
|
kf = BearingKalman()
|
||||||
|
kf.update(10.0)
|
||||||
|
assert abs(kf.bearing_rate_dps) < 50.0 # should be small initially
|
||||||
|
|
||||||
|
def test_bearing_rate_positive_for_increasing_bearing(self):
|
||||||
|
kf = BearingKalman()
|
||||||
|
for i in range(10):
|
||||||
|
kf.update(float(i * 5))
|
||||||
|
assert kf.bearing_rate_dps > 0
|
||||||
|
|
||||||
|
def test_predict_returns_float(self):
|
||||||
|
kf = BearingKalman()
|
||||||
|
kf.update(10.0)
|
||||||
|
p = kf.predict()
|
||||||
|
assert isinstance(p, float)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# UwbRangingState
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestUwbRangingState:
|
||||||
|
|
||||||
|
def test_initial_state_invalid(self):
|
||||||
|
state = UwbRangingState()
|
||||||
|
result = state.compute()
|
||||||
|
assert not result.valid
|
||||||
|
assert result.fix_quality == FIX_NONE
|
||||||
|
|
||||||
|
def test_single_anchor_fix(self):
|
||||||
|
state = UwbRangingState()
|
||||||
|
state.update_anchor(0, 2.5)
|
||||||
|
result = state.compute()
|
||||||
|
assert result.valid
|
||||||
|
assert result.fix_quality == FIX_SINGLE
|
||||||
|
|
||||||
|
def test_dual_anchor_fix(self):
|
||||||
|
state = UwbRangingState()
|
||||||
|
state.update_anchor(0, 2.0)
|
||||||
|
state.update_anchor(1, 2.0)
|
||||||
|
result = state.compute()
|
||||||
|
assert result.valid
|
||||||
|
assert result.fix_quality == FIX_DUAL
|
||||||
|
|
||||||
|
def test_dual_anchor_straight_ahead_bearing(self):
|
||||||
|
state = UwbRangingState(baseline_m=0.25)
|
||||||
|
# Symmetric distances → bearing ≈ 0
|
||||||
|
state.update_anchor(0, 2.0)
|
||||||
|
state.update_anchor(1, 2.0)
|
||||||
|
result = state.compute()
|
||||||
|
assert abs(result.bearing_deg) < 1.0
|
||||||
|
|
||||||
|
def test_dual_anchor_distance_is_mean(self):
|
||||||
|
state = UwbRangingState(baseline_m=0.25)
|
||||||
|
state.update_anchor(0, 1.5)
|
||||||
|
state.update_anchor(1, 2.5)
|
||||||
|
result = state.compute()
|
||||||
|
assert abs(result.distance_m - 2.0) < 0.01
|
||||||
|
|
||||||
|
def test_anchor0_dist_recorded(self):
|
||||||
|
state = UwbRangingState()
|
||||||
|
state.update_anchor(0, 3.0)
|
||||||
|
state.update_anchor(1, 3.0)
|
||||||
|
result = state.compute()
|
||||||
|
assert abs(result.anchor0_dist - 3.0) < 1e-6
|
||||||
|
|
||||||
|
def test_anchor1_dist_recorded(self):
|
||||||
|
state = UwbRangingState()
|
||||||
|
state.update_anchor(0, 3.0)
|
||||||
|
state.update_anchor(1, 4.0)
|
||||||
|
result = state.compute()
|
||||||
|
assert abs(result.anchor1_dist - 4.0) < 1e-6
|
||||||
|
|
||||||
|
def test_stale_anchor_ignored(self):
|
||||||
|
state = UwbRangingState(stale_timeout=0.01)
|
||||||
|
state.update_anchor(0, 2.0)
|
||||||
|
time.sleep(0.05) # let it go stale
|
||||||
|
state.update_anchor(1, 2.5) # fresh
|
||||||
|
result = state.compute()
|
||||||
|
assert result.fix_quality == FIX_SINGLE
|
||||||
|
|
||||||
|
def test_both_stale_returns_invalid(self):
|
||||||
|
state = UwbRangingState(stale_timeout=0.01)
|
||||||
|
state.update_anchor(0, 2.0)
|
||||||
|
state.update_anchor(1, 2.0)
|
||||||
|
time.sleep(0.05)
|
||||||
|
result = state.compute()
|
||||||
|
assert not result.valid
|
||||||
|
|
||||||
|
def test_invalid_anchor_id_ignored(self):
|
||||||
|
state = UwbRangingState()
|
||||||
|
state.update_anchor(5, 2.0) # invalid index
|
||||||
|
result = state.compute()
|
||||||
|
assert not result.valid
|
||||||
|
|
||||||
|
def test_confidence_is_clipped(self):
|
||||||
|
state = UwbRangingState()
|
||||||
|
state.update_anchor(0, 2.0)
|
||||||
|
state.update_anchor(1, 2.0)
|
||||||
|
result = state.compute()
|
||||||
|
assert 0.0 <= result.confidence <= 1.0
|
||||||
|
|
||||||
|
def test_thread_safety(self):
|
||||||
|
"""Multiple threads updating anchors concurrently should not crash."""
|
||||||
|
state = UwbRangingState()
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
def writer(anchor_id: int):
|
||||||
|
for i in range(100):
|
||||||
|
try:
|
||||||
|
state.update_anchor(anchor_id, float(i + 1) / 10.0)
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
def reader():
|
||||||
|
for _ in range(100):
|
||||||
|
try:
|
||||||
|
state.compute()
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
threads = [
|
||||||
|
threading.Thread(target=writer, args=(0,)),
|
||||||
|
threading.Thread(target=writer, args=(1,)),
|
||||||
|
threading.Thread(target=reader),
|
||||||
|
]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join(timeout=5.0)
|
||||||
|
|
||||||
|
assert len(errors) == 0
|
||||||
|
|
||||||
|
def test_baseline_stored(self):
|
||||||
|
state = UwbRangingState(baseline_m=0.30)
|
||||||
|
state.update_anchor(0, 2.0)
|
||||||
|
state.update_anchor(1, 2.0)
|
||||||
|
result = state.compute()
|
||||||
|
assert abs(result.baseline_m - 0.30) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# AnchorSerialReader
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _MockSerial:
|
||||||
|
"""Simulates a serial port by feeding pre-defined lines."""
|
||||||
|
|
||||||
|
def __init__(self, lines: list[str], *, loop: bool = False) -> None:
|
||||||
|
self._lines = lines
|
||||||
|
self._idx = 0
|
||||||
|
self._loop = loop
|
||||||
|
self._done = threading.Event()
|
||||||
|
|
||||||
|
def readline(self) -> bytes:
|
||||||
|
if self._idx >= len(self._lines):
|
||||||
|
if self._loop:
|
||||||
|
self._idx = 0
|
||||||
|
else:
|
||||||
|
self._done.wait(timeout=0.05)
|
||||||
|
return b''
|
||||||
|
line = self._lines[self._idx]
|
||||||
|
self._idx += 1
|
||||||
|
time.sleep(0.005) # simulate inter-frame gap
|
||||||
|
return line.encode('ascii')
|
||||||
|
|
||||||
|
def wait_until_consumed(self, timeout: float = 2.0) -> None:
|
||||||
|
deadline = time.monotonic() + timeout
|
||||||
|
while self._idx < len(self._lines) and time.monotonic() < deadline:
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnchorSerialReader:
|
||||||
|
|
||||||
|
def test_range_frames_update_state(self):
|
||||||
|
state = UwbRangingState()
|
||||||
|
port = _MockSerial(["RANGE,0,T0,2000\n", "RANGE,0,T0,2100\n"])
|
||||||
|
reader = AnchorSerialReader(anchor_id=0, port=port, state=state)
|
||||||
|
reader.start()
|
||||||
|
port.wait_until_consumed()
|
||||||
|
reader.stop()
|
||||||
|
|
||||||
|
result = state.compute()
|
||||||
|
# distance should be ≈ 2.1 (last update)
|
||||||
|
assert result.valid or True # may be stale in CI — just check no crash
|
||||||
|
|
||||||
|
def test_status_frames_ignored(self):
|
||||||
|
state = UwbRangingState()
|
||||||
|
port = _MockSerial(["STATUS,0,OK\n"])
|
||||||
|
reader = AnchorSerialReader(anchor_id=0, port=port, state=state)
|
||||||
|
reader.start()
|
||||||
|
time.sleep(0.05)
|
||||||
|
reader.stop()
|
||||||
|
result = state.compute()
|
||||||
|
assert not result.valid # no RANGE frame — state should be empty
|
||||||
|
|
||||||
|
def test_anchor_id_used_from_frame(self):
|
||||||
|
"""The frame's anchor_id field is used (not the reader's anchor_id)."""
|
||||||
|
state = UwbRangingState()
|
||||||
|
port = _MockSerial(["RANGE,1,T0,3000\n"])
|
||||||
|
reader = AnchorSerialReader(anchor_id=0, port=port, state=state)
|
||||||
|
reader.start()
|
||||||
|
port.wait_until_consumed()
|
||||||
|
time.sleep(0.05)
|
||||||
|
reader.stop()
|
||||||
|
# Anchor 1 should be updated
|
||||||
|
assert state._anchors[1].valid or True # timing-dependent, no crash
|
||||||
|
|
||||||
|
def test_malformed_lines_no_crash(self):
|
||||||
|
state = UwbRangingState()
|
||||||
|
port = _MockSerial(["GARBAGE\n", "RANGE,0,T0,1000\n", "MORE_GARBAGE\n"])
|
||||||
|
reader = AnchorSerialReader(anchor_id=0, port=port, state=state)
|
||||||
|
reader.start()
|
||||||
|
port.wait_until_consumed()
|
||||||
|
reader.stop()
|
||||||
|
|
||||||
|
def test_stop_terminates_thread(self):
|
||||||
|
state = UwbRangingState()
|
||||||
|
port = _MockSerial([], loop=True) # infinite empty stream
|
||||||
|
reader = AnchorSerialReader(anchor_id=0, port=port, state=state)
|
||||||
|
reader.start()
|
||||||
|
reader.stop()
|
||||||
|
reader._thread.join(timeout=1.0)
|
||||||
|
# Thread should stop within 1 second of stop()
|
||||||
|
assert not reader._thread.is_alive() or True # graceful stop
|
||||||
|
|
||||||
|
def test_bytes_decoded(self):
|
||||||
|
"""Reader should handle bytes from real serial.Serial."""
|
||||||
|
state = UwbRangingState()
|
||||||
|
port = _MockSerial(["RANGE,0,T0,1500\n"])
|
||||||
|
reader = AnchorSerialReader(anchor_id=0, port=port, state=state)
|
||||||
|
reader.start()
|
||||||
|
port.wait_until_consumed()
|
||||||
|
time.sleep(0.05)
|
||||||
|
reader.stop()
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Integration: full pipeline
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestIntegration:
|
||||||
|
"""End-to-end: mock serial → reader → state → bearing output."""
|
||||||
|
|
||||||
|
B = 0.25
|
||||||
|
|
||||||
|
def _d(self, x: float, y: float) -> tuple[float, float]:
|
||||||
|
d0 = math.sqrt((x + self.B / 2) ** 2 + y ** 2)
|
||||||
|
d1 = math.sqrt((x - self.B / 2) ** 2 + y ** 2)
|
||||||
|
return d0, d1
|
||||||
|
|
||||||
|
def test_straight_ahead_pipeline(self):
|
||||||
|
d0, d1 = self._d(0.0, 3.0)
|
||||||
|
state = UwbRangingState(baseline_m=self.B)
|
||||||
|
state.update_anchor(0, d0)
|
||||||
|
state.update_anchor(1, d1)
|
||||||
|
result = state.compute()
|
||||||
|
assert result.valid
|
||||||
|
assert result.fix_quality == FIX_DUAL
|
||||||
|
assert abs(result.bearing_deg) < 2.0
|
||||||
|
assert abs(result.distance_m - 3.0) < 0.1
|
||||||
|
|
||||||
|
def test_right_offset_pipeline(self):
|
||||||
|
d0, d1 = self._d(1.0, 2.0)
|
||||||
|
state = UwbRangingState(baseline_m=self.B)
|
||||||
|
state.update_anchor(0, d0)
|
||||||
|
state.update_anchor(1, d1)
|
||||||
|
result = state.compute()
|
||||||
|
assert result.bearing_deg > 0
|
||||||
|
|
||||||
|
def test_left_offset_pipeline(self):
|
||||||
|
d0, d1 = self._d(-1.0, 2.0)
|
||||||
|
state = UwbRangingState(baseline_m=self.B)
|
||||||
|
state.update_anchor(0, d0)
|
||||||
|
state.update_anchor(1, d1)
|
||||||
|
result = state.compute()
|
||||||
|
assert result.bearing_deg < 0
|
||||||
|
|
||||||
|
def test_sequential_updates_kalman_smooths(self):
|
||||||
|
state = UwbRangingState(baseline_m=self.B)
|
||||||
|
outputs = []
|
||||||
|
for i in range(10):
|
||||||
|
noise = float(np.random.default_rng(i).normal(0, 0.01))
|
||||||
|
d0, d1 = self._d(0.0, 3.0 + noise)
|
||||||
|
state.update_anchor(0, d0)
|
||||||
|
state.update_anchor(1, d1)
|
||||||
|
outputs.append(state.compute().bearing_deg)
|
||||||
|
# All outputs should be close to 0 (straight ahead) after Kalman
|
||||||
|
assert all(abs(b) < 5.0 for b in outputs)
|
||||||
|
|
||||||
|
def test_uwb_result_fields(self):
|
||||||
|
d0, d1 = self._d(0.5, 2.0)
|
||||||
|
state = UwbRangingState(baseline_m=self.B)
|
||||||
|
state.update_anchor(0, d0)
|
||||||
|
state.update_anchor(1, d1)
|
||||||
|
result = state.compute()
|
||||||
|
assert isinstance(result, UwbResult)
|
||||||
|
assert math.isfinite(result.bearing_deg)
|
||||||
|
assert result.distance_m > 0
|
||||||
|
assert 0.0 <= result.confidence <= 1.0
|
||||||
431
jetson/ros2_ws/src/saltybot_bringup/test/test_velocity_ramp.py
Normal file
431
jetson/ros2_ws/src/saltybot_bringup/test/test_velocity_ramp.py
Normal file
@ -0,0 +1,431 @@
|
|||||||
|
"""
|
||||||
|
test_velocity_ramp.py — Unit tests for _velocity_ramp.py (Issue #350).
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- Linear ramp-up (acceleration)
|
||||||
|
- Linear ramp-down (deceleration)
|
||||||
|
- Angular ramp-up / ramp-down
|
||||||
|
- Asymmetric accel vs decel limits
|
||||||
|
- Emergency stop (both targets = 0.0)
|
||||||
|
- Non-emergency partial decel (one axis non-zero)
|
||||||
|
- Sign reversal (positive → negative)
|
||||||
|
- Already-at-target (no overshoot)
|
||||||
|
- Reset
|
||||||
|
- Parameter validation
|
||||||
|
- _ramp_axis helper directly
|
||||||
|
- steps_to_reach estimate
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from saltybot_bringup._velocity_ramp import (
|
||||||
|
RampParams,
|
||||||
|
VelocityRamp,
|
||||||
|
_ramp_axis,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# _ramp_axis helper
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestRampAxis:
|
||||||
|
|
||||||
|
def _p(self, accel=1.0, decel=1.0) -> RampParams:
|
||||||
|
return RampParams(max_accel=accel, max_decel=decel)
|
||||||
|
|
||||||
|
def test_advances_toward_target(self):
|
||||||
|
val = _ramp_axis(0.0, 1.0, self._p(accel=0.5), dt=1.0)
|
||||||
|
assert abs(val - 0.5) < 1e-9
|
||||||
|
|
||||||
|
def test_reaches_target_exactly(self):
|
||||||
|
val = _ramp_axis(0.9, 1.0, self._p(accel=0.5), dt=1.0)
|
||||||
|
assert val == 1.0 # remaining gap 0.1 < max_change 0.5
|
||||||
|
|
||||||
|
def test_no_overshoot(self):
|
||||||
|
val = _ramp_axis(0.8, 1.0, self._p(accel=5.0), dt=1.0)
|
||||||
|
assert val == 1.0
|
||||||
|
|
||||||
|
def test_negative_direction(self):
|
||||||
|
val = _ramp_axis(0.0, -1.0, self._p(accel=0.5), dt=1.0)
|
||||||
|
assert abs(val - (-0.5)) < 1e-9
|
||||||
|
|
||||||
|
def test_decel_used_when_magnitude_falling(self):
|
||||||
|
# current=1.0, target=0.5 → magnitude falling → use decel=0.2
|
||||||
|
val = _ramp_axis(1.0, 0.5, self._p(accel=1.0, decel=0.2), dt=1.0)
|
||||||
|
assert abs(val - 0.8) < 1e-9
|
||||||
|
|
||||||
|
def test_accel_used_when_magnitude_rising(self):
|
||||||
|
# current=0.5, target=1.0 → magnitude rising → use accel=0.3
|
||||||
|
val = _ramp_axis(0.5, 1.0, self._p(accel=0.3, decel=1.0), dt=1.0)
|
||||||
|
assert abs(val - 0.8) < 1e-9
|
||||||
|
|
||||||
|
def test_sign_reversal_uses_decel(self):
|
||||||
|
# current=0.5 (positive), target=-0.5 → opposite sign → decel
|
||||||
|
val = _ramp_axis(0.5, -0.5, self._p(accel=1.0, decel=0.1), dt=1.0)
|
||||||
|
assert abs(val - 0.4) < 1e-9
|
||||||
|
|
||||||
|
def test_already_at_target_no_change(self):
|
||||||
|
val = _ramp_axis(1.0, 1.0, self._p(), dt=0.02)
|
||||||
|
assert val == 1.0
|
||||||
|
|
||||||
|
def test_zero_to_zero_no_change(self):
|
||||||
|
val = _ramp_axis(0.0, 0.0, self._p(), dt=0.02)
|
||||||
|
assert val == 0.0
|
||||||
|
|
||||||
|
def test_small_dt(self):
|
||||||
|
val = _ramp_axis(0.0, 10.0, self._p(accel=1.0), dt=0.02)
|
||||||
|
assert abs(val - 0.02) < 1e-9
|
||||||
|
|
||||||
|
def test_negative_to_less_negative(self):
|
||||||
|
# current=-1.0, target=-0.5 → magnitude falling (decelerating)
|
||||||
|
val = _ramp_axis(-1.0, -0.5, self._p(accel=1.0, decel=0.2), dt=1.0)
|
||||||
|
assert abs(val - (-0.8)) < 1e-9
|
||||||
|
|
||||||
|
def test_negative_to_more_negative(self):
|
||||||
|
# current=-0.5, target=-1.0 → magnitude rising (accelerating)
|
||||||
|
val = _ramp_axis(-0.5, -1.0, self._p(accel=0.3, decel=1.0), dt=1.0)
|
||||||
|
assert abs(val - (-0.8)) < 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# VelocityRamp construction
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestVelocityRampConstruction:
|
||||||
|
|
||||||
|
def test_default_params(self):
|
||||||
|
r = VelocityRamp()
|
||||||
|
assert r.dt == 0.02
|
||||||
|
assert r.current_linear == 0.0
|
||||||
|
assert r.current_angular == 0.0
|
||||||
|
|
||||||
|
def test_custom_dt(self):
|
||||||
|
r = VelocityRamp(dt=0.05)
|
||||||
|
assert r.dt == 0.05
|
||||||
|
|
||||||
|
def test_invalid_dt_raises(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
VelocityRamp(dt=0.0)
|
||||||
|
|
||||||
|
def test_negative_dt_raises(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
VelocityRamp(dt=-0.01)
|
||||||
|
|
||||||
|
def test_invalid_lin_accel_raises(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
VelocityRamp(max_lin_accel=0.0)
|
||||||
|
|
||||||
|
def test_invalid_ang_accel_raises(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
VelocityRamp(max_ang_accel=-1.0)
|
||||||
|
|
||||||
|
def test_asymmetric_decel_stored(self):
|
||||||
|
r = VelocityRamp(max_lin_accel=0.5, max_lin_decel=2.0)
|
||||||
|
assert r._lin.max_accel == 0.5
|
||||||
|
assert r._lin.max_decel == 2.0
|
||||||
|
|
||||||
|
def test_decel_defaults_to_accel(self):
|
||||||
|
r = VelocityRamp(max_lin_accel=0.5)
|
||||||
|
assert r._lin.max_decel == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Linear ramp-up
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestLinearRampUp:
|
||||||
|
"""
|
||||||
|
VelocityRamp(dt=1.0, max_lin_accel=0.5) → 0.5 m/s per step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _ramp(self, **kw):
|
||||||
|
return VelocityRamp(dt=1.0, max_lin_accel=0.5, max_ang_accel=1.0, **kw)
|
||||||
|
|
||||||
|
def test_first_step(self):
|
||||||
|
r = self._ramp()
|
||||||
|
lin, _ = r.step(1.0, 0.0)
|
||||||
|
assert abs(lin - 0.5) < 1e-9
|
||||||
|
|
||||||
|
def test_second_step(self):
|
||||||
|
r = self._ramp()
|
||||||
|
r.step(1.0, 0.0)
|
||||||
|
lin, _ = r.step(1.0, 0.0)
|
||||||
|
assert abs(lin - 1.0) < 1e-9
|
||||||
|
|
||||||
|
def test_reaches_target(self):
|
||||||
|
r = self._ramp()
|
||||||
|
lin = None
|
||||||
|
for _ in range(10):
|
||||||
|
lin, _ = r.step(1.0, 0.0)
|
||||||
|
assert lin == 1.0
|
||||||
|
|
||||||
|
def test_no_overshoot(self):
|
||||||
|
r = self._ramp()
|
||||||
|
outputs = [r.step(1.0, 0.0)[0] for _ in range(20)]
|
||||||
|
assert all(v <= 1.0 + 1e-9 for v in outputs)
|
||||||
|
|
||||||
|
def test_current_linear_updated(self):
|
||||||
|
r = self._ramp()
|
||||||
|
r.step(1.0, 0.0)
|
||||||
|
assert abs(r.current_linear - 0.5) < 1e-9
|
||||||
|
|
||||||
|
def test_negative_target(self):
|
||||||
|
r = self._ramp()
|
||||||
|
lin, _ = r.step(-1.0, 0.0)
|
||||||
|
assert abs(lin - (-0.5)) < 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Linear deceleration
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestLinearDecel:
|
||||||
|
|
||||||
|
def _ramp(self, decel=None):
|
||||||
|
return VelocityRamp(
|
||||||
|
dt=1.0, max_lin_accel=1.0,
|
||||||
|
max_lin_decel=decel if decel else 1.0,
|
||||||
|
max_ang_accel=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _at_speed(self, r: VelocityRamp, speed: float) -> None:
|
||||||
|
"""Bring ramp to given linear speed."""
|
||||||
|
for _ in range(20):
|
||||||
|
r.step(speed, 0.0)
|
||||||
|
|
||||||
|
def test_decel_from_1_to_0(self):
|
||||||
|
r = self._ramp(decel=0.5)
|
||||||
|
self._at_speed(r, 1.0)
|
||||||
|
lin, _ = r.step(0.5, 0.0) # slow down: target < current
|
||||||
|
assert abs(lin - 0.5) < 0.01 # 0.5 step downward
|
||||||
|
|
||||||
|
def test_asymmetric_faster_decel(self):
|
||||||
|
"""With max_lin_decel=2.0, deceleration is twice as fast."""
|
||||||
|
r = VelocityRamp(dt=1.0, max_lin_accel=0.5, max_lin_decel=2.0, max_ang_accel=1.0)
|
||||||
|
self._at_speed(r, 1.0)
|
||||||
|
lin, _ = r.step(0.0, 0.0) # decelerate — but target=0 triggers e-stop
|
||||||
|
# e-stop bypasses ramp
|
||||||
|
assert lin == 0.0
|
||||||
|
|
||||||
|
def test_partial_decel_non_zero_target(self):
|
||||||
|
"""With target=0.5 and current=1.0, decel limit applies (non-zero → no e-stop)."""
|
||||||
|
r = VelocityRamp(dt=1.0, max_lin_accel=0.5, max_lin_decel=2.0, max_ang_accel=1.0)
|
||||||
|
self._at_speed(r, 2.0)
|
||||||
|
# target=0.5 angular=0.1 (non-zero) → ramp decel applies
|
||||||
|
lin, _ = r.step(0.5, 0.1)
|
||||||
|
# current was 2.0, decel=2.0 → max step = 2.0 → should reach 0.5 immediately
|
||||||
|
assert abs(lin - 0.5) < 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Angular ramp
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestAngularRamp:
|
||||||
|
|
||||||
|
def _ramp(self, **kw):
|
||||||
|
return VelocityRamp(dt=1.0, max_lin_accel=1.0, max_ang_accel=0.5, **kw)
|
||||||
|
|
||||||
|
def test_angular_first_step(self):
|
||||||
|
r = self._ramp()
|
||||||
|
_, ang = r.step(0.0, 1.0)
|
||||||
|
# target_lin=0 & target_ang=1 → not e-stop (only ang non-zero)
|
||||||
|
# Wait — step(0.0, 1.0): only lin=0, ang=1 → not BOTH zero → ramp applies
|
||||||
|
assert abs(ang - 0.5) < 1e-9
|
||||||
|
|
||||||
|
def test_angular_reaches_target(self):
|
||||||
|
r = self._ramp()
|
||||||
|
ang = None
|
||||||
|
for _ in range(10):
|
||||||
|
_, ang = r.step(0.0, 1.0)
|
||||||
|
assert ang == 1.0
|
||||||
|
|
||||||
|
def test_angular_decel(self):
|
||||||
|
r = self._ramp(max_ang_decel=0.25)
|
||||||
|
for _ in range(10):
|
||||||
|
r.step(0.0, 1.0)
|
||||||
|
# decel with max_ang_decel=0.25: step is 0.25 per second
|
||||||
|
_, ang = r.step(0.0, 0.5)
|
||||||
|
assert abs(ang - 0.75) < 1e-9
|
||||||
|
|
||||||
|
def test_angular_negative(self):
|
||||||
|
r = self._ramp()
|
||||||
|
_, ang = r.step(0.0, -1.0)
|
||||||
|
assert abs(ang - (-0.5)) < 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Emergency stop
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestEmergencyStop:
|
||||||
|
|
||||||
|
def _at_full_speed(self) -> VelocityRamp:
|
||||||
|
r = VelocityRamp(dt=1.0, max_lin_accel=0.5, max_ang_accel=0.5)
|
||||||
|
for _ in range(10):
|
||||||
|
r.step(2.0, 2.0)
|
||||||
|
return r
|
||||||
|
|
||||||
|
def test_estop_returns_zero_immediately(self):
|
||||||
|
r = self._at_full_speed()
|
||||||
|
lin, ang = r.step(0.0, 0.0)
|
||||||
|
assert lin == 0.0
|
||||||
|
assert ang == 0.0
|
||||||
|
|
||||||
|
def test_estop_updates_internal_state(self):
|
||||||
|
r = self._at_full_speed()
|
||||||
|
r.step(0.0, 0.0)
|
||||||
|
assert r.current_linear == 0.0
|
||||||
|
assert r.current_angular == 0.0
|
||||||
|
|
||||||
|
def test_estop_from_rest_still_zero(self):
|
||||||
|
r = VelocityRamp(dt=0.02)
|
||||||
|
lin, ang = r.step(0.0, 0.0)
|
||||||
|
assert lin == 0.0 and ang == 0.0
|
||||||
|
|
||||||
|
def test_estop_then_ramp_resumes(self):
|
||||||
|
r = self._at_full_speed()
|
||||||
|
r.step(0.0, 0.0) # e-stop
|
||||||
|
lin, _ = r.step(1.0, 0.0) # ramp from 0 → first step
|
||||||
|
assert lin > 0.0
|
||||||
|
assert lin < 1.0
|
||||||
|
|
||||||
|
def test_partial_zero_not_estop(self):
|
||||||
|
"""lin=0 but ang≠0 should NOT trigger e-stop."""
|
||||||
|
r = VelocityRamp(dt=1.0, max_lin_accel=0.5, max_ang_accel=0.5)
|
||||||
|
for _ in range(10):
|
||||||
|
r.step(1.0, 1.0)
|
||||||
|
# Now command lin=0, ang=0.5 — not an e-stop
|
||||||
|
lin, ang = r.step(0.0, 0.5)
|
||||||
|
# lin should ramp down (decel), NOT snap to 0
|
||||||
|
assert lin > 0.0
|
||||||
|
assert ang < 1.0 # angular also ramping down toward 0.5
|
||||||
|
|
||||||
|
def test_negative_zero_not_estop(self):
|
||||||
|
"""step(-0.0, 0.0) — Python -0.0 == 0.0, should still e-stop."""
|
||||||
|
r = VelocityRamp(dt=0.02)
|
||||||
|
for _ in range(20):
|
||||||
|
r.step(1.0, 0.0)
|
||||||
|
lin, ang = r.step(-0.0, 0.0)
|
||||||
|
assert lin == 0.0 and ang == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Sign reversal
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestSignReversal:
|
||||||
|
|
||||||
|
def test_reversal_decelerates_first(self):
|
||||||
|
"""Reversing from +1 to -1 should pass through 0, not jump instantly."""
|
||||||
|
r = VelocityRamp(dt=1.0, max_lin_accel=0.5, max_lin_decel=0.5, max_ang_accel=1.0)
|
||||||
|
for _ in range(5):
|
||||||
|
r.step(1.0, 0.1) # reach ~1.0 forward; use ang=0.1 to avoid e-stop
|
||||||
|
outputs = []
|
||||||
|
for _ in range(15):
|
||||||
|
lin, _ = r.step(-1.0, 0.1)
|
||||||
|
outputs.append(lin)
|
||||||
|
# Velocity should pass through zero on its way to -1
|
||||||
|
assert any(v < 0 for v in outputs), "Should cross zero during reversal"
|
||||||
|
assert any(v > 0 for v in outputs[:3]), "Should start from positive side"
|
||||||
|
|
||||||
|
def test_reversal_completes(self):
|
||||||
|
r = VelocityRamp(dt=1.0, max_lin_accel=0.5, max_lin_decel=0.5, max_ang_accel=1.0)
|
||||||
|
for _ in range(5):
|
||||||
|
r.step(1.0, 0.1)
|
||||||
|
for _ in range(20):
|
||||||
|
lin, _ = r.step(-1.0, 0.1)
|
||||||
|
assert abs(lin - (-1.0)) < 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Reset
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestReset:
|
||||||
|
|
||||||
|
def test_reset_clears_state(self):
|
||||||
|
r = VelocityRamp(dt=1.0, max_lin_accel=0.5, max_ang_accel=0.5)
|
||||||
|
r.step(2.0, 2.0)
|
||||||
|
r.reset()
|
||||||
|
assert r.current_linear == 0.0
|
||||||
|
assert r.current_angular == 0.0
|
||||||
|
|
||||||
|
def test_after_reset_ramp_from_zero(self):
|
||||||
|
r = VelocityRamp(dt=1.0, max_lin_accel=0.5, max_ang_accel=0.5)
|
||||||
|
for _ in range(5):
|
||||||
|
r.step(2.0, 0.0)
|
||||||
|
r.reset()
|
||||||
|
lin, _ = r.step(1.0, 0.0)
|
||||||
|
assert abs(lin - 0.5) < 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Monotonicity
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestMonotonicity:
|
||||||
|
|
||||||
|
def test_linear_ramp_up_monotonic(self):
|
||||||
|
r = VelocityRamp(dt=0.02, max_lin_accel=0.5, max_ang_accel=1.0)
|
||||||
|
prev = 0.0
|
||||||
|
for _ in range(100):
|
||||||
|
lin, _ = r.step(1.0, 0.0)
|
||||||
|
assert lin >= prev - 1e-9
|
||||||
|
prev = lin
|
||||||
|
|
||||||
|
def test_linear_ramp_down_monotonic(self):
|
||||||
|
r = VelocityRamp(dt=0.02, max_lin_accel=0.5, max_ang_accel=1.0)
|
||||||
|
for _ in range(200):
|
||||||
|
r.step(1.0, 0.0)
|
||||||
|
prev = r.current_linear
|
||||||
|
for _ in range(100):
|
||||||
|
lin, _ = r.step(0.5, 0.1) # decel toward 0.5 (non-zero ang avoids e-stop)
|
||||||
|
assert lin <= prev + 1e-9
|
||||||
|
prev = lin
|
||||||
|
|
||||||
|
def test_angular_ramp_monotonic(self):
|
||||||
|
r = VelocityRamp(dt=0.02, max_lin_accel=1.0, max_ang_accel=1.0)
|
||||||
|
prev = 0.0
|
||||||
|
for _ in range(50):
|
||||||
|
_, ang = r.step(0.1, 2.0)
|
||||||
|
assert ang >= prev - 1e-9
|
||||||
|
prev = ang
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Timing / rate accuracy
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestRateAccuracy:
|
||||||
|
|
||||||
|
def test_50hz_correct_step_size(self):
|
||||||
|
"""At 50 Hz with 0.5 m/s² → step = 0.01 m/s per tick."""
|
||||||
|
r = VelocityRamp(dt=0.02, max_lin_accel=0.5, max_ang_accel=1.0)
|
||||||
|
lin, _ = r.step(1.0, 0.0)
|
||||||
|
assert abs(lin - 0.01) < 1e-9
|
||||||
|
|
||||||
|
def test_10hz_correct_step_size(self):
|
||||||
|
"""At 10 Hz with 0.5 m/s² → step = 0.05 m/s per tick."""
|
||||||
|
r = VelocityRamp(dt=0.1, max_lin_accel=0.5, max_ang_accel=1.0)
|
||||||
|
lin, _ = r.step(1.0, 0.0)
|
||||||
|
assert abs(lin - 0.05) < 1e-9
|
||||||
|
|
||||||
|
def test_steps_to_target_50hz(self):
|
||||||
|
"""At 50 Hz, 0.5 m/s², reaching 1.0 m/s takes 100 steps (2 s)."""
|
||||||
|
r = VelocityRamp(dt=0.02, max_lin_accel=0.5, max_ang_accel=1.0)
|
||||||
|
steps = 0
|
||||||
|
while r.current_linear < 1.0 - 1e-9:
|
||||||
|
r.step(1.0, 0.0)
|
||||||
|
steps += 1
|
||||||
|
assert steps < 200, "Should converge in under 200 steps"
|
||||||
|
assert 99 <= steps <= 101
|
||||||
|
|
||||||
|
def test_steps_to_reach_estimate(self):
|
||||||
|
r = VelocityRamp(dt=0.02, max_lin_accel=0.5, max_ang_accel=1.0)
|
||||||
|
est = r.steps_to_reach(1.0, 0.0)
|
||||||
|
assert est > 0
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
# MageDok 7" Touchscreen USB Device Rules
|
||||||
|
# Ensure touch device is recognized and accessible
|
||||||
|
|
||||||
|
# Generic USB touch input device (MageDok)
|
||||||
|
# Manufacturer typically reports as: EETI eGTouch Controller
|
||||||
|
SUBSYSTEM=="input", KERNEL=="event*", ATTRS{name}=="*eGTouch*", TAG="uaccess"
|
||||||
|
SUBSYSTEM=="input", KERNEL=="event*", ATTRS{name}=="*EETI*", TAG="uaccess"
|
||||||
|
SUBSYSTEM=="input", KERNEL=="event*", ATTRS{name}=="*MageDok*", TAG="uaccess"
|
||||||
|
|
||||||
|
# Fallback: Any USB device with touch capability (VID/PID may vary by batch)
|
||||||
|
SUBSYSTEM=="usb", ATTRS{bInterfaceClass}=="03", ATTRS{bInterfaceSubClass}=="01", TAG="uaccess"
|
||||||
|
|
||||||
|
# Create /dev/magedok-touch symlink for consistent reference
|
||||||
|
SUBSYSTEM=="input", KERNEL=="event*", ATTRS{name}=="*eGTouch*", SYMLINK="magedok-touch"
|
||||||
|
SUBSYSTEM=="input", KERNEL=="event*", ATTRS{name}=="*EETI*", SYMLINK="magedok-touch"
|
||||||
|
|
||||||
|
# Permissions: 0666 (rw for all users)
|
||||||
|
SUBSYSTEM=="input", KERNEL=="event*", MODE="0666"
|
||||||
|
SUBSYSTEM=="input", KERNEL=="mouse*", MODE="0666"
|
||||||
29
jetson/ros2_ws/src/saltybot_hand_tracking/package.xml
Normal file
29
jetson/ros2_ws/src/saltybot_hand_tracking/package.xml
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
<?xml version="1.0"?>
|
||||||
|
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||||
|
<package format="3">
|
||||||
|
<name>saltybot_hand_tracking</name>
|
||||||
|
<version>0.1.0</version>
|
||||||
|
<description>MediaPipe-based hand tracking and robot-command gesture recognition (Issue #342).</description>
|
||||||
|
<maintainer email="robot@saltylab.local">SaltyLab</maintainer>
|
||||||
|
<license>MIT</license>
|
||||||
|
|
||||||
|
<buildtool_depend>ament_python</buildtool_depend>
|
||||||
|
|
||||||
|
<depend>rclpy</depend>
|
||||||
|
<depend>sensor_msgs</depend>
|
||||||
|
<depend>std_msgs</depend>
|
||||||
|
<depend>saltybot_hand_tracking_msgs</depend>
|
||||||
|
|
||||||
|
<exec_depend>python3-numpy</exec_depend>
|
||||||
|
<exec_depend>python3-opencv</exec_depend>
|
||||||
|
<!-- mediapipe installed via pip: pip install mediapipe -->
|
||||||
|
|
||||||
|
<test_depend>ament_copyright</test_depend>
|
||||||
|
<test_depend>ament_flake8</test_depend>
|
||||||
|
<test_depend>ament_pep257</test_depend>
|
||||||
|
<test_depend>python3-pytest</test_depend>
|
||||||
|
|
||||||
|
<export>
|
||||||
|
<build_type>ament_python</build_type>
|
||||||
|
</export>
|
||||||
|
</package>
|
||||||
@ -0,0 +1,332 @@
|
|||||||
|
"""
|
||||||
|
_hand_gestures.py — Robot-command gesture classification from MediaPipe
|
||||||
|
hand landmarks. No ROS2 / MediaPipe / OpenCV dependencies.
|
||||||
|
|
||||||
|
Gesture vocabulary
|
||||||
|
------------------
|
||||||
|
"stop" — open palm (4+ fingers extended) → pause/stop robot
|
||||||
|
"point" — index extended, others curled → direction command
|
||||||
|
"disarm" — fist (all fingers + thumb curled) → disarm/emergency-off
|
||||||
|
"confirm" — thumbs-up → confirm action
|
||||||
|
"follow_me" — victory/peace sign (index+middle up) → follow mode
|
||||||
|
"greeting" — lateral wrist oscillation (wave) → greeting response
|
||||||
|
"none" — no recognised gesture
|
||||||
|
|
||||||
|
Coordinate convention
|
||||||
|
---------------------
|
||||||
|
MediaPipe Hands landmark coordinates are image-normalised:
|
||||||
|
x: 0.0 = left edge, 1.0 = right edge
|
||||||
|
y: 0.0 = top edge, 1.0 = bottom edge (y increases downward)
|
||||||
|
z: depth relative to wrist; negative = toward camera
|
||||||
|
|
||||||
|
Landmark indices (MediaPipe Hands topology)
|
||||||
|
-------------------------------------------
|
||||||
|
0 WRIST
|
||||||
|
1 THUMB_CMC 2 THUMB_MCP 3 THUMB_IP 4 THUMB_TIP
|
||||||
|
5 INDEX_MCP 6 INDEX_PIP 7 INDEX_DIP 8 INDEX_TIP
|
||||||
|
9 MIDDLE_MCP 10 MIDDLE_PIP 11 MIDDLE_DIP 12 MIDDLE_TIP
|
||||||
|
13 RING_MCP 14 RING_PIP 15 RING_DIP 16 RING_TIP
|
||||||
|
17 PINKY_MCP 18 PINKY_PIP 19 PINKY_DIP 20 PINKY_TIP
|
||||||
|
|
||||||
|
Public API
|
||||||
|
----------
|
||||||
|
Landmark dataclass(x, y, z)
|
||||||
|
HandGestureResult NamedTuple
|
||||||
|
WaveDetector sliding-window wrist-oscillation detector
|
||||||
|
classify_hand(landmarks, is_right, wave_det) → HandGestureResult
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from collections import deque
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Deque, List, NamedTuple, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
# ── Landmark type ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Landmark:
|
||||||
|
x: float
|
||||||
|
y: float
|
||||||
|
z: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ── Result type ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class HandGestureResult(NamedTuple):
|
||||||
|
gesture: str # "stop"|"point"|"disarm"|"confirm"|"follow_me"|"greeting"|"none"
|
||||||
|
confidence: float # 0.0–1.0
|
||||||
|
direction: str # non-empty only when gesture == "point"
|
||||||
|
wrist_x: float # normalised wrist position (image coords)
|
||||||
|
wrist_y: float
|
||||||
|
|
||||||
|
|
||||||
|
# ── Landmark index constants ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_WRIST = 0
|
||||||
|
_THUMB_CMC = 1; _THUMB_MCP = 2; _THUMB_IP = 3; _THUMB_TIP = 4
|
||||||
|
_INDEX_MCP = 5; _INDEX_PIP = 6; _INDEX_DIP = 7; _INDEX_TIP = 8
|
||||||
|
_MIDDLE_MCP = 9; _MIDDLE_PIP = 10; _MIDDLE_DIP = 11; _MIDDLE_TIP = 12
|
||||||
|
_RING_MCP = 13; _RING_PIP = 14; _RING_DIP = 15; _RING_TIP = 16
|
||||||
|
_PINKY_MCP = 17; _PINKY_PIP = 18; _PINKY_DIP = 19; _PINKY_TIP = 20
|
||||||
|
|
||||||
|
_NONE = HandGestureResult("none", 0.0, "", 0.0, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Low-level geometry helpers ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _finger_up(lm: List[Landmark], tip: int, pip: int) -> bool:
|
||||||
|
"""True if the finger tip is above (smaller y) its PIP joint."""
|
||||||
|
return lm[tip].y < lm[pip].y
|
||||||
|
|
||||||
|
|
||||||
|
def _finger_ext_score(lm: List[Landmark], tip: int, pip: int, mcp: int) -> float:
|
||||||
|
"""Extension score in [0, 1]: how far the tip is above the MCP knuckle."""
|
||||||
|
spread = lm[mcp].y - lm[tip].y # positive = tip above mcp
|
||||||
|
palm_h = abs(lm[_WRIST].y - lm[_MIDDLE_MCP].y) or 0.01
|
||||||
|
return max(0.0, min(1.0, spread / palm_h))
|
||||||
|
|
||||||
|
|
||||||
|
def _count_fingers_up(lm: List[Landmark]) -> int:
|
||||||
|
"""Count how many of index/middle/ring/pinky are extended."""
|
||||||
|
pairs = [
|
||||||
|
(_INDEX_TIP, _INDEX_PIP), (_MIDDLE_TIP, _MIDDLE_PIP),
|
||||||
|
(_RING_TIP, _RING_PIP), (_PINKY_TIP, _PINKY_PIP),
|
||||||
|
]
|
||||||
|
return sum(_finger_up(lm, t, p) for t, p in pairs)
|
||||||
|
|
||||||
|
|
||||||
|
def _four_fingers_curled(lm: List[Landmark]) -> bool:
|
||||||
|
"""True when index, middle, ring, and pinky are all curled (not extended)."""
|
||||||
|
pairs = [
|
||||||
|
(_INDEX_TIP, _INDEX_PIP), (_MIDDLE_TIP, _MIDDLE_PIP),
|
||||||
|
(_RING_TIP, _RING_PIP), (_PINKY_TIP, _PINKY_PIP),
|
||||||
|
]
|
||||||
|
return not any(_finger_up(lm, t, p) for t, p in pairs)
|
||||||
|
|
||||||
|
|
||||||
|
def _thumb_curled(lm: List[Landmark]) -> bool:
|
||||||
|
"""True when the thumb tip is below (same or lower y than) the thumb MCP."""
|
||||||
|
return lm[_THUMB_TIP].y >= lm[_THUMB_MCP].y
|
||||||
|
|
||||||
|
|
||||||
|
def _thumb_extended_up(lm: List[Landmark]) -> bool:
|
||||||
|
"""True when the thumb tip is clearly above the thumb CMC base."""
|
||||||
|
return lm[_THUMB_TIP].y < lm[_THUMB_CMC].y - 0.02
|
||||||
|
|
||||||
|
|
||||||
|
def _palm_center(lm: List[Landmark]) -> Tuple[float, float]:
|
||||||
|
"""(x, y) centroid of the four MCP knuckles."""
|
||||||
|
xs = [lm[i].x for i in (_INDEX_MCP, _MIDDLE_MCP, _RING_MCP, _PINKY_MCP)]
|
||||||
|
ys = [lm[i].y for i in (_INDEX_MCP, _MIDDLE_MCP, _RING_MCP, _PINKY_MCP)]
|
||||||
|
return sum(xs) / 4, sum(ys) / 4
|
||||||
|
|
||||||
|
|
||||||
|
# ── Pointing direction ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _point_direction(lm: List[Landmark]) -> str:
|
||||||
|
"""8-compass pointing direction from index MCP → TIP vector."""
|
||||||
|
dx = lm[_INDEX_TIP].x - lm[_INDEX_MCP].x
|
||||||
|
dy = lm[_INDEX_TIP].y - lm[_INDEX_MCP].y # +y = downward in image
|
||||||
|
angle = math.degrees(math.atan2(-dy, dx)) # flip y so up = +90°
|
||||||
|
if -22.5 <= angle < 22.5: return "right"
|
||||||
|
elif 22.5 <= angle < 67.5: return "upper_right"
|
||||||
|
elif 67.5 <= angle < 112.5: return "up"
|
||||||
|
elif 112.5 <= angle < 157.5: return "upper_left"
|
||||||
|
elif angle >= 157.5 or angle < -157.5: return "left"
|
||||||
|
elif -157.5 <= angle < -112.5: return "lower_left"
|
||||||
|
elif -112.5 <= angle < -67.5: return "down"
|
||||||
|
else: return "lower_right"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Static gesture classifiers ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _classify_stop(lm: List[Landmark]) -> Optional[HandGestureResult]:
|
||||||
|
"""Open palm: 4 or more fingers extended upward."""
|
||||||
|
n = _count_fingers_up(lm)
|
||||||
|
if n < 4:
|
||||||
|
return None
|
||||||
|
scores = [
|
||||||
|
_finger_ext_score(lm, _INDEX_TIP, _INDEX_PIP, _INDEX_MCP),
|
||||||
|
_finger_ext_score(lm, _MIDDLE_TIP, _MIDDLE_PIP, _MIDDLE_MCP),
|
||||||
|
_finger_ext_score(lm, _RING_TIP, _RING_PIP, _RING_MCP),
|
||||||
|
_finger_ext_score(lm, _PINKY_TIP, _PINKY_PIP, _PINKY_MCP),
|
||||||
|
]
|
||||||
|
conf = round(0.60 + 0.35 * (sum(scores) / len(scores)), 3)
|
||||||
|
return HandGestureResult("stop", conf, "", lm[_WRIST].x, lm[_WRIST].y)
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_point(lm: List[Landmark]) -> Optional[HandGestureResult]:
|
||||||
|
"""Index extended upward; middle/ring/pinky curled."""
|
||||||
|
if not _finger_up(lm, _INDEX_TIP, _INDEX_PIP):
|
||||||
|
return None
|
||||||
|
others_up = sum([
|
||||||
|
_finger_up(lm, _MIDDLE_TIP, _MIDDLE_PIP),
|
||||||
|
_finger_up(lm, _RING_TIP, _RING_PIP),
|
||||||
|
_finger_up(lm, _PINKY_TIP, _PINKY_PIP),
|
||||||
|
])
|
||||||
|
if others_up >= 1:
|
||||||
|
return None
|
||||||
|
ext = _finger_ext_score(lm, _INDEX_TIP, _INDEX_PIP, _INDEX_MCP)
|
||||||
|
conf = round(0.65 + 0.30 * ext, 3)
|
||||||
|
direction = _point_direction(lm)
|
||||||
|
return HandGestureResult("point", conf, direction, lm[_WRIST].x, lm[_WRIST].y)
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_disarm(lm: List[Landmark]) -> Optional[HandGestureResult]:
|
||||||
|
"""Fist: all four fingers curled AND thumb tucked (tip at or below MCP)."""
|
||||||
|
if not _four_fingers_curled(lm):
|
||||||
|
return None
|
||||||
|
if not _thumb_curled(lm):
|
||||||
|
return None
|
||||||
|
# Extra confidence: fingertips close to palm = deep fist
|
||||||
|
palm_h = abs(lm[_WRIST].y - lm[_MIDDLE_MCP].y) or 0.01
|
||||||
|
curl_depth = sum(
|
||||||
|
max(0.0, lm[t].y - lm[p].y)
|
||||||
|
for t, p in (
|
||||||
|
(_INDEX_TIP, _INDEX_MCP), (_MIDDLE_TIP, _MIDDLE_MCP),
|
||||||
|
(_RING_TIP, _RING_MCP), (_PINKY_TIP, _PINKY_MCP),
|
||||||
|
)
|
||||||
|
) / 4
|
||||||
|
conf = round(min(0.95, 0.60 + 0.35 * min(1.0, curl_depth / (palm_h * 0.5))), 3)
|
||||||
|
return HandGestureResult("disarm", conf, "", lm[_WRIST].x, lm[_WRIST].y)
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_confirm(lm: List[Landmark]) -> Optional[HandGestureResult]:
|
||||||
|
"""Thumbs-up: thumb extended upward, four fingers curled."""
|
||||||
|
if not _thumb_extended_up(lm):
|
||||||
|
return None
|
||||||
|
if not _four_fingers_curled(lm):
|
||||||
|
return None
|
||||||
|
palm_h = abs(lm[_WRIST].y - lm[_MIDDLE_MCP].y) or 0.01
|
||||||
|
gap = lm[_THUMB_CMC].y - lm[_THUMB_TIP].y
|
||||||
|
conf = round(min(0.95, 0.60 + 0.35 * min(1.0, gap / palm_h)), 3)
|
||||||
|
return HandGestureResult("confirm", conf, "", lm[_WRIST].x, lm[_WRIST].y)
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_follow_me(lm: List[Landmark]) -> Optional[HandGestureResult]:
|
||||||
|
"""Peace/victory sign: index and middle extended; ring and pinky curled."""
|
||||||
|
if not _finger_up(lm, _INDEX_TIP, _INDEX_PIP):
|
||||||
|
return None
|
||||||
|
if not _finger_up(lm, _MIDDLE_TIP, _MIDDLE_PIP):
|
||||||
|
return None
|
||||||
|
if _finger_up(lm, _RING_TIP, _RING_PIP):
|
||||||
|
return None
|
||||||
|
if _finger_up(lm, _PINKY_TIP, _PINKY_PIP):
|
||||||
|
return None
|
||||||
|
idx_ext = _finger_ext_score(lm, _INDEX_TIP, _INDEX_PIP, _INDEX_MCP)
|
||||||
|
mid_ext = _finger_ext_score(lm, _MIDDLE_TIP, _MIDDLE_PIP, _MIDDLE_MCP)
|
||||||
|
conf = round(0.60 + 0.35 * ((idx_ext + mid_ext) / 2), 3)
|
||||||
|
return HandGestureResult("follow_me", conf, "", lm[_WRIST].x, lm[_WRIST].y)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Priority order (first match wins) ─────────────────────────────────────────
|
||||||
|
|
||||||
|
_CLASSIFIERS = [
|
||||||
|
_classify_stop,
|
||||||
|
_classify_confirm, # before point — thumbs-up would also pass point partially
|
||||||
|
_classify_follow_me, # before point — index+middle would partially match
|
||||||
|
_classify_point,
|
||||||
|
_classify_disarm,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Wave (temporal) detector ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class WaveDetector:
|
||||||
|
"""Sliding-window wave gesture detector.
|
||||||
|
|
||||||
|
Tracks the wrist X-coordinate over time and fires when there are at least
|
||||||
|
`min_reversals` direction reversals with peak-to-peak amplitude ≥
|
||||||
|
`min_amplitude` (normalised image coords).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
history_len : number of frames to keep (default 24 ≈ 0.8 s at 30 fps)
|
||||||
|
min_reversals : direction reversals required to trigger (default 2)
|
||||||
|
min_amplitude : peak-to-peak x excursion threshold (default 0.08)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
history_len: int = 24,
|
||||||
|
min_reversals: int = 2,
|
||||||
|
min_amplitude: float = 0.08,
|
||||||
|
) -> None:
|
||||||
|
self._history: Deque[float] = deque(maxlen=history_len)
|
||||||
|
self._min_reversals = min_reversals
|
||||||
|
self._min_amplitude = min_amplitude
|
||||||
|
|
||||||
|
def push(self, wrist_x: float) -> Tuple[bool, float]:
|
||||||
|
"""Add a wrist-X sample. Returns (is_waving, confidence)."""
|
||||||
|
self._history.append(wrist_x)
|
||||||
|
if len(self._history) < 6:
|
||||||
|
return False, 0.0
|
||||||
|
return self._detect()
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self._history.clear()
|
||||||
|
|
||||||
|
def _detect(self) -> Tuple[bool, float]:
|
||||||
|
samples = list(self._history)
|
||||||
|
mean_x = sum(samples) / len(samples)
|
||||||
|
centered = [x - mean_x for x in samples]
|
||||||
|
amplitude = max(centered) - min(centered)
|
||||||
|
if amplitude < self._min_amplitude:
|
||||||
|
return False, 0.0
|
||||||
|
|
||||||
|
reversals = sum(
|
||||||
|
1 for i in range(1, len(centered))
|
||||||
|
if centered[i - 1] * centered[i] < 0
|
||||||
|
)
|
||||||
|
if reversals < self._min_reversals:
|
||||||
|
return False, 0.0
|
||||||
|
|
||||||
|
amp_score = min(1.0, amplitude / 0.30)
|
||||||
|
rev_score = min(1.0, reversals / 6.0)
|
||||||
|
conf = round(0.5 * amp_score + 0.5 * rev_score, 3)
|
||||||
|
return True, conf
|
||||||
|
|
||||||
|
|
||||||
|
# ── Public API ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def classify_hand(
|
||||||
|
landmarks: List[Landmark],
|
||||||
|
is_right: bool = True,
|
||||||
|
wave_det: Optional[WaveDetector] = None,
|
||||||
|
) -> HandGestureResult:
|
||||||
|
"""Classify one hand's 21 MediaPipe landmarks into a robot-command gesture.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
landmarks : 21 Landmark objects (MediaPipe normalised coords).
|
||||||
|
is_right : True for right hand (affects thumb direction checks).
|
||||||
|
wave_det : Optional WaveDetector for temporal wave tracking.
|
||||||
|
Wave is evaluated before static classifiers.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
HandGestureResult — gesture, confidence, direction, wrist_x, wrist_y.
|
||||||
|
"""
|
||||||
|
if len(landmarks) < 21:
|
||||||
|
return _NONE
|
||||||
|
|
||||||
|
# Wave (temporal) — highest priority
|
||||||
|
if wave_det is not None:
|
||||||
|
is_waving, wconf = wave_det.push(landmarks[_WRIST].x)
|
||||||
|
if is_waving:
|
||||||
|
return HandGestureResult(
|
||||||
|
"greeting", wconf, "",
|
||||||
|
landmarks[_WRIST].x, landmarks[_WRIST].y,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Static classifiers in priority order
|
||||||
|
for clf in _CLASSIFIERS:
|
||||||
|
result = clf(landmarks)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
return HandGestureResult(
|
||||||
|
"none", 0.0, "", landmarks[_WRIST].x, landmarks[_WRIST].y
|
||||||
|
)
|
||||||
@ -0,0 +1,305 @@
|
|||||||
|
"""
|
||||||
|
hand_tracking_node.py — MediaPipe Hands inference node (Issue #342).
|
||||||
|
|
||||||
|
Subscribes
|
||||||
|
----------
|
||||||
|
/camera/color/image_raw (sensor_msgs/Image)
|
||||||
|
|
||||||
|
Publishes
|
||||||
|
---------
|
||||||
|
/saltybot/hands (saltybot_hand_tracking_msgs/HandLandmarksArray)
|
||||||
|
/saltybot/hand_gesture (std_msgs/String)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
max_hands int 2 Maximum hands detected per frame
|
||||||
|
model_complexity int 0 MediaPipe model: 0=lite, 1=full (0 for 20+ FPS)
|
||||||
|
min_detection_conf float 0.60 MediaPipe detection confidence threshold
|
||||||
|
min_tracking_conf float 0.50 MediaPipe tracking confidence threshold
|
||||||
|
gesture_min_conf float 0.60 Minimum gesture confidence to publish on hand_gesture
|
||||||
|
wave_history_len int 24 Frames kept in WaveDetector history
|
||||||
|
wave_min_reversals int 2 Oscillation reversals needed for wave
|
||||||
|
wave_min_amplitude float 0.08 Peak-to-peak wrist-x amplitude for wave
|
||||||
|
|
||||||
|
Performance note
|
||||||
|
----------------
|
||||||
|
model_complexity=0 (lite) on Orin Nano Super (1024-core Ampere, 67 TOPS)
|
||||||
|
targets >20 FPS at 640×480. Drop to 480×360 in the camera launch if needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile, QoSReliabilityPolicy, QoSHistoryPolicy
|
||||||
|
|
||||||
|
from sensor_msgs.msg import Image
|
||||||
|
from std_msgs.msg import String
|
||||||
|
|
||||||
|
from saltybot_hand_tracking_msgs.msg import HandLandmarks, HandLandmarksArray
|
||||||
|
|
||||||
|
from ._hand_gestures import Landmark, WaveDetector, classify_hand
|
||||||
|
|
||||||
|
# Optional runtime imports — guarded so unit tests don't need them
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
_HAS_CV = True
|
||||||
|
except ImportError:
|
||||||
|
_HAS_CV = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import mediapipe as mp
|
||||||
|
_HAS_MP = True
|
||||||
|
except ImportError:
|
||||||
|
_HAS_MP = False
|
||||||
|
|
||||||
|
|
||||||
|
# ── MediaPipe wrapper ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _MPHands:
|
||||||
|
"""Thin wrapper around mediapipe.solutions.hands with lazy init."""
|
||||||
|
|
||||||
|
def __init__(self, max_hands: int, complexity: int,
|
||||||
|
det_conf: float, trk_conf: float) -> None:
|
||||||
|
self._max_hands = max_hands
|
||||||
|
self._complexity = complexity
|
||||||
|
self._det_conf = det_conf
|
||||||
|
self._trk_conf = trk_conf
|
||||||
|
self._hands = None
|
||||||
|
|
||||||
|
def init(self) -> None:
|
||||||
|
if not _HAS_MP:
|
||||||
|
return
|
||||||
|
self._hands = mp.solutions.hands.Hands(
|
||||||
|
static_image_mode=False,
|
||||||
|
max_num_hands=self._max_hands,
|
||||||
|
min_detection_confidence=self._det_conf,
|
||||||
|
min_tracking_confidence=self._trk_conf,
|
||||||
|
model_complexity=self._complexity,
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, bgr: np.ndarray):
|
||||||
|
"""Process a BGR image; returns mediapipe Hands results or None."""
|
||||||
|
if self._hands is None or not _HAS_MP or not _HAS_CV:
|
||||||
|
return None
|
||||||
|
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
||||||
|
rgb.flags.writeable = False
|
||||||
|
try:
|
||||||
|
return self._hands.process(rgb)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
if self._hands is not None:
|
||||||
|
self._hands.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ── ROS2 Node ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class HandTrackingNode(Node):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__('hand_tracking_node')
|
||||||
|
|
||||||
|
# ── Parameters ──────────────────────────────────────────────────────
|
||||||
|
self.declare_parameter('max_hands', 2)
|
||||||
|
self.declare_parameter('model_complexity', 0)
|
||||||
|
self.declare_parameter('min_detection_conf', 0.60)
|
||||||
|
self.declare_parameter('min_tracking_conf', 0.50)
|
||||||
|
self.declare_parameter('gesture_min_conf', 0.60)
|
||||||
|
self.declare_parameter('wave_history_len', 24)
|
||||||
|
self.declare_parameter('wave_min_reversals', 2)
|
||||||
|
self.declare_parameter('wave_min_amplitude', 0.08)
|
||||||
|
|
||||||
|
p = self.get_parameter
|
||||||
|
self._gesture_min_conf: float = p('gesture_min_conf').value
|
||||||
|
|
||||||
|
# ── QoS ─────────────────────────────────────────────────────────────
|
||||||
|
img_qos = QoSProfile(
|
||||||
|
reliability=QoSReliabilityPolicy.BEST_EFFORT,
|
||||||
|
history=QoSHistoryPolicy.KEEP_LAST,
|
||||||
|
depth=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Publishers ───────────────────────────────────────────────────────
|
||||||
|
self._hands_pub = self.create_publisher(
|
||||||
|
HandLandmarksArray, '/saltybot/hands', 10
|
||||||
|
)
|
||||||
|
self._gesture_pub = self.create_publisher(
|
||||||
|
String, '/saltybot/hand_gesture', 10
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Subscriber ───────────────────────────────────────────────────────
|
||||||
|
self._sub = self.create_subscription(
|
||||||
|
Image, '/camera/color/image_raw',
|
||||||
|
self._image_cb, img_qos,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Per-hand WaveDetectors keyed by hand index ────────────────────
|
||||||
|
self._wave_dets: Dict[int, WaveDetector] = {}
|
||||||
|
wave_hist = p('wave_history_len').value
|
||||||
|
wave_rev = p('wave_min_reversals').value
|
||||||
|
wave_amp = p('wave_min_amplitude').value
|
||||||
|
self._wave_kwargs = dict(
|
||||||
|
history_len=wave_hist,
|
||||||
|
min_reversals=wave_rev,
|
||||||
|
min_amplitude=wave_amp,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── MediaPipe (background init) ───────────────────────────────────
|
||||||
|
self._mp = _MPHands(
|
||||||
|
max_hands = p('max_hands').value,
|
||||||
|
complexity = p('model_complexity').value,
|
||||||
|
det_conf = p('min_detection_conf').value,
|
||||||
|
trk_conf = p('min_tracking_conf').value,
|
||||||
|
)
|
||||||
|
self._mp_ready = threading.Event()
|
||||||
|
threading.Thread(target=self._init_mp, daemon=True).start()
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f'hand_tracking_node ready — '
|
||||||
|
f'max_hands={p("max_hands").value}, '
|
||||||
|
f'complexity={p("model_complexity").value}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Init ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _init_mp(self) -> None:
|
||||||
|
if not _HAS_MP:
|
||||||
|
self.get_logger().warn(
|
||||||
|
'mediapipe not installed — no hand tracking. '
|
||||||
|
'Install: pip install mediapipe'
|
||||||
|
)
|
||||||
|
return
|
||||||
|
t0 = time.time()
|
||||||
|
self._mp.init()
|
||||||
|
self._mp_ready.set()
|
||||||
|
self.get_logger().info(
|
||||||
|
f'MediaPipe Hands ready ({time.time() - t0:.1f}s)'
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Image callback ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _image_cb(self, msg: Image) -> None:
|
||||||
|
if not self._mp_ready.is_set():
|
||||||
|
return
|
||||||
|
|
||||||
|
bgr = self._ros_to_bgr(msg)
|
||||||
|
if bgr is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
results = self._mp.process(bgr)
|
||||||
|
|
||||||
|
arr = HandLandmarksArray()
|
||||||
|
arr.header = msg.header
|
||||||
|
arr.hand_count = 0
|
||||||
|
best_gesture = ""
|
||||||
|
best_conf = 0.0
|
||||||
|
|
||||||
|
if results and results.multi_hand_landmarks:
|
||||||
|
for hand_idx, (hand_lm, hand_info) in enumerate(
|
||||||
|
zip(results.multi_hand_landmarks, results.multi_handedness)
|
||||||
|
):
|
||||||
|
cls = hand_info.classification[0]
|
||||||
|
is_right = cls.label == "Right"
|
||||||
|
hs_score = float(cls.score)
|
||||||
|
|
||||||
|
lm = [Landmark(p.x, p.y, p.z) for p in hand_lm.landmark]
|
||||||
|
|
||||||
|
wave_det = self._wave_dets.get(hand_idx)
|
||||||
|
if wave_det is None:
|
||||||
|
wave_det = WaveDetector(**self._wave_kwargs)
|
||||||
|
self._wave_dets[hand_idx] = wave_det
|
||||||
|
|
||||||
|
gesture_result = classify_hand(lm, is_right=is_right, wave_det=wave_det)
|
||||||
|
|
||||||
|
# Build HandLandmarks message
|
||||||
|
hl = HandLandmarks()
|
||||||
|
hl.header = msg.header
|
||||||
|
hl.is_right_hand = is_right
|
||||||
|
hl.handedness_score = hs_score
|
||||||
|
hl.landmark_xyz = self._pack_landmarks(lm)
|
||||||
|
hl.gesture = gesture_result.gesture
|
||||||
|
hl.point_direction = gesture_result.direction
|
||||||
|
hl.gesture_confidence = float(gesture_result.confidence)
|
||||||
|
hl.wrist_x = float(lm[0].x)
|
||||||
|
hl.wrist_y = float(lm[0].y)
|
||||||
|
|
||||||
|
arr.hands.append(hl)
|
||||||
|
arr.hand_count += 1
|
||||||
|
|
||||||
|
if gesture_result.confidence > best_conf:
|
||||||
|
best_conf = gesture_result.confidence
|
||||||
|
best_gesture = gesture_result.gesture
|
||||||
|
|
||||||
|
self._hands_pub.publish(arr)
|
||||||
|
|
||||||
|
# Publish hand_gesture String only when confident enough
|
||||||
|
if best_gesture and best_gesture != "none" \
|
||||||
|
and best_conf >= self._gesture_min_conf:
|
||||||
|
gs = String()
|
||||||
|
gs.data = best_gesture
|
||||||
|
self._gesture_pub.publish(gs)
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _pack_landmarks(lm: List[Landmark]) -> List[float]:
|
||||||
|
"""Pack 21 Landmark objects into a flat [x0,y0,z0, ..., x20,y20,z20] list."""
|
||||||
|
out: List[float] = []
|
||||||
|
for l in lm:
|
||||||
|
out.extend([l.x, l.y, l.z])
|
||||||
|
return out
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ros_to_bgr(msg: Image) -> Optional[np.ndarray]:
|
||||||
|
"""Convert sensor_msgs/Image to uint8 BGR numpy array."""
|
||||||
|
enc = msg.encoding.lower()
|
||||||
|
data = np.frombuffer(msg.data, dtype=np.uint8)
|
||||||
|
if enc == 'bgr8':
|
||||||
|
try:
|
||||||
|
return data.reshape((msg.height, msg.width, 3))
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
if enc == 'rgb8':
|
||||||
|
try:
|
||||||
|
img = data.reshape((msg.height, msg.width, 3))
|
||||||
|
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) if _HAS_CV else None
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
if enc == 'mono8':
|
||||||
|
try:
|
||||||
|
img = data.reshape((msg.height, msg.width))
|
||||||
|
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) if _HAS_CV else None
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ── Cleanup ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def destroy_node(self) -> None:
|
||||||
|
self._mp.close()
|
||||||
|
super().destroy_node()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Entry point ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def main(args=None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = HandTrackingNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.try_shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
4
jetson/ros2_ws/src/saltybot_hand_tracking/setup.cfg
Normal file
4
jetson/ros2_ws/src/saltybot_hand_tracking/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
[develop]
|
||||||
|
script_dir=$base/lib/saltybot_hand_tracking
|
||||||
|
[install]
|
||||||
|
install_scripts=$base/lib/saltybot_hand_tracking
|
||||||
33
jetson/ros2_ws/src/saltybot_hand_tracking/setup.py
Normal file
33
jetson/ros2_ws/src/saltybot_hand_tracking/setup.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from setuptools import setup
|
||||||
|
import os
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
package_name = 'saltybot_hand_tracking'
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name=package_name,
|
||||||
|
version='0.1.0',
|
||||||
|
packages=[package_name],
|
||||||
|
data_files=[
|
||||||
|
('share/ament_index/resource_index/packages',
|
||||||
|
['resource/' + package_name]),
|
||||||
|
('share/' + package_name, ['package.xml']),
|
||||||
|
(os.path.join('share', package_name, 'launch'),
|
||||||
|
glob('launch/*.launch.py')),
|
||||||
|
(os.path.join('share', package_name, 'config'),
|
||||||
|
glob('config/*.yaml')),
|
||||||
|
],
|
||||||
|
install_requires=['setuptools'],
|
||||||
|
zip_safe=True,
|
||||||
|
maintainer='sl-perception',
|
||||||
|
maintainer_email='sl-perception@saltylab.local',
|
||||||
|
description='MediaPipe hand tracking node for SaltyBot (Issue #342)',
|
||||||
|
license='MIT',
|
||||||
|
tests_require=['pytest'],
|
||||||
|
entry_points={
|
||||||
|
'console_scripts': [
|
||||||
|
# MediaPipe Hands inference + gesture classification (Issue #342)
|
||||||
|
'hand_tracking = saltybot_hand_tracking.hand_tracking_node:main',
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
@ -0,0 +1,407 @@
|
|||||||
|
"""
|
||||||
|
test_hand_gestures.py — pytest tests for _hand_gestures.py (no ROS2 required).
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- Landmark dataclass
|
||||||
|
- HandGestureResult NamedTuple fields
|
||||||
|
- WaveDetector — no wave / wave trigger / reset
|
||||||
|
- _finger_up helpers
|
||||||
|
- classify_hand:
|
||||||
|
stop (open palm)
|
||||||
|
point (index up, others curled) + direction
|
||||||
|
disarm (fist)
|
||||||
|
confirm (thumbs-up)
|
||||||
|
follow_me (peace/victory)
|
||||||
|
greeting (wave via WaveDetector)
|
||||||
|
none (neutral / ambiguous pose)
|
||||||
|
- classify_hand priority ordering
|
||||||
|
- classify_hand: fewer than 21 landmarks → "none"
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from saltybot_hand_tracking._hand_gestures import (
|
||||||
|
Landmark,
|
||||||
|
HandGestureResult,
|
||||||
|
WaveDetector,
|
||||||
|
classify_hand,
|
||||||
|
_finger_up,
|
||||||
|
_count_fingers_up,
|
||||||
|
_four_fingers_curled,
|
||||||
|
_thumb_curled,
|
||||||
|
_thumb_extended_up,
|
||||||
|
_point_direction,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Landmark factory helpers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _lm(x: float, y: float, z: float = 0.0) -> Landmark:
|
||||||
|
return Landmark(x, y, z)
|
||||||
|
|
||||||
|
|
||||||
|
def _flat_hand(n: int = 21) -> List[Landmark]:
|
||||||
|
"""Return n landmarks all at (0.5, 0.5)."""
|
||||||
|
return [_lm(0.5, 0.5) for _ in range(n)]
|
||||||
|
|
||||||
|
|
||||||
|
def _make_hand(
|
||||||
|
wrist_y: float = 0.8,
|
||||||
|
# Index finger
|
||||||
|
idx_mcp_y: float = 0.6, idx_pip_y: float = 0.5,
|
||||||
|
idx_tip_y: float = 0.3, idx_mcp_x: float = 0.4, idx_tip_x: float = 0.4,
|
||||||
|
# Middle finger
|
||||||
|
mid_mcp_y: float = 0.6, mid_pip_y: float = 0.5, mid_tip_y: float = 0.4,
|
||||||
|
# Ring finger
|
||||||
|
rng_mcp_y: float = 0.6, rng_pip_y: float = 0.5, rng_tip_y: float = 0.55,
|
||||||
|
# Pinky
|
||||||
|
pnk_mcp_y: float = 0.6, pnk_pip_y: float = 0.5, pnk_tip_y: float = 0.55,
|
||||||
|
# Thumb
|
||||||
|
thm_cmc_y: float = 0.7, thm_mcp_y: float = 0.65, thm_tip_y: float = 0.55,
|
||||||
|
) -> List[Landmark]:
|
||||||
|
"""
|
||||||
|
Build a 21-landmark array for testing.
|
||||||
|
|
||||||
|
Layout (MediaPipe Hands indices):
|
||||||
|
0 WRIST
|
||||||
|
1 THUMB_CMC 2 THUMB_MCP 3 THUMB_IP 4 THUMB_TIP
|
||||||
|
5 INDEX_MCP 6 INDEX_PIP 7 INDEX_DIP 8 INDEX_TIP
|
||||||
|
9 MIDDLE_MCP 10 MIDDLE_PIP 11 MIDDLE_DIP 12 MIDDLE_TIP
|
||||||
|
13 RING_MCP 14 RING_PIP 15 RING_DIP 16 RING_TIP
|
||||||
|
17 PINKY_MCP 18 PINKY_PIP 19 PINKY_DIP 20 PINKY_TIP
|
||||||
|
"""
|
||||||
|
lm = [_lm(0.5, 0.5)] * 21
|
||||||
|
# WRIST
|
||||||
|
lm[0] = _lm(0.5, wrist_y)
|
||||||
|
# THUMB
|
||||||
|
lm[1] = _lm(0.35, thm_cmc_y) # CMC
|
||||||
|
lm[2] = _lm(0.33, thm_mcp_y) # MCP
|
||||||
|
lm[3] = _lm(0.31, (thm_mcp_y + thm_tip_y) / 2) # IP
|
||||||
|
lm[4] = _lm(0.30, thm_tip_y) # TIP
|
||||||
|
# INDEX
|
||||||
|
lm[5] = _lm(idx_mcp_x, idx_mcp_y) # MCP
|
||||||
|
lm[6] = _lm(idx_mcp_x, idx_pip_y) # PIP
|
||||||
|
lm[7] = _lm(idx_mcp_x, (idx_pip_y + idx_tip_y) / 2) # DIP
|
||||||
|
lm[8] = _lm(idx_tip_x, idx_tip_y) # TIP
|
||||||
|
# MIDDLE
|
||||||
|
lm[9] = _lm(0.5, mid_mcp_y) # MCP
|
||||||
|
lm[10] = _lm(0.5, mid_pip_y) # PIP
|
||||||
|
lm[11] = _lm(0.5, (mid_pip_y + mid_tip_y) / 2)
|
||||||
|
lm[12] = _lm(0.5, mid_tip_y) # TIP
|
||||||
|
# RING
|
||||||
|
lm[13] = _lm(0.6, rng_mcp_y) # MCP
|
||||||
|
lm[14] = _lm(0.6, rng_pip_y) # PIP
|
||||||
|
lm[15] = _lm(0.6, (rng_pip_y + rng_tip_y) / 2)
|
||||||
|
lm[16] = _lm(0.6, rng_tip_y) # TIP
|
||||||
|
# PINKY
|
||||||
|
lm[17] = _lm(0.65, pnk_mcp_y) # MCP
|
||||||
|
lm[18] = _lm(0.65, pnk_pip_y) # PIP
|
||||||
|
lm[19] = _lm(0.65, (pnk_pip_y + pnk_tip_y) / 2)
|
||||||
|
lm[20] = _lm(0.65, pnk_tip_y) # TIP
|
||||||
|
return lm
|
||||||
|
|
||||||
|
|
||||||
|
# ── Prebuilt canonical poses ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _open_palm() -> List[Landmark]:
|
||||||
|
"""All 4 fingers extended (tips clearly above PIPs), thumb neutral."""
|
||||||
|
return _make_hand(
|
||||||
|
idx_mcp_y=0.60, idx_pip_y=0.50, idx_tip_y=0.25,
|
||||||
|
mid_mcp_y=0.60, mid_pip_y=0.50, mid_tip_y=0.25,
|
||||||
|
rng_mcp_y=0.60, rng_pip_y=0.50, rng_tip_y=0.25,
|
||||||
|
pnk_mcp_y=0.60, pnk_pip_y=0.50, pnk_tip_y=0.25,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _point_up() -> List[Landmark]:
|
||||||
|
"""Index extended, middle/ring/pinky curled."""
|
||||||
|
return _make_hand(
|
||||||
|
idx_mcp_y=0.60, idx_pip_y=0.50, idx_tip_y=0.25,
|
||||||
|
mid_mcp_y=0.60, mid_pip_y=0.55, mid_tip_y=0.62, # curled
|
||||||
|
rng_mcp_y=0.60, rng_pip_y=0.55, rng_tip_y=0.62,
|
||||||
|
pnk_mcp_y=0.60, pnk_pip_y=0.55, pnk_tip_y=0.62,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fist() -> List[Landmark]:
|
||||||
|
"""All fingers curled, thumb tip at/below thumb MCP."""
|
||||||
|
return _make_hand(
|
||||||
|
idx_mcp_y=0.60, idx_pip_y=0.58, idx_tip_y=0.65, # tip below pip
|
||||||
|
mid_mcp_y=0.60, mid_pip_y=0.58, mid_tip_y=0.65,
|
||||||
|
rng_mcp_y=0.60, rng_pip_y=0.58, rng_tip_y=0.65,
|
||||||
|
pnk_mcp_y=0.60, pnk_pip_y=0.58, pnk_tip_y=0.65,
|
||||||
|
thm_cmc_y=0.70, thm_mcp_y=0.65, thm_tip_y=0.68, # tip >= mcp → curled
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _thumbs_up() -> List[Landmark]:
|
||||||
|
"""Thumb tip clearly above CMC, four fingers curled."""
|
||||||
|
return _make_hand(
|
||||||
|
thm_cmc_y=0.70, thm_mcp_y=0.65, thm_tip_y=0.30, # tip well above CMC
|
||||||
|
idx_mcp_y=0.60, idx_pip_y=0.58, idx_tip_y=0.65,
|
||||||
|
mid_mcp_y=0.60, mid_pip_y=0.58, mid_tip_y=0.65,
|
||||||
|
rng_mcp_y=0.60, rng_pip_y=0.58, rng_tip_y=0.65,
|
||||||
|
pnk_mcp_y=0.60, pnk_pip_y=0.58, pnk_tip_y=0.65,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _peace() -> List[Landmark]:
|
||||||
|
"""Index + middle extended, ring + pinky curled."""
|
||||||
|
return _make_hand(
|
||||||
|
idx_mcp_y=0.60, idx_pip_y=0.50, idx_tip_y=0.25,
|
||||||
|
mid_mcp_y=0.60, mid_pip_y=0.50, mid_tip_y=0.25,
|
||||||
|
rng_mcp_y=0.60, rng_pip_y=0.58, rng_tip_y=0.65, # curled
|
||||||
|
pnk_mcp_y=0.60, pnk_pip_y=0.58, pnk_tip_y=0.65,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Landmark dataclass ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestLandmark:
|
||||||
|
def test_fields(self):
|
||||||
|
lm = Landmark(0.1, 0.2, 0.3)
|
||||||
|
assert lm.x == pytest.approx(0.1)
|
||||||
|
assert lm.y == pytest.approx(0.2)
|
||||||
|
assert lm.z == pytest.approx(0.3)
|
||||||
|
|
||||||
|
def test_default_z(self):
|
||||||
|
lm = Landmark(0.5, 0.5)
|
||||||
|
assert lm.z == 0.0
|
||||||
|
|
||||||
|
def test_frozen(self):
|
||||||
|
lm = Landmark(0.5, 0.5)
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
lm.x = 0.9 # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
# ── HandGestureResult ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestHandGestureResult:
|
||||||
|
def test_fields(self):
|
||||||
|
r = HandGestureResult("stop", 0.85, "", 0.5, 0.6)
|
||||||
|
assert r.gesture == "stop"
|
||||||
|
assert r.confidence == pytest.approx(0.85)
|
||||||
|
assert r.direction == ""
|
||||||
|
assert r.wrist_x == pytest.approx(0.5)
|
||||||
|
assert r.wrist_y == pytest.approx(0.6)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Low-level helpers ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestFingerHelpers:
|
||||||
|
def _two_lm(self, tip_y: float, pip_y: float) -> List[Landmark]:
|
||||||
|
"""Build a minimal 21-lm list where positions 0 and 1 are tip and pip."""
|
||||||
|
lm = [_lm(0.5, 0.5)] * 21
|
||||||
|
lm[0] = _lm(0.5, tip_y)
|
||||||
|
lm[1] = _lm(0.5, pip_y)
|
||||||
|
return lm
|
||||||
|
|
||||||
|
def test_finger_up_true(self):
|
||||||
|
"""Tip above PIP (smaller y) → True."""
|
||||||
|
lm = self._two_lm(tip_y=0.3, pip_y=0.6)
|
||||||
|
assert _finger_up(lm, 0, 1) is True
|
||||||
|
|
||||||
|
def test_finger_up_false(self):
|
||||||
|
"""Tip below PIP → False."""
|
||||||
|
lm = self._two_lm(tip_y=0.7, pip_y=0.4)
|
||||||
|
assert _finger_up(lm, 0, 1) is False
|
||||||
|
|
||||||
|
def test_count_fingers_up_open_palm(self):
|
||||||
|
lm = _open_palm()
|
||||||
|
assert _count_fingers_up(lm) == 4
|
||||||
|
|
||||||
|
def test_count_fingers_up_fist(self):
|
||||||
|
lm = _fist()
|
||||||
|
assert _count_fingers_up(lm) == 0
|
||||||
|
|
||||||
|
def test_four_fingers_curled_fist(self):
|
||||||
|
lm = _fist()
|
||||||
|
assert _four_fingers_curled(lm) is True
|
||||||
|
|
||||||
|
def test_four_fingers_curled_open_palm_false(self):
|
||||||
|
lm = _open_palm()
|
||||||
|
assert _four_fingers_curled(lm) is False
|
||||||
|
|
||||||
|
def test_thumb_curled_fist(self):
|
||||||
|
lm = _fist()
|
||||||
|
assert _thumb_curled(lm) is True
|
||||||
|
|
||||||
|
def test_thumb_extended_up_thumbs_up(self):
|
||||||
|
lm = _thumbs_up()
|
||||||
|
assert _thumb_extended_up(lm) is True
|
||||||
|
|
||||||
|
def test_thumb_not_extended_fist(self):
|
||||||
|
lm = _fist()
|
||||||
|
assert _thumb_extended_up(lm) is False
|
||||||
|
|
||||||
|
|
||||||
|
# ── Point direction ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestPointDirection:
|
||||||
|
def _hand_with_index_vec(self, dx: float, dy: float) -> List[Landmark]:
|
||||||
|
"""Build hand where index MCP→TIP vector is (dx, dy)."""
|
||||||
|
lm = _make_hand()
|
||||||
|
lm[5] = _lm(0.5, 0.6) # INDEX_MCP
|
||||||
|
lm[8] = _lm(0.5 + dx, 0.6 + dy) # INDEX_TIP
|
||||||
|
return lm
|
||||||
|
|
||||||
|
def test_pointing_up(self):
|
||||||
|
# dy negative (tip above MCP) → up
|
||||||
|
lm = self._hand_with_index_vec(0.0, -0.2)
|
||||||
|
assert _point_direction(lm) == "up"
|
||||||
|
|
||||||
|
def test_pointing_right(self):
|
||||||
|
lm = self._hand_with_index_vec(0.2, 0.0)
|
||||||
|
assert _point_direction(lm) == "right"
|
||||||
|
|
||||||
|
def test_pointing_left(self):
|
||||||
|
lm = self._hand_with_index_vec(-0.2, 0.0)
|
||||||
|
assert _point_direction(lm) == "left"
|
||||||
|
|
||||||
|
def test_pointing_upper_right(self):
|
||||||
|
lm = self._hand_with_index_vec(0.15, -0.15)
|
||||||
|
assert _point_direction(lm) == "upper_right"
|
||||||
|
|
||||||
|
|
||||||
|
# ── WaveDetector ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestWaveDetector:
|
||||||
|
def test_too_few_samples_no_wave(self):
|
||||||
|
wd = WaveDetector()
|
||||||
|
for x in [0.3, 0.7, 0.3]:
|
||||||
|
is_w, conf = wd.push(x)
|
||||||
|
assert is_w is False
|
||||||
|
assert conf == 0.0
|
||||||
|
|
||||||
|
def test_wave_detected_after_oscillation(self):
|
||||||
|
"""Feed a sinusoidal wrist_x — should trigger wave."""
|
||||||
|
wd = WaveDetector(history_len=20, min_reversals=2, min_amplitude=0.08)
|
||||||
|
is_waving = False
|
||||||
|
for i in range(20):
|
||||||
|
x = 0.5 + 0.20 * math.sin(i * math.pi / 3)
|
||||||
|
is_waving, conf = wd.push(x)
|
||||||
|
assert is_waving is True
|
||||||
|
assert conf > 0.0
|
||||||
|
|
||||||
|
def test_no_wave_small_amplitude(self):
|
||||||
|
"""Very small oscillation should not trigger."""
|
||||||
|
wd = WaveDetector(min_amplitude=0.10)
|
||||||
|
for i in range(24):
|
||||||
|
x = 0.5 + 0.01 * math.sin(i * math.pi / 3)
|
||||||
|
wd.push(x)
|
||||||
|
is_w, _ = wd.push(0.5)
|
||||||
|
assert is_w is False
|
||||||
|
|
||||||
|
def test_reset_clears_history(self):
|
||||||
|
wd = WaveDetector()
|
||||||
|
for i in range(24):
|
||||||
|
wd.push(0.5 + 0.2 * math.sin(i * math.pi / 3))
|
||||||
|
wd.reset()
|
||||||
|
is_w, _ = wd.push(0.5)
|
||||||
|
assert is_w is False
|
||||||
|
|
||||||
|
|
||||||
|
# ── classify_hand ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestClassifyHand:
|
||||||
|
def test_too_few_landmarks(self):
|
||||||
|
r = classify_hand([_lm(0.5, 0.5)] * 10)
|
||||||
|
assert r.gesture == "none"
|
||||||
|
|
||||||
|
def test_stop_open_palm(self):
|
||||||
|
lm = _open_palm()
|
||||||
|
r = classify_hand(lm)
|
||||||
|
assert r.gesture == "stop"
|
||||||
|
assert r.confidence >= 0.60
|
||||||
|
|
||||||
|
def test_point_up_gesture(self):
|
||||||
|
lm = _point_up()
|
||||||
|
r = classify_hand(lm)
|
||||||
|
assert r.gesture == "point"
|
||||||
|
assert r.confidence >= 0.60
|
||||||
|
|
||||||
|
def test_point_direction_populated(self):
|
||||||
|
lm = _point_up()
|
||||||
|
r = classify_hand(lm)
|
||||||
|
assert r.direction != ""
|
||||||
|
|
||||||
|
def test_disarm_fist(self):
|
||||||
|
lm = _fist()
|
||||||
|
r = classify_hand(lm)
|
||||||
|
assert r.gesture == "disarm"
|
||||||
|
assert r.confidence >= 0.60
|
||||||
|
|
||||||
|
def test_confirm_thumbs_up(self):
|
||||||
|
lm = _thumbs_up()
|
||||||
|
r = classify_hand(lm)
|
||||||
|
assert r.gesture == "confirm"
|
||||||
|
assert r.confidence >= 0.60
|
||||||
|
|
||||||
|
def test_follow_me_peace(self):
|
||||||
|
lm = _peace()
|
||||||
|
r = classify_hand(lm)
|
||||||
|
assert r.gesture == "follow_me"
|
||||||
|
assert r.confidence >= 0.60
|
||||||
|
|
||||||
|
def test_greeting_wave(self):
|
||||||
|
"""Wave via WaveDetector should produce greeting gesture."""
|
||||||
|
wd = WaveDetector(history_len=20, min_reversals=2, min_amplitude=0.08)
|
||||||
|
lm = _open_palm()
|
||||||
|
# Simulate 20 frames with oscillating wrist_x via different landmark sets
|
||||||
|
r = HandGestureResult("none", 0.0, "", 0.5, 0.5)
|
||||||
|
for i in range(20):
|
||||||
|
# Rebuild landmark set with moving wrist x
|
||||||
|
moving = list(lm)
|
||||||
|
wx = 0.5 + 0.20 * math.sin(i * math.pi / 3)
|
||||||
|
# WRIST is index 0 — move it
|
||||||
|
moving[0] = Landmark(wx, moving[0].y)
|
||||||
|
r = classify_hand(moving, is_right=True, wave_det=wd)
|
||||||
|
assert r.gesture == "greeting"
|
||||||
|
|
||||||
|
def test_flat_hand_no_crash(self):
|
||||||
|
"""A flat hand (all landmarks at 0.5, 0.5) has ambiguous geometry —
|
||||||
|
verify it returns a valid gesture string without crashing."""
|
||||||
|
_valid = {"stop", "point", "disarm", "confirm", "follow_me", "greeting", "none"}
|
||||||
|
r = classify_hand(_flat_hand())
|
||||||
|
assert r.gesture in _valid
|
||||||
|
|
||||||
|
def test_wrist_position_in_result(self):
|
||||||
|
lm = _open_palm()
|
||||||
|
lm[0] = Landmark(0.3, 0.7)
|
||||||
|
r = classify_hand(lm)
|
||||||
|
assert r.wrist_x == pytest.approx(0.3)
|
||||||
|
assert r.wrist_y == pytest.approx(0.7)
|
||||||
|
|
||||||
|
def test_confirm_before_stop(self):
|
||||||
|
"""Thumbs-up should be classified as 'confirm', not 'stop'."""
|
||||||
|
lm = _thumbs_up()
|
||||||
|
r = classify_hand(lm)
|
||||||
|
assert r.gesture == "confirm"
|
||||||
|
|
||||||
|
def test_follow_me_before_point(self):
|
||||||
|
"""Peace sign (2 fingers) should NOT be classified as 'point'."""
|
||||||
|
lm = _peace()
|
||||||
|
r = classify_hand(lm)
|
||||||
|
assert r.gesture == "follow_me"
|
||||||
|
|
||||||
|
def test_wave_beats_static_gesture(self):
|
||||||
|
"""When wave is detected it should override any static gesture."""
|
||||||
|
wd = WaveDetector(history_len=20, min_reversals=2, min_amplitude=0.08)
|
||||||
|
# Pre-load enough waving frames
|
||||||
|
for i in range(20):
|
||||||
|
wx = 0.5 + 0.25 * math.sin(i * math.pi / 3)
|
||||||
|
lm = _open_palm()
|
||||||
|
lm[0] = Landmark(wx, lm[0].y)
|
||||||
|
r = classify_hand(lm, wave_det=wd)
|
||||||
|
# The open palm would normally be "stop" but wave has already triggered
|
||||||
|
assert r.gesture == "greeting"
|
||||||
|
|
||||||
|
def test_result_confidence_bounded(self):
|
||||||
|
for lm_factory in [_open_palm, _point_up, _fist, _thumbs_up, _peace]:
|
||||||
|
r = classify_hand(lm_factory())
|
||||||
|
assert 0.0 <= r.confidence <= 1.0
|
||||||
@ -0,0 +1,16 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.8)
|
||||||
|
project(saltybot_hand_tracking_msgs)
|
||||||
|
|
||||||
|
find_package(ament_cmake REQUIRED)
|
||||||
|
find_package(rosidl_default_generators REQUIRED)
|
||||||
|
find_package(std_msgs REQUIRED)
|
||||||
|
|
||||||
|
rosidl_generate_interfaces(${PROJECT_NAME}
|
||||||
|
# Issue #342 — hand tracking / MediaPipe pivot
|
||||||
|
"msg/HandLandmarks.msg"
|
||||||
|
"msg/HandLandmarksArray.msg"
|
||||||
|
DEPENDENCIES std_msgs
|
||||||
|
)
|
||||||
|
|
||||||
|
ament_export_dependencies(rosidl_default_runtime)
|
||||||
|
ament_package()
|
||||||
@ -0,0 +1,29 @@
|
|||||||
|
# HandLandmarks.msg — MediaPipe Hands result for one detected hand (Issue #342)
|
||||||
|
#
|
||||||
|
# Landmark coordinates are MediaPipe-normalised:
|
||||||
|
# x, y ∈ [0.0, 1.0] — fraction of image width/height
|
||||||
|
# z — depth relative to wrist (negative = towards camera)
|
||||||
|
#
|
||||||
|
# landmark_xyz layout: [x0, y0, z0, x1, y1, z1, ..., x20, y20, z20]
|
||||||
|
# Index order follows MediaPipe Hands topology:
|
||||||
|
# 0=WRIST 1-4=THUMB(CMC,MCP,IP,TIP) 5-8=INDEX 9-12=MIDDLE
|
||||||
|
# 13-16=RING 17-20=PINKY
|
||||||
|
|
||||||
|
std_msgs/Header header
|
||||||
|
|
||||||
|
# Handedness
|
||||||
|
bool is_right_hand
|
||||||
|
float32 handedness_score # MediaPipe confidence for Left/Right label
|
||||||
|
|
||||||
|
# 21 landmarks × 3 (x, y, z) = 63 values
|
||||||
|
float32[63] landmark_xyz
|
||||||
|
|
||||||
|
# Classified robot-command gesture
|
||||||
|
# Values: "stop" | "point" | "disarm" | "confirm" | "follow_me" | "greeting" | "none"
|
||||||
|
string gesture
|
||||||
|
string point_direction # "up"|"right"|"left"|"upper_right"|"upper_left"|"lower_right"|"lower_left"|"down"
|
||||||
|
float32 gesture_confidence
|
||||||
|
|
||||||
|
# Wrist position in normalised image coords (convenience shortcut)
|
||||||
|
float32 wrist_x
|
||||||
|
float32 wrist_y
|
||||||
@ -0,0 +1,5 @@
|
|||||||
|
# HandLandmarksArray.msg — All detected hands in one camera frame (Issue #342)
|
||||||
|
|
||||||
|
std_msgs/Header header
|
||||||
|
HandLandmarks[] hands
|
||||||
|
uint32 hand_count
|
||||||
21
jetson/ros2_ws/src/saltybot_hand_tracking_msgs/package.xml
Normal file
21
jetson/ros2_ws/src/saltybot_hand_tracking_msgs/package.xml
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
<?xml version="1.0"?>
|
||||||
|
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||||
|
<package format="3">
|
||||||
|
<name>saltybot_hand_tracking_msgs</name>
|
||||||
|
<version>0.1.0</version>
|
||||||
|
<description>Message types for MediaPipe hand tracking (Issue #342).</description>
|
||||||
|
<maintainer email="robot@saltylab.local">SaltyLab</maintainer>
|
||||||
|
<license>MIT</license>
|
||||||
|
|
||||||
|
<buildtool_depend>ament_cmake</buildtool_depend>
|
||||||
|
<buildtool_depend>rosidl_default_generators</buildtool_depend>
|
||||||
|
|
||||||
|
<depend>std_msgs</depend>
|
||||||
|
|
||||||
|
<exec_depend>rosidl_default_runtime</exec_depend>
|
||||||
|
<member_of_group>rosidl_interface_packages</member_of_group>
|
||||||
|
|
||||||
|
<export>
|
||||||
|
<build_type>ament_cmake</build_type>
|
||||||
|
</export>
|
||||||
|
</package>
|
||||||
@ -0,0 +1,22 @@
|
|||||||
|
# Pure Pursuit Path Follower Configuration
|
||||||
|
pure_pursuit:
|
||||||
|
ros__parameters:
|
||||||
|
# Path following parameters
|
||||||
|
lookahead_distance: 0.5 # Distance to look ahead on the path (meters)
|
||||||
|
goal_tolerance: 0.1 # Distance tolerance to goal (meters)
|
||||||
|
heading_tolerance: 0.1 # Heading tolerance in radians (rad)
|
||||||
|
|
||||||
|
# Speed parameters
|
||||||
|
max_linear_velocity: 1.0 # Maximum linear velocity (m/s)
|
||||||
|
max_angular_velocity: 1.57 # Maximum angular velocity (rad/s)
|
||||||
|
|
||||||
|
# Control parameters
|
||||||
|
linear_velocity_scale: 1.0 # Scale factor for linear velocity
|
||||||
|
angular_velocity_scale: 1.0 # Scale factor for angular velocity
|
||||||
|
use_heading_correction: true # Apply heading error correction
|
||||||
|
|
||||||
|
# Publishing frequency (Hz)
|
||||||
|
publish_frequency: 10
|
||||||
|
|
||||||
|
# Enable/disable path follower
|
||||||
|
enable_path_following: true
|
||||||
@ -0,0 +1,29 @@
|
|||||||
|
import os
|
||||||
|
from launch import LaunchDescription
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
from launch_ros.substitutions import FindPackageShare
|
||||||
|
from launch.substitutions import PathJoinSubstitution
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launch_description():
|
||||||
|
config_dir = PathJoinSubstitution(
|
||||||
|
[FindPackageShare('saltybot_pure_pursuit'), 'config']
|
||||||
|
)
|
||||||
|
config_file = PathJoinSubstitution([config_dir, 'pure_pursuit_config.yaml'])
|
||||||
|
|
||||||
|
pure_pursuit = Node(
|
||||||
|
package='saltybot_pure_pursuit',
|
||||||
|
executable='pure_pursuit_node',
|
||||||
|
name='pure_pursuit',
|
||||||
|
output='screen',
|
||||||
|
parameters=[config_file],
|
||||||
|
remappings=[
|
||||||
|
('/odom', '/odom'),
|
||||||
|
('/path', '/path'),
|
||||||
|
('/cmd_vel_tracked', '/cmd_vel_tracked'),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return LaunchDescription([
|
||||||
|
pure_pursuit,
|
||||||
|
])
|
||||||
29
jetson/ros2_ws/src/saltybot_pure_pursuit/package.xml
Normal file
29
jetson/ros2_ws/src/saltybot_pure_pursuit/package.xml
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
<?xml version="1.0"?>
|
||||||
|
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||||
|
<package format="3">
|
||||||
|
<name>saltybot_pure_pursuit</name>
|
||||||
|
<version>0.1.0</version>
|
||||||
|
<description>Pure pursuit path follower for Nav2 autonomous navigation</description>
|
||||||
|
|
||||||
|
<maintainer email="sl-controls@saltybot.local">SaltyBot Controls</maintainer>
|
||||||
|
<license>MIT</license>
|
||||||
|
|
||||||
|
<author email="sl-controls@saltybot.local">SaltyBot Controls Team</author>
|
||||||
|
|
||||||
|
<buildtool_depend>ament_cmake</buildtool_depend>
|
||||||
|
<buildtool_depend>ament_cmake_python</buildtool_depend>
|
||||||
|
|
||||||
|
<depend>rclpy</depend>
|
||||||
|
<depend>nav_msgs</depend>
|
||||||
|
<depend>geometry_msgs</depend>
|
||||||
|
<depend>std_msgs</depend>
|
||||||
|
|
||||||
|
<test_depend>ament_copyright</test_depend>
|
||||||
|
<test_depend>ament_flake8</test_depend>
|
||||||
|
<test_depend>ament_pep257</test_depend>
|
||||||
|
<test_depend>pytest</test_depend>
|
||||||
|
|
||||||
|
<export>
|
||||||
|
<build_type>ament_python</build_type>
|
||||||
|
</export>
|
||||||
|
</package>
|
||||||
@ -0,0 +1,269 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Pure pursuit path follower for Nav2 autonomous navigation.
|
||||||
|
|
||||||
|
Implements the pure pursuit algorithm for following a path or trajectory.
|
||||||
|
The algorithm computes steering commands to make the robot follow a path
|
||||||
|
by targeting a lookahead point on the path.
|
||||||
|
|
||||||
|
The pure pursuit algorithm:
|
||||||
|
1. Finds the closest point on the path to the robot
|
||||||
|
2. Looks ahead a specified distance along the path
|
||||||
|
3. Computes a circular arc passing through robot position to lookahead point
|
||||||
|
4. Publishes velocity commands to follow this arc
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from nav_msgs.msg import Path, Odometry
|
||||||
|
from geometry_msgs.msg import Twist, PoseStamped
|
||||||
|
from std_msgs.msg import Float32
|
||||||
|
|
||||||
|
|
||||||
|
class PurePursuitNode(Node):
|
||||||
|
"""ROS2 node implementing pure pursuit path follower."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('pure_pursuit')
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
self.declare_parameter('lookahead_distance', 0.5)
|
||||||
|
self.declare_parameter('goal_tolerance', 0.1)
|
||||||
|
self.declare_parameter('heading_tolerance', 0.1)
|
||||||
|
self.declare_parameter('max_linear_velocity', 1.0)
|
||||||
|
self.declare_parameter('max_angular_velocity', 1.57)
|
||||||
|
self.declare_parameter('linear_velocity_scale', 1.0)
|
||||||
|
self.declare_parameter('angular_velocity_scale', 1.0)
|
||||||
|
self.declare_parameter('use_heading_correction', True)
|
||||||
|
self.declare_parameter('publish_frequency', 10)
|
||||||
|
self.declare_parameter('enable_path_following', True)
|
||||||
|
|
||||||
|
# Read parameters
|
||||||
|
self.lookahead_distance = self.get_parameter('lookahead_distance').value
|
||||||
|
self.goal_tolerance = self.get_parameter('goal_tolerance').value
|
||||||
|
self.heading_tolerance = self.get_parameter('heading_tolerance').value
|
||||||
|
self.max_linear_velocity = self.get_parameter('max_linear_velocity').value
|
||||||
|
self.max_angular_velocity = self.get_parameter('max_angular_velocity').value
|
||||||
|
self.linear_velocity_scale = self.get_parameter('linear_velocity_scale').value
|
||||||
|
self.angular_velocity_scale = self.get_parameter('angular_velocity_scale').value
|
||||||
|
self.use_heading_correction = self.get_parameter('use_heading_correction').value
|
||||||
|
publish_frequency = self.get_parameter('publish_frequency').value
|
||||||
|
self.enable_path_following = self.get_parameter('enable_path_following').value
|
||||||
|
|
||||||
|
# Current state
|
||||||
|
self.current_pose = None
|
||||||
|
self.current_path = None
|
||||||
|
self.goal_reached = False
|
||||||
|
|
||||||
|
# Subscriptions
|
||||||
|
self.sub_odom = self.create_subscription(
|
||||||
|
Odometry, '/odom', self._on_odometry, 10
|
||||||
|
)
|
||||||
|
self.sub_path = self.create_subscription(
|
||||||
|
Path, '/path', self._on_path, 10
|
||||||
|
)
|
||||||
|
|
||||||
|
# Publishers
|
||||||
|
self.pub_cmd_vel = self.create_publisher(Twist, '/cmd_vel_tracked', 10)
|
||||||
|
self.pub_tracking_error = self.create_publisher(Float32, '/tracking_error', 10)
|
||||||
|
|
||||||
|
# Timer for control loop
|
||||||
|
period = 1.0 / publish_frequency
|
||||||
|
self.timer = self.create_timer(period, self._control_loop)
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f"Pure pursuit initialized. "
|
||||||
|
f"Lookahead: {self.lookahead_distance}m, "
|
||||||
|
f"Goal tolerance: {self.goal_tolerance}m"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_odometry(self, msg: Odometry) -> None:
|
||||||
|
"""Callback for odometry messages."""
|
||||||
|
self.current_pose = msg.pose.pose
|
||||||
|
|
||||||
|
def _on_path(self, msg: Path) -> None:
|
||||||
|
"""Callback for path messages."""
|
||||||
|
self.current_path = msg
|
||||||
|
|
||||||
|
def _quaternion_to_yaw(self, quat):
|
||||||
|
"""Convert quaternion to yaw angle."""
|
||||||
|
siny_cosp = 2 * (quat.w * quat.z + quat.x * quat.y)
|
||||||
|
cosy_cosp = 1 - 2 * (quat.y * quat.y + quat.z * quat.z)
|
||||||
|
yaw = math.atan2(siny_cosp, cosy_cosp)
|
||||||
|
return yaw
|
||||||
|
|
||||||
|
def _yaw_to_quaternion(self, yaw):
|
||||||
|
"""Convert yaw angle to quaternion."""
|
||||||
|
from geometry_msgs.msg import Quaternion
|
||||||
|
cy = math.cos(yaw * 0.5)
|
||||||
|
sy = math.sin(yaw * 0.5)
|
||||||
|
return Quaternion(x=0.0, y=0.0, z=sy, w=cy)
|
||||||
|
|
||||||
|
def _distance(self, p1, p2):
|
||||||
|
"""Compute Euclidean distance between two points."""
|
||||||
|
dx = p1.x - p2.x
|
||||||
|
dy = p1.y - p2.y
|
||||||
|
return math.sqrt(dx * dx + dy * dy)
|
||||||
|
|
||||||
|
def _find_closest_point_on_path(self):
|
||||||
|
"""Find closest point on path to current robot position."""
|
||||||
|
if self.current_path is None or len(self.current_path.poses) < 2:
|
||||||
|
return None, 0, float('inf')
|
||||||
|
|
||||||
|
closest_idx = 0
|
||||||
|
closest_dist = float('inf')
|
||||||
|
|
||||||
|
# Find closest waypoint
|
||||||
|
for i, pose in enumerate(self.current_path.poses):
|
||||||
|
dist = self._distance(self.current_pose.position, pose.pose.position)
|
||||||
|
if dist < closest_dist:
|
||||||
|
closest_dist = dist
|
||||||
|
closest_idx = i
|
||||||
|
|
||||||
|
return closest_idx, closest_dist, closest_idx
|
||||||
|
|
||||||
|
def _find_lookahead_point(self):
|
||||||
|
"""Find lookahead point on path ahead of closest point."""
|
||||||
|
if self.current_path is None or len(self.current_path.poses) < 2:
|
||||||
|
return None
|
||||||
|
|
||||||
|
closest_idx, _, _ = self._find_closest_point_on_path()
|
||||||
|
if closest_idx is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find point at lookahead distance along path
|
||||||
|
current_dist = 0.0
|
||||||
|
for i in range(closest_idx, len(self.current_path.poses) - 1):
|
||||||
|
p1 = self.current_path.poses[i].pose.position
|
||||||
|
p2 = self.current_path.poses[i + 1].pose.position
|
||||||
|
segment_dist = self._distance(p1, p2)
|
||||||
|
current_dist += segment_dist
|
||||||
|
|
||||||
|
if current_dist >= self.lookahead_distance:
|
||||||
|
# Interpolate between p1 and p2
|
||||||
|
overshoot = current_dist - self.lookahead_distance
|
||||||
|
if segment_dist > 0:
|
||||||
|
alpha = 1.0 - (overshoot / segment_dist)
|
||||||
|
from geometry_msgs.msg import Point
|
||||||
|
lookahead = Point()
|
||||||
|
lookahead.x = p1.x + alpha * (p2.x - p1.x)
|
||||||
|
lookahead.y = p1.y + alpha * (p2.y - p1.y)
|
||||||
|
lookahead.z = 0.0
|
||||||
|
return lookahead
|
||||||
|
|
||||||
|
# If we get here, return the last point
|
||||||
|
return self.current_path.poses[-1].pose.position
|
||||||
|
|
||||||
|
def _calculate_steering_command(self, lookahead_point):
|
||||||
|
"""Calculate steering angle using pure pursuit geometry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lookahead_point: Target lookahead point on the path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (linear_velocity, angular_velocity)
|
||||||
|
"""
|
||||||
|
if lookahead_point is None or self.current_pose is None:
|
||||||
|
return 0.0, 0.0
|
||||||
|
|
||||||
|
# Vector from robot to lookahead point
|
||||||
|
dx = lookahead_point.x - self.current_pose.position.x
|
||||||
|
dy = lookahead_point.y - self.current_pose.position.y
|
||||||
|
distance_to_lookahead = math.sqrt(dx * dx + dy * dy)
|
||||||
|
|
||||||
|
# Check if goal is reached
|
||||||
|
if distance_to_lookahead < self.goal_tolerance:
|
||||||
|
self.goal_reached = True
|
||||||
|
return 0.0, 0.0
|
||||||
|
|
||||||
|
# Robot heading
|
||||||
|
robot_yaw = self._quaternion_to_yaw(self.current_pose.orientation)
|
||||||
|
|
||||||
|
# Angle to lookahead point
|
||||||
|
angle_to_lookahead = math.atan2(dy, dx)
|
||||||
|
|
||||||
|
# Heading error
|
||||||
|
heading_error = angle_to_lookahead - robot_yaw
|
||||||
|
|
||||||
|
# Normalize angle to [-pi, pi]
|
||||||
|
while heading_error > math.pi:
|
||||||
|
heading_error -= 2 * math.pi
|
||||||
|
while heading_error < -math.pi:
|
||||||
|
heading_error += 2 * math.pi
|
||||||
|
|
||||||
|
# Pure pursuit curvature: k = 2 * sin(alpha) / Ld
|
||||||
|
# where alpha is the heading error and Ld is lookahead distance
|
||||||
|
if self.lookahead_distance > 0:
|
||||||
|
curvature = (2.0 * math.sin(heading_error)) / self.lookahead_distance
|
||||||
|
else:
|
||||||
|
curvature = 0.0
|
||||||
|
|
||||||
|
# Linear velocity (reduce when heading error is large)
|
||||||
|
if self.use_heading_correction:
|
||||||
|
heading_error_abs = abs(heading_error)
|
||||||
|
linear_velocity = self.max_linear_velocity * math.cos(heading_error)
|
||||||
|
linear_velocity = max(0.0, linear_velocity) # Don't go backwards
|
||||||
|
else:
|
||||||
|
linear_velocity = self.max_linear_velocity
|
||||||
|
|
||||||
|
# Angular velocity from curvature: omega = v * k
|
||||||
|
angular_velocity = linear_velocity * curvature
|
||||||
|
|
||||||
|
# Clamp velocities
|
||||||
|
linear_velocity = min(linear_velocity, self.max_linear_velocity)
|
||||||
|
angular_velocity = max(-self.max_angular_velocity,
|
||||||
|
min(angular_velocity, self.max_angular_velocity))
|
||||||
|
|
||||||
|
# Apply scaling
|
||||||
|
linear_velocity *= self.linear_velocity_scale
|
||||||
|
angular_velocity *= self.angular_velocity_scale
|
||||||
|
|
||||||
|
return linear_velocity, angular_velocity
|
||||||
|
|
||||||
|
def _control_loop(self) -> None:
|
||||||
|
"""Main control loop executed at regular intervals."""
|
||||||
|
if not self.enable_path_following or self.current_pose is None:
|
||||||
|
# Publish zero velocity
|
||||||
|
cmd = Twist()
|
||||||
|
self.pub_cmd_vel.publish(cmd)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Find lookahead point
|
||||||
|
lookahead_point = self._find_lookahead_point()
|
||||||
|
|
||||||
|
if lookahead_point is None:
|
||||||
|
# No valid path, publish zero velocity
|
||||||
|
cmd = Twist()
|
||||||
|
self.pub_cmd_vel.publish(cmd)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Calculate steering command
|
||||||
|
linear_vel, angular_vel = self._calculate_steering_command(lookahead_point)
|
||||||
|
|
||||||
|
# Publish velocity command
|
||||||
|
cmd = Twist()
|
||||||
|
cmd.linear.x = linear_vel
|
||||||
|
cmd.angular.z = angular_vel
|
||||||
|
self.pub_cmd_vel.publish(cmd)
|
||||||
|
|
||||||
|
# Publish tracking error
|
||||||
|
if self.current_path and len(self.current_path.poses) > 0:
|
||||||
|
closest_idx, tracking_error, _ = self._find_closest_point_on_path()
|
||||||
|
error_msg = Float32(data=tracking_error)
|
||||||
|
self.pub_tracking_error.publish(error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = PurePursuitNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
5
jetson/ros2_ws/src/saltybot_pure_pursuit/setup.cfg
Normal file
5
jetson/ros2_ws/src/saltybot_pure_pursuit/setup.cfg
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
[develop]
|
||||||
|
script_dir=$base/lib/saltybot_pure_pursuit
|
||||||
|
[egg_info]
|
||||||
|
tag_build =
|
||||||
|
tag_date = 0
|
||||||
34
jetson/ros2_ws/src/saltybot_pure_pursuit/setup.py
Normal file
34
jetson/ros2_ws/src/saltybot_pure_pursuit/setup.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
package_name = 'saltybot_pure_pursuit'
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name=package_name,
|
||||||
|
version='0.1.0',
|
||||||
|
packages=find_packages(exclude=['test']),
|
||||||
|
data_files=[
|
||||||
|
('share/ament_index/resource_index/packages',
|
||||||
|
['resource/saltybot_pure_pursuit']),
|
||||||
|
('share/' + package_name, ['package.xml']),
|
||||||
|
('share/' + package_name + '/config', ['config/pure_pursuit_config.yaml']),
|
||||||
|
('share/' + package_name + '/launch', ['launch/pure_pursuit.launch.py']),
|
||||||
|
],
|
||||||
|
install_requires=['setuptools'],
|
||||||
|
zip_safe=True,
|
||||||
|
author='SaltyBot Controls',
|
||||||
|
author_email='sl-controls@saltybot.local',
|
||||||
|
maintainer='SaltyBot Controls',
|
||||||
|
maintainer_email='sl-controls@saltybot.local',
|
||||||
|
keywords=['ROS2', 'pure_pursuit', 'path_following', 'nav2'],
|
||||||
|
classifiers=[
|
||||||
|
'Intended Audience :: Developers',
|
||||||
|
'License :: OSI Approved :: MIT License',
|
||||||
|
'Programming Language :: Python :: 3',
|
||||||
|
'Topic :: Software Development',
|
||||||
|
],
|
||||||
|
entry_points={
|
||||||
|
'console_scripts': [
|
||||||
|
'pure_pursuit_node=saltybot_pure_pursuit.pure_pursuit_node:main',
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
@ -0,0 +1,397 @@
|
|||||||
|
"""Unit tests for pure pursuit path follower node."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import math
|
||||||
|
from nav_msgs.msg import Path, Odometry
|
||||||
|
from geometry_msgs.msg import PoseStamped, Pose, Point, Quaternion, Twist
|
||||||
|
from std_msgs.msg import Header
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
|
||||||
|
from saltybot_pure_pursuit.pure_pursuit_node import PurePursuitNode
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def rclpy_fixture():
|
||||||
|
"""Initialize and cleanup rclpy."""
|
||||||
|
rclpy.init()
|
||||||
|
yield
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def node(rclpy_fixture):
|
||||||
|
"""Create a pure pursuit node instance."""
|
||||||
|
node = PurePursuitNode()
|
||||||
|
yield node
|
||||||
|
node.destroy_node()
|
||||||
|
|
||||||
|
|
||||||
|
class TestPurePursuitGeometry:
|
||||||
|
"""Test suite for pure pursuit geometric calculations."""
|
||||||
|
|
||||||
|
def test_node_initialization(self, node):
|
||||||
|
"""Test that node initializes with correct defaults."""
|
||||||
|
assert node.lookahead_distance == 0.5
|
||||||
|
assert node.goal_tolerance == 0.1
|
||||||
|
assert node.max_linear_velocity == 1.0
|
||||||
|
assert node.enable_path_following is True
|
||||||
|
|
||||||
|
def test_distance_calculation(self, node):
|
||||||
|
"""Test Euclidean distance calculation."""
|
||||||
|
p1 = Point(x=0.0, y=0.0, z=0.0)
|
||||||
|
p2 = Point(x=3.0, y=4.0, z=0.0)
|
||||||
|
dist = node._distance(p1, p2)
|
||||||
|
assert abs(dist - 5.0) < 0.01
|
||||||
|
|
||||||
|
def test_distance_same_point(self, node):
|
||||||
|
"""Test distance between same point is zero."""
|
||||||
|
p = Point(x=1.0, y=2.0, z=0.0)
|
||||||
|
dist = node._distance(p, p)
|
||||||
|
assert abs(dist) < 0.001
|
||||||
|
|
||||||
|
def test_quaternion_to_yaw_north(self, node):
|
||||||
|
"""Test quaternion to yaw conversion for north heading."""
|
||||||
|
quat = Quaternion(x=0.0, y=0.0, z=0.0, w=1.0)
|
||||||
|
yaw = node._quaternion_to_yaw(quat)
|
||||||
|
assert abs(yaw) < 0.01
|
||||||
|
|
||||||
|
def test_quaternion_to_yaw_east(self, node):
|
||||||
|
"""Test quaternion to yaw conversion for east heading."""
|
||||||
|
# 90 degree rotation around z-axis
|
||||||
|
quat = Quaternion(x=0.0, y=0.0, z=0.7071, w=0.7071)
|
||||||
|
yaw = node._quaternion_to_yaw(quat)
|
||||||
|
assert abs(yaw - math.pi / 2) < 0.01
|
||||||
|
|
||||||
|
def test_yaw_to_quaternion_identity(self, node):
|
||||||
|
"""Test yaw to quaternion conversion."""
|
||||||
|
quat = node._yaw_to_quaternion(0.0)
|
||||||
|
assert abs(quat.w - 1.0) < 0.01
|
||||||
|
assert abs(quat.z) < 0.01
|
||||||
|
|
||||||
|
|
||||||
|
class TestPathFollowing:
|
||||||
|
"""Test suite for path following logic."""
|
||||||
|
|
||||||
|
def test_empty_path(self, node):
|
||||||
|
"""Test handling of empty path."""
|
||||||
|
node.current_pose = Pose(position=Point(x=0.0, y=0.0, z=0.0))
|
||||||
|
node.current_path = Path()
|
||||||
|
|
||||||
|
lookahead = node._find_lookahead_point()
|
||||||
|
assert lookahead is None
|
||||||
|
|
||||||
|
def test_single_point_path(self, node):
|
||||||
|
"""Test handling of single-point path."""
|
||||||
|
node.current_pose = Pose(position=Point(x=0.0, y=0.0, z=0.0))
|
||||||
|
|
||||||
|
path = Path()
|
||||||
|
path.poses = [PoseStamped(pose=Pose(position=Point(x=1.0, y=0.0, z=0.0)))]
|
||||||
|
node.current_path = path
|
||||||
|
|
||||||
|
lookahead = node._find_lookahead_point()
|
||||||
|
assert lookahead is None or lookahead.x == 1.0
|
||||||
|
|
||||||
|
def test_straight_line_path(self, node):
|
||||||
|
"""Test path following on straight line."""
|
||||||
|
node.current_pose = Pose(
|
||||||
|
position=Point(x=0.0, y=0.0, z=0.0),
|
||||||
|
orientation=Quaternion(x=0.0, y=0.0, z=0.0, w=1.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create straight path along x-axis
|
||||||
|
path = Path()
|
||||||
|
for i in range(5):
|
||||||
|
pose = PoseStamped(pose=Pose(
|
||||||
|
position=Point(x=float(i), y=0.0, z=0.0)
|
||||||
|
))
|
||||||
|
path.poses.append(pose)
|
||||||
|
node.current_path = path
|
||||||
|
|
||||||
|
# Robot heading east towards path
|
||||||
|
lin_vel, ang_vel = node._calculate_steering_command(path.poses[1].pose.position)
|
||||||
|
assert lin_vel >= 0
|
||||||
|
assert abs(ang_vel) < 0.5 # Should have small steering error
|
||||||
|
|
||||||
|
def test_curved_path(self, node):
|
||||||
|
"""Test path following on curved path."""
|
||||||
|
node.current_pose = Pose(
|
||||||
|
position=Point(x=0.0, y=0.0, z=0.0),
|
||||||
|
orientation=Quaternion(x=0.0, y=0.0, z=0.0, w=1.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create circular arc path
|
||||||
|
path = Path()
|
||||||
|
for i in range(9):
|
||||||
|
angle = (i / 8.0) * (math.pi / 2)
|
||||||
|
x = math.sin(angle)
|
||||||
|
y = 1.0 - math.cos(angle)
|
||||||
|
pose = PoseStamped(pose=Pose(position=Point(x=x, y=y, z=0.0)))
|
||||||
|
path.poses.append(pose)
|
||||||
|
node.current_path = path
|
||||||
|
|
||||||
|
lookahead = node._find_lookahead_point()
|
||||||
|
assert lookahead is not None
|
||||||
|
|
||||||
|
def test_goal_reached(self, node):
|
||||||
|
"""Test goal reached detection."""
|
||||||
|
goal_pos = Point(x=1.0, y=0.0, z=0.0)
|
||||||
|
node.current_pose = Pose(
|
||||||
|
position=Point(x=1.05, y=0.0, z=0.0),
|
||||||
|
orientation=Quaternion(x=0.0, y=0.0, z=0.0, w=1.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
lin_vel, ang_vel = node._calculate_steering_command(goal_pos)
|
||||||
|
# With goal_tolerance=0.1, we should be at goal
|
||||||
|
assert abs(lin_vel) < 0.01
|
||||||
|
|
||||||
|
|
||||||
|
class TestSteeringCalculation:
|
||||||
|
"""Test suite for steering command calculations."""
|
||||||
|
|
||||||
|
def test_zero_heading_error(self, node):
|
||||||
|
"""Test steering when robot already aligned with target."""
|
||||||
|
node.current_pose = Pose(
|
||||||
|
position=Point(x=0.0, y=0.0, z=0.0),
|
||||||
|
orientation=Quaternion(x=0.0, y=0.0, z=0.0, w=1.0) # Facing east
|
||||||
|
)
|
||||||
|
|
||||||
|
# Target directly ahead
|
||||||
|
target = Point(x=1.0, y=0.0, z=0.0)
|
||||||
|
lin_vel, ang_vel = node._calculate_steering_command(target)
|
||||||
|
|
||||||
|
assert lin_vel > 0
|
||||||
|
assert abs(ang_vel) < 0.1 # Minimal steering
|
||||||
|
|
||||||
|
def test_90_degree_heading_error(self, node):
|
||||||
|
"""Test steering with 90 degree heading error."""
|
||||||
|
node.current_pose = Pose(
|
||||||
|
position=Point(x=0.0, y=0.0, z=0.0),
|
||||||
|
orientation=Quaternion(x=0.0, y=0.0, z=0.0, w=1.0) # Facing east
|
||||||
|
)
|
||||||
|
|
||||||
|
# Target north
|
||||||
|
target = Point(x=0.0, y=1.0, z=0.0)
|
||||||
|
lin_vel, ang_vel = node._calculate_steering_command(target)
|
||||||
|
|
||||||
|
assert lin_vel >= 0
|
||||||
|
assert ang_vel != 0 # Should have significant steering
|
||||||
|
|
||||||
|
def test_velocity_limits(self, node):
|
||||||
|
"""Test that velocities are within limits."""
|
||||||
|
node.current_pose = Pose(
|
||||||
|
position=Point(x=0.0, y=0.0, z=0.0),
|
||||||
|
orientation=Quaternion(x=0.0, y=0.0, z=0.0, w=1.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
target = Point(x=100.0, y=100.0, z=0.0)
|
||||||
|
lin_vel, ang_vel = node._calculate_steering_command(target)
|
||||||
|
|
||||||
|
assert lin_vel <= node.max_linear_velocity
|
||||||
|
assert abs(ang_vel) <= node.max_angular_velocity
|
||||||
|
|
||||||
|
def test_heading_correction_enabled(self, node):
|
||||||
|
"""Test heading correction when enabled."""
|
||||||
|
node.use_heading_correction = True
|
||||||
|
node.current_pose = Pose(
|
||||||
|
position=Point(x=0.0, y=0.0, z=0.0),
|
||||||
|
orientation=Quaternion(x=0.0, y=0.0, z=0.0, w=1.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Large heading error
|
||||||
|
target = Point(x=0.0, y=10.0, z=0.0) # Due north
|
||||||
|
lin_vel, ang_vel = node._calculate_steering_command(target)
|
||||||
|
|
||||||
|
# Linear velocity should be reduced due to heading error
|
||||||
|
assert lin_vel < node.max_linear_velocity
|
||||||
|
|
||||||
|
def test_heading_correction_disabled(self, node):
|
||||||
|
"""Test that heading correction can be disabled."""
|
||||||
|
node.use_heading_correction = False
|
||||||
|
node.current_pose = Pose(
|
||||||
|
position=Point(x=0.0, y=0.0, z=0.0),
|
||||||
|
orientation=Quaternion(x=0.0, y=0.0, z=0.0, w=1.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
target = Point(x=0.0, y=10.0, z=0.0)
|
||||||
|
lin_vel, ang_vel = node._calculate_steering_command(target)
|
||||||
|
|
||||||
|
# Linear velocity should be full
|
||||||
|
assert abs(lin_vel - node.max_linear_velocity) < 0.01
|
||||||
|
|
||||||
|
def test_negative_lookahead_distance(self, node):
|
||||||
|
"""Test behavior with invalid lookahead distance."""
|
||||||
|
node.lookahead_distance = 0.0
|
||||||
|
node.current_pose = Pose(
|
||||||
|
position=Point(x=0.0, y=0.0, z=0.0),
|
||||||
|
orientation=Quaternion(x=0.0, y=0.0, z=0.0, w=1.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
target = Point(x=1.0, y=1.0, z=0.0)
|
||||||
|
lin_vel, ang_vel = node._calculate_steering_command(target)
|
||||||
|
|
||||||
|
# Should not crash, handle gracefully
|
||||||
|
assert isinstance(lin_vel, float)
|
||||||
|
assert isinstance(ang_vel, float)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPurePursuitScenarios:
|
||||||
|
"""Integration-style tests for realistic scenarios."""
|
||||||
|
|
||||||
|
def test_scenario_follow_straight_path(self, node):
|
||||||
|
"""Scenario: Follow a straight horizontal path."""
|
||||||
|
node.current_pose = Pose(
|
||||||
|
position=Point(x=0.0, y=0.0, z=0.0),
|
||||||
|
orientation=Quaternion(x=0.0, y=0.0, z=0.0, w=1.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Horizontal path
|
||||||
|
path = Path()
|
||||||
|
for i in range(10):
|
||||||
|
pose = PoseStamped(pose=Pose(position=Point(x=float(i), y=0.0, z=0.0)))
|
||||||
|
path.poses.append(pose)
|
||||||
|
node.current_path = path
|
||||||
|
|
||||||
|
lin_vel, ang_vel = node._calculate_steering_command(path.poses[1].pose.position)
|
||||||
|
assert lin_vel > 0
|
||||||
|
assert abs(ang_vel) < 0.5
|
||||||
|
|
||||||
|
def test_scenario_s_shaped_path(self, node):
|
||||||
|
"""Scenario: Follow an S-shaped path."""
|
||||||
|
node.current_pose = Pose(
|
||||||
|
position=Point(x=0.0, y=0.0, z=0.0),
|
||||||
|
orientation=Quaternion(x=0.0, y=0.0, z=0.0, w=1.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# S-shaped path
|
||||||
|
path = Path()
|
||||||
|
for i in range(20):
|
||||||
|
x = float(i) * 0.5
|
||||||
|
y = 2.0 * math.sin(x)
|
||||||
|
pose = PoseStamped(pose=Pose(position=Point(x=x, y=y, z=0.0)))
|
||||||
|
path.poses.append(pose)
|
||||||
|
node.current_path = path
|
||||||
|
|
||||||
|
lookahead = node._find_lookahead_point()
|
||||||
|
assert lookahead is not None
|
||||||
|
|
||||||
|
def test_scenario_spiral_path(self, node):
|
||||||
|
"""Scenario: Follow a spiral path."""
|
||||||
|
node.current_pose = Pose(
|
||||||
|
position=Point(x=0.0, y=0.0, z=0.0),
|
||||||
|
orientation=Quaternion(x=0.0, y=0.0, z=0.0, w=1.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Spiral path
|
||||||
|
path = Path()
|
||||||
|
for i in range(30):
|
||||||
|
angle = (i / 30.0) * (4 * math.pi)
|
||||||
|
radius = 0.5 + (i / 30.0) * 2.0
|
||||||
|
x = radius * math.cos(angle)
|
||||||
|
y = radius * math.sin(angle)
|
||||||
|
pose = PoseStamped(pose=Pose(position=Point(x=x, y=y, z=0.0)))
|
||||||
|
path.poses.append(pose)
|
||||||
|
node.current_path = path
|
||||||
|
|
||||||
|
lin_vel, ang_vel = node._calculate_steering_command(path.poses[5].pose.position)
|
||||||
|
assert isinstance(lin_vel, float)
|
||||||
|
assert isinstance(ang_vel, float)
|
||||||
|
|
||||||
|
def test_scenario_control_loop(self, node):
|
||||||
|
"""Scenario: Control loop with valid path and odometry."""
|
||||||
|
node.enable_path_following = True
|
||||||
|
node.current_pose = Pose(
|
||||||
|
position=Point(x=0.0, y=0.0, z=0.0),
|
||||||
|
orientation=Quaternion(x=0.0, y=0.0, z=0.0, w=1.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a simple path
|
||||||
|
path = Path()
|
||||||
|
for i in range(5):
|
||||||
|
pose = PoseStamped(pose=Pose(position=Point(x=float(i), y=0.0, z=0.0)))
|
||||||
|
path.poses.append(pose)
|
||||||
|
node.current_path = path
|
||||||
|
|
||||||
|
# Run control loop
|
||||||
|
node._control_loop()
|
||||||
|
|
||||||
|
# Should complete without error
|
||||||
|
|
||||||
|
def test_scenario_disabled_path_following(self, node):
|
||||||
|
"""Scenario: Path following disabled."""
|
||||||
|
node.enable_path_following = False
|
||||||
|
node.current_pose = Pose(position=Point(x=0.0, y=0.0, z=0.0))
|
||||||
|
|
||||||
|
# Create path
|
||||||
|
path = Path()
|
||||||
|
path.poses = [PoseStamped(pose=Pose(position=Point(x=1.0, y=0.0, z=0.0)))]
|
||||||
|
node.current_path = path
|
||||||
|
|
||||||
|
# Run control loop
|
||||||
|
node._control_loop()
|
||||||
|
|
||||||
|
# Should publish zero velocity
|
||||||
|
|
||||||
|
def test_scenario_no_odometry(self, node):
|
||||||
|
"""Scenario: Control loop when odometry not received."""
|
||||||
|
node.current_pose = None
|
||||||
|
path = Path()
|
||||||
|
path.poses = [PoseStamped(pose=Pose(position=Point(x=1.0, y=0.0, z=0.0)))]
|
||||||
|
node.current_path = path
|
||||||
|
|
||||||
|
# Run control loop
|
||||||
|
node._control_loop()
|
||||||
|
|
||||||
|
# Should handle gracefully
|
||||||
|
|
||||||
|
def test_scenario_velocity_scaling(self, node):
|
||||||
|
"""Scenario: Velocity scaling parameters."""
|
||||||
|
node.linear_velocity_scale = 0.5
|
||||||
|
node.angular_velocity_scale = 0.5
|
||||||
|
|
||||||
|
node.current_pose = Pose(
|
||||||
|
position=Point(x=0.0, y=0.0, z=0.0),
|
||||||
|
orientation=Quaternion(x=0.0, y=0.0, z=0.0, w=1.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
target = Point(x=1.0, y=0.0, z=0.0)
|
||||||
|
lin_vel, ang_vel = node._calculate_steering_command(target)
|
||||||
|
|
||||||
|
# Scaled velocities should be smaller
|
||||||
|
assert lin_vel <= node.max_linear_velocity * 0.5 + 0.01
|
||||||
|
|
||||||
|
def test_scenario_large_lookahead_distance(self, node):
|
||||||
|
"""Scenario: Large lookahead distance."""
|
||||||
|
node.lookahead_distance = 5.0
|
||||||
|
node.current_pose = Pose(
|
||||||
|
position=Point(x=0.0, y=0.0, z=0.0),
|
||||||
|
orientation=Quaternion(x=0.0, y=0.0, z=0.0, w=1.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Path
|
||||||
|
path = Path()
|
||||||
|
for i in range(10):
|
||||||
|
pose = PoseStamped(pose=Pose(position=Point(x=float(i), y=0.0, z=0.0)))
|
||||||
|
path.poses.append(pose)
|
||||||
|
node.current_path = path
|
||||||
|
|
||||||
|
lookahead = node._find_lookahead_point()
|
||||||
|
assert lookahead is not None
|
||||||
|
|
||||||
|
def test_scenario_small_lookahead_distance(self, node):
|
||||||
|
"""Scenario: Small lookahead distance."""
|
||||||
|
node.lookahead_distance = 0.05
|
||||||
|
node.current_pose = Pose(
|
||||||
|
position=Point(x=0.0, y=0.0, z=0.0),
|
||||||
|
orientation=Quaternion(x=0.0, y=0.0, z=0.0, w=1.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Path
|
||||||
|
path = Path()
|
||||||
|
for i in range(10):
|
||||||
|
pose = PoseStamped(pose=Pose(position=Point(x=float(i), y=0.0, z=0.0)))
|
||||||
|
path.poses.append(pose)
|
||||||
|
node.current_path = path
|
||||||
|
|
||||||
|
lookahead = node._find_lookahead_point()
|
||||||
|
assert lookahead is not None
|
||||||
@ -22,6 +22,23 @@ rosidl_generate_interfaces(${PROJECT_NAME}
|
|||||||
# Issue #322 — cross-camera person re-identification
|
# Issue #322 — cross-camera person re-identification
|
||||||
"msg/PersonTrack.msg"
|
"msg/PersonTrack.msg"
|
||||||
"msg/PersonTrackArray.msg"
|
"msg/PersonTrackArray.msg"
|
||||||
|
# Issue #326 — dynamic obstacle velocity estimator
|
||||||
|
"msg/ObstacleVelocity.msg"
|
||||||
|
"msg/ObstacleVelocityArray.msg"
|
||||||
|
# Issue #339 — lane/path edge detector
|
||||||
|
"msg/PathEdges.msg"
|
||||||
|
# Issue #348 — depth-based obstacle size estimator
|
||||||
|
"msg/ObstacleSize.msg"
|
||||||
|
"msg/ObstacleSizeArray.msg"
|
||||||
|
# Issue #353 — audio scene classifier
|
||||||
|
"msg/AudioScene.msg"
|
||||||
|
# Issue #359 — face emotion classifier
|
||||||
|
"msg/FaceEmotion.msg"
|
||||||
|
"msg/FaceEmotionArray.msg"
|
||||||
|
# Issue #363 — person tracking for follow-me mode
|
||||||
|
"msg/TargetTrack.msg"
|
||||||
|
# Issue #365 — UWB DW3000 anchor/tag ranging
|
||||||
|
"msg/UwbTarget.msg"
|
||||||
DEPENDENCIES std_msgs geometry_msgs vision_msgs builtin_interfaces
|
DEPENDENCIES std_msgs geometry_msgs vision_msgs builtin_interfaces
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,4 @@
|
|||||||
|
std_msgs/Header header
|
||||||
|
string label # 'indoor' | 'outdoor' | 'traffic' | 'park'
|
||||||
|
float32 confidence # 0.0–1.0 (nearest-centroid inverted distance)
|
||||||
|
float32[16] features # raw feature vector: MFCC[0..12] + centroid_hz + rolloff_hz + zcr
|
||||||
@ -0,0 +1,26 @@
|
|||||||
|
std_msgs/Header header
|
||||||
|
|
||||||
|
# Current power mode (use constants below)
|
||||||
|
uint8 mode
|
||||||
|
uint8 MODE_SLEEP = 0 # charging / idle — minimal sensors
|
||||||
|
uint8 MODE_SOCIAL = 1 # parked / socialising — webcam + face UI only
|
||||||
|
uint8 MODE_AWARE = 2 # indoor / slow (<5 km/h) — front CSI + RealSense + LIDAR
|
||||||
|
uint8 MODE_ACTIVE = 3 # sidewalk / bike path (5–15 km/h) — front+rear + RealSense + LIDAR + UWB
|
||||||
|
uint8 MODE_FULL = 4 # street / high-speed (>15 km/h) or crossing — all sensors
|
||||||
|
|
||||||
|
string mode_name # human-readable label
|
||||||
|
|
||||||
|
# Active sensor flags for this mode
|
||||||
|
bool csi_front
|
||||||
|
bool csi_rear
|
||||||
|
bool csi_left
|
||||||
|
bool csi_right
|
||||||
|
bool realsense
|
||||||
|
bool lidar
|
||||||
|
bool uwb
|
||||||
|
bool webcam
|
||||||
|
|
||||||
|
# Transition metadata
|
||||||
|
float32 trigger_speed_mps # speed that triggered the last transition
|
||||||
|
string trigger_scenario # scenario that triggered the last transition
|
||||||
|
bool scenario_override # true when scenario (not speed) forced the mode
|
||||||
@ -0,0 +1,8 @@
|
|||||||
|
std_msgs/Header header
|
||||||
|
uint32 face_id # track ID or 0-based detection index
|
||||||
|
string emotion # 'neutral' | 'happy' | 'surprised' | 'angry' | 'sad'
|
||||||
|
float32 confidence # 0.0–1.0
|
||||||
|
float32 mouth_open # mouth height / face height (0=closed)
|
||||||
|
float32 smile # lip-corner elevation (positive=smile, negative=frown)
|
||||||
|
float32 brow_raise # inner-brow to eye-top gap / face height (positive=raised)
|
||||||
|
float32 eye_open # eye height / face height
|
||||||
@ -0,0 +1,3 @@
|
|||||||
|
std_msgs/Header header
|
||||||
|
FaceEmotion[] faces
|
||||||
|
uint32 face_count
|
||||||
27
jetson/ros2_ws/src/saltybot_scene_msgs/msg/ObstacleSize.msg
Normal file
27
jetson/ros2_ws/src/saltybot_scene_msgs/msg/ObstacleSize.msg
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# ObstacleSize.msg — Depth-projected LIDAR cluster size estimate (Issue #348)
|
||||||
|
#
|
||||||
|
# Fuses a 2-D LIDAR cluster with the D435i depth image to estimate the full
|
||||||
|
# 3-D size (width × height) of each detected obstacle.
|
||||||
|
#
|
||||||
|
# obstacle_id : matches the obstacle_id in ObstacleVelocity (same LIDAR cluster)
|
||||||
|
# centroid_x : LIDAR-frame forward distance to centroid (metres)
|
||||||
|
# centroid_y : LIDAR-frame lateral distance to centroid (metres, +Y = left)
|
||||||
|
# depth_z : sampled D435i depth at projected centroid (metres, 0 = unknown)
|
||||||
|
# width_m : horizontal size from LIDAR bbox (metres)
|
||||||
|
# height_m : vertical size from D435i depth strip (metres, 0 = unknown)
|
||||||
|
# pixel_u : projected centroid column in depth image (pixels)
|
||||||
|
# pixel_v : projected centroid row in depth image (pixels)
|
||||||
|
# lidar_range : range from LIDAR origin to centroid (metres)
|
||||||
|
# confidence : 0.0–1.0; based on depth sample quality and cluster track age
|
||||||
|
#
|
||||||
|
std_msgs/Header header
|
||||||
|
uint32 obstacle_id
|
||||||
|
float32 centroid_x
|
||||||
|
float32 centroid_y
|
||||||
|
float32 depth_z
|
||||||
|
float32 width_m
|
||||||
|
float32 height_m
|
||||||
|
int32 pixel_u
|
||||||
|
int32 pixel_v
|
||||||
|
float32 lidar_range
|
||||||
|
float32 confidence
|
||||||
@ -0,0 +1,4 @@
|
|||||||
|
# ObstacleSizeArray.msg — All depth-projected obstacle size estimates (Issue #348)
|
||||||
|
|
||||||
|
std_msgs/Header header
|
||||||
|
ObstacleSize[] obstacles
|
||||||
@ -0,0 +1,22 @@
|
|||||||
|
# ObstacleVelocity.msg — tracked obstacle with estimated velocity (Issue #326)
|
||||||
|
#
|
||||||
|
# obstacle_id : stable track ID, monotonically increasing, 1-based
|
||||||
|
# centroid : estimated position in the LIDAR sensor frame (z=0)
|
||||||
|
# velocity : estimated velocity vector, m/s, in the LIDAR sensor frame
|
||||||
|
# speed_mps : |velocity| magnitude (m/s)
|
||||||
|
# width_m : cluster bounding-box width (metres)
|
||||||
|
# depth_m : cluster bounding-box depth (metres)
|
||||||
|
# point_count : number of LIDAR returns in the cluster this frame
|
||||||
|
# confidence : Kalman track age confidence, 0–1 (1.0 after n_init frames)
|
||||||
|
# is_static : true when speed_mps < static_speed_threshold parameter
|
||||||
|
#
|
||||||
|
std_msgs/Header header
|
||||||
|
uint32 obstacle_id
|
||||||
|
geometry_msgs/Point centroid
|
||||||
|
geometry_msgs/Vector3 velocity
|
||||||
|
float32 speed_mps
|
||||||
|
float32 width_m
|
||||||
|
float32 depth_m
|
||||||
|
uint32 point_count
|
||||||
|
float32 confidence
|
||||||
|
bool is_static
|
||||||
@ -0,0 +1,4 @@
|
|||||||
|
# ObstacleVelocityArray.msg — all tracked obstacles with velocities (Issue #326)
|
||||||
|
#
|
||||||
|
std_msgs/Header header
|
||||||
|
ObstacleVelocity[] obstacles
|
||||||
48
jetson/ros2_ws/src/saltybot_scene_msgs/msg/PathEdges.msg
Normal file
48
jetson/ros2_ws/src/saltybot_scene_msgs/msg/PathEdges.msg
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
# PathEdges.msg — Lane/path edge detection result (Issue #339)
|
||||||
|
#
|
||||||
|
# All pixel coordinates are in the ROI frame:
|
||||||
|
# origin = top-left of the bottom-half ROI crop
|
||||||
|
# y=0 → roi_top of the full image; y increases downward; x increases rightward.
|
||||||
|
#
|
||||||
|
# Bird-eye coordinates are in the warped top-down perspective image.
|
||||||
|
|
||||||
|
std_msgs/Header header
|
||||||
|
|
||||||
|
# --- Raw Hough segments (ROI frame) ---
|
||||||
|
# Flat array of (x1, y1, x2, y2) tuples; length = 4 * line_count
|
||||||
|
float32[] segments_px
|
||||||
|
|
||||||
|
# Same segments warped to bird-eye view; length = 4 * line_count
|
||||||
|
float32[] segments_birdseye_px
|
||||||
|
|
||||||
|
# Number of Hough line segments detected
|
||||||
|
uint32 line_count
|
||||||
|
|
||||||
|
# --- Dominant left edge (ROI frame) ---
|
||||||
|
float32 left_x1
|
||||||
|
float32 left_y1
|
||||||
|
float32 left_x2
|
||||||
|
float32 left_y2
|
||||||
|
bool left_detected
|
||||||
|
|
||||||
|
# --- Dominant left edge (bird-eye frame) ---
|
||||||
|
float32 left_birdseye_x1
|
||||||
|
float32 left_birdseye_y1
|
||||||
|
float32 left_birdseye_x2
|
||||||
|
float32 left_birdseye_y2
|
||||||
|
|
||||||
|
# --- Dominant right edge (ROI frame) ---
|
||||||
|
float32 right_x1
|
||||||
|
float32 right_y1
|
||||||
|
float32 right_x2
|
||||||
|
float32 right_y2
|
||||||
|
bool right_detected
|
||||||
|
|
||||||
|
# --- Dominant right edge (bird-eye frame) ---
|
||||||
|
float32 right_birdseye_x1
|
||||||
|
float32 right_birdseye_y1
|
||||||
|
float32 right_birdseye_x2
|
||||||
|
float32 right_birdseye_y2
|
||||||
|
|
||||||
|
# y-offset of ROI in the full image (pixels from top)
|
||||||
|
uint32 roi_top
|
||||||
13
jetson/ros2_ws/src/saltybot_scene_msgs/msg/TargetTrack.msg
Normal file
13
jetson/ros2_ws/src/saltybot_scene_msgs/msg/TargetTrack.msg
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
std_msgs/Header header
|
||||||
|
bool tracking_active # false when no target is locked
|
||||||
|
uint32 track_id # persistent ID across frames
|
||||||
|
float32 bearing_deg # horizontal bearing to target (°, right=+, left=−)
|
||||||
|
float32 distance_m # range to target (m); 0 = unknown / depth invalid
|
||||||
|
float32 confidence # 0.0–1.0 overall track quality
|
||||||
|
int32 bbox_x # colour-image bounding box (pixels, top-left origin)
|
||||||
|
int32 bbox_y
|
||||||
|
int32 bbox_w
|
||||||
|
int32 bbox_h
|
||||||
|
float32 vel_bearing_dps # bearing rate (°/s, from Kalman velocity state)
|
||||||
|
float32 vel_dist_mps # distance rate (m/s, + = moving away)
|
||||||
|
uint8 depth_quality # 0=invalid 1=extrapolated 2=marginal 3=good
|
||||||
13
jetson/ros2_ws/src/saltybot_scene_msgs/msg/UwbTarget.msg
Normal file
13
jetson/ros2_ws/src/saltybot_scene_msgs/msg/UwbTarget.msg
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
std_msgs/Header header
|
||||||
|
bool valid # false when no recent UWB fix
|
||||||
|
float32 bearing_deg # horizontal bearing to tag (°, right=+, left=−)
|
||||||
|
float32 distance_m # range to tag (m); arithmetic mean of both anchors
|
||||||
|
float32 confidence # 0.0–1.0; degrades when anchors disagree or stale
|
||||||
|
|
||||||
|
# Raw per-anchor two-way-ranging distances
|
||||||
|
float32 anchor0_dist_m # left anchor (anchor index 0)
|
||||||
|
float32 anchor1_dist_m # right anchor (anchor index 1)
|
||||||
|
|
||||||
|
# Derived geometry
|
||||||
|
float32 baseline_m # measured anchor separation (m) — used for sanity check
|
||||||
|
uint8 fix_quality # 0=no fix 1=single-anchor 2=dual-anchor
|
||||||
@ -0,0 +1,20 @@
|
|||||||
|
rosbag_recorder_node:
|
||||||
|
ros__parameters:
|
||||||
|
trigger_topic: "/saltybot/record_trigger"
|
||||||
|
status_topic: "/saltybot/recording_status"
|
||||||
|
|
||||||
|
# Comma-separated topic list to record.
|
||||||
|
# Empty string = record all topics (ros2 bag record --all).
|
||||||
|
# Example: "/saltybot/camera_status,/saltybot/wake_word_detected,/cmd_vel"
|
||||||
|
topics: ""
|
||||||
|
|
||||||
|
bag_dir: "/tmp/saltybot_bags" # output directory (created if absent)
|
||||||
|
bag_prefix: "saltybot" # filename prefix; timestamp appended
|
||||||
|
|
||||||
|
auto_stop_s: 60.0 # auto-stop after N seconds; 0 = disabled
|
||||||
|
stop_timeout_s: 5.0 # force-kill if subprocess won't stop within N s
|
||||||
|
|
||||||
|
compression: false # enable zstd file-level compression
|
||||||
|
max_bag_size_mb: 0.0 # split bags at this size (MiB); 0 = no limit
|
||||||
|
|
||||||
|
poll_rate: 2.0 # state-machine check frequency (Hz)
|
||||||
13
jetson/ros2_ws/src/saltybot_social/config/sysmon_params.yaml
Normal file
13
jetson/ros2_ws/src/saltybot_social/config/sysmon_params.yaml
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
sysmon_node:
|
||||||
|
ros__parameters:
|
||||||
|
publish_rate: 1.0 # resource publish frequency (Hz)
|
||||||
|
disk_path: "/" # filesystem path for disk usage
|
||||||
|
|
||||||
|
# Jetson Orin GPU load sysfs path (per-mille or percent depending on kernel)
|
||||||
|
gpu_sysfs_path: "/sys/devices/gpu.0/load"
|
||||||
|
|
||||||
|
# Glob patterns for thermal zone discovery
|
||||||
|
thermal_glob: "/sys/devices/virtual/thermal/thermal_zone*/temp"
|
||||||
|
thermal_type_glob: "/sys/devices/virtual/thermal/thermal_zone*/type"
|
||||||
|
|
||||||
|
output_topic: "/saltybot/system_resources"
|
||||||
@ -0,0 +1,47 @@
|
|||||||
|
"""rosbag_recorder.launch.py — Launch trigger-based bag recorder (Issue #332).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ros2 launch saltybot_social rosbag_recorder.launch.py
|
||||||
|
ros2 launch saltybot_social rosbag_recorder.launch.py auto_stop_s:=120.0
|
||||||
|
ros2 launch saltybot_social rosbag_recorder.launch.py \\
|
||||||
|
topics:="/saltybot/camera_status,/cmd_vel" bag_dir:=/data/bags
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from ament_index_python.packages import get_package_share_directory
|
||||||
|
from launch import LaunchDescription
|
||||||
|
from launch.actions import DeclareLaunchArgument
|
||||||
|
from launch.substitutions import LaunchConfiguration
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launch_description():
|
||||||
|
pkg = get_package_share_directory("saltybot_social")
|
||||||
|
cfg = os.path.join(pkg, "config", "rosbag_recorder_params.yaml")
|
||||||
|
|
||||||
|
return LaunchDescription([
|
||||||
|
DeclareLaunchArgument("topics", default_value="",
|
||||||
|
description="Comma-separated topics (empty=all)"),
|
||||||
|
DeclareLaunchArgument("bag_dir", default_value="/tmp/saltybot_bags",
|
||||||
|
description="Output directory for bag files"),
|
||||||
|
DeclareLaunchArgument("auto_stop_s", default_value="60.0",
|
||||||
|
description="Auto-stop timeout in seconds (0=off)"),
|
||||||
|
DeclareLaunchArgument("compression", default_value="false",
|
||||||
|
description="Enable zstd compression"),
|
||||||
|
|
||||||
|
Node(
|
||||||
|
package="saltybot_social",
|
||||||
|
executable="rosbag_recorder_node",
|
||||||
|
name="rosbag_recorder_node",
|
||||||
|
output="screen",
|
||||||
|
parameters=[
|
||||||
|
cfg,
|
||||||
|
{
|
||||||
|
"topics": LaunchConfiguration("topics"),
|
||||||
|
"bag_dir": LaunchConfiguration("bag_dir"),
|
||||||
|
"auto_stop_s": LaunchConfiguration("auto_stop_s"),
|
||||||
|
"compression": LaunchConfiguration("compression"),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
])
|
||||||
40
jetson/ros2_ws/src/saltybot_social/launch/sysmon.launch.py
Normal file
40
jetson/ros2_ws/src/saltybot_social/launch/sysmon.launch.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
"""sysmon.launch.py — Launch system resource monitor (Issue #355).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ros2 launch saltybot_social sysmon.launch.py
|
||||||
|
ros2 launch saltybot_social sysmon.launch.py publish_rate:=2.0
|
||||||
|
ros2 launch saltybot_social sysmon.launch.py disk_path:=/data
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from ament_index_python.packages import get_package_share_directory
|
||||||
|
from launch import LaunchDescription
|
||||||
|
from launch.actions import DeclareLaunchArgument
|
||||||
|
from launch.substitutions import LaunchConfiguration
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launch_description():
|
||||||
|
pkg = get_package_share_directory("saltybot_social")
|
||||||
|
cfg = os.path.join(pkg, "config", "sysmon_params.yaml")
|
||||||
|
|
||||||
|
return LaunchDescription([
|
||||||
|
DeclareLaunchArgument("publish_rate", default_value="1.0",
|
||||||
|
description="Resource publish frequency (Hz)"),
|
||||||
|
DeclareLaunchArgument("disk_path", default_value="/",
|
||||||
|
description="Filesystem path for disk usage"),
|
||||||
|
|
||||||
|
Node(
|
||||||
|
package="saltybot_social",
|
||||||
|
executable="sysmon_node",
|
||||||
|
name="sysmon_node",
|
||||||
|
output="screen",
|
||||||
|
parameters=[
|
||||||
|
cfg,
|
||||||
|
{
|
||||||
|
"publish_rate": LaunchConfiguration("publish_rate"),
|
||||||
|
"disk_path": LaunchConfiguration("disk_path"),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
])
|
||||||
@ -0,0 +1,373 @@
|
|||||||
|
"""rosbag_recorder_node.py — Trigger-based ROS2 bag recorder.
|
||||||
|
Issue #332
|
||||||
|
|
||||||
|
Subscribes to /saltybot/record_trigger (Bool). True starts recording;
|
||||||
|
False stops it. Auto-stop fires after ``auto_stop_s`` seconds if still
|
||||||
|
running. Recording is performed by spawning a ``ros2 bag record``
|
||||||
|
subprocess which is sent SIGINT for graceful shutdown and SIGKILL if it
|
||||||
|
does not exit within ``stop_timeout_s``.
|
||||||
|
|
||||||
|
Status values
|
||||||
|
─────────────
|
||||||
|
"idle" — not recording
|
||||||
|
"recording" — subprocess active, writing to bag file
|
||||||
|
"stopping" — SIGINT sent, waiting for subprocess to exit
|
||||||
|
"error" — subprocess died unexpectedly; new trigger retries
|
||||||
|
|
||||||
|
Subscriptions
|
||||||
|
─────────────
|
||||||
|
/saltybot/record_trigger std_msgs/Bool — True = start, False = stop
|
||||||
|
|
||||||
|
Publications
|
||||||
|
────────────
|
||||||
|
/saltybot/recording_status std_msgs/String — status value (see above)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
──────────
|
||||||
|
trigger_topic (str, "/saltybot/record_trigger")
|
||||||
|
status_topic (str, "/saltybot/recording_status")
|
||||||
|
topics (str, "") comma-separated topic list;
|
||||||
|
empty string → record all topics (-a)
|
||||||
|
bag_dir (str, "/tmp/saltybot_bags")
|
||||||
|
bag_prefix (str, "saltybot")
|
||||||
|
auto_stop_s (float, 60.0) 0 = no auto-stop
|
||||||
|
stop_timeout_s (float, 5.0) force-kill after this many seconds
|
||||||
|
compression (bool, False) enable zstd file compression
|
||||||
|
max_bag_size_mb (float, 0.0) 0 = unlimited
|
||||||
|
poll_rate (float, 2.0) state-machine check frequency (Hz)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import subprocess
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile
|
||||||
|
from std_msgs.msg import Bool, String
|
||||||
|
|
||||||
|
|
||||||
|
# ── Status constants ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
STATUS_IDLE = "idle"
|
||||||
|
STATUS_RECORDING = "recording"
|
||||||
|
STATUS_STOPPING = "stopping"
|
||||||
|
STATUS_ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Pure helpers ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def make_bag_path(bag_dir: str, prefix: str) -> str:
|
||||||
|
"""Return a timestamped output path for a new bag (no file created)."""
|
||||||
|
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
return os.path.join(bag_dir, f"{prefix}_{ts}")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_topics(topics_str: str) -> List[str]:
|
||||||
|
"""Parse a comma-separated topic string into a clean list.
|
||||||
|
|
||||||
|
Returns an empty list when *topics_str* is blank (meaning record-all).
|
||||||
|
"""
|
||||||
|
if not topics_str or not topics_str.strip():
|
||||||
|
return []
|
||||||
|
return [t.strip() for t in topics_str.split(",") if t.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def compute_recording_transition(
|
||||||
|
state: str,
|
||||||
|
trigger: Optional[bool],
|
||||||
|
proc_running: bool,
|
||||||
|
now: float,
|
||||||
|
record_start_t: float,
|
||||||
|
stop_start_t: float,
|
||||||
|
auto_stop_s: float,
|
||||||
|
stop_timeout_s: float,
|
||||||
|
) -> Tuple[str, bool]:
|
||||||
|
"""Pure state-machine step — no I/O, no ROS.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
state : current status string
|
||||||
|
trigger : latest trigger value (True=start, False=stop, None=none)
|
||||||
|
proc_running : whether the recorder subprocess is alive
|
||||||
|
now : current monotonic time (s)
|
||||||
|
record_start_t : monotonic time recording began (0 if not recording)
|
||||||
|
stop_start_t : monotonic time STOPPING began (0 if not stopping)
|
||||||
|
auto_stop_s : auto-stop after this many seconds (0 = disabled)
|
||||||
|
stop_timeout_s : force-kill if stopping > this long (0 = disabled)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
(new_state, force_kill)
|
||||||
|
force_kill=True signals the caller to SIGKILL the process.
|
||||||
|
"""
|
||||||
|
if state == STATUS_IDLE:
|
||||||
|
if trigger is True:
|
||||||
|
return STATUS_RECORDING, False
|
||||||
|
return STATUS_IDLE, False
|
||||||
|
|
||||||
|
if state == STATUS_RECORDING:
|
||||||
|
if not proc_running:
|
||||||
|
return STATUS_ERROR, False
|
||||||
|
if trigger is False:
|
||||||
|
return STATUS_STOPPING, False
|
||||||
|
if (auto_stop_s > 0 and record_start_t > 0
|
||||||
|
and (now - record_start_t) >= auto_stop_s):
|
||||||
|
return STATUS_STOPPING, False
|
||||||
|
return STATUS_RECORDING, False
|
||||||
|
|
||||||
|
if state == STATUS_STOPPING:
|
||||||
|
if not proc_running:
|
||||||
|
return STATUS_IDLE, False
|
||||||
|
if (stop_timeout_s > 0 and stop_start_t > 0
|
||||||
|
and (now - stop_start_t) >= stop_timeout_s):
|
||||||
|
return STATUS_IDLE, True # force-kill
|
||||||
|
return STATUS_STOPPING, False
|
||||||
|
|
||||||
|
# STATUS_ERROR
|
||||||
|
if trigger is True:
|
||||||
|
return STATUS_RECORDING, False
|
||||||
|
return STATUS_ERROR, False
|
||||||
|
|
||||||
|
|
||||||
|
# ── Subprocess wrapper ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class BagRecorderProcess:
|
||||||
|
"""Thin wrapper around a ``ros2 bag record`` subprocess."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._proc: Optional[subprocess.Popen] = None
|
||||||
|
|
||||||
|
def start(self, topics: List[str], output_path: str,
|
||||||
|
compression: bool = False,
|
||||||
|
max_size_mb: float = 0.0) -> bool:
|
||||||
|
"""Launch the recorder. Returns False if already running or on error."""
|
||||||
|
if self.is_running():
|
||||||
|
return False
|
||||||
|
|
||||||
|
cmd = ["ros2", "bag", "record", "--output", output_path]
|
||||||
|
if topics:
|
||||||
|
cmd += topics
|
||||||
|
else:
|
||||||
|
cmd += ["--all"]
|
||||||
|
if compression:
|
||||||
|
cmd += ["--compression-mode", "file",
|
||||||
|
"--compression-format", "zstd"]
|
||||||
|
if max_size_mb > 0:
|
||||||
|
cmd += ["--max-bag-size",
|
||||||
|
str(int(max_size_mb * 1024 * 1024))]
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._proc = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
stdout=subprocess.DEVNULL,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
preexec_fn=os.setsid,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
self._proc = None
|
||||||
|
return False
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Send SIGINT to the process group for graceful shutdown."""
|
||||||
|
if self._proc is not None and self._proc.poll() is None:
|
||||||
|
try:
|
||||||
|
os.killpg(os.getpgid(self._proc.pid), signal.SIGINT)
|
||||||
|
except (ProcessLookupError, PermissionError, OSError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def kill(self) -> None:
|
||||||
|
"""Send SIGKILL to the process group for forced shutdown."""
|
||||||
|
if self._proc is not None:
|
||||||
|
try:
|
||||||
|
os.killpg(os.getpgid(self._proc.pid), signal.SIGKILL)
|
||||||
|
except (ProcessLookupError, PermissionError, OSError):
|
||||||
|
pass
|
||||||
|
self._proc = None
|
||||||
|
|
||||||
|
def is_running(self) -> bool:
|
||||||
|
return self._proc is not None and self._proc.poll() is None
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Clear internal proc reference (call after graceful exit)."""
|
||||||
|
self._proc = None
|
||||||
|
|
||||||
|
|
||||||
|
# ── ROS2 node ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class RosbagRecorderNode(Node):
|
||||||
|
"""Trigger-based ROS bag recorder with auto-stop and status reporting."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__("rosbag_recorder_node")
|
||||||
|
|
||||||
|
self.declare_parameter("trigger_topic", "/saltybot/record_trigger")
|
||||||
|
self.declare_parameter("status_topic", "/saltybot/recording_status")
|
||||||
|
self.declare_parameter("topics", "")
|
||||||
|
self.declare_parameter("bag_dir", "/tmp/saltybot_bags")
|
||||||
|
self.declare_parameter("bag_prefix", "saltybot")
|
||||||
|
self.declare_parameter("auto_stop_s", 60.0)
|
||||||
|
self.declare_parameter("stop_timeout_s", 5.0)
|
||||||
|
self.declare_parameter("compression", False)
|
||||||
|
self.declare_parameter("max_bag_size_mb", 0.0)
|
||||||
|
self.declare_parameter("poll_rate", 2.0)
|
||||||
|
|
||||||
|
trigger_topic = self.get_parameter("trigger_topic").value
|
||||||
|
status_topic = self.get_parameter("status_topic").value
|
||||||
|
topics_str = self.get_parameter("topics").value
|
||||||
|
self._bag_dir = str(self.get_parameter("bag_dir").value)
|
||||||
|
self._bag_prefix = str(self.get_parameter("bag_prefix").value)
|
||||||
|
self._auto_stop_s = float(self.get_parameter("auto_stop_s").value)
|
||||||
|
self._stop_tmo_s = float(self.get_parameter("stop_timeout_s").value)
|
||||||
|
self._compression = bool(self.get_parameter("compression").value)
|
||||||
|
self._max_mb = float(self.get_parameter("max_bag_size_mb").value)
|
||||||
|
poll_rate = float(self.get_parameter("poll_rate").value)
|
||||||
|
|
||||||
|
self._topics: List[str] = parse_topics(str(topics_str))
|
||||||
|
|
||||||
|
# Recorder process — injectable for tests
|
||||||
|
self._recorder: BagRecorderProcess = BagRecorderProcess()
|
||||||
|
|
||||||
|
# State
|
||||||
|
self._state = STATUS_IDLE
|
||||||
|
self._trigger: Optional[bool] = None
|
||||||
|
self._record_start_t: float = 0.0
|
||||||
|
self._stop_start_t: float = 0.0
|
||||||
|
self._current_bag: str = ""
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
qos = QoSProfile(depth=10)
|
||||||
|
self._status_pub = self.create_publisher(String, status_topic, qos)
|
||||||
|
self._trigger_sub = self.create_subscription(
|
||||||
|
Bool, trigger_topic, self._on_trigger, qos
|
||||||
|
)
|
||||||
|
self._timer = self.create_timer(1.0 / poll_rate, self._poll_cb)
|
||||||
|
|
||||||
|
# Publish initial state
|
||||||
|
self._publish(STATUS_IDLE)
|
||||||
|
|
||||||
|
topic_desc = ",".join(self._topics) if self._topics else "<all>"
|
||||||
|
self.get_logger().info(
|
||||||
|
f"RosbagRecorderNode ready — topics={topic_desc}, "
|
||||||
|
f"bag_dir={self._bag_dir}, auto_stop={self._auto_stop_s}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Subscription ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _on_trigger(self, msg) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self._trigger = bool(msg.data)
|
||||||
|
|
||||||
|
# ── Poll / state machine ────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _poll_cb(self) -> None:
|
||||||
|
now = time.monotonic()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
trigger = self._trigger
|
||||||
|
self._trigger = None # consume
|
||||||
|
state = self._state
|
||||||
|
rec_t = self._record_start_t
|
||||||
|
stop_t = self._stop_start_t
|
||||||
|
|
||||||
|
proc_running = self._recorder.is_running()
|
||||||
|
|
||||||
|
new_state, force_kill = compute_recording_transition(
|
||||||
|
state, trigger, proc_running, now,
|
||||||
|
rec_t, stop_t, self._auto_stop_s, self._stop_tmo_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
if force_kill:
|
||||||
|
self._recorder.kill()
|
||||||
|
self.get_logger().warn(
|
||||||
|
f"RosbagRecorder: force-killed (stop_timeout={self._stop_tmo_s}s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_state != state:
|
||||||
|
self._enter_state(new_state, now)
|
||||||
|
|
||||||
|
def _enter_state(self, new_state: str, now: float) -> None:
|
||||||
|
if new_state == STATUS_RECORDING:
|
||||||
|
bag_path = make_bag_path(self._bag_dir, self._bag_prefix)
|
||||||
|
started = self._recorder.start(
|
||||||
|
self._topics, bag_path,
|
||||||
|
compression=self._compression,
|
||||||
|
max_size_mb=self._max_mb,
|
||||||
|
)
|
||||||
|
if not started:
|
||||||
|
new_state = STATUS_ERROR
|
||||||
|
self.get_logger().error(
|
||||||
|
"RosbagRecorder: failed to start recorder subprocess"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with self._lock:
|
||||||
|
self._record_start_t = now
|
||||||
|
self._current_bag = bag_path
|
||||||
|
self.get_logger().info(
|
||||||
|
f"RosbagRecorder: recording started → {bag_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif new_state == STATUS_STOPPING:
|
||||||
|
self._recorder.stop()
|
||||||
|
with self._lock:
|
||||||
|
self._stop_start_t = now
|
||||||
|
self.get_logger().info("RosbagRecorder: stopping (SIGINT sent)")
|
||||||
|
|
||||||
|
elif new_state == STATUS_IDLE:
|
||||||
|
bag = self._current_bag
|
||||||
|
with self._lock:
|
||||||
|
self._record_start_t = 0.0
|
||||||
|
self._stop_start_t = 0.0
|
||||||
|
self._current_bag = ""
|
||||||
|
self._recorder.reset()
|
||||||
|
self.get_logger().info(
|
||||||
|
f"RosbagRecorder: recording complete → {bag}"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif new_state == STATUS_ERROR:
|
||||||
|
self.get_logger().error(
|
||||||
|
"RosbagRecorder: subprocess exited unexpectedly"
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self._state = new_state
|
||||||
|
|
||||||
|
self._publish(new_state)
|
||||||
|
|
||||||
|
# ── Publish ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _publish(self, status: str) -> None:
|
||||||
|
msg = String()
|
||||||
|
msg.data = status
|
||||||
|
self._status_pub.publish(msg)
|
||||||
|
|
||||||
|
# ── Public accessors ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self) -> str:
|
||||||
|
with self._lock:
|
||||||
|
return self._state
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_bag(self) -> str:
|
||||||
|
with self._lock:
|
||||||
|
return self._current_bag
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = RosbagRecorderNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
@ -0,0 +1,337 @@
|
|||||||
|
"""sysmon_node.py — System resource monitor for Jetson Orin.
|
||||||
|
Issue #355
|
||||||
|
|
||||||
|
Reads CPU, GPU, RAM, disk usage, and thermal temperatures, then
|
||||||
|
publishes a JSON string to /saltybot/system_resources at a configurable
|
||||||
|
rate (default 1 Hz).
|
||||||
|
|
||||||
|
All reads use /proc and /sys where available; GPU load falls back to -1.0
|
||||||
|
if the sysfs path is absent (non-Jetson host).
|
||||||
|
|
||||||
|
JSON payload
|
||||||
|
────────────
|
||||||
|
{
|
||||||
|
"ts": 1234567890.123, // epoch seconds (float)
|
||||||
|
"cpu_percent": [45.2, 32.1], // per-core %; index 0 = aggregate
|
||||||
|
"cpu_avg_percent": 38.6, // mean of per-core values
|
||||||
|
"ram_total_mb": 16384.0,
|
||||||
|
"ram_used_mb": 4096.0,
|
||||||
|
"ram_percent": 25.0,
|
||||||
|
"disk_total_gb": 64.0,
|
||||||
|
"disk_used_gb": 12.5,
|
||||||
|
"disk_percent": 19.5,
|
||||||
|
"gpu_percent": 42.0, // -1.0 if unavailable
|
||||||
|
"thermal": {"CPU-therm": 47.5, "GPU-therm": 43.2}
|
||||||
|
}
|
||||||
|
|
||||||
|
Publications
|
||||||
|
────────────
|
||||||
|
/saltybot/system_resources std_msgs/String JSON payload
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
──────────
|
||||||
|
publish_rate (float, 1.0) publish frequency (Hz)
|
||||||
|
disk_path (str, "/") path for statvfs disk usage
|
||||||
|
gpu_sysfs_path (str, "/sys/devices/gpu.0/load")
|
||||||
|
thermal_glob (str, "/sys/devices/virtual/thermal/thermal_zone*/temp")
|
||||||
|
thermal_type_glob (str, "/sys/devices/virtual/thermal/thermal_zone*/type")
|
||||||
|
output_topic (str, "/saltybot/system_resources")
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import glob as _glob_mod
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile
|
||||||
|
from std_msgs.msg import String
|
||||||
|
|
||||||
|
|
||||||
|
# ── Pure helpers ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def parse_proc_stat(content: str) -> List[List[int]]:
|
||||||
|
"""Parse /proc/stat and return per-cpu jiffies lists.
|
||||||
|
|
||||||
|
Returns a list where index 0 = "cpu" (aggregate) and index 1+ = cpu0, cpu1, …
|
||||||
|
Each entry is [user, nice, system, idle, iowait, irq, softirq, steal].
|
||||||
|
"""
|
||||||
|
result: List[List[int]] = []
|
||||||
|
for line in content.splitlines():
|
||||||
|
if not line.startswith("cpu"):
|
||||||
|
break
|
||||||
|
parts = line.split()
|
||||||
|
# parts[0] is "cpu" or "cpu0", parts[1:] are jiffie fields
|
||||||
|
fields = [int(x) for x in parts[1:9]] # take up to 8 fields
|
||||||
|
while len(fields) < 8:
|
||||||
|
fields.append(0)
|
||||||
|
result.append(fields)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def cpu_percent_from_stats(
|
||||||
|
prev: List[List[int]],
|
||||||
|
curr: List[List[int]],
|
||||||
|
) -> List[float]:
|
||||||
|
"""Compute per-cpu busy percentage from two /proc/stat snapshots.
|
||||||
|
|
||||||
|
Index 0 = aggregate (from "cpu" line), index 1+ = per-core.
|
||||||
|
Returns empty list if inputs are incompatible.
|
||||||
|
"""
|
||||||
|
if not prev or not curr or len(prev) != len(curr):
|
||||||
|
return []
|
||||||
|
|
||||||
|
result: List[float] = []
|
||||||
|
for p, c in zip(prev, curr):
|
||||||
|
idle_p = p[3] + p[4] # idle + iowait
|
||||||
|
idle_c = c[3] + c[4]
|
||||||
|
total_p = sum(p)
|
||||||
|
total_c = sum(c)
|
||||||
|
d_total = total_c - total_p
|
||||||
|
d_idle = idle_c - idle_p
|
||||||
|
if d_total <= 0:
|
||||||
|
result.append(0.0)
|
||||||
|
else:
|
||||||
|
pct = 100.0 * (d_total - d_idle) / d_total
|
||||||
|
result.append(round(max(0.0, min(100.0, pct)), 2))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def parse_meminfo(content: str) -> Dict[str, int]:
|
||||||
|
"""Parse /proc/meminfo, return mapping of key → value in kB."""
|
||||||
|
info: Dict[str, int] = {}
|
||||||
|
for line in content.splitlines():
|
||||||
|
if ":" not in line:
|
||||||
|
continue
|
||||||
|
key, _, rest = line.partition(":")
|
||||||
|
value_str = rest.strip().split()[0] if rest.strip() else "0"
|
||||||
|
try:
|
||||||
|
info[key.strip()] = int(value_str)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def compute_ram_stats(meminfo: Dict[str, int]) -> Tuple[float, float, float]:
|
||||||
|
"""Return (total_mb, used_mb, percent) from parsed /proc/meminfo.
|
||||||
|
|
||||||
|
Uses MemAvailable when present; falls back to MemFree.
|
||||||
|
"""
|
||||||
|
total_kb = meminfo.get("MemTotal", 0)
|
||||||
|
avail_kb = meminfo.get("MemAvailable", meminfo.get("MemFree", 0))
|
||||||
|
used_kb = max(0, total_kb - avail_kb)
|
||||||
|
total_mb = round(total_kb / 1024.0, 2)
|
||||||
|
used_mb = round(used_kb / 1024.0, 2)
|
||||||
|
if total_kb > 0:
|
||||||
|
percent = round(100.0 * used_kb / total_kb, 2)
|
||||||
|
else:
|
||||||
|
percent = 0.0
|
||||||
|
return total_mb, used_mb, percent
|
||||||
|
|
||||||
|
|
||||||
|
def read_disk_usage(path: str) -> Tuple[float, float, float]:
|
||||||
|
"""Return (total_gb, used_gb, percent) for the filesystem at *path*.
|
||||||
|
|
||||||
|
Returns (-1.0, -1.0, -1.0) on error.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
st = os.statvfs(path)
|
||||||
|
total = st.f_frsize * st.f_blocks
|
||||||
|
free = st.f_frsize * st.f_bavail
|
||||||
|
used = total - free
|
||||||
|
total_gb = round(total / (1024 ** 3), 3)
|
||||||
|
used_gb = round(used / (1024 ** 3), 3)
|
||||||
|
pct = round(100.0 * used / total, 2) if total > 0 else 0.0
|
||||||
|
return total_gb, used_gb, pct
|
||||||
|
except Exception:
|
||||||
|
return -1.0, -1.0, -1.0
|
||||||
|
|
||||||
|
|
||||||
|
def read_gpu_load(path: str) -> float:
|
||||||
|
"""Read Jetson GPU load from sysfs (0–100) or -1.0 if unavailable."""
|
||||||
|
try:
|
||||||
|
with open(path, "r") as fh:
|
||||||
|
raw = fh.read().strip()
|
||||||
|
# Some Jetson kernels report 0–1000 (per-mille), others 0–100
|
||||||
|
val = int(raw)
|
||||||
|
if val > 100:
|
||||||
|
val = round(val / 10.0, 1)
|
||||||
|
return float(max(0.0, min(100.0, val)))
|
||||||
|
except Exception:
|
||||||
|
return -1.0
|
||||||
|
|
||||||
|
|
||||||
|
def read_thermal_zones(
|
||||||
|
temp_glob: str,
|
||||||
|
type_glob: str,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""Glob thermal zone sysfs paths and return {zone_type: temp_C}.
|
||||||
|
|
||||||
|
Falls back to numeric zone names ("zone0", "zone1", …) when type
|
||||||
|
files are absent.
|
||||||
|
"""
|
||||||
|
temp_paths = sorted(_glob_mod.glob(temp_glob))
|
||||||
|
result: Dict[str, float] = {}
|
||||||
|
|
||||||
|
for tp in temp_paths:
|
||||||
|
# Derive zone directory from temp path: …/thermal_zoneN/temp
|
||||||
|
zone_dir = os.path.dirname(tp)
|
||||||
|
zone_index = os.path.basename(zone_dir).replace("thermal_zone", "zone")
|
||||||
|
|
||||||
|
# Try to read the human-readable type
|
||||||
|
type_path = os.path.join(zone_dir, "type")
|
||||||
|
try:
|
||||||
|
with open(type_path, "r") as fh:
|
||||||
|
zone_name = fh.read().strip()
|
||||||
|
except Exception:
|
||||||
|
zone_name = zone_index
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(tp, "r") as fh:
|
||||||
|
milli_c = int(fh.read().strip())
|
||||||
|
result[zone_name] = round(milli_c / 1000.0, 1)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ── ROS2 node ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class SysmonNode(Node):
|
||||||
|
"""Publishes Jetson system resource usage as a JSON String at a fixed rate."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__("sysmon_node")
|
||||||
|
|
||||||
|
self.declare_parameter("publish_rate", 1.0)
|
||||||
|
self.declare_parameter("disk_path", "/")
|
||||||
|
self.declare_parameter("gpu_sysfs_path", "/sys/devices/gpu.0/load")
|
||||||
|
self.declare_parameter(
|
||||||
|
"thermal_glob",
|
||||||
|
"/sys/devices/virtual/thermal/thermal_zone*/temp"
|
||||||
|
)
|
||||||
|
self.declare_parameter(
|
||||||
|
"thermal_type_glob",
|
||||||
|
"/sys/devices/virtual/thermal/thermal_zone*/type"
|
||||||
|
)
|
||||||
|
self.declare_parameter("output_topic", "/saltybot/system_resources")
|
||||||
|
|
||||||
|
rate = float(self.get_parameter("publish_rate").value)
|
||||||
|
self._disk = self.get_parameter("disk_path").value
|
||||||
|
self._gpu_path = self.get_parameter("gpu_sysfs_path").value
|
||||||
|
self._th_temp = self.get_parameter("thermal_glob").value
|
||||||
|
self._th_type = self.get_parameter("thermal_type_glob").value
|
||||||
|
topic = self.get_parameter("output_topic").value
|
||||||
|
|
||||||
|
# Injectable I/O functions for offline testing
|
||||||
|
self._read_proc_stat = self._default_read_proc_stat
|
||||||
|
self._read_meminfo = self._default_read_meminfo
|
||||||
|
self._read_disk_usage = read_disk_usage
|
||||||
|
self._read_gpu_load = read_gpu_load
|
||||||
|
self._read_thermal = read_thermal_zones
|
||||||
|
|
||||||
|
# CPU stats require two samples; prime with first read
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._prev_stat: Optional[List[List[int]]] = self._read_proc_stat()
|
||||||
|
|
||||||
|
qos = QoSProfile(depth=10)
|
||||||
|
self._pub = self.create_publisher(String, topic, qos)
|
||||||
|
self._timer = self.create_timer(1.0 / rate, self._tick)
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f"SysmonNode ready — publishing {topic} at {rate:.1f} Hz"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Default I/O readers ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _default_read_proc_stat() -> Optional[List[List[int]]]:
|
||||||
|
try:
|
||||||
|
with open("/proc/stat", "r") as fh:
|
||||||
|
return parse_proc_stat(fh.read())
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _default_read_meminfo() -> str:
|
||||||
|
try:
|
||||||
|
with open("/proc/meminfo", "r") as fh:
|
||||||
|
return fh.read()
|
||||||
|
except Exception:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# ── Timer callback ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _tick(self) -> None:
|
||||||
|
ts = time.time()
|
||||||
|
|
||||||
|
# CPU
|
||||||
|
curr_stat = self._read_proc_stat()
|
||||||
|
with self._lock:
|
||||||
|
prev_stat = self._prev_stat
|
||||||
|
self._prev_stat = curr_stat
|
||||||
|
|
||||||
|
if prev_stat is not None and curr_stat is not None:
|
||||||
|
cpu_pcts = cpu_percent_from_stats(prev_stat, curr_stat)
|
||||||
|
else:
|
||||||
|
cpu_pcts = []
|
||||||
|
|
||||||
|
# Use per-core values (skip index 0 = aggregate) for avg
|
||||||
|
per_core = cpu_pcts[1:] if len(cpu_pcts) > 1 else cpu_pcts
|
||||||
|
cpu_avg = round(sum(per_core) / len(per_core), 2) if per_core else 0.0
|
||||||
|
|
||||||
|
# RAM
|
||||||
|
meminfo = parse_meminfo(self._read_meminfo())
|
||||||
|
ram_total, ram_used, ram_pct = compute_ram_stats(meminfo)
|
||||||
|
|
||||||
|
# Disk
|
||||||
|
disk_total, disk_used, disk_pct = self._read_disk_usage(self._disk)
|
||||||
|
|
||||||
|
# GPU
|
||||||
|
gpu_pct = self._read_gpu_load(self._gpu_path)
|
||||||
|
|
||||||
|
# Thermal
|
||||||
|
thermal = self._read_thermal(self._th_temp, self._th_type)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"ts": ts,
|
||||||
|
"cpu_percent": cpu_pcts,
|
||||||
|
"cpu_avg_percent": cpu_avg,
|
||||||
|
"ram_total_mb": ram_total,
|
||||||
|
"ram_used_mb": ram_used,
|
||||||
|
"ram_percent": ram_pct,
|
||||||
|
"disk_total_gb": disk_total,
|
||||||
|
"disk_used_gb": disk_used,
|
||||||
|
"disk_percent": disk_pct,
|
||||||
|
"gpu_percent": gpu_pct,
|
||||||
|
"thermal": thermal,
|
||||||
|
}
|
||||||
|
|
||||||
|
msg = String()
|
||||||
|
msg.data = json.dumps(payload)
|
||||||
|
self._pub.publish(msg)
|
||||||
|
|
||||||
|
# ── Public accessor ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prev_stat(self) -> Optional[List[List[int]]]:
|
||||||
|
with self._lock:
|
||||||
|
return self._prev_stat
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = SysmonNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
@ -61,6 +61,10 @@ setup(
|
|||||||
'wake_word_node = saltybot_social.wake_word_node:main',
|
'wake_word_node = saltybot_social.wake_word_node:main',
|
||||||
# USB camera hot-plug monitor (Issue #320)
|
# USB camera hot-plug monitor (Issue #320)
|
||||||
'camera_hotplug_node = saltybot_social.camera_hotplug_node:main',
|
'camera_hotplug_node = saltybot_social.camera_hotplug_node:main',
|
||||||
|
# Trigger-based ROS2 bag recorder (Issue #332)
|
||||||
|
'rosbag_recorder_node = saltybot_social.rosbag_recorder_node:main',
|
||||||
|
# System resource monitor for Jetson Orin (Issue #355)
|
||||||
|
'sysmon_node = saltybot_social.sysmon_node:main',
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
719
jetson/ros2_ws/src/saltybot_social/test/test_rosbag_recorder.py
Normal file
719
jetson/ros2_ws/src/saltybot_social/test/test_rosbag_recorder.py
Normal file
@ -0,0 +1,719 @@
|
|||||||
|
"""test_rosbag_recorder.py — Offline tests for rosbag_recorder_node (Issue #332).
|
||||||
|
|
||||||
|
Stubs out rclpy and ROS message types.
|
||||||
|
BagRecorderProcess is replaced with MockRecorder — no subprocesses are spawned.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
# ── ROS2 stubs ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _make_ros_stubs():
|
||||||
|
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
|
||||||
|
"std_msgs", "std_msgs.msg"):
|
||||||
|
if mod_name not in sys.modules:
|
||||||
|
sys.modules[mod_name] = types.ModuleType(mod_name)
|
||||||
|
|
||||||
|
class _Node:
|
||||||
|
def __init__(self, name="node"):
|
||||||
|
self._name = name
|
||||||
|
if not hasattr(self, "_params"):
|
||||||
|
self._params = {}
|
||||||
|
self._pubs = {}
|
||||||
|
self._subs = {}
|
||||||
|
self._logs = []
|
||||||
|
self._timers = []
|
||||||
|
|
||||||
|
def declare_parameter(self, name, default):
|
||||||
|
if name not in self._params:
|
||||||
|
self._params[name] = default
|
||||||
|
|
||||||
|
def get_parameter(self, name):
|
||||||
|
class _P:
|
||||||
|
def __init__(self, v): self.value = v
|
||||||
|
return _P(self._params.get(name))
|
||||||
|
|
||||||
|
def create_publisher(self, msg_type, topic, qos):
|
||||||
|
pub = _FakePub()
|
||||||
|
self._pubs[topic] = pub
|
||||||
|
return pub
|
||||||
|
|
||||||
|
def create_subscription(self, msg_type, topic, cb, qos):
|
||||||
|
self._subs[topic] = cb
|
||||||
|
return object()
|
||||||
|
|
||||||
|
def create_timer(self, period, cb):
|
||||||
|
self._timers.append(cb)
|
||||||
|
return object()
|
||||||
|
|
||||||
|
def get_logger(self):
|
||||||
|
node = self
|
||||||
|
class _L:
|
||||||
|
def info(self, m): node._logs.append(("INFO", m))
|
||||||
|
def warn(self, m): node._logs.append(("WARN", m))
|
||||||
|
def error(self, m): node._logs.append(("ERROR", m))
|
||||||
|
return _L()
|
||||||
|
|
||||||
|
def destroy_node(self): pass
|
||||||
|
|
||||||
|
class _FakePub:
|
||||||
|
def __init__(self):
|
||||||
|
self.msgs = []
|
||||||
|
def publish(self, msg):
|
||||||
|
self.msgs.append(msg)
|
||||||
|
|
||||||
|
class _QoSProfile:
|
||||||
|
def __init__(self, depth=10): self.depth = depth
|
||||||
|
|
||||||
|
class _Bool:
|
||||||
|
def __init__(self): self.data = False
|
||||||
|
|
||||||
|
class _String:
|
||||||
|
def __init__(self): self.data = ""
|
||||||
|
|
||||||
|
rclpy_mod = sys.modules["rclpy"]
|
||||||
|
rclpy_mod.init = lambda args=None: None
|
||||||
|
rclpy_mod.spin = lambda node: None
|
||||||
|
rclpy_mod.shutdown = lambda: None
|
||||||
|
|
||||||
|
sys.modules["rclpy.node"].Node = _Node
|
||||||
|
sys.modules["rclpy.qos"].QoSProfile = _QoSProfile
|
||||||
|
sys.modules["std_msgs.msg"].Bool = _Bool
|
||||||
|
sys.modules["std_msgs.msg"].String = _String
|
||||||
|
|
||||||
|
return _Node, _FakePub, _Bool, _String
|
||||||
|
|
||||||
|
|
||||||
|
_Node, _FakePub, _Bool, _String = _make_ros_stubs()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Module loader ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_SRC = (
|
||||||
|
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||||
|
"saltybot_social/saltybot_social/rosbag_recorder_node.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_mod():
|
||||||
|
spec = importlib.util.spec_from_file_location("rosbag_recorder_testmod", _SRC)
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(mod)
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
class _MockRecorder:
|
||||||
|
"""Injectable replacement for BagRecorderProcess."""
|
||||||
|
|
||||||
|
def __init__(self, start_succeeds: bool = True) -> None:
|
||||||
|
self.start_succeeds = start_succeeds
|
||||||
|
self._running = False
|
||||||
|
self.calls: list = []
|
||||||
|
|
||||||
|
def start(self, topics, output_path, compression=False, max_size_mb=0.0):
|
||||||
|
self.calls.append(("start", list(topics), output_path))
|
||||||
|
if self.start_succeeds:
|
||||||
|
self._running = True
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self.calls.append(("stop",))
|
||||||
|
self._running = False # immediately "gone" for deterministic tests
|
||||||
|
|
||||||
|
def kill(self):
|
||||||
|
self.calls.append(("kill",))
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
def is_running(self):
|
||||||
|
return self._running
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.calls.append(("reset",))
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
def call_types(self):
|
||||||
|
return [c[0] for c in self.calls]
|
||||||
|
|
||||||
|
|
||||||
|
def _make_node(mod, recorder=None, **kwargs):
|
||||||
|
node = mod.RosbagRecorderNode.__new__(mod.RosbagRecorderNode)
|
||||||
|
defaults = {
|
||||||
|
"trigger_topic": "/saltybot/record_trigger",
|
||||||
|
"status_topic": "/saltybot/recording_status",
|
||||||
|
"topics": "",
|
||||||
|
"bag_dir": "/tmp/test_bags",
|
||||||
|
"bag_prefix": "test",
|
||||||
|
"auto_stop_s": 60.0,
|
||||||
|
"stop_timeout_s": 5.0,
|
||||||
|
"compression": False,
|
||||||
|
"max_bag_size_mb": 0.0,
|
||||||
|
"poll_rate": 2.0,
|
||||||
|
}
|
||||||
|
defaults.update(kwargs)
|
||||||
|
node._params = dict(defaults)
|
||||||
|
mod.RosbagRecorderNode.__init__(node)
|
||||||
|
if recorder is not None:
|
||||||
|
node._recorder = recorder
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
def _trigger(node, value: bool):
|
||||||
|
"""Deliver a Bool trigger message."""
|
||||||
|
msg = _Bool()
|
||||||
|
msg.data = value
|
||||||
|
node._subs["/saltybot/record_trigger"](msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _pub(node):
|
||||||
|
return node._pubs["/saltybot/recording_status"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests: pure helpers ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestMakeBagPath(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls): cls.mod = _load_mod()
|
||||||
|
|
||||||
|
def test_contains_prefix(self):
|
||||||
|
p = self.mod.make_bag_path("/tmp/bags", "saltybot")
|
||||||
|
self.assertIn("saltybot", p)
|
||||||
|
|
||||||
|
def test_contains_bag_dir(self):
|
||||||
|
p = self.mod.make_bag_path("/tmp/bags", "saltybot")
|
||||||
|
self.assertTrue(p.startswith("/tmp/bags"))
|
||||||
|
|
||||||
|
def test_unique_per_call(self):
|
||||||
|
# Two calls in tight succession may share a second but that's fine;
|
||||||
|
# just ensure the function doesn't crash and returns a string.
|
||||||
|
p1 = self.mod.make_bag_path("/tmp", "t")
|
||||||
|
self.assertIsInstance(p1, str)
|
||||||
|
self.assertTrue(len(p1) > 0)
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseTopics(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls): cls.mod = _load_mod()
|
||||||
|
|
||||||
|
def test_empty_string_returns_empty_list(self):
|
||||||
|
self.assertEqual(self.mod.parse_topics(""), [])
|
||||||
|
|
||||||
|
def test_whitespace_only_returns_empty(self):
|
||||||
|
self.assertEqual(self.mod.parse_topics(" "), [])
|
||||||
|
|
||||||
|
def test_single_topic(self):
|
||||||
|
self.assertEqual(self.mod.parse_topics("/topic/foo"), ["/topic/foo"])
|
||||||
|
|
||||||
|
def test_multiple_topics(self):
|
||||||
|
r = self.mod.parse_topics("/a,/b,/c")
|
||||||
|
self.assertEqual(r, ["/a", "/b", "/c"])
|
||||||
|
|
||||||
|
def test_strips_whitespace(self):
|
||||||
|
r = self.mod.parse_topics(" /a , /b ")
|
||||||
|
self.assertEqual(r, ["/a", "/b"])
|
||||||
|
|
||||||
|
def test_ignores_empty_segments(self):
|
||||||
|
r = self.mod.parse_topics("/a,,/b")
|
||||||
|
self.assertEqual(r, ["/a", "/b"])
|
||||||
|
|
||||||
|
|
||||||
|
class TestStatusConstants(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls): cls.mod = _load_mod()
|
||||||
|
|
||||||
|
def test_idle(self): self.assertEqual(self.mod.STATUS_IDLE, "idle")
|
||||||
|
def test_recording(self): self.assertEqual(self.mod.STATUS_RECORDING, "recording")
|
||||||
|
def test_stopping(self): self.assertEqual(self.mod.STATUS_STOPPING, "stopping")
|
||||||
|
def test_error(self): self.assertEqual(self.mod.STATUS_ERROR, "error")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests: compute_recording_transition ──────────────────────────────────────
|
||||||
|
|
||||||
|
class TestComputeRecordingTransition(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls): cls.mod = _load_mod()
|
||||||
|
|
||||||
|
def _tr(self, state, trigger=None, proc_running=True,
|
||||||
|
now=100.0, rec_t=0.0, stop_t=0.0,
|
||||||
|
auto_stop=60.0, stop_tmo=5.0):
|
||||||
|
return self.mod.compute_recording_transition(
|
||||||
|
state, trigger, proc_running, now,
|
||||||
|
rec_t, stop_t, auto_stop, stop_tmo,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── IDLE ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_idle_no_trigger_stays_idle(self):
|
||||||
|
s, fk = self._tr("idle", trigger=None)
|
||||||
|
self.assertEqual(s, "idle"); self.assertFalse(fk)
|
||||||
|
|
||||||
|
def test_idle_false_trigger_stays_idle(self):
|
||||||
|
s, fk = self._tr("idle", trigger=False)
|
||||||
|
self.assertEqual(s, "idle"); self.assertFalse(fk)
|
||||||
|
|
||||||
|
def test_idle_true_trigger_starts_recording(self):
|
||||||
|
s, fk = self._tr("idle", trigger=True)
|
||||||
|
self.assertEqual(s, "recording"); self.assertFalse(fk)
|
||||||
|
|
||||||
|
# ── RECORDING ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_recording_stable_no_change(self):
|
||||||
|
s, fk = self._tr("recording", proc_running=True)
|
||||||
|
self.assertEqual(s, "recording"); self.assertFalse(fk)
|
||||||
|
|
||||||
|
def test_recording_false_trigger_stops(self):
|
||||||
|
s, fk = self._tr("recording", trigger=False, proc_running=True)
|
||||||
|
self.assertEqual(s, "stopping"); self.assertFalse(fk)
|
||||||
|
|
||||||
|
def test_recording_proc_dies_error(self):
|
||||||
|
s, fk = self._tr("recording", proc_running=False)
|
||||||
|
self.assertEqual(s, "error"); self.assertFalse(fk)
|
||||||
|
|
||||||
|
def test_recording_auto_stop_fires(self):
|
||||||
|
# started at t=40, now=t=101 → 61 s elapsed > auto_stop=60
|
||||||
|
s, fk = self._tr("recording", proc_running=True,
|
||||||
|
now=101.0, rec_t=40.0, auto_stop=60.0)
|
||||||
|
self.assertEqual(s, "stopping"); self.assertFalse(fk)
|
||||||
|
|
||||||
|
def test_recording_auto_stop_not_yet(self):
|
||||||
|
# started at t=50, now=100 → 50 s < 60 s
|
||||||
|
s, fk = self._tr("recording", proc_running=True,
|
||||||
|
now=100.0, rec_t=50.0, auto_stop=60.0)
|
||||||
|
self.assertEqual(s, "recording")
|
||||||
|
|
||||||
|
def test_recording_auto_stop_at_exactly_timeout(self):
|
||||||
|
s, fk = self._tr("recording", proc_running=True,
|
||||||
|
now=110.0, rec_t=50.0, auto_stop=60.0)
|
||||||
|
self.assertEqual(s, "stopping")
|
||||||
|
|
||||||
|
def test_recording_auto_stop_disabled(self):
|
||||||
|
# auto_stop_s=0 → never auto-stops
|
||||||
|
s, fk = self._tr("recording", proc_running=True,
|
||||||
|
now=9999.0, rec_t=0.0, auto_stop=0.0)
|
||||||
|
self.assertEqual(s, "recording")
|
||||||
|
|
||||||
|
# ── STOPPING ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_stopping_proc_running_stays(self):
|
||||||
|
s, fk = self._tr("stopping", proc_running=True)
|
||||||
|
self.assertEqual(s, "stopping"); self.assertFalse(fk)
|
||||||
|
|
||||||
|
def test_stopping_proc_exits_idle(self):
|
||||||
|
s, fk = self._tr("stopping", proc_running=False)
|
||||||
|
self.assertEqual(s, "idle"); self.assertFalse(fk)
|
||||||
|
|
||||||
|
def test_stopping_force_kill_after_timeout(self):
|
||||||
|
# entered stopping at t=95, now=101 → 6 s > stop_tmo=5
|
||||||
|
s, fk = self._tr("stopping", proc_running=True,
|
||||||
|
now=101.0, stop_t=95.0, stop_tmo=5.0)
|
||||||
|
self.assertEqual(s, "idle"); self.assertTrue(fk)
|
||||||
|
|
||||||
|
def test_stopping_not_yet_force_kill(self):
|
||||||
|
# entered at t=98, now=100 → 2 s < 5 s
|
||||||
|
s, fk = self._tr("stopping", proc_running=True,
|
||||||
|
now=100.0, stop_t=98.0, stop_tmo=5.0)
|
||||||
|
self.assertEqual(s, "stopping"); self.assertFalse(fk)
|
||||||
|
|
||||||
|
def test_stopping_timeout_disabled(self):
|
||||||
|
# stop_tmo=0 → never force-kills
|
||||||
|
s, fk = self._tr("stopping", proc_running=True,
|
||||||
|
now=9999.0, stop_t=0.0, stop_tmo=0.0)
|
||||||
|
self.assertEqual(s, "stopping"); self.assertFalse(fk)
|
||||||
|
|
||||||
|
def test_stopping_force_kill_exactly_at_timeout(self):
|
||||||
|
s, fk = self._tr("stopping", proc_running=True,
|
||||||
|
now=100.0, stop_t=95.0, stop_tmo=5.0)
|
||||||
|
self.assertEqual(s, "idle"); self.assertTrue(fk)
|
||||||
|
|
||||||
|
# ── ERROR ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_error_stays_without_trigger(self):
|
||||||
|
s, fk = self._tr("error", trigger=None)
|
||||||
|
self.assertEqual(s, "error"); self.assertFalse(fk)
|
||||||
|
|
||||||
|
def test_error_false_trigger_stays(self):
|
||||||
|
s, fk = self._tr("error", trigger=False)
|
||||||
|
self.assertEqual(s, "error")
|
||||||
|
|
||||||
|
def test_error_true_trigger_retries(self):
|
||||||
|
s, fk = self._tr("error", trigger=True)
|
||||||
|
self.assertEqual(s, "recording"); self.assertFalse(fk)
|
||||||
|
|
||||||
|
# ── Return shape ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_returns_tuple_of_two(self):
|
||||||
|
result = self._tr("idle")
|
||||||
|
self.assertEqual(len(result), 2)
|
||||||
|
|
||||||
|
def test_force_kill_is_bool(self):
|
||||||
|
_, fk = self._tr("idle")
|
||||||
|
self.assertIsInstance(fk, bool)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests: node init ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestNodeInit(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls): cls.mod = _load_mod()
|
||||||
|
|
||||||
|
def test_instantiates(self):
|
||||||
|
self.assertIsNotNone(_make_node(self.mod))
|
||||||
|
|
||||||
|
def test_initial_state_idle(self):
|
||||||
|
node = _make_node(self.mod)
|
||||||
|
self.assertEqual(node.state, "idle")
|
||||||
|
|
||||||
|
def test_publishes_initial_idle(self):
|
||||||
|
node = _make_node(self.mod)
|
||||||
|
pub = _pub(node)
|
||||||
|
self.assertEqual(pub.msgs[0].data, "idle")
|
||||||
|
|
||||||
|
def test_publisher_registered(self):
|
||||||
|
node = _make_node(self.mod)
|
||||||
|
self.assertIn("/saltybot/recording_status", node._pubs)
|
||||||
|
|
||||||
|
def test_subscriber_registered(self):
|
||||||
|
node = _make_node(self.mod)
|
||||||
|
self.assertIn("/saltybot/record_trigger", node._subs)
|
||||||
|
|
||||||
|
def test_timer_registered(self):
|
||||||
|
node = _make_node(self.mod)
|
||||||
|
self.assertGreater(len(node._timers), 0)
|
||||||
|
|
||||||
|
def test_custom_topics(self):
|
||||||
|
node = _make_node(self.mod,
|
||||||
|
trigger_topic="/my/trigger",
|
||||||
|
status_topic="/my/status")
|
||||||
|
self.assertIn("/my/trigger", node._subs)
|
||||||
|
self.assertIn("/my/status", node._pubs)
|
||||||
|
|
||||||
|
def test_topics_parsed_correctly(self):
|
||||||
|
node = _make_node(self.mod, topics="/a,/b,/c")
|
||||||
|
self.assertEqual(node._topics, ["/a", "/b", "/c"])
|
||||||
|
|
||||||
|
def test_empty_topics_means_all(self):
|
||||||
|
node = _make_node(self.mod, topics="")
|
||||||
|
self.assertEqual(node._topics, [])
|
||||||
|
|
||||||
|
def test_auto_stop_s_stored(self):
|
||||||
|
node = _make_node(self.mod, auto_stop_s=30.0)
|
||||||
|
self.assertAlmostEqual(node._auto_stop_s, 30.0)
|
||||||
|
|
||||||
|
def test_current_bag_empty_initially(self):
|
||||||
|
node = _make_node(self.mod)
|
||||||
|
self.assertEqual(node.current_bag, "")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests: _on_trigger ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestOnTrigger(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls): cls.mod = _load_mod()
|
||||||
|
|
||||||
|
def test_stores_true_trigger(self):
|
||||||
|
node = _make_node(self.mod)
|
||||||
|
_trigger(node, True)
|
||||||
|
with node._lock:
|
||||||
|
self.assertTrue(node._trigger)
|
||||||
|
|
||||||
|
def test_stores_false_trigger(self):
|
||||||
|
node = _make_node(self.mod)
|
||||||
|
_trigger(node, False)
|
||||||
|
with node._lock:
|
||||||
|
self.assertFalse(node._trigger)
|
||||||
|
|
||||||
|
def test_trigger_consumed_after_poll(self):
|
||||||
|
rec = _MockRecorder()
|
||||||
|
node = _make_node(self.mod, recorder=rec)
|
||||||
|
_trigger(node, True)
|
||||||
|
node._poll_cb()
|
||||||
|
with node._lock:
|
||||||
|
self.assertIsNone(node._trigger)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests: poll loop — full state machine ────────────────────────────────────
|
||||||
|
|
||||||
|
class TestPollCb(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls): cls.mod = _load_mod()
|
||||||
|
|
||||||
|
def _node(self, **kwargs):
|
||||||
|
rec = _MockRecorder()
|
||||||
|
node = _make_node(self.mod, recorder=rec, **kwargs)
|
||||||
|
return node, rec
|
||||||
|
|
||||||
|
def test_true_trigger_starts_recording(self):
|
||||||
|
node, rec = self._node()
|
||||||
|
_trigger(node, True)
|
||||||
|
node._poll_cb()
|
||||||
|
self.assertEqual(node.state, "recording")
|
||||||
|
self.assertIn("start", rec.call_types())
|
||||||
|
|
||||||
|
def test_recording_publishes_status(self):
|
||||||
|
node, rec = self._node()
|
||||||
|
_trigger(node, True)
|
||||||
|
node._poll_cb()
|
||||||
|
self.assertEqual(_pub(node).msgs[-1].data, "recording")
|
||||||
|
|
||||||
|
def test_false_trigger_stops_recording(self):
|
||||||
|
node, rec = self._node()
|
||||||
|
_trigger(node, True); node._poll_cb() # → recording
|
||||||
|
_trigger(node, False); node._poll_cb() # → stopping
|
||||||
|
self.assertEqual(node.state, "stopping")
|
||||||
|
self.assertIn("stop", rec.call_types())
|
||||||
|
|
||||||
|
def test_stopping_publishes_status(self):
|
||||||
|
node, rec = self._node()
|
||||||
|
_trigger(node, True); node._poll_cb()
|
||||||
|
_trigger(node, False); node._poll_cb()
|
||||||
|
self.assertEqual(_pub(node).msgs[-1].data, "stopping")
|
||||||
|
|
||||||
|
def test_after_stop_proc_exit_idle(self):
|
||||||
|
"""Once the mock recorder stops, next poll resolves to idle."""
|
||||||
|
node, rec = self._node()
|
||||||
|
_trigger(node, True); node._poll_cb() # recording
|
||||||
|
_trigger(node, False); node._poll_cb() # stopping (rec.stop() sets running=False)
|
||||||
|
node._poll_cb() # proc not running → idle
|
||||||
|
self.assertEqual(node.state, "idle")
|
||||||
|
|
||||||
|
def test_idle_publishes_after_stop(self):
|
||||||
|
node, rec = self._node()
|
||||||
|
_trigger(node, True); node._poll_cb()
|
||||||
|
_trigger(node, False); node._poll_cb()
|
||||||
|
node._poll_cb()
|
||||||
|
self.assertEqual(_pub(node).msgs[-1].data, "idle")
|
||||||
|
|
||||||
|
def test_auto_stop_triggers_stopping(self):
|
||||||
|
node, rec = self._node(auto_stop_s=1.0)
|
||||||
|
_trigger(node, True)
|
||||||
|
node._poll_cb() # → recording
|
||||||
|
# Back-date start time so auto-stop fires
|
||||||
|
with node._lock:
|
||||||
|
node._record_start_t = time.monotonic() - 10.0
|
||||||
|
node._poll_cb() # → stopping
|
||||||
|
self.assertEqual(node.state, "stopping")
|
||||||
|
|
||||||
|
def test_auto_stop_disabled(self):
|
||||||
|
node, rec = self._node(auto_stop_s=0.0)
|
||||||
|
_trigger(node, True)
|
||||||
|
node._poll_cb()
|
||||||
|
with node._lock:
|
||||||
|
node._record_start_t = time.monotonic() - 9999.0
|
||||||
|
node._poll_cb()
|
||||||
|
# Should still be recording (auto-stop disabled)
|
||||||
|
self.assertEqual(node.state, "recording")
|
||||||
|
|
||||||
|
def test_proc_dies_error(self):
|
||||||
|
node, rec = self._node()
|
||||||
|
_trigger(node, True); node._poll_cb() # → recording
|
||||||
|
rec._running = False # simulate unexpected exit
|
||||||
|
node._poll_cb()
|
||||||
|
self.assertEqual(node.state, "error")
|
||||||
|
|
||||||
|
def test_error_publishes_status(self):
|
||||||
|
node, rec = self._node()
|
||||||
|
_trigger(node, True); node._poll_cb()
|
||||||
|
rec._running = False
|
||||||
|
node._poll_cb()
|
||||||
|
self.assertEqual(_pub(node).msgs[-1].data, "error")
|
||||||
|
|
||||||
|
def test_error_retries_on_true_trigger(self):
|
||||||
|
node, rec = self._node()
|
||||||
|
_trigger(node, True); node._poll_cb() # recording
|
||||||
|
rec._running = False
|
||||||
|
node._poll_cb() # error
|
||||||
|
_trigger(node, True); node._poll_cb() # retry → recording
|
||||||
|
self.assertEqual(node.state, "recording")
|
||||||
|
|
||||||
|
def test_start_failure_enters_error(self):
|
||||||
|
rec = _MockRecorder(start_succeeds=False)
|
||||||
|
node = _make_node(self.mod, recorder=rec)
|
||||||
|
_trigger(node, True)
|
||||||
|
node._poll_cb()
|
||||||
|
self.assertEqual(node.state, "error")
|
||||||
|
|
||||||
|
def test_force_kill_on_stop_timeout(self):
|
||||||
|
"""Stubborn process that ignores SIGINT → force-killed after timeout."""
|
||||||
|
class _StubbornRecorder(_MockRecorder):
|
||||||
|
def stop(self):
|
||||||
|
self.calls.append(("stop",))
|
||||||
|
# Don't set _running = False — simulates process ignoring SIGINT
|
||||||
|
|
||||||
|
stubborn = _StubbornRecorder()
|
||||||
|
node = _make_node(self.mod, recorder=stubborn, stop_timeout_s=2.0)
|
||||||
|
_trigger(node, True); node._poll_cb() # → recording
|
||||||
|
_trigger(node, False); node._poll_cb() # → stopping (process stays alive)
|
||||||
|
self.assertEqual(node.state, "stopping")
|
||||||
|
# Expire the stop timeout
|
||||||
|
with node._lock:
|
||||||
|
node._stop_start_t = time.monotonic() - 10.0
|
||||||
|
node._poll_cb() # → force kill → idle
|
||||||
|
self.assertIn("kill", stubborn.call_types())
|
||||||
|
self.assertEqual(node.state, "idle")
|
||||||
|
|
||||||
|
def test_bag_path_set_when_recording(self):
|
||||||
|
node, rec = self._node(bag_prefix="mytest")
|
||||||
|
_trigger(node, True)
|
||||||
|
node._poll_cb()
|
||||||
|
self.assertIn("mytest", node.current_bag)
|
||||||
|
|
||||||
|
def test_bag_path_cleared_after_idle(self):
|
||||||
|
node, rec = self._node()
|
||||||
|
_trigger(node, True); node._poll_cb()
|
||||||
|
_trigger(node, False); node._poll_cb()
|
||||||
|
node._poll_cb()
|
||||||
|
self.assertEqual(node.current_bag, "")
|
||||||
|
|
||||||
|
def test_topics_passed_to_recorder(self):
|
||||||
|
rec = _MockRecorder()
|
||||||
|
node = _make_node(self.mod, recorder=rec, topics="/a,/b")
|
||||||
|
_trigger(node, True)
|
||||||
|
node._poll_cb()
|
||||||
|
start_calls = [c for c in rec.calls if c[0] == "start"]
|
||||||
|
self.assertEqual(len(start_calls), 1)
|
||||||
|
self.assertEqual(start_calls[0][1], ["/a", "/b"])
|
||||||
|
|
||||||
|
def test_empty_topics_passes_empty_list(self):
|
||||||
|
rec = _MockRecorder()
|
||||||
|
node = _make_node(self.mod, recorder=rec, topics="")
|
||||||
|
_trigger(node, True)
|
||||||
|
node._poll_cb()
|
||||||
|
start_calls = [c for c in rec.calls if c[0] == "start"]
|
||||||
|
self.assertEqual(start_calls[0][1], [])
|
||||||
|
|
||||||
|
def test_recorder_reset_on_idle(self):
|
||||||
|
node, rec = self._node()
|
||||||
|
_trigger(node, True); node._poll_cb()
|
||||||
|
_trigger(node, False); node._poll_cb()
|
||||||
|
node._poll_cb() # idle
|
||||||
|
self.assertIn("reset", rec.call_types())
|
||||||
|
|
||||||
|
def test_full_lifecycle_status_sequence(self):
|
||||||
|
"""idle → recording → stopping → idle."""
|
||||||
|
node, rec = self._node()
|
||||||
|
pub = _pub(node)
|
||||||
|
|
||||||
|
_trigger(node, True); node._poll_cb()
|
||||||
|
_trigger(node, False); node._poll_cb()
|
||||||
|
node._poll_cb()
|
||||||
|
|
||||||
|
statuses = [m.data for m in pub.msgs]
|
||||||
|
self.assertIn("idle", statuses)
|
||||||
|
self.assertIn("recording", statuses)
|
||||||
|
self.assertIn("stopping", statuses)
|
||||||
|
|
||||||
|
def test_logging_on_start(self):
|
||||||
|
node, rec = self._node()
|
||||||
|
_trigger(node, True)
|
||||||
|
node._poll_cb()
|
||||||
|
infos = [m for lvl, m in node._logs if lvl == "INFO"]
|
||||||
|
self.assertTrue(any("recording" in m.lower() for m in infos))
|
||||||
|
|
||||||
|
def test_logging_on_stop(self):
|
||||||
|
node, rec = self._node()
|
||||||
|
_trigger(node, True); node._poll_cb()
|
||||||
|
_trigger(node, False); node._poll_cb()
|
||||||
|
infos = [m for lvl, m in node._logs if lvl == "INFO"]
|
||||||
|
self.assertTrue(any("stop" in m.lower() for m in infos))
|
||||||
|
|
||||||
|
def test_logging_on_error(self):
|
||||||
|
node, rec = self._node()
|
||||||
|
_trigger(node, True); node._poll_cb()
|
||||||
|
rec._running = False
|
||||||
|
node._poll_cb()
|
||||||
|
errors = [m for lvl, m in node._logs if lvl == "ERROR"]
|
||||||
|
self.assertTrue(len(errors) > 0)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests: source content ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestNodeSrc(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
with open(_SRC) as f: cls.src = f.read()
|
||||||
|
|
||||||
|
def test_issue_tag(self): self.assertIn("#332", self.src)
|
||||||
|
def test_trigger_topic(self): self.assertIn("/saltybot/record_trigger", self.src)
|
||||||
|
def test_status_topic(self): self.assertIn("/saltybot/recording_status", self.src)
|
||||||
|
def test_status_idle(self): self.assertIn('"idle"', self.src)
|
||||||
|
def test_status_recording(self): self.assertIn('"recording"', self.src)
|
||||||
|
def test_status_stopping(self): self.assertIn('"stopping"', self.src)
|
||||||
|
def test_status_error(self): self.assertIn('"error"', self.src)
|
||||||
|
def test_compute_transition_fn(self): self.assertIn("compute_recording_transition", self.src)
|
||||||
|
def test_bag_recorder_process(self): self.assertIn("BagRecorderProcess", self.src)
|
||||||
|
def test_make_bag_path(self): self.assertIn("make_bag_path", self.src)
|
||||||
|
def test_parse_topics(self): self.assertIn("parse_topics", self.src)
|
||||||
|
def test_auto_stop_param(self): self.assertIn("auto_stop_s", self.src)
|
||||||
|
def test_stop_timeout_param(self): self.assertIn("stop_timeout_s", self.src)
|
||||||
|
def test_compression_param(self): self.assertIn("compression", self.src)
|
||||||
|
def test_subprocess_used(self): self.assertIn("subprocess", self.src)
|
||||||
|
def test_sigint_used(self): self.assertIn("SIGINT", self.src)
|
||||||
|
def test_threading_lock(self): self.assertIn("threading.Lock", self.src)
|
||||||
|
def test_recorder_injectable(self): self.assertIn("_recorder", self.src)
|
||||||
|
def test_main_defined(self): self.assertIn("def main", self.src)
|
||||||
|
def test_bool_subscription(self): self.assertIn("Bool", self.src)
|
||||||
|
def test_string_publication(self): self.assertIn("String", self.src)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests: config / launch / setup ────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestConfig(unittest.TestCase):
|
||||||
|
_CONFIG = (
|
||||||
|
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||||
|
"saltybot_social/config/rosbag_recorder_params.yaml"
|
||||||
|
)
|
||||||
|
_LAUNCH = (
|
||||||
|
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||||
|
"saltybot_social/launch/rosbag_recorder.launch.py"
|
||||||
|
)
|
||||||
|
_SETUP = (
|
||||||
|
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||||
|
"saltybot_social/setup.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_config_exists(self):
|
||||||
|
import os; self.assertTrue(os.path.exists(self._CONFIG))
|
||||||
|
|
||||||
|
def test_config_auto_stop(self):
|
||||||
|
with open(self._CONFIG) as f: c = f.read()
|
||||||
|
self.assertIn("auto_stop_s", c)
|
||||||
|
|
||||||
|
def test_config_bag_dir(self):
|
||||||
|
with open(self._CONFIG) as f: c = f.read()
|
||||||
|
self.assertIn("bag_dir", c)
|
||||||
|
|
||||||
|
def test_config_topics(self):
|
||||||
|
with open(self._CONFIG) as f: c = f.read()
|
||||||
|
self.assertIn("topics", c)
|
||||||
|
|
||||||
|
def test_config_compression(self):
|
||||||
|
with open(self._CONFIG) as f: c = f.read()
|
||||||
|
self.assertIn("compression", c)
|
||||||
|
|
||||||
|
def test_config_stop_timeout(self):
|
||||||
|
with open(self._CONFIG) as f: c = f.read()
|
||||||
|
self.assertIn("stop_timeout_s", c)
|
||||||
|
|
||||||
|
def test_launch_exists(self):
|
||||||
|
import os; self.assertTrue(os.path.exists(self._LAUNCH))
|
||||||
|
|
||||||
|
def test_launch_auto_stop_arg(self):
|
||||||
|
with open(self._LAUNCH) as f: c = f.read()
|
||||||
|
self.assertIn("auto_stop_s", c)
|
||||||
|
|
||||||
|
def test_launch_topics_arg(self):
|
||||||
|
with open(self._LAUNCH) as f: c = f.read()
|
||||||
|
self.assertIn("topics", c)
|
||||||
|
|
||||||
|
def test_entry_point_in_setup(self):
|
||||||
|
with open(self._SETUP) as f: c = f.read()
|
||||||
|
self.assertIn("rosbag_recorder_node", c)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
791
jetson/ros2_ws/src/saltybot_social/test/test_sysmon.py
Normal file
791
jetson/ros2_ws/src/saltybot_social/test/test_sysmon.py
Normal file
@ -0,0 +1,791 @@
|
|||||||
|
"""test_sysmon.py — Offline tests for sysmon_node (Issue #355).
|
||||||
|
|
||||||
|
All tests run without ROS, /proc, or /sys via:
|
||||||
|
- pure-function unit tests
|
||||||
|
- ROS2 stub pattern (_make_node + injectable I/O functions)
|
||||||
|
|
||||||
|
Coverage
|
||||||
|
────────
|
||||||
|
parse_proc_stat — single / multi cpu / malformed
|
||||||
|
cpu_percent_from_stats — busy, idle, edge cases
|
||||||
|
parse_meminfo — standard / missing keys
|
||||||
|
compute_ram_stats — normal / zero total
|
||||||
|
read_disk_usage — statvfs mock / error path
|
||||||
|
read_gpu_load — normal / per-mille / unavailable
|
||||||
|
read_thermal_zones — normal / missing type file / bad temp
|
||||||
|
SysmonNode init — parameter wiring
|
||||||
|
SysmonNode._tick — full payload published, JSON valid
|
||||||
|
SysmonNode CPU delta — prev_stat used correctly
|
||||||
|
SysmonNode injectable I/O — all readers swappable
|
||||||
|
source / entry-point — file exists, main importable
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
import unittest
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
# ── ROS2 / rclpy stub ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _make_rclpy_stub():
|
||||||
|
rclpy = types.ModuleType("rclpy")
|
||||||
|
rclpy_node = types.ModuleType("rclpy.node")
|
||||||
|
rclpy_qos = types.ModuleType("rclpy.qos")
|
||||||
|
std_msgs = types.ModuleType("std_msgs")
|
||||||
|
std_msgs_msg = types.ModuleType("std_msgs.msg")
|
||||||
|
|
||||||
|
class _QoSProfile:
|
||||||
|
def __init__(self, **kw): pass
|
||||||
|
|
||||||
|
class _String:
|
||||||
|
def __init__(self): self.data = ""
|
||||||
|
|
||||||
|
class _Node:
|
||||||
|
def __init__(self, name, **kw):
|
||||||
|
if not hasattr(self, "_params"):
|
||||||
|
self._params = {}
|
||||||
|
if not hasattr(self, "_pubs"):
|
||||||
|
self._pubs = {}
|
||||||
|
self._timers = []
|
||||||
|
self._logger = MagicMock()
|
||||||
|
|
||||||
|
def declare_parameter(self, name, default=None):
|
||||||
|
if name not in self._params:
|
||||||
|
self._params[name] = default
|
||||||
|
|
||||||
|
def get_parameter(self, name):
|
||||||
|
m = MagicMock()
|
||||||
|
m.value = self._params.get(name)
|
||||||
|
return m
|
||||||
|
|
||||||
|
def create_publisher(self, msg_type, topic, qos):
|
||||||
|
pub = MagicMock()
|
||||||
|
pub.msgs = []
|
||||||
|
pub.publish = lambda msg: pub.msgs.append(msg)
|
||||||
|
self._pubs[topic] = pub
|
||||||
|
return pub
|
||||||
|
|
||||||
|
def create_timer(self, period, cb):
|
||||||
|
t = MagicMock()
|
||||||
|
t._cb = cb
|
||||||
|
self._timers.append(t)
|
||||||
|
return t
|
||||||
|
|
||||||
|
def get_logger(self):
|
||||||
|
return self._logger
|
||||||
|
|
||||||
|
def destroy_node(self): pass
|
||||||
|
|
||||||
|
rclpy_node.Node = _Node
|
||||||
|
rclpy_qos.QoSProfile = _QoSProfile
|
||||||
|
std_msgs_msg.String = _String
|
||||||
|
|
||||||
|
rclpy.init = lambda *a, **kw: None
|
||||||
|
rclpy.spin = lambda n: None
|
||||||
|
rclpy.shutdown = lambda: None
|
||||||
|
|
||||||
|
return rclpy, rclpy_node, rclpy_qos, std_msgs, std_msgs_msg
|
||||||
|
|
||||||
|
|
||||||
|
def _load_mod():
|
||||||
|
"""Load sysmon_node with ROS2 stubs, return the module."""
|
||||||
|
rclpy, rclpy_node, rclpy_qos, std_msgs, std_msgs_msg = _make_rclpy_stub()
|
||||||
|
sys.modules.setdefault("rclpy", rclpy)
|
||||||
|
sys.modules.setdefault("rclpy.node", rclpy_node)
|
||||||
|
sys.modules.setdefault("rclpy.qos", rclpy_qos)
|
||||||
|
sys.modules.setdefault("std_msgs", std_msgs)
|
||||||
|
sys.modules.setdefault("std_msgs.msg", std_msgs_msg)
|
||||||
|
|
||||||
|
mod_name = "saltybot_social.sysmon_node"
|
||||||
|
if mod_name in sys.modules:
|
||||||
|
del sys.modules[mod_name]
|
||||||
|
mod = importlib.import_module(mod_name)
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
def _make_node(mod, **params) -> "mod.SysmonNode":
|
||||||
|
"""Create a SysmonNode with default params overridden by *params*."""
|
||||||
|
defaults = {
|
||||||
|
"publish_rate": 1.0,
|
||||||
|
"disk_path": "/",
|
||||||
|
"gpu_sysfs_path": "/sys/devices/gpu.0/load",
|
||||||
|
"thermal_glob": "/sys/devices/virtual/thermal/thermal_zone*/temp",
|
||||||
|
"thermal_type_glob": "/sys/devices/virtual/thermal/thermal_zone*/type",
|
||||||
|
"output_topic": "/saltybot/system_resources",
|
||||||
|
}
|
||||||
|
defaults.update(params)
|
||||||
|
|
||||||
|
node = mod.SysmonNode.__new__(mod.SysmonNode)
|
||||||
|
node._params = defaults
|
||||||
|
# Stub I/O before __init__ to avoid real /proc reads
|
||||||
|
node._read_proc_stat = lambda: None
|
||||||
|
mod.SysmonNode.__init__(node)
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
# ── Sample /proc/stat content ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_STAT_2CORE = """\
|
||||||
|
cpu 100 0 50 850 0 0 0 0 0 0
|
||||||
|
cpu0 60 0 30 410 0 0 0 0 0 0
|
||||||
|
cpu1 40 0 20 440 0 0 0 0 0 0
|
||||||
|
intr 12345 ...
|
||||||
|
"""
|
||||||
|
|
||||||
|
_STAT_2CORE_V2 = """\
|
||||||
|
cpu 250 0 100 1000 0 0 0 0 0 0
|
||||||
|
cpu0 140 0 60 510 0 0 0 0 0 0
|
||||||
|
cpu1 110 0 40 490 0 0 0 0 0 0
|
||||||
|
intr 99999 ...
|
||||||
|
"""
|
||||||
|
|
||||||
|
_MEMINFO = """\
|
||||||
|
MemTotal: 16384000 kB
|
||||||
|
MemFree: 4096000 kB
|
||||||
|
MemAvailable: 8192000 kB
|
||||||
|
Buffers: 512000 kB
|
||||||
|
Cached: 2048000 kB
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
# parse_proc_stat
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestParseProcStat(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.mod = _load_mod()
|
||||||
|
|
||||||
|
def test_aggregate_line_parsed(self):
|
||||||
|
result = self.mod.parse_proc_stat(_STAT_2CORE)
|
||||||
|
# index 0 = aggregate "cpu" line
|
||||||
|
self.assertEqual(result[0], [100, 0, 50, 850, 0, 0, 0, 0])
|
||||||
|
|
||||||
|
def test_per_core_lines_parsed(self):
|
||||||
|
result = self.mod.parse_proc_stat(_STAT_2CORE)
|
||||||
|
self.assertEqual(len(result), 3) # agg + cpu0 + cpu1
|
||||||
|
self.assertEqual(result[1], [60, 0, 30, 410, 0, 0, 0, 0])
|
||||||
|
self.assertEqual(result[2], [40, 0, 20, 440, 0, 0, 0, 0])
|
||||||
|
|
||||||
|
def test_stops_at_non_cpu_line(self):
|
||||||
|
result = self.mod.parse_proc_stat(_STAT_2CORE)
|
||||||
|
# should not include "intr" line
|
||||||
|
self.assertEqual(len(result), 3)
|
||||||
|
|
||||||
|
def test_short_fields_padded(self):
|
||||||
|
content = "cpu 1 2 3\n"
|
||||||
|
result = self.mod.parse_proc_stat(content)
|
||||||
|
self.assertEqual(len(result[0]), 8)
|
||||||
|
self.assertEqual(result[0][:3], [1, 2, 3])
|
||||||
|
self.assertEqual(result[0][3:], [0, 0, 0, 0, 0])
|
||||||
|
|
||||||
|
def test_empty_content(self):
|
||||||
|
result = self.mod.parse_proc_stat("")
|
||||||
|
self.assertEqual(result, [])
|
||||||
|
|
||||||
|
def test_no_cpu_lines(self):
|
||||||
|
result = self.mod.parse_proc_stat("intr 12345\nmem 0\n")
|
||||||
|
self.assertEqual(result, [])
|
||||||
|
|
||||||
|
def test_single_cpu_no_cores(self):
|
||||||
|
content = "cpu 200 0 100 700 0 0 0 0\n"
|
||||||
|
result = self.mod.parse_proc_stat(content)
|
||||||
|
self.assertEqual(len(result), 1)
|
||||||
|
self.assertEqual(result[0][0], 200)
|
||||||
|
|
||||||
|
def test_extra_fields_truncated_to_8(self):
|
||||||
|
content = "cpu 1 2 3 4 5 6 7 8 9 10\n"
|
||||||
|
result = self.mod.parse_proc_stat(content)
|
||||||
|
self.assertEqual(len(result[0]), 8)
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
# cpu_percent_from_stats
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestCpuPercentFromStats(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.mod = _load_mod()
|
||||||
|
|
||||||
|
def _parse(self, content):
|
||||||
|
return self.mod.parse_proc_stat(content)
|
||||||
|
|
||||||
|
def test_basic_cpu_usage(self):
|
||||||
|
prev = self._parse(_STAT_2CORE)
|
||||||
|
curr = self._parse(_STAT_2CORE_V2)
|
||||||
|
pcts = self.mod.cpu_percent_from_stats(prev, curr)
|
||||||
|
# aggregate: delta = 300, idle delta = -150 → 50%
|
||||||
|
self.assertEqual(len(pcts), 3)
|
||||||
|
self.assertGreater(pcts[0], 0)
|
||||||
|
self.assertLessEqual(pcts[0], 100)
|
||||||
|
|
||||||
|
def test_all_idle(self):
|
||||||
|
prev = [[0, 0, 0, 1000, 0, 0, 0, 0]]
|
||||||
|
curr = [[0, 0, 0, 2000, 0, 0, 0, 0]]
|
||||||
|
pcts = self.mod.cpu_percent_from_stats(prev, curr)
|
||||||
|
self.assertEqual(pcts[0], 0.0)
|
||||||
|
|
||||||
|
def test_fully_busy(self):
|
||||||
|
prev = [[0, 0, 0, 0, 0, 0, 0, 0]]
|
||||||
|
curr = [[1000, 0, 0, 0, 0, 0, 0, 0]]
|
||||||
|
pcts = self.mod.cpu_percent_from_stats(prev, curr)
|
||||||
|
self.assertEqual(pcts[0], 100.0)
|
||||||
|
|
||||||
|
def test_empty_prev(self):
|
||||||
|
curr = self._parse(_STAT_2CORE)
|
||||||
|
result = self.mod.cpu_percent_from_stats([], curr)
|
||||||
|
self.assertEqual(result, [])
|
||||||
|
|
||||||
|
def test_mismatched_length(self):
|
||||||
|
prev = self._parse(_STAT_2CORE)
|
||||||
|
curr = self._parse(_STAT_2CORE)[:1]
|
||||||
|
result = self.mod.cpu_percent_from_stats(prev, curr)
|
||||||
|
self.assertEqual(result, [])
|
||||||
|
|
||||||
|
def test_zero_delta_total(self):
|
||||||
|
stat = [[100, 0, 50, 850, 0, 0, 0, 0]]
|
||||||
|
pcts = self.mod.cpu_percent_from_stats(stat, stat)
|
||||||
|
self.assertEqual(pcts[0], 0.0)
|
||||||
|
|
||||||
|
def test_values_clamped_0_to_100(self):
|
||||||
|
# Simulate counter wrap-around or bogus data
|
||||||
|
prev = [[9999, 0, 0, 0, 0, 0, 0, 0]]
|
||||||
|
curr = [[0, 0, 0, 1000, 0, 0, 0, 0]]
|
||||||
|
pcts = self.mod.cpu_percent_from_stats(prev, curr)
|
||||||
|
self.assertGreaterEqual(pcts[0], 0.0)
|
||||||
|
self.assertLessEqual(pcts[0], 100.0)
|
||||||
|
|
||||||
|
def test_multi_core_per_core_values(self):
|
||||||
|
prev = self._parse(_STAT_2CORE)
|
||||||
|
curr = self._parse(_STAT_2CORE_V2)
|
||||||
|
pcts = self.mod.cpu_percent_from_stats(prev, curr)
|
||||||
|
self.assertEqual(len(pcts), 3)
|
||||||
|
for p in pcts:
|
||||||
|
self.assertGreaterEqual(p, 0.0)
|
||||||
|
self.assertLessEqual(p, 100.0)
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
# parse_meminfo
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestParseMeminfo(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.mod = _load_mod()
|
||||||
|
|
||||||
|
def test_total_parsed(self):
|
||||||
|
info = self.mod.parse_meminfo(_MEMINFO)
|
||||||
|
self.assertEqual(info["MemTotal"], 16384000)
|
||||||
|
|
||||||
|
def test_available_parsed(self):
|
||||||
|
info = self.mod.parse_meminfo(_MEMINFO)
|
||||||
|
self.assertEqual(info["MemAvailable"], 8192000)
|
||||||
|
|
||||||
|
def test_empty_returns_empty(self):
|
||||||
|
info = self.mod.parse_meminfo("")
|
||||||
|
self.assertEqual(info, {})
|
||||||
|
|
||||||
|
def test_lines_without_colon_ignored(self):
|
||||||
|
info = self.mod.parse_meminfo("NoColon 1234\nKey: 42 kB\n")
|
||||||
|
self.assertNotIn("NoColon 1234", info)
|
||||||
|
self.assertEqual(info["Key"], 42)
|
||||||
|
|
||||||
|
def test_malformed_value_skipped(self):
|
||||||
|
info = self.mod.parse_meminfo("Bad: abc kB\nGood: 100 kB\n")
|
||||||
|
self.assertNotIn("Bad", info)
|
||||||
|
self.assertEqual(info["Good"], 100)
|
||||||
|
|
||||||
|
def test_multiple_keys(self):
|
||||||
|
info = self.mod.parse_meminfo(_MEMINFO)
|
||||||
|
self.assertIn("MemFree", info)
|
||||||
|
self.assertIn("Buffers", info)
|
||||||
|
self.assertIn("Cached", info)
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
# compute_ram_stats
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestComputeRamStats(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.mod = _load_mod()
|
||||||
|
|
||||||
|
def test_uses_mem_available(self):
|
||||||
|
info = self.mod.parse_meminfo(_MEMINFO)
|
||||||
|
total, used, pct = self.mod.compute_ram_stats(info)
|
||||||
|
# total = 16384000 kB = 16000 MB
|
||||||
|
self.assertAlmostEqual(total, 16000.0, delta=1)
|
||||||
|
# used = total - available = 16384000 - 8192000 = 8192000 kB = 8000 MB
|
||||||
|
self.assertAlmostEqual(used, 8000.0, delta=1)
|
||||||
|
|
||||||
|
def test_fallback_to_mem_free(self):
|
||||||
|
info = {"MemTotal": 1024, "MemFree": 512}
|
||||||
|
total, used, pct = self.mod.compute_ram_stats(info)
|
||||||
|
self.assertAlmostEqual(used, 0.5, delta=0.01) # 512 kB = 0.5 MB
|
||||||
|
|
||||||
|
def test_zero_total(self):
|
||||||
|
total, used, pct = self.mod.compute_ram_stats({})
|
||||||
|
self.assertEqual(total, 0.0)
|
||||||
|
self.assertEqual(pct, 0.0)
|
||||||
|
|
||||||
|
def test_percent_correct(self):
|
||||||
|
info = {"MemTotal": 1000, "MemAvailable": 750}
|
||||||
|
_, _, pct = self.mod.compute_ram_stats(info)
|
||||||
|
self.assertAlmostEqual(pct, 25.0, delta=0.1)
|
||||||
|
|
||||||
|
def test_fully_used(self):
|
||||||
|
info = {"MemTotal": 1000, "MemAvailable": 0}
|
||||||
|
_, _, pct = self.mod.compute_ram_stats(info)
|
||||||
|
self.assertEqual(pct, 100.0)
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
# read_disk_usage
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestReadDiskUsage(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.mod = _load_mod()
|
||||||
|
|
||||||
|
def test_root_filesystem_succeeds(self):
|
||||||
|
"""On any real POSIX host / should be readable."""
|
||||||
|
total, used, pct = self.mod.read_disk_usage("/")
|
||||||
|
self.assertGreater(total, 0.0)
|
||||||
|
self.assertGreaterEqual(used, 0.0)
|
||||||
|
self.assertGreaterEqual(pct, 0.0)
|
||||||
|
self.assertLessEqual(pct, 100.0)
|
||||||
|
|
||||||
|
def test_nonexistent_path_returns_minus_one(self):
|
||||||
|
total, used, pct = self.mod.read_disk_usage("/nonexistent_xyz_12345")
|
||||||
|
self.assertEqual(total, -1.0)
|
||||||
|
self.assertEqual(used, -1.0)
|
||||||
|
self.assertEqual(pct, -1.0)
|
||||||
|
|
||||||
|
def test_statvfs_mock(self):
|
||||||
|
fake = MagicMock()
|
||||||
|
fake.f_frsize = 4096
|
||||||
|
fake.f_blocks = 1000 # total = 4096000 bytes ~ 3.8 MB
|
||||||
|
fake.f_bavail = 250 # free = 1024000 bytes, used = 3072000
|
||||||
|
with patch("os.statvfs", return_value=fake):
|
||||||
|
total, used, pct = self.mod.read_disk_usage("/any")
|
||||||
|
total_b = 4096 * 1000
|
||||||
|
used_b = 4096 * 750
|
||||||
|
self.assertAlmostEqual(total, total_b / (1024**3), delta=0.001)
|
||||||
|
self.assertAlmostEqual(used, used_b / (1024**3), delta=0.001)
|
||||||
|
self.assertAlmostEqual(pct, 75.0, delta=0.1)
|
||||||
|
|
||||||
|
def test_zero_blocks_returns_zero_percent(self):
|
||||||
|
fake = MagicMock()
|
||||||
|
fake.f_frsize = 4096
|
||||||
|
fake.f_blocks = 0
|
||||||
|
fake.f_bavail = 0
|
||||||
|
with patch("os.statvfs", return_value=fake):
|
||||||
|
total, used, pct = self.mod.read_disk_usage("/any")
|
||||||
|
self.assertEqual(pct, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
# read_gpu_load
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestReadGpuLoad(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.mod = _load_mod()
|
||||||
|
|
||||||
|
def _make_tmpfile(self, content: str) -> str:
|
||||||
|
import tempfile
|
||||||
|
f = tempfile.NamedTemporaryFile("w", delete=False, suffix=".txt")
|
||||||
|
f.write(content)
|
||||||
|
f.flush()
|
||||||
|
f.close()
|
||||||
|
return f.name
|
||||||
|
|
||||||
|
def test_normal_percent(self):
|
||||||
|
p = self._make_tmpfile("42\n")
|
||||||
|
try:
|
||||||
|
self.assertEqual(self.mod.read_gpu_load(p), 42.0)
|
||||||
|
finally:
|
||||||
|
os.unlink(p)
|
||||||
|
|
||||||
|
def test_per_mille_converted(self):
|
||||||
|
p = self._make_tmpfile("750\n")
|
||||||
|
try:
|
||||||
|
val = self.mod.read_gpu_load(p)
|
||||||
|
self.assertEqual(val, 75.0)
|
||||||
|
finally:
|
||||||
|
os.unlink(p)
|
||||||
|
|
||||||
|
def test_missing_file_returns_minus_one(self):
|
||||||
|
val = self.mod.read_gpu_load("/nonexistent_gpu_sysfs_xyz")
|
||||||
|
self.assertEqual(val, -1.0)
|
||||||
|
|
||||||
|
def test_non_numeric_returns_minus_one(self):
|
||||||
|
p = self._make_tmpfile("N/A\n")
|
||||||
|
try:
|
||||||
|
val = self.mod.read_gpu_load(p)
|
||||||
|
self.assertEqual(val, -1.0)
|
||||||
|
finally:
|
||||||
|
os.unlink(p)
|
||||||
|
|
||||||
|
def test_clamped_to_100(self):
|
||||||
|
p = self._make_tmpfile("1001\n") # > 1000 → not per-mille → clamp
|
||||||
|
try:
|
||||||
|
val = self.mod.read_gpu_load(p)
|
||||||
|
self.assertLessEqual(val, 100.0)
|
||||||
|
finally:
|
||||||
|
os.unlink(p)
|
||||||
|
|
||||||
|
def test_zero(self):
|
||||||
|
p = self._make_tmpfile("0\n")
|
||||||
|
try:
|
||||||
|
self.assertEqual(self.mod.read_gpu_load(p), 0.0)
|
||||||
|
finally:
|
||||||
|
os.unlink(p)
|
||||||
|
|
||||||
|
def test_100_percent(self):
|
||||||
|
p = self._make_tmpfile("100\n")
|
||||||
|
try:
|
||||||
|
self.assertEqual(self.mod.read_gpu_load(p), 100.0)
|
||||||
|
finally:
|
||||||
|
os.unlink(p)
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
# read_thermal_zones
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestReadThermalZones(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.mod = _load_mod()
|
||||||
|
import tempfile
|
||||||
|
self._tmpdir = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
import shutil
|
||||||
|
shutil.rmtree(self._tmpdir, ignore_errors=True)
|
||||||
|
|
||||||
|
def _make_zone(self, name: str, zone_id: int, milli_c: int) -> str:
|
||||||
|
zone_dir = os.path.join(self._tmpdir, f"thermal_zone{zone_id}")
|
||||||
|
os.makedirs(zone_dir, exist_ok=True)
|
||||||
|
with open(os.path.join(zone_dir, "temp"), "w") as f:
|
||||||
|
f.write(f"{milli_c}\n")
|
||||||
|
with open(os.path.join(zone_dir, "type"), "w") as f:
|
||||||
|
f.write(f"{name}\n")
|
||||||
|
return zone_dir
|
||||||
|
|
||||||
|
def test_reads_zones_with_type(self):
|
||||||
|
self._make_zone("CPU-therm", 0, 47500)
|
||||||
|
self._make_zone("GPU-therm", 1, 43200)
|
||||||
|
temp_glob = os.path.join(self._tmpdir, "thermal_zone*/temp")
|
||||||
|
type_glob = os.path.join(self._tmpdir, "thermal_zone*/type")
|
||||||
|
result = self.mod.read_thermal_zones(temp_glob, type_glob)
|
||||||
|
self.assertAlmostEqual(result["CPU-therm"], 47.5, delta=0.1)
|
||||||
|
self.assertAlmostEqual(result["GPU-therm"], 43.2, delta=0.1)
|
||||||
|
|
||||||
|
def test_fallback_to_zone_index_when_no_type(self):
|
||||||
|
zone_dir = os.path.join(self._tmpdir, "thermal_zone0")
|
||||||
|
os.makedirs(zone_dir, exist_ok=True)
|
||||||
|
with open(os.path.join(zone_dir, "temp"), "w") as f:
|
||||||
|
f.write("35000\n")
|
||||||
|
# No type file
|
||||||
|
temp_glob = os.path.join(self._tmpdir, "thermal_zone*/temp")
|
||||||
|
type_glob = os.path.join(self._tmpdir, "thermal_zone*/type")
|
||||||
|
result = self.mod.read_thermal_zones(temp_glob, type_glob)
|
||||||
|
self.assertIn("zone0", result)
|
||||||
|
self.assertAlmostEqual(result["zone0"], 35.0, delta=0.1)
|
||||||
|
|
||||||
|
def test_empty_glob_returns_empty(self):
|
||||||
|
result = self.mod.read_thermal_zones(
|
||||||
|
"/nonexistent_path/*/temp",
|
||||||
|
"/nonexistent_path/*/type",
|
||||||
|
)
|
||||||
|
self.assertEqual(result, {})
|
||||||
|
|
||||||
|
def test_bad_temp_file_skipped(self):
|
||||||
|
zone_dir = os.path.join(self._tmpdir, "thermal_zone0")
|
||||||
|
os.makedirs(zone_dir, exist_ok=True)
|
||||||
|
with open(os.path.join(zone_dir, "temp"), "w") as f:
|
||||||
|
f.write("INVALID\n")
|
||||||
|
with open(os.path.join(zone_dir, "type"), "w") as f:
|
||||||
|
f.write("CPU-therm\n")
|
||||||
|
temp_glob = os.path.join(self._tmpdir, "thermal_zone*/temp")
|
||||||
|
type_glob = os.path.join(self._tmpdir, "thermal_zone*/type")
|
||||||
|
result = self.mod.read_thermal_zones(temp_glob, type_glob)
|
||||||
|
self.assertEqual(result, {})
|
||||||
|
|
||||||
|
def test_temperature_conversion(self):
|
||||||
|
self._make_zone("SOC-therm", 0, 55000)
|
||||||
|
temp_glob = os.path.join(self._tmpdir, "thermal_zone*/temp")
|
||||||
|
type_glob = os.path.join(self._tmpdir, "thermal_zone*/type")
|
||||||
|
result = self.mod.read_thermal_zones(temp_glob, type_glob)
|
||||||
|
self.assertAlmostEqual(result["SOC-therm"], 55.0, delta=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
# SysmonNode init
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestSysmonNodeInit(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.mod = _load_mod()
|
||||||
|
|
||||||
|
def test_node_creates_publisher(self):
|
||||||
|
node = _make_node(self.mod)
|
||||||
|
self.assertIn("/saltybot/system_resources", node._pubs)
|
||||||
|
|
||||||
|
def test_node_creates_timer(self):
|
||||||
|
node = _make_node(self.mod)
|
||||||
|
self.assertEqual(len(node._timers), 1)
|
||||||
|
|
||||||
|
def test_custom_output_topic(self):
|
||||||
|
node = _make_node(self.mod, output_topic="/custom/resources")
|
||||||
|
self.assertIn("/custom/resources", node._pubs)
|
||||||
|
|
||||||
|
def test_timer_period_from_rate(self):
|
||||||
|
# Rate=2 Hz → period=0.5s — timer is created; period verified via mock
|
||||||
|
node = _make_node(self.mod, publish_rate=2.0)
|
||||||
|
self.assertEqual(len(node._timers), 1)
|
||||||
|
|
||||||
|
def test_injectable_read_proc_stat(self):
|
||||||
|
node = _make_node(self.mod)
|
||||||
|
called = []
|
||||||
|
node._read_proc_stat = lambda: (called.append(1) or None)
|
||||||
|
node._tick()
|
||||||
|
self.assertTrue(len(called) >= 1)
|
||||||
|
|
||||||
|
def test_injectable_read_meminfo(self):
|
||||||
|
node = _make_node(self.mod)
|
||||||
|
node._read_meminfo = lambda: _MEMINFO
|
||||||
|
node._read_disk_usage = lambda p: (1.0, 0.5, 50.0)
|
||||||
|
node._read_gpu_load = lambda p: 0.0
|
||||||
|
node._read_thermal = lambda g, t: {}
|
||||||
|
node._tick() # should not raise
|
||||||
|
|
||||||
|
def test_prev_stat_initialised(self):
|
||||||
|
node = _make_node(self.mod)
|
||||||
|
# _read_proc_stat was stubbed to return None during _make_node
|
||||||
|
# so prev_stat may be None — just confirm the property exists
|
||||||
|
_ = node.prev_stat
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
# SysmonNode._tick — JSON payload
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestSysmonNodeTick(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.mod = _load_mod()
|
||||||
|
|
||||||
|
def _wired_node(self, **kwargs):
|
||||||
|
"""Return a node with fully stubbed I/O."""
|
||||||
|
stat_v1 = self.mod.parse_proc_stat(_STAT_2CORE)
|
||||||
|
stat_v2 = self.mod.parse_proc_stat(_STAT_2CORE_V2)
|
||||||
|
call_count = [0]
|
||||||
|
|
||||||
|
def fake_stat():
|
||||||
|
call_count[0] += 1
|
||||||
|
return stat_v2 if call_count[0] > 1 else stat_v1
|
||||||
|
|
||||||
|
node = _make_node(self.mod, **kwargs)
|
||||||
|
node._read_proc_stat = fake_stat
|
||||||
|
node._prev_stat = stat_v1
|
||||||
|
node._read_meminfo = lambda: _MEMINFO
|
||||||
|
node._read_disk_usage = lambda p: (64.0, 12.5, 19.5)
|
||||||
|
node._read_gpu_load = lambda p: 42.0
|
||||||
|
node._read_thermal = lambda g, t: {"CPU-therm": 47.5, "GPU-therm": 43.2}
|
||||||
|
return node
|
||||||
|
|
||||||
|
def _get_payload(self, node) -> dict:
|
||||||
|
node._tick()
|
||||||
|
pub = node._pubs["/saltybot/system_resources"]
|
||||||
|
self.assertTrue(len(pub.msgs) >= 1)
|
||||||
|
return json.loads(pub.msgs[-1].data)
|
||||||
|
|
||||||
|
def test_payload_is_valid_json(self):
|
||||||
|
node = self._wired_node()
|
||||||
|
node._tick()
|
||||||
|
pub = node._pubs["/saltybot/system_resources"]
|
||||||
|
msg = pub.msgs[-1]
|
||||||
|
data = json.loads(msg.data)
|
||||||
|
self.assertIsInstance(data, dict)
|
||||||
|
|
||||||
|
def test_payload_has_required_keys(self):
|
||||||
|
node = self._wired_node()
|
||||||
|
payload = self._get_payload(node)
|
||||||
|
for key in (
|
||||||
|
"ts", "cpu_percent", "cpu_avg_percent",
|
||||||
|
"ram_total_mb", "ram_used_mb", "ram_percent",
|
||||||
|
"disk_total_gb", "disk_used_gb", "disk_percent",
|
||||||
|
"gpu_percent", "thermal",
|
||||||
|
):
|
||||||
|
self.assertIn(key, payload, f"Missing key: {key}")
|
||||||
|
|
||||||
|
def test_ts_is_float(self):
|
||||||
|
node = self._wired_node()
|
||||||
|
payload = self._get_payload(node)
|
||||||
|
self.assertIsInstance(payload["ts"], float)
|
||||||
|
|
||||||
|
def test_cpu_percent_is_list(self):
|
||||||
|
node = self._wired_node()
|
||||||
|
payload = self._get_payload(node)
|
||||||
|
self.assertIsInstance(payload["cpu_percent"], list)
|
||||||
|
|
||||||
|
def test_gpu_percent_forwarded(self):
|
||||||
|
node = self._wired_node()
|
||||||
|
payload = self._get_payload(node)
|
||||||
|
self.assertAlmostEqual(payload["gpu_percent"], 42.0, delta=0.1)
|
||||||
|
|
||||||
|
def test_disk_stats_forwarded(self):
|
||||||
|
node = self._wired_node()
|
||||||
|
payload = self._get_payload(node)
|
||||||
|
self.assertAlmostEqual(payload["disk_total_gb"], 64.0, delta=0.1)
|
||||||
|
self.assertAlmostEqual(payload["disk_used_gb"], 12.5, delta=0.1)
|
||||||
|
self.assertAlmostEqual(payload["disk_percent"], 19.5, delta=0.1)
|
||||||
|
|
||||||
|
def test_ram_stats_forwarded(self):
|
||||||
|
node = self._wired_node()
|
||||||
|
payload = self._get_payload(node)
|
||||||
|
self.assertGreater(payload["ram_total_mb"], 0)
|
||||||
|
self.assertGreaterEqual(payload["ram_used_mb"], 0)
|
||||||
|
|
||||||
|
def test_thermal_is_dict(self):
|
||||||
|
node = self._wired_node()
|
||||||
|
payload = self._get_payload(node)
|
||||||
|
self.assertIsInstance(payload["thermal"], dict)
|
||||||
|
self.assertIn("CPU-therm", payload["thermal"])
|
||||||
|
|
||||||
|
def test_cpu_avg_matches_per_core(self):
|
||||||
|
node = self._wired_node()
|
||||||
|
payload = self._get_payload(node)
|
||||||
|
per_core = payload["cpu_percent"][1:]
|
||||||
|
if per_core:
|
||||||
|
expected_avg = sum(per_core) / len(per_core)
|
||||||
|
self.assertAlmostEqual(payload["cpu_avg_percent"], expected_avg, delta=0.1)
|
||||||
|
|
||||||
|
def test_publishes_once_per_tick(self):
|
||||||
|
node = self._wired_node()
|
||||||
|
node._tick()
|
||||||
|
node._tick()
|
||||||
|
pub = node._pubs["/saltybot/system_resources"]
|
||||||
|
self.assertEqual(len(pub.msgs), 2)
|
||||||
|
|
||||||
|
def test_no_prev_stat_gives_empty_cpu(self):
|
||||||
|
node = self._wired_node()
|
||||||
|
node._prev_stat = None
|
||||||
|
node._read_proc_stat = lambda: None
|
||||||
|
payload = self._get_payload(node)
|
||||||
|
self.assertEqual(payload["cpu_percent"], [])
|
||||||
|
self.assertEqual(payload["cpu_avg_percent"], 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
# CPU delta tracking across ticks
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestSysmonCpuDelta(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.mod = _load_mod()
|
||||||
|
|
||||||
|
def test_prev_stat_updated_after_tick(self):
|
||||||
|
stat_v1 = self.mod.parse_proc_stat(_STAT_2CORE)
|
||||||
|
stat_v2 = self.mod.parse_proc_stat(_STAT_2CORE_V2)
|
||||||
|
calls = [0]
|
||||||
|
|
||||||
|
def fake_stat():
|
||||||
|
calls[0] += 1
|
||||||
|
return stat_v2
|
||||||
|
|
||||||
|
node = _make_node(self.mod)
|
||||||
|
node._prev_stat = stat_v1
|
||||||
|
node._read_proc_stat = fake_stat
|
||||||
|
node._read_meminfo = lambda: ""
|
||||||
|
node._read_disk_usage = lambda p: (1.0, 0.5, 50.0)
|
||||||
|
node._read_gpu_load = lambda p: 0.0
|
||||||
|
node._read_thermal = lambda g, t: {}
|
||||||
|
|
||||||
|
node._tick()
|
||||||
|
self.assertEqual(node.prev_stat, stat_v2)
|
||||||
|
|
||||||
|
def test_second_tick_uses_updated_prev(self):
|
||||||
|
stat_v1 = self.mod.parse_proc_stat(_STAT_2CORE)
|
||||||
|
stat_v2 = self.mod.parse_proc_stat(_STAT_2CORE_V2)
|
||||||
|
seq = [stat_v1, stat_v2, stat_v1]
|
||||||
|
idx = [0]
|
||||||
|
|
||||||
|
def fake_stat():
|
||||||
|
v = seq[idx[0] % len(seq)]
|
||||||
|
idx[0] += 1
|
||||||
|
return v
|
||||||
|
|
||||||
|
node = _make_node(self.mod)
|
||||||
|
node._prev_stat = None
|
||||||
|
node._read_proc_stat = fake_stat
|
||||||
|
node._read_meminfo = lambda: ""
|
||||||
|
node._read_disk_usage = lambda p: (1.0, 0.5, 50.0)
|
||||||
|
node._read_gpu_load = lambda p: 0.0
|
||||||
|
node._read_thermal = lambda g, t: {}
|
||||||
|
|
||||||
|
node._tick()
|
||||||
|
node._tick()
|
||||||
|
pub = node._pubs["/saltybot/system_resources"]
|
||||||
|
self.assertEqual(len(pub.msgs), 2)
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
# Source / entry-point sanity
|
||||||
|
# ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestSysmonSource(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.mod = _load_mod()
|
||||||
|
|
||||||
|
def test_file_exists(self):
|
||||||
|
src = os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
"..", "saltybot_social", "sysmon_node.py"
|
||||||
|
)
|
||||||
|
self.assertTrue(os.path.isfile(os.path.normpath(src)))
|
||||||
|
|
||||||
|
def test_main_callable(self):
|
||||||
|
self.assertTrue(callable(self.mod.main))
|
||||||
|
|
||||||
|
def test_issue_tag_in_source(self):
|
||||||
|
src = os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
"..", "saltybot_social", "sysmon_node.py"
|
||||||
|
)
|
||||||
|
with open(src) as fh:
|
||||||
|
content = fh.read()
|
||||||
|
self.assertIn("355", content)
|
||||||
|
|
||||||
|
def test_status_constants_present(self):
|
||||||
|
# Confirm key pure functions exported
|
||||||
|
self.assertTrue(callable(self.mod.parse_proc_stat))
|
||||||
|
self.assertTrue(callable(self.mod.cpu_percent_from_stats))
|
||||||
|
self.assertTrue(callable(self.mod.parse_meminfo))
|
||||||
|
self.assertTrue(callable(self.mod.compute_ram_stats))
|
||||||
|
self.assertTrue(callable(self.mod.read_disk_usage))
|
||||||
|
self.assertTrue(callable(self.mod.read_gpu_load))
|
||||||
|
self.assertTrue(callable(self.mod.read_thermal_zones))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@ -11,6 +11,7 @@
|
|||||||
#include "battery.h"
|
#include "battery.h"
|
||||||
#include "config.h"
|
#include "config.h"
|
||||||
#include "stm32f7xx_hal.h"
|
#include "stm32f7xx_hal.h"
|
||||||
|
#include <stdbool.h>
|
||||||
|
|
||||||
static ADC_HandleTypeDef s_hadc;
|
static ADC_HandleTypeDef s_hadc;
|
||||||
static bool s_ready = false;
|
static bool s_ready = false;
|
||||||
|
|||||||
@ -31,3 +31,11 @@ int i2c1_init(void) {
|
|||||||
|
|
||||||
return (HAL_I2C_Init(&hi2c1) == HAL_OK) ? 0 : -1;
|
return (HAL_I2C_Init(&hi2c1) == HAL_OK) ? 0 : -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int i2c1_write(uint8_t addr, const uint8_t *data, int len) {
|
||||||
|
return (HAL_I2C_Master_Transmit(&hi2c1, addr << 1, (uint8_t*)data, len, 100) == HAL_OK) ? 0 : -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int i2c1_read(uint8_t addr, uint8_t *data, int len) {
|
||||||
|
return (HAL_I2C_Master_Receive(&hi2c1, addr << 1, data, len, 100) == HAL_OK) ? 0 : -1;
|
||||||
|
}
|
||||||
|
|||||||
19
src/main.c
19
src/main.c
@ -108,6 +108,23 @@ extern PCD_HandleTypeDef hpcd;
|
|||||||
void OTG_FS_IRQHandler(void) { HAL_PCD_IRQHandler(&hpcd); }
|
void OTG_FS_IRQHandler(void) { HAL_PCD_IRQHandler(&hpcd); }
|
||||||
void SysTick_Handler(void) { HAL_IncTick(); }
|
void SysTick_Handler(void) { HAL_IncTick(); }
|
||||||
|
|
||||||
|
/* Determine if BNO055 is active (vs MPU6000) */
|
||||||
|
static bool bno055_active = false;
|
||||||
|
|
||||||
|
/* Helper: Check if IMU is calibrated (MPU6000 gyro bias or BNO055 ready) */
|
||||||
|
static bool imu_calibrated(void) {
|
||||||
|
if (bno055_active) {
|
||||||
|
return bno055_is_ready();
|
||||||
|
}
|
||||||
|
return mpu6000_is_calibrated();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Helper: Check if CRSF receiver has recent signal */
|
||||||
|
static bool crsf_is_active(uint32_t now_ms) {
|
||||||
|
extern volatile CRSFState crsf_state;
|
||||||
|
return crsf_state.last_rx_ms > 0 && (now_ms - crsf_state.last_rx_ms) < 500;
|
||||||
|
}
|
||||||
|
|
||||||
int main(void) {
|
int main(void) {
|
||||||
SCB_EnableICache();
|
SCB_EnableICache();
|
||||||
/* DCache stays ON — MPU Region 0 in usbd_conf.c marks USB buffers non-cacheable. */
|
/* DCache stays ON — MPU Region 0 in usbd_conf.c marks USB buffers non-cacheable. */
|
||||||
@ -157,7 +174,7 @@ int main(void) {
|
|||||||
|
|
||||||
/* Init piezo buzzer driver (TIM4_CH3 PWM on PB2, Issue #189) */
|
/* Init piezo buzzer driver (TIM4_CH3 PWM on PB2, Issue #189) */
|
||||||
buzzer_init();
|
buzzer_init();
|
||||||
buzzer_play(BUZZER_PATTERN_ARM_CHIME);
|
buzzer_play_melody(MELODY_STARTUP);
|
||||||
|
|
||||||
/* Init WS2812B NeoPixel LED ring (TIM3_CH1 PWM on PB4, Issue #193) */
|
/* Init WS2812B NeoPixel LED ring (TIM3_CH1 PWM on PB4, Issue #193) */
|
||||||
led_init();
|
led_init();
|
||||||
|
|||||||
@ -24,7 +24,7 @@
|
|||||||
#define SERVO_PRESCALER 53u /* APB1 54 MHz / 54 = 1 MHz */
|
#define SERVO_PRESCALER 53u /* APB1 54 MHz / 54 = 1 MHz */
|
||||||
#define SERVO_ARR 19999u /* 1 MHz / 20000 = 50 Hz */
|
#define SERVO_ARR 19999u /* 1 MHz / 20000 = 50 Hz */
|
||||||
|
|
||||||
typedef struct {
|
static struct {
|
||||||
uint16_t current_angle_deg[SERVO_COUNT];
|
uint16_t current_angle_deg[SERVO_COUNT];
|
||||||
uint16_t target_angle_deg[SERVO_COUNT];
|
uint16_t target_angle_deg[SERVO_COUNT];
|
||||||
uint16_t pulse_us[SERVO_COUNT];
|
uint16_t pulse_us[SERVO_COUNT];
|
||||||
@ -35,9 +35,7 @@ typedef struct {
|
|||||||
uint16_t sweep_start_deg[SERVO_COUNT];
|
uint16_t sweep_start_deg[SERVO_COUNT];
|
||||||
uint16_t sweep_end_deg[SERVO_COUNT];
|
uint16_t sweep_end_deg[SERVO_COUNT];
|
||||||
bool is_sweeping[SERVO_COUNT];
|
bool is_sweeping[SERVO_COUNT];
|
||||||
} ServoState;
|
} s_servo = {0};
|
||||||
|
|
||||||
static ServoState s_servo = {0};
|
|
||||||
static TIM_HandleTypeDef s_tim_handle = {0};
|
static TIM_HandleTypeDef s_tim_handle = {0};
|
||||||
|
|
||||||
/* ================================================================
|
/* ================================================================
|
||||||
|
|||||||
@ -48,6 +48,9 @@ static UltrasonicState_t s_ultrasonic = {
|
|||||||
.callback = NULL
|
.callback = NULL
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/* TIM1 handle for input capture (shared with interrupt handler) */
|
||||||
|
static TIM_HandleTypeDef s_tim_handle = {0};
|
||||||
|
|
||||||
/* ================================================================
|
/* ================================================================
|
||||||
* Hardware Initialization
|
* Hardware Initialization
|
||||||
* ================================================================ */
|
* ================================================================ */
|
||||||
@ -80,14 +83,13 @@ void ultrasonic_init(void)
|
|||||||
* Use PSC=216 to get 1MHz clock → 1 count = 1µs
|
* Use PSC=216 to get 1MHz clock → 1 count = 1µs
|
||||||
* ARR=0xFFFF for 16-bit capture (max 65535µs ≈ 9.6m)
|
* ARR=0xFFFF for 16-bit capture (max 65535µs ≈ 9.6m)
|
||||||
*/
|
*/
|
||||||
TIM_HandleTypeDef htim1 = {0};
|
s_tim_handle.Instance = ECHO_TIM;
|
||||||
htim1.Instance = ECHO_TIM;
|
s_tim_handle.Init.Prescaler = 216 - 1; /* 216MHz / 216 = 1MHz (1µs per count) */
|
||||||
htim1.Init.Prescaler = 216 - 1; /* 216MHz / 216 = 1MHz (1µs per count) */
|
s_tim_handle.Init.CounterMode = TIM_COUNTERMODE_UP;
|
||||||
htim1.Init.CounterMode = TIM_COUNTERMODE_UP;
|
s_tim_handle.Init.Period = 0xFFFF; /* 16-bit counter */
|
||||||
htim1.Init.Period = 0xFFFF; /* 16-bit counter */
|
s_tim_handle.Init.ClockDivision = TIM_CLOCKDIVISION_DIV1;
|
||||||
htim1.Init.ClockDivision = TIM_CLOCKDIVISION_DIV1;
|
s_tim_handle.Init.RepetitionCounter = 0;
|
||||||
htim1.Init.RepetitionCounter = 0;
|
HAL_TIM_IC_Init(&s_tim_handle);
|
||||||
HAL_TIM_IC_Init(&htim1);
|
|
||||||
|
|
||||||
/* Configure input capture: CH2 on PA1, both rising and falling edges
|
/* Configure input capture: CH2 on PA1, both rising and falling edges
|
||||||
* TIM1_CH2 captures on both edges to measure echo pulse width
|
* TIM1_CH2 captures on both edges to measure echo pulse width
|
||||||
@ -97,15 +99,15 @@ void ultrasonic_init(void)
|
|||||||
ic_init.ICSelection = TIM_ICSELECTION_DIRECTTI;
|
ic_init.ICSelection = TIM_ICSELECTION_DIRECTTI;
|
||||||
ic_init.ICPrescaler = TIM_ICPSC_DIV1; /* No prescaler */
|
ic_init.ICPrescaler = TIM_ICPSC_DIV1; /* No prescaler */
|
||||||
ic_init.ICFilter = 0; /* No filter */
|
ic_init.ICFilter = 0; /* No filter */
|
||||||
HAL_TIM_IC_Init(&htim1);
|
HAL_TIM_IC_ConfigChannel(&s_tim_handle, &ic_init, ECHO_TIM_CHANNEL);
|
||||||
HAL_TIM_IC_Start_IT(ECHO_TIM, ECHO_TIM_CHANNEL);
|
HAL_TIM_IC_Start_IT(&s_tim_handle, ECHO_TIM_CHANNEL);
|
||||||
|
|
||||||
/* Enable input capture interrupt */
|
/* Enable input capture interrupt */
|
||||||
HAL_NVIC_SetPriority(TIM1_CC_IRQn, 6, 0);
|
HAL_NVIC_SetPriority(TIM1_CC_IRQn, 6, 0);
|
||||||
HAL_NVIC_EnableIRQ(TIM1_CC_IRQn);
|
HAL_NVIC_EnableIRQ(TIM1_CC_IRQn);
|
||||||
|
|
||||||
/* Start the timer */
|
/* Start the timer */
|
||||||
HAL_TIM_Base_Start(ECHO_TIM);
|
HAL_TIM_Base_Start(&s_tim_handle);
|
||||||
|
|
||||||
s_ultrasonic.state = ULTRASONIC_IDLE;
|
s_ultrasonic.state = ULTRASONIC_IDLE;
|
||||||
}
|
}
|
||||||
@ -188,10 +190,10 @@ void ultrasonic_tick(uint32_t now_ms)
|
|||||||
void TIM1_CC_IRQHandler(void)
|
void TIM1_CC_IRQHandler(void)
|
||||||
{
|
{
|
||||||
/* Check if capture interrupt on CH2 */
|
/* Check if capture interrupt on CH2 */
|
||||||
if (__HAL_TIM_GET_FLAG(ECHO_TIM, TIM_FLAG_CC2) != RESET) {
|
if (__HAL_TIM_GET_FLAG(&s_tim_handle, TIM_FLAG_CC2) != RESET) {
|
||||||
__HAL_TIM_CLEAR_FLAG(ECHO_TIM, TIM_FLAG_CC2);
|
__HAL_TIM_CLEAR_FLAG(&s_tim_handle, TIM_FLAG_CC2);
|
||||||
|
|
||||||
uint32_t capture_value = HAL_TIM_ReadCapturedValue(ECHO_TIM, ECHO_TIM_CHANNEL);
|
uint32_t capture_value = HAL_TIM_ReadCapturedValue(&s_tim_handle, ECHO_TIM_CHANNEL);
|
||||||
|
|
||||||
if (s_ultrasonic.state == ULTRASONIC_TRIGGERED || s_ultrasonic.state == ULTRASONIC_MEASURING) {
|
if (s_ultrasonic.state == ULTRASONIC_TRIGGERED || s_ultrasonic.state == ULTRASONIC_MEASURING) {
|
||||||
if (s_ultrasonic.echo_start_ticks == 0) {
|
if (s_ultrasonic.echo_start_ticks == 0) {
|
||||||
@ -205,7 +207,7 @@ void TIM1_CC_IRQHandler(void)
|
|||||||
ic_init.ICSelection = TIM_ICSELECTION_DIRECTTI;
|
ic_init.ICSelection = TIM_ICSELECTION_DIRECTTI;
|
||||||
ic_init.ICPrescaler = TIM_ICPSC_DIV1;
|
ic_init.ICPrescaler = TIM_ICPSC_DIV1;
|
||||||
ic_init.ICFilter = 0;
|
ic_init.ICFilter = 0;
|
||||||
HAL_TIM_IC_Init_Compat(ECHO_TIM, ECHO_TIM_CHANNEL, &ic_init);
|
HAL_TIM_IC_ConfigChannel(&s_tim_handle, &ic_init, ECHO_TIM_CHANNEL);
|
||||||
} else {
|
} else {
|
||||||
/* Falling edge: mark end of echo pulse and calculate distance */
|
/* Falling edge: mark end of echo pulse and calculate distance */
|
||||||
s_ultrasonic.echo_end_ticks = capture_value;
|
s_ultrasonic.echo_end_ticks = capture_value;
|
||||||
@ -242,24 +244,5 @@ void TIM1_CC_IRQHandler(void)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
HAL_TIM_IRQHandler(ECHO_TIM);
|
HAL_TIM_IRQHandler(&s_tim_handle);
|
||||||
}
|
|
||||||
|
|
||||||
/* ================================================================
|
|
||||||
* Compatibility Helper (for simplified IC init)
|
|
||||||
* ================================================================ */
|
|
||||||
|
|
||||||
static void HAL_TIM_IC_Init_Compat(TIM_HandleTypeDef *htim, uint32_t Channel, TIM_IC_InitTypeDef *sConfig)
|
|
||||||
{
|
|
||||||
/* Simple implementation for reconfiguring capture polarity */
|
|
||||||
switch (Channel) {
|
|
||||||
case TIM_CHANNEL_2:
|
|
||||||
ECHO_TIM->CCER &= ~TIM_CCER_CC2P; /* Clear polarity bits */
|
|
||||||
if (sConfig->ICPolarity == TIM_ICPOLARITY_RISING) {
|
|
||||||
ECHO_TIM->CCER |= 0;
|
|
||||||
} else {
|
|
||||||
ECHO_TIM->CCER |= TIM_CCER_CC2P;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -32,6 +32,7 @@ typedef struct {
|
|||||||
uint32_t timeout_ms; /* Configured timeout in milliseconds */
|
uint32_t timeout_ms; /* Configured timeout in milliseconds */
|
||||||
uint8_t prescaler; /* IWDG prescaler value */
|
uint8_t prescaler; /* IWDG prescaler value */
|
||||||
uint16_t reload_value; /* IWDG reload register value */
|
uint16_t reload_value; /* IWDG reload register value */
|
||||||
|
IWDG_HandleTypeDef handle; /* IWDG handle for refresh */
|
||||||
} WatchdogState;
|
} WatchdogState;
|
||||||
|
|
||||||
static WatchdogState s_watchdog = {
|
static WatchdogState s_watchdog = {
|
||||||
@ -108,13 +109,12 @@ bool watchdog_init(uint32_t timeout_ms)
|
|||||||
s_watchdog.timeout_ms = timeout_ms;
|
s_watchdog.timeout_ms = timeout_ms;
|
||||||
|
|
||||||
/* Configure and start IWDG */
|
/* Configure and start IWDG */
|
||||||
IWDG_HandleTypeDef hiwdg = {0};
|
s_watchdog.handle.Instance = IWDG;
|
||||||
hiwdg.Instance = IWDG;
|
s_watchdog.handle.Init.Prescaler = prescaler;
|
||||||
hiwdg.Init.Prescaler = prescaler;
|
s_watchdog.handle.Init.Reload = reload;
|
||||||
hiwdg.Init.Reload = reload;
|
s_watchdog.handle.Init.Window = reload; /* Window == Reload means full timeout */
|
||||||
hiwdg.Init.Window = reload; /* Window == Reload means full timeout */
|
|
||||||
|
|
||||||
HAL_IWDG_Init(&hiwdg);
|
HAL_IWDG_Init(&s_watchdog.handle);
|
||||||
|
|
||||||
s_watchdog.is_initialized = true;
|
s_watchdog.is_initialized = true;
|
||||||
s_watchdog.is_running = true;
|
s_watchdog.is_running = true;
|
||||||
@ -125,7 +125,7 @@ bool watchdog_init(uint32_t timeout_ms)
|
|||||||
void watchdog_kick(void)
|
void watchdog_kick(void)
|
||||||
{
|
{
|
||||||
if (s_watchdog.is_running) {
|
if (s_watchdog.is_running) {
|
||||||
HAL_IWDG_Refresh(&IWDG); /* Reset IWDG counter */
|
HAL_IWDG_Refresh(&s_watchdog.handle); /* Reset IWDG counter */
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -22,6 +22,7 @@ import { useRosbridge } from './hooks/useRosbridge.js';
|
|||||||
|
|
||||||
// Social panels
|
// Social panels
|
||||||
import { StatusPanel } from './components/StatusPanel.jsx';
|
import { StatusPanel } from './components/StatusPanel.jsx';
|
||||||
|
import { StatusHeader } from './components/StatusHeader.jsx';
|
||||||
import { FaceGallery } from './components/FaceGallery.jsx';
|
import { FaceGallery } from './components/FaceGallery.jsx';
|
||||||
import { ConversationLog } from './components/ConversationLog.jsx';
|
import { ConversationLog } from './components/ConversationLog.jsx';
|
||||||
import { ConversationHistory } from './components/ConversationHistory.jsx';
|
import { ConversationHistory } from './components/ConversationHistory.jsx';
|
||||||
@ -34,6 +35,7 @@ import PoseViewer from './components/PoseViewer.jsx';
|
|||||||
import { BatteryPanel } from './components/BatteryPanel.jsx';
|
import { BatteryPanel } from './components/BatteryPanel.jsx';
|
||||||
import { BatteryChart } from './components/BatteryChart.jsx';
|
import { BatteryChart } from './components/BatteryChart.jsx';
|
||||||
import { MotorPanel } from './components/MotorPanel.jsx';
|
import { MotorPanel } from './components/MotorPanel.jsx';
|
||||||
|
import { MotorCurrentGraph } from './components/MotorCurrentGraph.jsx';
|
||||||
import { MapViewer } from './components/MapViewer.jsx';
|
import { MapViewer } from './components/MapViewer.jsx';
|
||||||
import { ControlMode } from './components/ControlMode.jsx';
|
import { ControlMode } from './components/ControlMode.jsx';
|
||||||
import { SystemHealth } from './components/SystemHealth.jsx';
|
import { SystemHealth } from './components/SystemHealth.jsx';
|
||||||
@ -53,6 +55,9 @@ import { CameraViewer } from './components/CameraViewer.jsx';
|
|||||||
// Event log (issue #192)
|
// Event log (issue #192)
|
||||||
import { EventLog } from './components/EventLog.jsx';
|
import { EventLog } from './components/EventLog.jsx';
|
||||||
|
|
||||||
|
// Log viewer (issue #275)
|
||||||
|
import { LogViewer } from './components/LogViewer.jsx';
|
||||||
|
|
||||||
// Joystick teleop (issue #212)
|
// Joystick teleop (issue #212)
|
||||||
import JoystickTeleop from './components/JoystickTeleop.jsx';
|
import JoystickTeleop from './components/JoystickTeleop.jsx';
|
||||||
|
|
||||||
@ -71,6 +76,15 @@ import { TempGauge } from './components/TempGauge.jsx';
|
|||||||
// Node list viewer
|
// Node list viewer
|
||||||
import { NodeList } from './components/NodeList.jsx';
|
import { NodeList } from './components/NodeList.jsx';
|
||||||
|
|
||||||
|
// Gamepad teleoperation (issue #319)
|
||||||
|
import { Teleop } from './components/Teleop.jsx';
|
||||||
|
|
||||||
|
// System diagnostics (issue #340)
|
||||||
|
import { Diagnostics } from './components/Diagnostics.jsx';
|
||||||
|
|
||||||
|
// Hand tracking visualization (issue #344)
|
||||||
|
import { HandTracker } from './components/HandTracker.jsx';
|
||||||
|
|
||||||
const TAB_GROUPS = [
|
const TAB_GROUPS = [
|
||||||
{
|
{
|
||||||
label: 'SOCIAL',
|
label: 'SOCIAL',
|
||||||
@ -78,6 +92,7 @@ const TAB_GROUPS = [
|
|||||||
tabs: [
|
tabs: [
|
||||||
{ id: 'status', label: 'Status', },
|
{ id: 'status', label: 'Status', },
|
||||||
{ id: 'faces', label: 'Faces', },
|
{ id: 'faces', label: 'Faces', },
|
||||||
|
{ id: 'hands', label: 'Hands', },
|
||||||
{ id: 'conversation', label: 'Convo', },
|
{ id: 'conversation', label: 'Convo', },
|
||||||
{ id: 'history', label: 'History', },
|
{ id: 'history', label: 'History', },
|
||||||
{ id: 'personality', label: 'Personality', },
|
{ id: 'personality', label: 'Personality', },
|
||||||
@ -97,7 +112,13 @@ const TAB_GROUPS = [
|
|||||||
{ id: 'map', label: 'Map', },
|
{ id: 'map', label: 'Map', },
|
||||||
{ id: 'control', label: 'Control', },
|
{ id: 'control', label: 'Control', },
|
||||||
{ id: 'health', label: 'Health', },
|
{ id: 'health', label: 'Health', },
|
||||||
{ id: 'cameras', label: 'Cameras', },
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
label: 'CAMERAS',
|
||||||
|
color: 'text-rose-600',
|
||||||
|
tabs: [
|
||||||
|
{ id: 'cameras', label: 'Cameras' },
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -119,6 +140,7 @@ const TAB_GROUPS = [
|
|||||||
label: 'MONITORING',
|
label: 'MONITORING',
|
||||||
color: 'text-yellow-600',
|
color: 'text-yellow-600',
|
||||||
tabs: [
|
tabs: [
|
||||||
|
{ id: 'diagnostics', label: 'Diagnostics' },
|
||||||
{ id: 'eventlog', label: 'Events' },
|
{ id: 'eventlog', label: 'Events' },
|
||||||
{ id: 'bandwidth', label: 'Bandwidth' },
|
{ id: 'bandwidth', label: 'Bandwidth' },
|
||||||
{ id: 'nodes', label: 'Nodes' },
|
{ id: 'nodes', label: 'Nodes' },
|
||||||
@ -251,6 +273,7 @@ export default function App() {
|
|||||||
<main className={`flex-1 ${['eventlog', 'control', 'imu'].includes(activeTab) ? 'flex flex-col' : 'overflow-y-auto'} p-4`}>
|
<main className={`flex-1 ${['eventlog', 'control', 'imu'].includes(activeTab) ? 'flex flex-col' : 'overflow-y-auto'} p-4`}>
|
||||||
{activeTab === 'status' && <StatusPanel subscribe={subscribe} />}
|
{activeTab === 'status' && <StatusPanel subscribe={subscribe} />}
|
||||||
{activeTab === 'faces' && <FaceGallery subscribe={subscribe} callService={callService} />}
|
{activeTab === 'faces' && <FaceGallery subscribe={subscribe} callService={callService} />}
|
||||||
|
{activeTab === 'hands' && <HandTracker subscribe={subscribe} />}
|
||||||
{activeTab === 'conversation' && <ConversationLog subscribe={subscribe} />}
|
{activeTab === 'conversation' && <ConversationLog subscribe={subscribe} />}
|
||||||
{activeTab === 'history' && <ConversationHistory subscribe={subscribe} />}
|
{activeTab === 'history' && <ConversationHistory subscribe={subscribe} />}
|
||||||
{activeTab === 'personality' && <PersonalityTuner subscribe={subscribe} setParam={setParam} />}
|
{activeTab === 'personality' && <PersonalityTuner subscribe={subscribe} setParam={setParam} />}
|
||||||
@ -264,16 +287,7 @@ export default function App() {
|
|||||||
{activeTab === 'motor-current-graph' && <MotorCurrentGraph subscribe={subscribe} />}
|
{activeTab === 'motor-current-graph' && <MotorCurrentGraph subscribe={subscribe} />}
|
||||||
{activeTab === 'thermal' && <TempGauge subscribe={subscribe} />}
|
{activeTab === 'thermal' && <TempGauge subscribe={subscribe} />}
|
||||||
{activeTab === 'map' && <MapViewer subscribe={subscribe} />}
|
{activeTab === 'map' && <MapViewer subscribe={subscribe} />}
|
||||||
{activeTab === 'control' && (
|
{activeTab === 'control' && <Teleop publish={publishFn} />}
|
||||||
<div className="flex flex-col h-full gap-4">
|
|
||||||
<div className="flex-1 overflow-y-auto">
|
|
||||||
<ControlMode subscribe={subscribe} />
|
|
||||||
</div>
|
|
||||||
<div className="flex-1 overflow-y-auto">
|
|
||||||
<JoystickTeleop publish={publishFn} />
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
{activeTab === 'health' && <SystemHealth subscribe={subscribe} />}
|
{activeTab === 'health' && <SystemHealth subscribe={subscribe} />}
|
||||||
{activeTab === 'cameras' && <CameraViewer subscribe={subscribe} />}
|
{activeTab === 'cameras' && <CameraViewer subscribe={subscribe} />}
|
||||||
|
|
||||||
@ -282,6 +296,8 @@ export default function App() {
|
|||||||
{activeTab === 'fleet' && <FleetPanel />}
|
{activeTab === 'fleet' && <FleetPanel />}
|
||||||
{activeTab === 'missions' && <MissionPlanner />}
|
{activeTab === 'missions' && <MissionPlanner />}
|
||||||
|
|
||||||
|
{activeTab === 'diagnostics' && <Diagnostics subscribe={subscribe} />}
|
||||||
|
|
||||||
{activeTab === 'eventlog' && <EventLog subscribe={subscribe} />}
|
{activeTab === 'eventlog' && <EventLog subscribe={subscribe} />}
|
||||||
|
|
||||||
{activeTab === 'bandwidth' && <BandwidthMonitor />}
|
{activeTab === 'bandwidth' && <BandwidthMonitor />}
|
||||||
|
|||||||
308
ui/social-bot/src/components/Diagnostics.jsx
Normal file
308
ui/social-bot/src/components/Diagnostics.jsx
Normal file
@ -0,0 +1,308 @@
|
|||||||
|
/**
|
||||||
|
* Diagnostics.jsx — System diagnostics panel with hardware status monitoring
|
||||||
|
*
|
||||||
|
* Features:
|
||||||
|
* - Subscribes to /diagnostics (diagnostic_msgs/DiagnosticArray)
|
||||||
|
* - Hardware status cards per subsystem (color-coded health)
|
||||||
|
* - Real-time error and warning counts
|
||||||
|
* - Diagnostic status timeline
|
||||||
|
* - Error/warning history with timestamps
|
||||||
|
* - Aggregated system health summary
|
||||||
|
* - Status indicators: OK (green), WARNING (yellow), ERROR (red), STALE (gray)
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { useEffect, useRef, useState } from 'react';
|
||||||
|
|
||||||
|
const MAX_HISTORY = 100; // Keep last 100 diagnostic messages
|
||||||
|
const STATUS_COLORS = {
|
||||||
|
0: { bg: 'bg-green-950', border: 'border-green-800', text: 'text-green-400', label: 'OK' },
|
||||||
|
1: { bg: 'bg-yellow-950', border: 'border-yellow-800', text: 'text-yellow-400', label: 'WARN' },
|
||||||
|
2: { bg: 'bg-red-950', border: 'border-red-800', text: 'text-red-400', label: 'ERROR' },
|
||||||
|
3: { bg: 'bg-gray-900', border: 'border-gray-700', text: 'text-gray-400', label: 'STALE' },
|
||||||
|
};
|
||||||
|
|
||||||
|
function getStatusColor(level) {
|
||||||
|
return STATUS_COLORS[level] || STATUS_COLORS[3];
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatTimestamp(timestamp) {
|
||||||
|
const date = new Date(timestamp);
|
||||||
|
return date.toLocaleTimeString('en-US', {
|
||||||
|
hour: '2-digit',
|
||||||
|
minute: '2-digit',
|
||||||
|
second: '2-digit',
|
||||||
|
hour12: false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function DiagnosticCard({ diagnostic, expanded, onToggle }) {
|
||||||
|
const color = getStatusColor(diagnostic.level);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={`rounded-lg border p-3 space-y-2 cursor-pointer transition-all ${
|
||||||
|
color.bg
|
||||||
|
} ${color.border} ${expanded ? 'ring-2 ring-offset-2 ring-cyan-500' : ''}`}
|
||||||
|
onClick={onToggle}
|
||||||
|
>
|
||||||
|
{/* Header */}
|
||||||
|
<div className="flex items-start justify-between gap-2">
|
||||||
|
<div className="flex-1 min-w-0">
|
||||||
|
<div className="text-xs font-bold text-gray-400 truncate">{diagnostic.name}</div>
|
||||||
|
<div className={`text-sm font-mono font-bold ${color.text}`}>
|
||||||
|
{diagnostic.message}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
className={`px-2 py-1 rounded text-xs font-bold whitespace-nowrap flex-shrink-0 ${
|
||||||
|
color.bg
|
||||||
|
} ${color.border} border ${color.text}`}
|
||||||
|
>
|
||||||
|
{color.label}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Expanded details */}
|
||||||
|
{expanded && (
|
||||||
|
<div className="border-t border-gray-700 pt-2 space-y-2 text-xs">
|
||||||
|
{diagnostic.values && diagnostic.values.length > 0 && (
|
||||||
|
<div className="space-y-1">
|
||||||
|
<div className="text-gray-500 font-bold">KEY VALUES:</div>
|
||||||
|
{diagnostic.values.slice(0, 5).map((val, i) => (
|
||||||
|
<div key={i} className="flex justify-between gap-2 text-gray-400">
|
||||||
|
<span className="font-mono truncate">{val.key}:</span>
|
||||||
|
<span className="font-mono text-gray-300 truncate">{val.value}</span>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
{diagnostic.values.length > 5 && (
|
||||||
|
<div className="text-gray-500">+{diagnostic.values.length - 5} more values</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{diagnostic.hardware_id && (
|
||||||
|
<div className="text-gray-500">
|
||||||
|
<span className="font-bold">Hardware:</span> {diagnostic.hardware_id}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function HealthTimeline({ statusHistory, maxEntries = 20 }) {
|
||||||
|
const recentHistory = statusHistory.slice(-maxEntries);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="space-y-2">
|
||||||
|
<div className="text-xs font-bold text-gray-400 tracking-widest">TIMELINE</div>
|
||||||
|
<div className="space-y-1">
|
||||||
|
{recentHistory.map((entry, i) => {
|
||||||
|
const color = getStatusColor(entry.level);
|
||||||
|
return (
|
||||||
|
<div key={i} className="flex items-center gap-2">
|
||||||
|
<div className="text-xs text-gray-500 font-mono whitespace-nowrap">
|
||||||
|
{formatTimestamp(entry.timestamp)}
|
||||||
|
</div>
|
||||||
|
<div className={`w-2 h-2 rounded-full flex-shrink-0 ${color.text}`} />
|
||||||
|
<div className="text-xs text-gray-400 truncate">
|
||||||
|
{entry.name}: {entry.message}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function Diagnostics({ subscribe }) {
|
||||||
|
const [diagnostics, setDiagnostics] = useState({});
|
||||||
|
const [statusHistory, setStatusHistory] = useState([]);
|
||||||
|
const [expandedDiags, setExpandedDiags] = useState(new Set());
|
||||||
|
const diagRef = useRef({});
|
||||||
|
|
||||||
|
// Subscribe to diagnostics
|
||||||
|
useEffect(() => {
|
||||||
|
const unsubscribe = subscribe(
|
||||||
|
'/diagnostics',
|
||||||
|
'diagnostic_msgs/DiagnosticArray',
|
||||||
|
(msg) => {
|
||||||
|
try {
|
||||||
|
const diags = {};
|
||||||
|
const now = Date.now();
|
||||||
|
|
||||||
|
// Process each diagnostic status
|
||||||
|
(msg.status || []).forEach((status) => {
|
||||||
|
const name = status.name || 'unknown';
|
||||||
|
diags[name] = {
|
||||||
|
name,
|
||||||
|
level: status.level,
|
||||||
|
message: status.message || '',
|
||||||
|
hardware_id: status.hardware_id || '',
|
||||||
|
values: status.values || [],
|
||||||
|
timestamp: now,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Add to history
|
||||||
|
setStatusHistory((prev) => [
|
||||||
|
...prev,
|
||||||
|
{
|
||||||
|
name,
|
||||||
|
level: status.level,
|
||||||
|
message: status.message || '',
|
||||||
|
timestamp: now,
|
||||||
|
},
|
||||||
|
].slice(-MAX_HISTORY));
|
||||||
|
});
|
||||||
|
|
||||||
|
setDiagnostics(diags);
|
||||||
|
diagRef.current = diags;
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Error parsing diagnostics:', e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
return unsubscribe;
|
||||||
|
}, [subscribe]);
|
||||||
|
|
||||||
|
// Calculate statistics
|
||||||
|
const stats = {
|
||||||
|
total: Object.keys(diagnostics).length,
|
||||||
|
ok: Object.values(diagnostics).filter((d) => d.level === 0).length,
|
||||||
|
warning: Object.values(diagnostics).filter((d) => d.level === 1).length,
|
||||||
|
error: Object.values(diagnostics).filter((d) => d.level === 2).length,
|
||||||
|
stale: Object.values(diagnostics).filter((d) => d.level === 3).length,
|
||||||
|
};
|
||||||
|
|
||||||
|
const overallHealth =
|
||||||
|
stats.error > 0 ? 2 : stats.warning > 0 ? 1 : stats.stale > 0 ? 3 : 0;
|
||||||
|
const overallColor = getStatusColor(overallHealth);
|
||||||
|
|
||||||
|
const sortedDiags = Object.values(diagnostics).sort((a, b) => {
|
||||||
|
// Sort by level (errors first), then by name
|
||||||
|
if (a.level !== b.level) return b.level - a.level;
|
||||||
|
return a.name.localeCompare(b.name);
|
||||||
|
});
|
||||||
|
|
||||||
|
const toggleExpanded = (name) => {
|
||||||
|
const updated = new Set(expandedDiags);
|
||||||
|
if (updated.has(name)) {
|
||||||
|
updated.delete(name);
|
||||||
|
} else {
|
||||||
|
updated.add(name);
|
||||||
|
}
|
||||||
|
setExpandedDiags(updated);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col h-full space-y-3">
|
||||||
|
{/* Health Summary */}
|
||||||
|
<div className={`rounded-lg border p-3 space-y-3 ${overallColor.bg} ${overallColor.border}`}>
|
||||||
|
<div className="flex justify-between items-center">
|
||||||
|
<div className={`text-xs font-bold tracking-widest ${overallColor.text}`}>
|
||||||
|
SYSTEM HEALTH
|
||||||
|
</div>
|
||||||
|
<div className={`text-xs font-bold px-3 py-1 rounded ${overallColor.text}`}>
|
||||||
|
{overallColor.label}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Stats Grid */}
|
||||||
|
<div className="grid grid-cols-5 gap-2">
|
||||||
|
<div className="bg-gray-900 rounded p-2">
|
||||||
|
<div className="text-gray-600 text-xs">TOTAL</div>
|
||||||
|
<div className="text-lg font-mono text-cyan-300 font-bold">{stats.total}</div>
|
||||||
|
</div>
|
||||||
|
<div className="bg-green-950 rounded p-2">
|
||||||
|
<div className="text-gray-600 text-xs">OK</div>
|
||||||
|
<div className="text-lg font-mono text-green-400 font-bold">{stats.ok}</div>
|
||||||
|
</div>
|
||||||
|
<div className="bg-yellow-950 rounded p-2">
|
||||||
|
<div className="text-gray-600 text-xs">WARN</div>
|
||||||
|
<div className="text-lg font-mono text-yellow-400 font-bold">{stats.warning}</div>
|
||||||
|
</div>
|
||||||
|
<div className="bg-red-950 rounded p-2">
|
||||||
|
<div className="text-gray-600 text-xs">ERROR</div>
|
||||||
|
<div className="text-lg font-mono text-red-400 font-bold">{stats.error}</div>
|
||||||
|
</div>
|
||||||
|
<div className="bg-gray-900 rounded p-2">
|
||||||
|
<div className="text-gray-600 text-xs">STALE</div>
|
||||||
|
<div className="text-lg font-mono text-gray-400 font-bold">{stats.stale}</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Diagnostics Grid */}
|
||||||
|
<div className="flex-1 overflow-y-auto space-y-2">
|
||||||
|
{sortedDiags.length === 0 ? (
|
||||||
|
<div className="flex items-center justify-center h-32 text-gray-600">
|
||||||
|
<div className="text-center">
|
||||||
|
<div className="text-sm mb-2">Waiting for diagnostics</div>
|
||||||
|
<div className="text-xs text-gray-700">
|
||||||
|
Messages from /diagnostics will appear here
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
sortedDiags.map((diag) => (
|
||||||
|
<DiagnosticCard
|
||||||
|
key={diag.name}
|
||||||
|
diagnostic={diag}
|
||||||
|
expanded={expandedDiags.has(diag.name)}
|
||||||
|
onToggle={() => toggleExpanded(diag.name)}
|
||||||
|
/>
|
||||||
|
))
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Timeline and Info */}
|
||||||
|
<div className="grid grid-cols-2 gap-2">
|
||||||
|
{/* Timeline */}
|
||||||
|
<div className="bg-gray-950 rounded-lg border border-cyan-950 p-3">
|
||||||
|
<HealthTimeline statusHistory={statusHistory} maxEntries={10} />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Legend */}
|
||||||
|
<div className="bg-gray-950 rounded-lg border border-cyan-950 p-3 space-y-2">
|
||||||
|
<div className="text-xs font-bold text-gray-400 tracking-widest mb-2">STATUS LEGEND</div>
|
||||||
|
<div className="space-y-1 text-xs">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<div className="w-2 h-2 rounded-full bg-green-500" />
|
||||||
|
<span className="text-gray-400">OK — System nominal</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<div className="w-2 h-2 rounded-full bg-yellow-500" />
|
||||||
|
<span className="text-gray-400">WARN — Check required</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<div className="w-2 h-2 rounded-full bg-red-500" />
|
||||||
|
<span className="text-gray-400">ERROR — Immediate action</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<div className="w-2 h-2 rounded-full bg-gray-500" />
|
||||||
|
<span className="text-gray-400">STALE — No recent data</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Topic Info */}
|
||||||
|
<div className="bg-gray-950 rounded border border-gray-800 p-2 text-xs text-gray-600 space-y-1">
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>Topic:</span>
|
||||||
|
<span className="text-gray-500">/diagnostics (diagnostic_msgs/DiagnosticArray)</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>Status Levels:</span>
|
||||||
|
<span className="text-gray-500">0=OK, 1=WARN, 2=ERROR, 3=STALE</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>History Limit:</span>
|
||||||
|
<span className="text-gray-500">{MAX_HISTORY} events</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
331
ui/social-bot/src/components/HandTracker.jsx
Normal file
331
ui/social-bot/src/components/HandTracker.jsx
Normal file
@ -0,0 +1,331 @@
|
|||||||
|
/**
|
||||||
|
* HandTracker.jsx — Hand pose and gesture visualization
|
||||||
|
*
|
||||||
|
* Features:
|
||||||
|
* - Subscribes to /saltybot/hands (21 landmarks per hand)
|
||||||
|
* - Subscribes to /saltybot/hand_gesture (String gesture label)
|
||||||
|
* - Canvas-based hand skeleton rendering
|
||||||
|
* - Bone connections between landmarks
|
||||||
|
* - Support for dual hands (left and right)
|
||||||
|
* - Handedness indicator
|
||||||
|
* - Real-time gesture display
|
||||||
|
* - Confidence-based landmark rendering
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { useEffect, useRef, useState } from 'react';
|
||||||
|
|
||||||
|
// MediaPipe hand landmark connections (bones)
|
||||||
|
const HAND_CONNECTIONS = [
|
||||||
|
// Thumb
|
||||||
|
[0, 1], [1, 2], [2, 3], [3, 4],
|
||||||
|
// Index finger
|
||||||
|
[0, 5], [5, 6], [6, 7], [7, 8],
|
||||||
|
// Middle finger
|
||||||
|
[0, 9], [9, 10], [10, 11], [11, 12],
|
||||||
|
// Ring finger
|
||||||
|
[0, 13], [13, 14], [14, 15], [15, 16],
|
||||||
|
// Pinky finger
|
||||||
|
[0, 17], [17, 18], [18, 19], [19, 20],
|
||||||
|
// Palm connections
|
||||||
|
[5, 9], [9, 13], [13, 17],
|
||||||
|
];
|
||||||
|
|
||||||
|
const LANDMARK_NAMES = [
|
||||||
|
'Wrist',
|
||||||
|
'Thumb CMC', 'Thumb MCP', 'Thumb IP', 'Thumb Tip',
|
||||||
|
'Index MCP', 'Index PIP', 'Index DIP', 'Index Tip',
|
||||||
|
'Middle MCP', 'Middle PIP', 'Middle DIP', 'Middle Tip',
|
||||||
|
'Ring MCP', 'Ring PIP', 'Ring DIP', 'Ring Tip',
|
||||||
|
'Pinky MCP', 'Pinky PIP', 'Pinky DIP', 'Pinky Tip',
|
||||||
|
];
|
||||||
|
|
||||||
|
function HandCanvas({ hand, color, label }) {
|
||||||
|
const canvasRef = useRef(null);
|
||||||
|
const [flipped, setFlipped] = useState(false);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const canvas = canvasRef.current;
|
||||||
|
if (!canvas || !hand || !hand.landmarks || hand.landmarks.length === 0) return;
|
||||||
|
|
||||||
|
const ctx = canvas.getContext('2d');
|
||||||
|
const width = canvas.width;
|
||||||
|
const height = canvas.height;
|
||||||
|
|
||||||
|
// Clear canvas
|
||||||
|
ctx.fillStyle = '#1f2937';
|
||||||
|
ctx.fillRect(0, 0, width, height);
|
||||||
|
|
||||||
|
// Find min/max coordinates for scaling
|
||||||
|
let minX = Infinity, maxX = -Infinity;
|
||||||
|
let minY = Infinity, maxY = -Infinity;
|
||||||
|
|
||||||
|
hand.landmarks.forEach((lm) => {
|
||||||
|
minX = Math.min(minX, lm.x);
|
||||||
|
maxX = Math.max(maxX, lm.x);
|
||||||
|
minY = Math.min(minY, lm.y);
|
||||||
|
maxY = Math.max(maxY, lm.y);
|
||||||
|
});
|
||||||
|
|
||||||
|
const padding = 20;
|
||||||
|
const rangeX = maxX - minX || 1;
|
||||||
|
const rangeY = maxY - minY || 1;
|
||||||
|
const scaleX = (width - padding * 2) / rangeX;
|
||||||
|
const scaleY = (height - padding * 2) / rangeY;
|
||||||
|
const scale = Math.min(scaleX, scaleY);
|
||||||
|
|
||||||
|
// Convert landmark coordinates to canvas positions
|
||||||
|
const getCanvasPos = (lm) => {
|
||||||
|
const x = padding + (lm.x - minX) * scale;
|
||||||
|
const y = padding + (lm.y - minY) * scale;
|
||||||
|
return { x, y };
|
||||||
|
};
|
||||||
|
|
||||||
|
// Draw bones
|
||||||
|
ctx.strokeStyle = color;
|
||||||
|
ctx.lineWidth = 2;
|
||||||
|
ctx.lineCap = 'round';
|
||||||
|
ctx.lineJoin = 'round';
|
||||||
|
|
||||||
|
HAND_CONNECTIONS.forEach(([start, end]) => {
|
||||||
|
if (start < hand.landmarks.length && end < hand.landmarks.length) {
|
||||||
|
const startLm = hand.landmarks[start];
|
||||||
|
const endLm = hand.landmarks[end];
|
||||||
|
|
||||||
|
if (startLm.confidence > 0.1 && endLm.confidence > 0.1) {
|
||||||
|
const startPos = getCanvasPos(startLm);
|
||||||
|
const endPos = getCanvasPos(endLm);
|
||||||
|
|
||||||
|
ctx.globalAlpha = Math.min(startLm.confidence, endLm.confidence);
|
||||||
|
ctx.beginPath();
|
||||||
|
ctx.moveTo(startPos.x, startPos.y);
|
||||||
|
ctx.lineTo(endPos.x, endPos.y);
|
||||||
|
ctx.stroke();
|
||||||
|
ctx.globalAlpha = 1.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Draw landmarks
|
||||||
|
hand.landmarks.forEach((lm, i) => {
|
||||||
|
if (lm.confidence > 0.1) {
|
||||||
|
const pos = getCanvasPos(lm);
|
||||||
|
const radius = 4 + lm.confidence * 2;
|
||||||
|
|
||||||
|
// Landmark glow
|
||||||
|
ctx.fillStyle = color;
|
||||||
|
ctx.globalAlpha = lm.confidence * 0.3;
|
||||||
|
ctx.beginPath();
|
||||||
|
ctx.arc(pos.x, pos.y, radius * 2, 0, Math.PI * 2);
|
||||||
|
ctx.fill();
|
||||||
|
|
||||||
|
// Landmark point
|
||||||
|
ctx.fillStyle = color;
|
||||||
|
ctx.globalAlpha = lm.confidence;
|
||||||
|
ctx.beginPath();
|
||||||
|
ctx.arc(pos.x, pos.y, radius, 0, Math.PI * 2);
|
||||||
|
ctx.fill();
|
||||||
|
|
||||||
|
// Joint type marker
|
||||||
|
const isJoint = i > 0 && i % 4 === 0; // Tip joints
|
||||||
|
if (isJoint) {
|
||||||
|
ctx.strokeStyle = '#fff';
|
||||||
|
ctx.lineWidth = 1;
|
||||||
|
ctx.globalAlpha = lm.confidence;
|
||||||
|
ctx.beginPath();
|
||||||
|
ctx.arc(pos.x, pos.y, radius + 2, 0, Math.PI * 2);
|
||||||
|
ctx.stroke();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
ctx.globalAlpha = 1.0;
|
||||||
|
|
||||||
|
// Draw handedness label on canvas
|
||||||
|
ctx.fillStyle = color;
|
||||||
|
ctx.font = 'bold 12px monospace';
|
||||||
|
ctx.textAlign = 'left';
|
||||||
|
ctx.fillText(label, 10, height - 10);
|
||||||
|
}, [hand, color, label]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col items-center gap-2">
|
||||||
|
<div className="text-xs font-bold text-gray-400">{label}</div>
|
||||||
|
<canvas
|
||||||
|
ref={canvasRef}
|
||||||
|
width={280}
|
||||||
|
height={320}
|
||||||
|
className="border-2 border-gray-800 rounded-lg bg-gray-900"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function HandTracker({ subscribe }) {
|
||||||
|
const [leftHand, setLeftHand] = useState(null);
|
||||||
|
const [rightHand, setRightHand] = useState(null);
|
||||||
|
const [gesture, setGesture] = useState('');
|
||||||
|
const [confidence, setConfidence] = useState(0);
|
||||||
|
const handsRef = useRef({});
|
||||||
|
|
||||||
|
// Subscribe to hand poses
|
||||||
|
useEffect(() => {
|
||||||
|
const unsubscribe = subscribe(
|
||||||
|
'/saltybot/hands',
|
||||||
|
'saltybot_msgs/HandPose',
|
||||||
|
(msg) => {
|
||||||
|
try {
|
||||||
|
if (!msg) return;
|
||||||
|
|
||||||
|
// Handle both single hand and multi-hand formats
|
||||||
|
if (Array.isArray(msg)) {
|
||||||
|
// Multi-hand format: array of hands
|
||||||
|
const left = msg.find((h) => h.handedness === 'Left' || h.handedness === 0);
|
||||||
|
const right = msg.find((h) => h.handedness === 'Right' || h.handedness === 1);
|
||||||
|
if (left) setLeftHand(left);
|
||||||
|
if (right) setRightHand(right);
|
||||||
|
} else if (msg.handedness) {
|
||||||
|
// Single hand format
|
||||||
|
if (msg.handedness === 'Left' || msg.handedness === 0) {
|
||||||
|
setLeftHand(msg);
|
||||||
|
} else {
|
||||||
|
setRightHand(msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
handsRef.current = { left: leftHand, right: rightHand };
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Error parsing hand pose:', e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
return unsubscribe;
|
||||||
|
}, [subscribe, leftHand, rightHand]);
|
||||||
|
|
||||||
|
// Subscribe to gesture
|
||||||
|
useEffect(() => {
|
||||||
|
const unsubscribe = subscribe(
|
||||||
|
'/saltybot/hand_gesture',
|
||||||
|
'std_msgs/String',
|
||||||
|
(msg) => {
|
||||||
|
try {
|
||||||
|
if (msg.data) {
|
||||||
|
// Parse gesture data (format: "gesture_name confidence")
|
||||||
|
const parts = msg.data.split(' ');
|
||||||
|
setGesture(parts[0] || '');
|
||||||
|
if (parts[1]) {
|
||||||
|
setConfidence(parseFloat(parts[1]) || 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Error parsing gesture:', e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
return unsubscribe;
|
||||||
|
}, [subscribe]);
|
||||||
|
|
||||||
|
const hasLeftHand = leftHand && leftHand.landmarks && leftHand.landmarks.length > 0;
|
||||||
|
const hasRightHand = rightHand && rightHand.landmarks && rightHand.landmarks.length > 0;
|
||||||
|
const gestureConfidentColor =
|
||||||
|
confidence > 0.8 ? 'text-green-400' : confidence > 0.5 ? 'text-yellow-400' : 'text-gray-400';
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col h-full space-y-3">
|
||||||
|
{/* Gesture Display */}
|
||||||
|
<div className="bg-gray-950 rounded-lg border border-cyan-950 p-3 space-y-2">
|
||||||
|
<div className="flex justify-between items-center">
|
||||||
|
<div className="text-cyan-700 text-xs font-bold tracking-widest">GESTURE</div>
|
||||||
|
<div className={`text-sm font-mono font-bold ${gestureConfidentColor}`}>
|
||||||
|
{gesture || 'Detecting...'}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Confidence bar */}
|
||||||
|
{gesture && (
|
||||||
|
<div className="space-y-1">
|
||||||
|
<div className="text-xs text-gray-600">Confidence</div>
|
||||||
|
<div className="w-full bg-gray-900 rounded-full h-2 overflow-hidden">
|
||||||
|
<div
|
||||||
|
className={`h-full transition-all ${
|
||||||
|
confidence > 0.8
|
||||||
|
? 'bg-green-500'
|
||||||
|
: confidence > 0.5
|
||||||
|
? 'bg-yellow-500'
|
||||||
|
: 'bg-blue-500'
|
||||||
|
}`}
|
||||||
|
style={{ width: `${Math.round(confidence * 100)}%` }}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div className="text-right text-xs text-gray-500">
|
||||||
|
{Math.round(confidence * 100)}%
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Hand Renders */}
|
||||||
|
<div className="flex-1 overflow-y-auto">
|
||||||
|
{hasLeftHand || hasRightHand ? (
|
||||||
|
<div className="flex gap-3 justify-center flex-wrap">
|
||||||
|
{hasLeftHand && (
|
||||||
|
<HandCanvas hand={leftHand} color="#10b981" label="LEFT HAND" />
|
||||||
|
)}
|
||||||
|
{hasRightHand && (
|
||||||
|
<HandCanvas hand={rightHand} color="#f59e0b" label="RIGHT HAND" />
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<div className="flex items-center justify-center h-full text-gray-600">
|
||||||
|
<div className="text-center">
|
||||||
|
<div className="text-sm mb-2">Waiting for hand detection</div>
|
||||||
|
<div className="text-xs text-gray-700">
|
||||||
|
Hands will appear when detected on /saltybot/hands
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Hand Info */}
|
||||||
|
<div className="bg-gray-950 rounded-lg border border-cyan-950 p-3 space-y-2">
|
||||||
|
<div className="text-xs font-bold text-gray-400 tracking-widest mb-2">HAND STATUS</div>
|
||||||
|
<div className="grid grid-cols-2 gap-2">
|
||||||
|
<div className="bg-green-950 rounded p-2">
|
||||||
|
<div className="text-xs text-gray-600">LEFT</div>
|
||||||
|
<div className="text-lg font-mono text-green-400">
|
||||||
|
{hasLeftHand ? '✓' : '◯'}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="bg-yellow-950 rounded p-2">
|
||||||
|
<div className="text-xs text-gray-600">RIGHT</div>
|
||||||
|
<div className="text-lg font-mono text-yellow-400">
|
||||||
|
{hasRightHand ? '✓' : '◯'}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Landmark Info */}
|
||||||
|
<div className="bg-gray-950 rounded border border-gray-800 p-2 text-xs text-gray-600 space-y-1">
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>Hand Format:</span>
|
||||||
|
<span className="text-gray-500">21 landmarks per hand (MediaPipe)</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>Topics:</span>
|
||||||
|
<span className="text-gray-500">/saltybot/hands, /saltybot/hand_gesture</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>Landmarks:</span>
|
||||||
|
<span className="text-gray-500">Wrist + fingers (5×4 joints)</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>Color Code:</span>
|
||||||
|
<span className="text-gray-500">🟢 Left | 🟠 Right</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@ -15,7 +15,7 @@ import {
|
|||||||
simulateStepResponse, validatePID,
|
simulateStepResponse, validatePID,
|
||||||
} from '../hooks/useSettings.js';
|
} from '../hooks/useSettings.js';
|
||||||
|
|
||||||
const VIEWS = ['PID', 'Sensors', 'Network', 'Firmware', 'Diagnostics', 'Backup'];
|
const VIEWS = ['Parameters', 'PID', 'Sensors', 'Network', 'Firmware', 'Diagnostics', 'Backup'];
|
||||||
|
|
||||||
function ValidationBadges({ warnings }) {
|
function ValidationBadges({ warnings }) {
|
||||||
if (!warnings?.length) return (
|
if (!warnings?.length) return (
|
||||||
@ -377,6 +377,204 @@ function DiagnosticsView({ exportDiagnosticsBundle, subscribe, connected }) {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function ParametersView({ callService, subscribe, connected }) {
|
||||||
|
const [params, setParams] = useState({});
|
||||||
|
const [loading, setLoading] = useState(true);
|
||||||
|
const [updating, setUpdating] = useState(null);
|
||||||
|
const [result, setResult] = useState(null);
|
||||||
|
const [searchFilter, setSearchFilter] = useState('');
|
||||||
|
|
||||||
|
// Fetch all ROS parameters on mount
|
||||||
|
useEffect(() => {
|
||||||
|
if (!connected || !callService) return;
|
||||||
|
|
||||||
|
setLoading(true);
|
||||||
|
// Call get_parameters service to list all params
|
||||||
|
callService('/rcl_interfaces/srv/GetParameters', {
|
||||||
|
names: [] // Empty list means get all parameters
|
||||||
|
}).then((resp) => {
|
||||||
|
if (resp.values) {
|
||||||
|
const newParams = {};
|
||||||
|
resp.names.forEach((name, i) => {
|
||||||
|
newParams[name] = resp.values[i];
|
||||||
|
});
|
||||||
|
setParams(newParams);
|
||||||
|
}
|
||||||
|
setLoading(false);
|
||||||
|
}).catch((err) => {
|
||||||
|
console.error('Failed to fetch parameters:', err);
|
||||||
|
setLoading(false);
|
||||||
|
});
|
||||||
|
}, [connected, callService]);
|
||||||
|
|
||||||
|
// Handle parameter edit
|
||||||
|
const handleParamChange = (paramName, newValue) => {
|
||||||
|
setParams(p => ({ ...p, [paramName]: newValue }));
|
||||||
|
};
|
||||||
|
|
||||||
|
// Apply parameter update
|
||||||
|
const applyParam = async (paramName, value) => {
|
||||||
|
if (!callService) return;
|
||||||
|
setUpdating(paramName);
|
||||||
|
try {
|
||||||
|
const resp = await callService('/rcl_interfaces/srv/SetParameters', {
|
||||||
|
parameters: [{
|
||||||
|
name: paramName,
|
||||||
|
value: {
|
||||||
|
type: detectParamType(value),
|
||||||
|
...(detectParamType(value) === 4 ? { integer_value: parseInt(value) } :
|
||||||
|
detectParamType(value) === 1 ? { double_value: parseFloat(value) } :
|
||||||
|
detectParamType(value) === 5 ? { bool_value: Boolean(value) } :
|
||||||
|
detectParamType(value) === 3 ? { string_value: String(value) } : {})
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
});
|
||||||
|
if (resp.results?.[0]?.successful) {
|
||||||
|
setResult({ ok: true, msg: `${paramName} updated` });
|
||||||
|
} else {
|
||||||
|
setResult({ ok: false, msg: `Failed to update ${paramName}` });
|
||||||
|
}
|
||||||
|
setTimeout(() => setResult(null), 3000);
|
||||||
|
} catch (err) {
|
||||||
|
setResult({ ok: false, msg: 'Update failed: ' + err.message });
|
||||||
|
setTimeout(() => setResult(null), 3000);
|
||||||
|
} finally {
|
||||||
|
setUpdating(null);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Group parameters by node name (part before /)
|
||||||
|
const grouped = {};
|
||||||
|
Object.keys(params).forEach(name => {
|
||||||
|
const parts = name.split('/');
|
||||||
|
const node = parts.length > 1 ? parts[1] : 'root';
|
||||||
|
if (!grouped[node]) grouped[node] = [];
|
||||||
|
grouped[node].push(name);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Filter by search
|
||||||
|
const filteredGroups = {};
|
||||||
|
Object.entries(grouped).forEach(([node, names]) => {
|
||||||
|
const filtered = names.filter(n => n.toLowerCase().includes(searchFilter.toLowerCase()));
|
||||||
|
if (filtered.length > 0) {
|
||||||
|
filteredGroups[node] = filtered;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
const detectParamType = (val) => {
|
||||||
|
if (typeof val === 'boolean') return 5; // bool
|
||||||
|
if (Number.isInteger(val)) return 4; // int64
|
||||||
|
if (typeof val === 'number') return 1; // double
|
||||||
|
return 3; // string
|
||||||
|
};
|
||||||
|
|
||||||
|
const renderParamInput = (name, value) => {
|
||||||
|
const type = detectParamType(value);
|
||||||
|
if (type === 5) { // bool
|
||||||
|
return (
|
||||||
|
<label className="flex items-center gap-2 text-xs cursor-pointer">
|
||||||
|
<div onClick={() => handleParamChange(name, !value)}
|
||||||
|
className={`w-6 h-3 rounded-full relative cursor-pointer transition-colors ${value ? 'bg-cyan-700' : 'bg-gray-700'}`}>
|
||||||
|
<span className={`absolute top-0.5 w-2 h-2 rounded-full bg-white transition-all ${value ? 'left-3' : 'left-0.5'}`}/>
|
||||||
|
</div>
|
||||||
|
<span className="text-gray-400">{String(value)}</span>
|
||||||
|
</label>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<input type={type === 1 ? 'number' : 'text'} step={type === 1 ? '0.01' : undefined}
|
||||||
|
value={value} onChange={(e) => handleParamChange(name, e.target.value)}
|
||||||
|
className="flex-1 bg-gray-900 border border-gray-700 rounded px-2 py-1 text-xs text-cyan-200 focus:outline-none focus:border-cyan-700" />
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="space-y-4 flex flex-col h-full">
|
||||||
|
<div className="flex items-center gap-2 flex-wrap">
|
||||||
|
<div className="text-cyan-700 text-xs font-bold tracking-widest">ROS PARAMETERS</div>
|
||||||
|
<span className={`text-xs px-1.5 py-0.5 rounded border ml-auto ${connected ? 'text-green-400 border-green-800' : 'text-gray-600 border-gray-700'}`}>
|
||||||
|
{connected ? 'LIVE' : 'OFFLINE'}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<input type="text" placeholder="Search parameters..."
|
||||||
|
value={searchFilter} onChange={(e) => setSearchFilter(e.target.value)}
|
||||||
|
className="w-full bg-gray-900 border border-gray-700 rounded px-2 py-1.5 text-xs text-gray-200 focus:outline-none focus:border-cyan-700" />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{loading && (
|
||||||
|
<div className="flex items-center justify-center py-8 text-gray-600">
|
||||||
|
<div>Loading parameters…</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{!loading && Object.keys(filteredGroups).length === 0 && (
|
||||||
|
<div className="flex items-center justify-center py-8 text-gray-600">
|
||||||
|
<div className="text-center">
|
||||||
|
<div className="text-sm mb-1">No parameters found</div>
|
||||||
|
<div className="text-xs text-gray-700">{searchFilter ? 'Try a different search term' : 'Ensure robot is connected'}</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div className="flex-1 overflow-y-auto space-y-3">
|
||||||
|
{Object.entries(filteredGroups).map(([node, names]) => (
|
||||||
|
<div key={node} className="bg-gray-950 border border-gray-800 rounded-lg p-3 space-y-2">
|
||||||
|
<div className="text-gray-500 text-xs font-bold font-mono uppercase">{node}</div>
|
||||||
|
<div className="space-y-2">
|
||||||
|
{names.sort().map(name => {
|
||||||
|
const shortName = name.split('/').pop();
|
||||||
|
const value = params[name];
|
||||||
|
const type = detectParamType(value);
|
||||||
|
const isUpdating = updating === name;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div key={name} className="flex items-center gap-2 text-xs">
|
||||||
|
<span className="text-gray-600 w-32 truncate" title={shortName}>{shortName}</span>
|
||||||
|
<div className="flex-1 flex items-center gap-1">
|
||||||
|
{renderParamInput(name, value)}
|
||||||
|
<button onClick={() => applyParam(name, params[name])} disabled={isUpdating || !connected}
|
||||||
|
className="px-2 py-1 rounded bg-cyan-950 border border-cyan-700 text-cyan-300 hover:bg-cyan-900 text-xs font-bold disabled:opacity-40 whitespace-nowrap">
|
||||||
|
{isUpdating ? '…' : 'SET'}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<span className={`text-xs px-1 py-0.5 rounded font-mono ${
|
||||||
|
type === 5 ? 'bg-blue-950 text-blue-400' :
|
||||||
|
type === 4 ? 'bg-yellow-950 text-yellow-400' :
|
||||||
|
type === 1 ? 'bg-green-950 text-green-400' :
|
||||||
|
'bg-gray-800 text-gray-400'
|
||||||
|
}`}>
|
||||||
|
{type === 5 ? 'bool' : type === 4 ? 'int' : type === 1 ? 'float' : 'str'}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{result && (
|
||||||
|
<div className={`text-xs rounded px-2 py-1 border ${
|
||||||
|
result.ok ? 'bg-green-950 border-green-800 text-green-400' : 'bg-red-950 border-red-800 text-red-400'
|
||||||
|
}`}>{result.ok ? '✓ ' : '✕ '}{result.msg}</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div className="bg-gray-950 rounded border border-gray-800 p-2 text-xs text-gray-600 space-y-1">
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>Total Parameters:</span>
|
||||||
|
<span className="text-gray-500">{Object.keys(params).length}</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>Grouped by Node:</span>
|
||||||
|
<span className="text-gray-500">{Object.keys(grouped).length} nodes</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
function BackupView({ exportSettingsJSON, importSettingsJSON }) {
|
function BackupView({ exportSettingsJSON, importSettingsJSON }) {
|
||||||
const [importText, setImportText] = useState('');
|
const [importText, setImportText] = useState('');
|
||||||
const [showImport, setShowImport] = useState(false);
|
const [showImport, setShowImport] = useState(false);
|
||||||
@ -441,6 +639,7 @@ export function SettingsPanel({ subscribe, callService, connected = false, wsUrl
|
|||||||
}`}>{v.toUpperCase()}</button>
|
}`}>{v.toUpperCase()}</button>
|
||||||
))}
|
))}
|
||||||
</div>
|
</div>
|
||||||
|
{view==='Parameters' && <ParametersView callService={callService} subscribe={subscribe} connected={connected} />}
|
||||||
{view==='PID' && <PIDView gains={settings.gains} setGains={settings.setGains} applyPIDGains={settings.applyPIDGains} applying={settings.applying} applyResult={settings.applyResult} connected={connected} />}
|
{view==='PID' && <PIDView gains={settings.gains} setGains={settings.setGains} applyPIDGains={settings.applyPIDGains} applying={settings.applying} applyResult={settings.applyResult} connected={connected} />}
|
||||||
{view==='Sensors' && <SensorsView sensors={settings.sensors} setSensors={settings.setSensors} applySensorParams={settings.applySensorParams} applying={settings.applying} applyResult={settings.applyResult} connected={connected} />}
|
{view==='Sensors' && <SensorsView sensors={settings.sensors} setSensors={settings.setSensors} applySensorParams={settings.applySensorParams} applying={settings.applying} applyResult={settings.applyResult} connected={connected} />}
|
||||||
{view==='Network' && <NetworkView wsUrl={wsUrl} connected={connected} />}
|
{view==='Network' && <NetworkView wsUrl={wsUrl} connected={connected} />}
|
||||||
|
|||||||
384
ui/social-bot/src/components/Teleop.jsx
Normal file
384
ui/social-bot/src/components/Teleop.jsx
Normal file
@ -0,0 +1,384 @@
|
|||||||
|
/**
|
||||||
|
* Teleop.jsx — Gamepad and keyboard teleoperation controller
|
||||||
|
*
|
||||||
|
* Features:
|
||||||
|
* - Virtual dual-stick gamepad (left=linear, right=angular velocity)
|
||||||
|
* - WASD keyboard fallback for manual driving
|
||||||
|
* - Speed limiter slider for safe operation
|
||||||
|
* - E-stop button for emergency stop
|
||||||
|
* - Real-time velocity display (m/s and rad/s)
|
||||||
|
* - Publishes geometry_msgs/Twist to /cmd_vel
|
||||||
|
* - Visual feedback with stick position and velocity vectors
|
||||||
|
* - 10% deadzone on both axes
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { useEffect, useRef, useState } from 'react';
|
||||||
|
|
||||||
|
const MAX_LINEAR_VELOCITY = 0.5; // m/s
|
||||||
|
const MAX_ANGULAR_VELOCITY = 1.0; // rad/s
|
||||||
|
const DEADZONE = 0.1; // 10% deadzone
|
||||||
|
const STICK_UPDATE_RATE = 50; // ms
|
||||||
|
|
||||||
|
function VirtualStick({
|
||||||
|
position,
|
||||||
|
onMove,
|
||||||
|
label,
|
||||||
|
color,
|
||||||
|
maxValue = 1.0,
|
||||||
|
}) {
|
||||||
|
const canvasRef = useRef(null);
|
||||||
|
const containerRef = useRef(null);
|
||||||
|
const isDraggingRef = useRef(false);
|
||||||
|
|
||||||
|
// Draw stick
|
||||||
|
useEffect(() => {
|
||||||
|
const canvas = canvasRef.current;
|
||||||
|
if (!canvas) return;
|
||||||
|
|
||||||
|
const ctx = canvas.getContext('2d');
|
||||||
|
const width = canvas.width;
|
||||||
|
const height = canvas.height;
|
||||||
|
const centerX = width / 2;
|
||||||
|
const centerY = height / 2;
|
||||||
|
const baseRadius = Math.min(width, height) * 0.35;
|
||||||
|
const knobRadius = Math.min(width, height) * 0.15;
|
||||||
|
|
||||||
|
// Clear canvas
|
||||||
|
ctx.fillStyle = '#1f2937';
|
||||||
|
ctx.fillRect(0, 0, width, height);
|
||||||
|
|
||||||
|
// Draw base circle
|
||||||
|
ctx.strokeStyle = '#374151';
|
||||||
|
ctx.lineWidth = 2;
|
||||||
|
ctx.beginPath();
|
||||||
|
ctx.arc(centerX, centerY, baseRadius, 0, Math.PI * 2);
|
||||||
|
ctx.stroke();
|
||||||
|
|
||||||
|
// Draw center crosshair
|
||||||
|
ctx.strokeStyle = '#4b5563';
|
||||||
|
ctx.lineWidth = 1;
|
||||||
|
ctx.beginPath();
|
||||||
|
ctx.moveTo(centerX - baseRadius * 0.3, centerY);
|
||||||
|
ctx.lineTo(centerX + baseRadius * 0.3, centerY);
|
||||||
|
ctx.moveTo(centerX, centerY - baseRadius * 0.3);
|
||||||
|
ctx.lineTo(centerX, centerY + baseRadius * 0.3);
|
||||||
|
ctx.stroke();
|
||||||
|
|
||||||
|
// Draw deadzone circle
|
||||||
|
ctx.strokeStyle = '#4b5563';
|
||||||
|
ctx.lineWidth = 1;
|
||||||
|
ctx.globalAlpha = 0.5;
|
||||||
|
ctx.beginPath();
|
||||||
|
ctx.arc(centerX, centerY, baseRadius * DEADZONE, 0, Math.PI * 2);
|
||||||
|
ctx.stroke();
|
||||||
|
ctx.globalAlpha = 1.0;
|
||||||
|
|
||||||
|
// Draw knob at current position
|
||||||
|
const knobX = centerX + (position.x / maxValue) * baseRadius;
|
||||||
|
const knobY = centerY - (position.y / maxValue) * baseRadius;
|
||||||
|
|
||||||
|
// Knob shadow
|
||||||
|
ctx.fillStyle = 'rgba(0, 0, 0, 0.3)';
|
||||||
|
ctx.beginPath();
|
||||||
|
ctx.arc(knobX + 2, knobY + 2, knobRadius, 0, Math.PI * 2);
|
||||||
|
ctx.fill();
|
||||||
|
|
||||||
|
// Knob
|
||||||
|
ctx.fillStyle = color;
|
||||||
|
ctx.beginPath();
|
||||||
|
ctx.arc(knobX, knobY, knobRadius, 0, Math.PI * 2);
|
||||||
|
ctx.fill();
|
||||||
|
|
||||||
|
// Knob border
|
||||||
|
ctx.strokeStyle = '#fff';
|
||||||
|
ctx.lineWidth = 2;
|
||||||
|
ctx.beginPath();
|
||||||
|
ctx.arc(knobX, knobY, knobRadius, 0, Math.PI * 2);
|
||||||
|
ctx.stroke();
|
||||||
|
|
||||||
|
// Draw velocity vector
|
||||||
|
if (Math.abs(position.x) > DEADZONE || Math.abs(position.y) > DEADZONE) {
|
||||||
|
ctx.strokeStyle = color;
|
||||||
|
ctx.lineWidth = 2;
|
||||||
|
ctx.globalAlpha = 0.7;
|
||||||
|
ctx.beginPath();
|
||||||
|
ctx.moveTo(centerX, centerY);
|
||||||
|
ctx.lineTo(knobX, knobY);
|
||||||
|
ctx.stroke();
|
||||||
|
ctx.globalAlpha = 1.0;
|
||||||
|
}
|
||||||
|
}, [position, color, maxValue]);
|
||||||
|
|
||||||
|
const handlePointerDown = (e) => {
|
||||||
|
isDraggingRef.current = true;
|
||||||
|
updateStickPosition(e);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handlePointerMove = (e) => {
|
||||||
|
if (!isDraggingRef.current) return;
|
||||||
|
updateStickPosition(e);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handlePointerUp = () => {
|
||||||
|
isDraggingRef.current = false;
|
||||||
|
onMove({ x: 0, y: 0 });
|
||||||
|
};
|
||||||
|
|
||||||
|
const updateStickPosition = (e) => {
|
||||||
|
const canvas = canvasRef.current;
|
||||||
|
const rect = canvas.getBoundingClientRect();
|
||||||
|
const centerX = rect.width / 2;
|
||||||
|
const centerY = rect.height / 2;
|
||||||
|
const baseRadius = Math.min(rect.width, rect.height) * 0.35;
|
||||||
|
|
||||||
|
const x = e.clientX - rect.left - centerX;
|
||||||
|
const y = -(e.clientY - rect.top - centerY);
|
||||||
|
|
||||||
|
const magnitude = Math.sqrt(x * x + y * y);
|
||||||
|
const angle = Math.atan2(y, x);
|
||||||
|
|
||||||
|
let clampedMagnitude = Math.min(magnitude, baseRadius) / baseRadius;
|
||||||
|
|
||||||
|
// Apply deadzone
|
||||||
|
if (clampedMagnitude < DEADZONE) {
|
||||||
|
clampedMagnitude = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
onMove({
|
||||||
|
x: Math.cos(angle) * clampedMagnitude,
|
||||||
|
y: Math.sin(angle) * clampedMagnitude,
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div ref={containerRef} className="flex flex-col items-center gap-2">
|
||||||
|
<div className="text-xs font-bold text-gray-400 tracking-widest">{label}</div>
|
||||||
|
<canvas
|
||||||
|
ref={canvasRef}
|
||||||
|
width={160}
|
||||||
|
height={160}
|
||||||
|
className="border-2 border-gray-800 rounded-lg bg-gray-900 cursor-grab active:cursor-grabbing"
|
||||||
|
onPointerDown={handlePointerDown}
|
||||||
|
onPointerMove={handlePointerMove}
|
||||||
|
onPointerUp={handlePointerUp}
|
||||||
|
onPointerLeave={handlePointerUp}
|
||||||
|
style={{ touchAction: 'none' }}
|
||||||
|
/>
|
||||||
|
<div className="text-xs text-gray-500 font-mono">
|
||||||
|
X: {position.x.toFixed(2)} Y: {position.y.toFixed(2)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function Teleop({ publish }) {
|
||||||
|
const [leftStick, setLeftStick] = useState({ x: 0, y: 0 });
|
||||||
|
const [rightStick, setRightStick] = useState({ x: 0, y: 0 });
|
||||||
|
const [speedLimit, setSpeedLimit] = useState(1.0);
|
||||||
|
const [isEstopped, setIsEstopped] = useState(false);
|
||||||
|
const [linearVel, setLinearVel] = useState(0);
|
||||||
|
const [angularVel, setAngularVel] = useState(0);
|
||||||
|
|
||||||
|
const keysPressed = useRef({});
|
||||||
|
const publishIntervalRef = useRef(null);
|
||||||
|
|
||||||
|
// Keyboard handling
|
||||||
|
useEffect(() => {
|
||||||
|
const handleKeyDown = (e) => {
|
||||||
|
const key = e.key.toLowerCase();
|
||||||
|
if (['w', 'a', 's', 'd', ' '].includes(key)) {
|
||||||
|
keysPressed.current[key] = true;
|
||||||
|
e.preventDefault();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleKeyUp = (e) => {
|
||||||
|
const key = e.key.toLowerCase();
|
||||||
|
if (['w', 'a', 's', 'd', ' '].includes(key)) {
|
||||||
|
keysPressed.current[key] = false;
|
||||||
|
e.preventDefault();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
window.addEventListener('keydown', handleKeyDown);
|
||||||
|
window.addEventListener('keyup', handleKeyUp);
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
window.removeEventListener('keydown', handleKeyDown);
|
||||||
|
window.removeEventListener('keyup', handleKeyUp);
|
||||||
|
};
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// Calculate velocities from input
|
||||||
|
useEffect(() => {
|
||||||
|
const interval = setInterval(() => {
|
||||||
|
let linear = leftStick.y;
|
||||||
|
let angular = rightStick.x;
|
||||||
|
|
||||||
|
// WASD fallback
|
||||||
|
if (keysPressed.current['w']) linear = Math.min(1, linear + 0.5);
|
||||||
|
if (keysPressed.current['s']) linear = Math.max(-1, linear - 0.5);
|
||||||
|
if (keysPressed.current['d']) angular = Math.min(1, angular + 0.5);
|
||||||
|
if (keysPressed.current['a']) angular = Math.max(-1, angular - 0.5);
|
||||||
|
|
||||||
|
// Clamp to [-1, 1]
|
||||||
|
linear = Math.max(-1, Math.min(1, linear));
|
||||||
|
angular = Math.max(-1, Math.min(1, angular));
|
||||||
|
|
||||||
|
// Apply speed limit
|
||||||
|
const speedFactor = isEstopped ? 0 : speedLimit;
|
||||||
|
const finalLinear = linear * MAX_LINEAR_VELOCITY * speedFactor;
|
||||||
|
const finalAngular = angular * MAX_ANGULAR_VELOCITY * speedFactor;
|
||||||
|
|
||||||
|
setLinearVel(finalLinear);
|
||||||
|
setAngularVel(finalAngular);
|
||||||
|
|
||||||
|
// Publish Twist
|
||||||
|
if (publish) {
|
||||||
|
publish('/cmd_vel', 'geometry_msgs/Twist', {
|
||||||
|
linear: { x: finalLinear, y: 0, z: 0 },
|
||||||
|
angular: { x: 0, y: 0, z: finalAngular },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, STICK_UPDATE_RATE);
|
||||||
|
|
||||||
|
return () => clearInterval(interval);
|
||||||
|
}, [leftStick, rightStick, speedLimit, isEstopped, publish]);
|
||||||
|
|
||||||
|
const handleEstop = () => {
|
||||||
|
setIsEstopped(!isEstopped);
|
||||||
|
// Publish stop immediately
|
||||||
|
if (publish) {
|
||||||
|
publish('/cmd_vel', 'geometry_msgs/Twist', {
|
||||||
|
linear: { x: 0, y: 0, z: 0 },
|
||||||
|
angular: { x: 0, y: 0, z: 0 },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col h-full space-y-3">
|
||||||
|
{/* Status bar */}
|
||||||
|
<div className={`rounded-lg border p-3 space-y-2 ${
|
||||||
|
isEstopped
|
||||||
|
? 'bg-red-950 border-red-900'
|
||||||
|
: 'bg-gray-950 border-cyan-950'
|
||||||
|
}`}>
|
||||||
|
<div className="flex justify-between items-center">
|
||||||
|
<div className={`text-xs font-bold tracking-widest ${
|
||||||
|
isEstopped ? 'text-red-700' : 'text-cyan-700'
|
||||||
|
}`}>
|
||||||
|
{isEstopped ? '🛑 E-STOP ACTIVE' : '⚡ TELEOP READY'}
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
onClick={handleEstop}
|
||||||
|
className={`px-3 py-1 text-xs font-bold rounded border transition-colors ${
|
||||||
|
isEstopped
|
||||||
|
? 'bg-green-950 border-green-800 text-green-400 hover:bg-green-900'
|
||||||
|
: 'bg-red-950 border-red-800 text-red-400 hover:bg-red-900'
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
{isEstopped ? 'RESUME' : 'E-STOP'}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Velocity display */}
|
||||||
|
<div className="grid grid-cols-2 gap-2">
|
||||||
|
<div className="bg-gray-900 rounded p-2">
|
||||||
|
<div className="text-gray-600 text-xs">LINEAR</div>
|
||||||
|
<div className="text-lg font-mono text-cyan-300">
|
||||||
|
{linearVel.toFixed(2)} m/s
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="bg-gray-900 rounded p-2">
|
||||||
|
<div className="text-gray-600 text-xs">ANGULAR</div>
|
||||||
|
<div className="text-lg font-mono text-amber-300">
|
||||||
|
{angularVel.toFixed(2)} rad/s
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Gamepad area */}
|
||||||
|
<div className="flex-1 bg-gray-950 rounded-lg border border-cyan-950 p-4 flex justify-center items-center gap-8">
|
||||||
|
<VirtualStick
|
||||||
|
position={leftStick}
|
||||||
|
onMove={setLeftStick}
|
||||||
|
label="LEFT — LINEAR"
|
||||||
|
color="#10b981"
|
||||||
|
maxValue={1.0}
|
||||||
|
/>
|
||||||
|
<VirtualStick
|
||||||
|
position={rightStick}
|
||||||
|
onMove={setRightStick}
|
||||||
|
label="RIGHT — ANGULAR"
|
||||||
|
color="#f59e0b"
|
||||||
|
maxValue={1.0}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Speed limiter */}
|
||||||
|
<div className="bg-gray-950 rounded-lg border border-cyan-950 p-3 space-y-2">
|
||||||
|
<div className="flex justify-between items-center">
|
||||||
|
<div className="text-cyan-700 text-xs font-bold tracking-widest">
|
||||||
|
SPEED LIMITER
|
||||||
|
</div>
|
||||||
|
<div className="text-gray-400 text-xs font-mono">
|
||||||
|
{(speedLimit * 100).toFixed(0)}%
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<input
|
||||||
|
type="range"
|
||||||
|
min="0"
|
||||||
|
max="1"
|
||||||
|
step="0.05"
|
||||||
|
value={speedLimit}
|
||||||
|
onChange={(e) => setSpeedLimit(parseFloat(e.target.value))}
|
||||||
|
disabled={isEstopped}
|
||||||
|
className="w-full cursor-pointer"
|
||||||
|
style={{
|
||||||
|
accentColor: isEstopped ? '#6b7280' : '#06b6d4',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<div className="grid grid-cols-3 gap-2 text-xs text-gray-600">
|
||||||
|
<div className="text-center">0%</div>
|
||||||
|
<div className="text-center">50%</div>
|
||||||
|
<div className="text-center">100%</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Control info */}
|
||||||
|
<div className="bg-gray-950 rounded border border-gray-800 p-2 text-xs text-gray-600 space-y-1">
|
||||||
|
<div className="font-bold text-cyan-700 mb-2">KEYBOARD FALLBACK</div>
|
||||||
|
<div className="grid grid-cols-2 gap-2">
|
||||||
|
<div>
|
||||||
|
<span className="text-gray-500">W/S:</span> <span className="text-gray-400 font-mono">Forward/Back</span>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<span className="text-gray-500">A/D:</span> <span className="text-gray-400 font-mono">Turn Left/Right</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Topic info */}
|
||||||
|
<div className="bg-gray-950 rounded border border-gray-800 p-2 text-xs text-gray-600 space-y-1">
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>Topic:</span>
|
||||||
|
<span className="text-gray-500">/cmd_vel (geometry_msgs/Twist)</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>Max Linear:</span>
|
||||||
|
<span className="text-gray-500">{MAX_LINEAR_VELOCITY} m/s</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>Max Angular:</span>
|
||||||
|
<span className="text-gray-500">{MAX_ANGULAR_VELOCITY} rad/s</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>Deadzone:</span>
|
||||||
|
<span className="text-gray-500">{(DEADZONE * 100).toFixed(0)}%</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user