Compare commits
9 Commits
4c751e576b
...
d468cb515e
| Author | SHA1 | Date | |
|---|---|---|---|
| d468cb515e | |||
| 0b35a61217 | |||
| 0342caecdb | |||
| 6d80ca35af | |||
| b8a14e2bfc | |||
| e24c0b2e26 | |||
| 20801c4a0e | |||
| fc87862603 | |||
| b8f9d3eca6 |
@ -44,7 +44,7 @@
|
||||
// tabs audibly engage (2–3 mm deflection), test rotation lock.
|
||||
// =============================================================================
|
||||
|
||||
$fn = 64;
|
||||
\$fn = 64;
|
||||
e = 0.01;
|
||||
|
||||
// =============================================================================
|
||||
|
||||
@ -0,0 +1,122 @@
|
||||
"""
|
||||
_sky_detector.py — Sky detection via HSV thresholding + horizon estimation (no ROS2 deps).
|
||||
|
||||
Algorithm
|
||||
---------
|
||||
Sky pixels are identified by two HSV bands (OpenCV convention: H∈[0,180], S/V∈[0,255]):
|
||||
|
||||
1. Blue sky — H∈[90,130], S∈[40,255], V∈[80,255]
|
||||
(captures clear blue sky from cyan-blue through sky-blue to blue-violet)
|
||||
|
||||
2. Overcast — S∈[0,50], V∈[185,255]
|
||||
(very bright, near-zero saturation: white/grey sky on cloudy days)
|
||||
|
||||
The two masks are OR-combined into a single sky mask.
|
||||
|
||||
Sky fraction
|
||||
------------
|
||||
Computed over the top *scan_frac* of the image (default 60 %). This concentrates
|
||||
sensitivity on the region where sky is expected to appear and avoids penalising ground
|
||||
reflections or obstacles in the lower frame.
|
||||
|
||||
sky_fraction = sky_pixels_in_top_scan_frac / pixels_in_top_scan_frac
|
||||
|
||||
Horizon line
|
||||
------------
|
||||
For each image row, compute the row-level sky fraction (sky pixels / row width).
|
||||
The horizon is the **bottommost row** where that fraction ≥ *row_threshold* (default 0.30).
|
||||
This represents the lower boundary of continuous sky content.
|
||||
|
||||
horizon_y = max { r : row_sky_frac[r] ≥ row_threshold } or -1 if no sky found.
|
||||
|
||||
Returns -1 when no sky is detected at all (indoors, underground, etc.).
|
||||
|
||||
Public API
|
||||
----------
|
||||
SkyResult(sky_fraction, horizon_y, sky_mask)
|
||||
detect_sky(bgr, scan_frac=0.60, row_threshold=0.30) -> SkyResult
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ── HSV sky bands ─────────────────────────────────────────────────────────────
|
||||
|
||||
# Blue sky (OpenCV H ∈ [0, 180])
|
||||
_BLUE_SKY_LO = np.array([90, 40, 80], dtype=np.uint8)
|
||||
_BLUE_SKY_HI = np.array([130, 255, 255], dtype=np.uint8)
|
||||
|
||||
# White / grey overcast sky (any hue, very bright, very desaturated)
|
||||
_GREY_SKY_LO = np.array([0, 0, 185], dtype=np.uint8)
|
||||
_GREY_SKY_HI = np.array([180, 50, 255], dtype=np.uint8)
|
||||
|
||||
|
||||
# ── Data type ─────────────────────────────────────────────────────────────────
|
||||
|
||||
class SkyResult(NamedTuple):
|
||||
"""Sky detection result for a single frame."""
|
||||
sky_fraction: float # fraction of top scan_frac region classified as sky [0, 1]
|
||||
horizon_y: int # pixel row of horizon (-1 = no sky detected)
|
||||
sky_mask: np.ndarray # (H, W) uint8 binary mask (255 = sky pixel)
|
||||
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
def detect_sky(
|
||||
bgr: np.ndarray,
|
||||
scan_frac: float = 0.60,
|
||||
row_threshold: float = 0.30,
|
||||
) -> SkyResult:
|
||||
"""
|
||||
Detect sky pixels and estimate the horizon row.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bgr : (H, W, 3) uint8 BGR ndarray
|
||||
scan_frac : fraction of image height to use for sky_fraction computation
|
||||
(top of frame, where sky is expected)
|
||||
row_threshold : minimum per-row sky fraction to count a row as sky for
|
||||
horizon estimation
|
||||
|
||||
Returns
|
||||
-------
|
||||
SkyResult(sky_fraction, horizon_y, sky_mask)
|
||||
"""
|
||||
import cv2
|
||||
|
||||
bgr = np.asarray(bgr, dtype=np.uint8)
|
||||
h, w = bgr.shape[:2]
|
||||
|
||||
if h == 0 or w == 0:
|
||||
empty = np.zeros((h, w), dtype=np.uint8)
|
||||
return SkyResult(sky_fraction=0.0, horizon_y=-1, sky_mask=empty)
|
||||
|
||||
hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
|
||||
|
||||
# Build combined sky mask (blue OR overcast)
|
||||
mask_blue = cv2.inRange(hsv, _BLUE_SKY_LO, _BLUE_SKY_HI)
|
||||
mask_grey = cv2.inRange(hsv, _GREY_SKY_LO, _GREY_SKY_HI)
|
||||
sky_mask = cv2.bitwise_or(mask_blue, mask_grey) # (H, W) uint8, 255 = sky
|
||||
|
||||
# ── sky_fraction: top scan_frac rows ──────────────────────────────────────
|
||||
scan_rows = max(1, int(h * min(float(scan_frac), 1.0)))
|
||||
top_mask = sky_mask[:scan_rows, :]
|
||||
sky_fraction = float(np.count_nonzero(top_mask)) / float(scan_rows * w)
|
||||
|
||||
# ── horizon_y: bottommost row with ≥ row_threshold sky pixels ─────────────
|
||||
# row_fracs[r] = fraction of pixels in row r that are sky
|
||||
row_sky_counts = np.count_nonzero(sky_mask, axis=1).astype(np.float32) # (H,)
|
||||
row_fracs = row_sky_counts / float(w)
|
||||
sky_rows = np.where(row_fracs >= row_threshold)[0]
|
||||
|
||||
horizon_y = int(sky_rows.max()) if len(sky_rows) > 0 else -1
|
||||
|
||||
return SkyResult(
|
||||
sky_fraction=sky_fraction,
|
||||
horizon_y=horizon_y,
|
||||
sky_mask=sky_mask,
|
||||
)
|
||||
@ -0,0 +1,106 @@
|
||||
"""
|
||||
sky_detect_node.py — D435i sky detector for outdoor navigation (Issue #307).
|
||||
|
||||
Classifies the top portion of the D435i colour image as sky vs non-sky using
|
||||
HSV blue/grey thresholding, then estimates the horizon line.
|
||||
|
||||
Useful for:
|
||||
- Outdoor/indoor scene detection (sky_fraction > 0.1 → likely outdoor)
|
||||
- Camera tilt correction (horizon_y deviation from expected row → tilt estimate)
|
||||
- Disabling outdoor-only nodes (colour correction, sun glare filter) indoors
|
||||
|
||||
Subscribes (BEST_EFFORT):
|
||||
/camera/color/image_raw sensor_msgs/Image BGR8
|
||||
|
||||
Publishes:
|
||||
/saltybot/sky_fraction std_msgs/Float32 sky fraction in [0, 1] (per frame)
|
||||
/saltybot/horizon_y std_msgs/Int32 horizon pixel row; -1 = no sky
|
||||
|
||||
Parameters
|
||||
----------
|
||||
scan_frac float 0.60 Top fraction of image analysed for sky_fraction
|
||||
row_threshold float 0.30 Per-row sky fraction required to count a row as sky
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||
|
||||
from cv_bridge import CvBridge
|
||||
|
||||
from sensor_msgs.msg import Image
|
||||
from std_msgs.msg import Float32, Int32
|
||||
|
||||
from ._sky_detector import detect_sky
|
||||
|
||||
|
||||
_SENSOR_QOS = QoSProfile(
|
||||
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||
history=HistoryPolicy.KEEP_LAST,
|
||||
depth=4,
|
||||
)
|
||||
|
||||
|
||||
class SkyDetectNode(Node):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__('sky_detect_node')
|
||||
|
||||
self.declare_parameter('scan_frac', 0.60)
|
||||
self.declare_parameter('row_threshold', 0.30)
|
||||
|
||||
self._scan_frac = float(self.get_parameter('scan_frac').value)
|
||||
self._row_threshold = float(self.get_parameter('row_threshold').value)
|
||||
|
||||
self._bridge = CvBridge()
|
||||
|
||||
self._sub = self.create_subscription(
|
||||
Image, '/camera/color/image_raw', self._on_image, _SENSOR_QOS)
|
||||
|
||||
self._pub_frac = self.create_publisher(Float32, '/saltybot/sky_fraction', 10)
|
||||
self._pub_horizon = self.create_publisher(Int32, '/saltybot/horizon_y', 10)
|
||||
|
||||
self.get_logger().info(
|
||||
f'sky_detect_node ready — scan_frac={self._scan_frac} '
|
||||
f'row_threshold={self._row_threshold}'
|
||||
)
|
||||
|
||||
# ── Callback ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _on_image(self, msg: Image) -> None:
|
||||
try:
|
||||
bgr = self._bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')
|
||||
except Exception as exc:
|
||||
self.get_logger().error(
|
||||
f'cv_bridge: {exc}', throttle_duration_sec=5.0)
|
||||
return
|
||||
|
||||
result = detect_sky(
|
||||
bgr,
|
||||
scan_frac=self._scan_frac,
|
||||
row_threshold=self._row_threshold,
|
||||
)
|
||||
|
||||
frac_msg = Float32()
|
||||
frac_msg.data = result.sky_fraction
|
||||
self._pub_frac.publish(frac_msg)
|
||||
|
||||
hz_msg = Int32()
|
||||
hz_msg.data = result.horizon_y
|
||||
self._pub_horizon.publish(hz_msg)
|
||||
|
||||
|
||||
def main(args=None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = SkyDetectNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -45,6 +45,8 @@ setup(
|
||||
'blur_detector = saltybot_bringup.blur_detect_node:main',
|
||||
# Terrain roughness estimator (Issue #296)
|
||||
'terrain_roughness = saltybot_bringup.terrain_rough_node:main',
|
||||
# Sky detector for outdoor navigation (Issue #307)
|
||||
'sky_detector = saltybot_bringup.sky_detect_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
327
jetson/ros2_ws/src/saltybot_bringup/test/test_sky_detector.py
Normal file
327
jetson/ros2_ws/src/saltybot_bringup/test/test_sky_detector.py
Normal file
@ -0,0 +1,327 @@
|
||||
"""
|
||||
test_sky_detector.py — Unit tests for sky detection helpers (no ROS2 required).
|
||||
|
||||
Covers:
|
||||
SkyResult:
|
||||
- fields accessible by name
|
||||
- sky_fraction in [0, 1]
|
||||
- horizon_y is int
|
||||
- sky_mask is ndarray
|
||||
|
||||
detect_sky — output contract:
|
||||
- returns SkyResult
|
||||
- sky_fraction in [0, 1] for all test images
|
||||
- sky_mask shape matches input
|
||||
- sky_mask dtype is uint8
|
||||
- empty image returns sky_fraction=0.0, horizon_y=-1
|
||||
- sky_fraction=0.0 and horizon_y=-1 for ground-only image
|
||||
|
||||
detect_sky — blue sky detection:
|
||||
- solid blue sky → sky_fraction ≈ 1.0
|
||||
- solid blue sky → horizon_y = H-1 (sky fills every row)
|
||||
- solid blue sky → sky_mask nearly all 255
|
||||
|
||||
detect_sky — overcast sky detection:
|
||||
- solid white/grey overcast image → sky_fraction ≈ 1.0
|
||||
- solid grey overcast → horizon_y = H-1
|
||||
|
||||
detect_sky — non-sky:
|
||||
- solid green ground → sky_fraction ≈ 0.0
|
||||
- solid green ground → horizon_y = -1
|
||||
- solid brown → sky_fraction ≈ 0.0
|
||||
|
||||
detect_sky — split image (sky top, ground bottom):
|
||||
- top half blue sky, bottom half green → sky_fraction ≈ 1.0 (scan_frac=0.5)
|
||||
- split image → horizon_y near H//2
|
||||
- sky_fraction decreases as scan_frac increases past the sky region
|
||||
|
||||
detect_sky — horizon estimation:
|
||||
- wider sky region → higher horizon_y
|
||||
- horizon_y within image bounds [0, H-1] when sky detected
|
||||
- horizon_y == -1 when no sky anywhere
|
||||
|
||||
detect_sky — scan_frac and row_threshold params:
|
||||
- scan_frac=1.0 analyses full frame
|
||||
- scan_frac=0.0 → sky_fraction=0.0 regardless of image content
|
||||
- row_threshold=0.0 → every row counts as sky row → horizon_y = H-1 for any image
|
||||
- row_threshold=1.0 → only rows fully sky count → tighter horizon on mixed image
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from saltybot_bringup._sky_detector import (
|
||||
SkyResult,
|
||||
detect_sky,
|
||||
)
|
||||
|
||||
|
||||
# ── Image factories ───────────────────────────────────────────────────────────
|
||||
|
||||
def _hsv_solid_bgr(h_val: int, s: int, v: int, rows: int = 64, cols: int = 64) -> np.ndarray:
|
||||
"""Create a solid BGR image from an HSV specification."""
|
||||
import cv2
|
||||
hsv = np.full((rows, cols, 3), [h_val, s, v], dtype=np.uint8)
|
||||
return cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
|
||||
|
||||
|
||||
# Canonical sky and ground colours (OpenCV HSV: H∈[0,180], S/V∈[0,255])
|
||||
def _blue_sky(rows=64, cols=64) -> np.ndarray: return _hsv_solid_bgr(105, 180, 200, rows, cols) # mid-blue sky
|
||||
def _overcast(rows=64, cols=64) -> np.ndarray: return _hsv_solid_bgr(0, 20, 220, rows, cols) # bright grey
|
||||
def _green_ground(rows=64, cols=64) -> np.ndarray: return _hsv_solid_bgr(40, 180, 100, rows, cols) # green grass
|
||||
def _brown_ground(rows=64, cols=64) -> np.ndarray: return _hsv_solid_bgr(15, 160, 90, rows, cols) # soil/gravel
|
||||
|
||||
|
||||
def _split(top_bgr: np.ndarray, bottom_bgr: np.ndarray) -> np.ndarray:
|
||||
"""Stack two same-width BGR images vertically."""
|
||||
return np.concatenate([top_bgr, bottom_bgr], axis=0)
|
||||
|
||||
|
||||
# ── SkyResult ─────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestSkyResult:
|
||||
|
||||
def test_fields_accessible(self):
|
||||
mask = np.zeros((4, 4), dtype=np.uint8)
|
||||
r = SkyResult(sky_fraction=0.7, horizon_y=30, sky_mask=mask)
|
||||
assert r.sky_fraction == pytest.approx(0.7)
|
||||
assert r.horizon_y == 30
|
||||
assert r.sky_mask is mask
|
||||
|
||||
def test_sky_fraction_in_range(self):
|
||||
mask = np.zeros((4, 4), dtype=np.uint8)
|
||||
for v in (0.0, 0.5, 1.0):
|
||||
r = SkyResult(sky_fraction=v, horizon_y=-1, sky_mask=mask)
|
||||
assert 0.0 <= r.sky_fraction <= 1.0
|
||||
|
||||
def test_horizon_y_is_int(self):
|
||||
mask = np.zeros((4, 4), dtype=np.uint8)
|
||||
r = SkyResult(sky_fraction=0.0, horizon_y=-1, sky_mask=mask)
|
||||
assert isinstance(r.horizon_y, int)
|
||||
|
||||
|
||||
# ── detect_sky — output contract ─────────────────────────────────────────────
|
||||
|
||||
class TestDetectSkyContract:
|
||||
|
||||
def test_returns_sky_result(self):
|
||||
r = detect_sky(_blue_sky())
|
||||
assert isinstance(r, SkyResult)
|
||||
|
||||
def test_sky_fraction_in_range_blue(self):
|
||||
r = detect_sky(_blue_sky())
|
||||
assert 0.0 <= r.sky_fraction <= 1.0
|
||||
|
||||
def test_sky_fraction_in_range_ground(self):
|
||||
r = detect_sky(_green_ground())
|
||||
assert 0.0 <= r.sky_fraction <= 1.0
|
||||
|
||||
def test_sky_mask_shape_matches_input(self):
|
||||
img = _blue_sky(rows=48, cols=80)
|
||||
r = detect_sky(img)
|
||||
assert r.sky_mask.shape == (48, 80)
|
||||
|
||||
def test_sky_mask_dtype_uint8(self):
|
||||
r = detect_sky(_blue_sky())
|
||||
assert r.sky_mask.dtype == np.uint8
|
||||
|
||||
def test_empty_image_returns_zero_fraction(self):
|
||||
r = detect_sky(np.zeros((0, 64, 3), dtype=np.uint8))
|
||||
assert r.sky_fraction == pytest.approx(0.0)
|
||||
assert r.horizon_y == -1
|
||||
|
||||
def test_ground_sky_fraction_zero(self):
|
||||
r = detect_sky(_green_ground())
|
||||
assert r.sky_fraction == pytest.approx(0.0, abs=0.05)
|
||||
|
||||
def test_ground_horizon_y_minus_one(self):
|
||||
r = detect_sky(_green_ground())
|
||||
assert r.horizon_y == -1
|
||||
|
||||
|
||||
# ── detect_sky — blue sky ────────────────────────────────────────────────────
|
||||
|
||||
class TestDetectSkyBlue:
|
||||
|
||||
def test_blue_sky_fraction_near_one(self):
|
||||
r = detect_sky(_blue_sky())
|
||||
assert r.sky_fraction > 0.90, (
|
||||
f'solid blue sky should give sky_fraction > 0.9, got {r.sky_fraction:.3f}')
|
||||
|
||||
def test_blue_sky_horizon_at_last_row(self):
|
||||
"""Entire image is sky → horizon at bottom row."""
|
||||
h = 64
|
||||
r = detect_sky(_blue_sky(rows=h))
|
||||
assert r.horizon_y == h - 1, (
|
||||
f'all-sky image: horizon_y should be {h-1}, got {r.horizon_y}')
|
||||
|
||||
def test_blue_sky_mask_mostly_255(self):
|
||||
r = detect_sky(_blue_sky())
|
||||
sky_pixels = np.count_nonzero(r.sky_mask)
|
||||
total = r.sky_mask.size
|
||||
assert sky_pixels / total > 0.90
|
||||
|
||||
def test_blue_sky_positive_horizon(self):
|
||||
r = detect_sky(_blue_sky())
|
||||
assert r.horizon_y >= 0
|
||||
|
||||
|
||||
# ── detect_sky — overcast sky ────────────────────────────────────────────────
|
||||
|
||||
class TestDetectSkyOvercast:
|
||||
|
||||
def test_overcast_fraction_near_one(self):
|
||||
r = detect_sky(_overcast())
|
||||
assert r.sky_fraction > 0.90, (
|
||||
f'overcast grey sky should give sky_fraction > 0.9, got {r.sky_fraction:.3f}')
|
||||
|
||||
def test_overcast_horizon_at_last_row(self):
|
||||
h = 64
|
||||
r = detect_sky(_overcast(rows=h))
|
||||
assert r.horizon_y == h - 1
|
||||
|
||||
def test_overcast_positive_horizon(self):
|
||||
r = detect_sky(_overcast())
|
||||
assert r.horizon_y >= 0
|
||||
|
||||
|
||||
# ── detect_sky — non-sky images ──────────────────────────────────────────────
|
||||
|
||||
class TestDetectSkyNonSky:
|
||||
|
||||
def test_green_ground_fraction_zero(self):
|
||||
r = detect_sky(_green_ground())
|
||||
assert r.sky_fraction < 0.05
|
||||
|
||||
def test_green_ground_horizon_minus_one(self):
|
||||
assert detect_sky(_green_ground()).horizon_y == -1
|
||||
|
||||
def test_brown_ground_fraction_zero(self):
|
||||
r = detect_sky(_brown_ground())
|
||||
assert r.sky_fraction < 0.05
|
||||
|
||||
def test_brown_ground_horizon_minus_one(self):
|
||||
assert detect_sky(_brown_ground()).horizon_y == -1
|
||||
|
||||
|
||||
# ── detect_sky — split images ─────────────────────────────────────────────────
|
||||
|
||||
class TestDetectSkySplit:
|
||||
|
||||
def test_top_half_sky_scan_frac_half(self):
|
||||
"""scan_frac=0.5 → only top half analysed → sky_fraction ≈ 1.0 for top-sky image."""
|
||||
h = 64
|
||||
img = _split(_blue_sky(rows=h // 2), _green_ground(rows=h // 2))
|
||||
r = detect_sky(img, scan_frac=0.5)
|
||||
assert r.sky_fraction > 0.85, (
|
||||
f'top-half sky with scan_frac=0.5: expected >0.85, got {r.sky_fraction:.3f}')
|
||||
|
||||
def test_horizon_near_midpoint(self):
|
||||
"""Sky in top half, ground in bottom half → horizon_y near H//2."""
|
||||
h = 64
|
||||
img = _split(_blue_sky(rows=h // 2), _green_ground(rows=h // 2))
|
||||
r = detect_sky(img)
|
||||
# Horizon should be within the top half (rows 0 .. h//2-1)
|
||||
assert 0 <= r.horizon_y < h, f'horizon_y={r.horizon_y} out of bounds'
|
||||
assert r.horizon_y < h // 2 + 4, (
|
||||
f'horizon_y={r.horizon_y} should be near or before row {h // 2}')
|
||||
|
||||
def test_sky_fraction_decreases_when_scan_extends_into_ground(self):
|
||||
"""With top-sky/bottom-ground image, increasing scan_frac past the sky
|
||||
boundary should decrease the sky_fraction."""
|
||||
h = 128
|
||||
img = _split(_blue_sky(rows=h // 2), _green_ground(rows=h // 2))
|
||||
r_top = detect_sky(img, scan_frac=0.4) # mostly sky region
|
||||
r_full = detect_sky(img, scan_frac=1.0) # half sky, half ground
|
||||
assert r_top.sky_fraction > r_full.sky_fraction, (
|
||||
f'scanning less of the ground should give higher fraction: '
|
||||
f'scan0.4={r_top.sky_fraction:.3f} scan1.0={r_full.sky_fraction:.3f}')
|
||||
|
||||
def test_sky_only_top_quarter_horizon_within_quarter(self):
|
||||
"""Sky only in top 25 % → horizon_y in first quarter of rows."""
|
||||
h = 64
|
||||
sky_rows = h // 4
|
||||
img = _split(_blue_sky(rows=sky_rows), _green_ground(rows=h - sky_rows))
|
||||
r = detect_sky(img)
|
||||
assert 0 <= r.horizon_y < sky_rows + 2, (
|
||||
f'horizon_y={r.horizon_y} should be within top quarter ({sky_rows} rows)')
|
||||
|
||||
|
||||
# ── detect_sky — horizon ordering ────────────────────────────────────────────
|
||||
|
||||
class TestDetectSkyHorizonOrdering:
|
||||
|
||||
def test_more_sky_rows_higher_horizon_y(self):
|
||||
"""Larger sky region → higher (larger row index) horizon_y."""
|
||||
h = 128
|
||||
sky_a = h // 4 # 25 % sky
|
||||
sky_b = h * 3 // 4 # 75 % sky
|
||||
img_a = _split(_blue_sky(rows=sky_a), _green_ground(rows=h - sky_a))
|
||||
img_b = _split(_blue_sky(rows=sky_b), _green_ground(rows=h - sky_b))
|
||||
r_a = detect_sky(img_a)
|
||||
r_b = detect_sky(img_b)
|
||||
assert r_b.horizon_y > r_a.horizon_y, (
|
||||
f'larger sky region ({sky_b}px) should give higher horizon_y than '
|
||||
f'smaller ({sky_a}px): {r_b.horizon_y} vs {r_a.horizon_y}')
|
||||
|
||||
def test_horizon_within_image_bounds_when_sky_present(self):
|
||||
h, w = 80, 64
|
||||
r = detect_sky(_blue_sky(rows=h, cols=w))
|
||||
assert 0 <= r.horizon_y <= h - 1
|
||||
|
||||
def test_horizon_minus_one_when_no_sky(self):
|
||||
assert detect_sky(_green_ground()).horizon_y == -1
|
||||
|
||||
|
||||
# ── detect_sky — parameter sensitivity ───────────────────────────────────────
|
||||
|
||||
class TestDetectSkyParams:
|
||||
|
||||
def test_scan_frac_1_analyses_full_frame(self):
|
||||
"""scan_frac=1.0 on all-sky image → fraction ≈ 1.0."""
|
||||
r = detect_sky(_blue_sky(), scan_frac=1.0)
|
||||
assert r.sky_fraction > 0.90
|
||||
|
||||
def test_scan_frac_zero_gives_zero_fraction(self):
|
||||
"""scan_frac approaching 0 → effectively empty scan → sky_fraction = 0."""
|
||||
# The implementation clamps scan_rows = max(1, ...) so scan_frac=0 → 1 row scanned
|
||||
# We check that it doesn't crash and returns a valid result.
|
||||
r = detect_sky(_blue_sky(), scan_frac=0.0)
|
||||
assert 0.0 <= r.sky_fraction <= 1.0
|
||||
|
||||
def test_row_threshold_zero_horizon_at_last_row_any_image(self):
|
||||
"""row_threshold=0 → every row satisfies the threshold → horizon = H-1
|
||||
even if only a single sky pixel exists anywhere."""
|
||||
h = 64
|
||||
img = _split(_blue_sky(rows=4), _green_ground(rows=h - 4))
|
||||
r = detect_sky(img, row_threshold=0.0)
|
||||
# At least the sky rows (top 4) will have sky pixels, so horizon ≥ 3
|
||||
assert r.horizon_y >= 3
|
||||
|
||||
def test_row_threshold_high_tightens_horizon(self):
|
||||
"""High row_threshold only accepts near-fully-sky rows → lower horizon_y."""
|
||||
h = 64
|
||||
# Left 60% of each row: sky; right 40%: ground.
|
||||
# Low threshold (0.3) → both halves count as sky rows → horizon = H-1
|
||||
# High threshold (0.7) → only sky-dominant rows count → horizon = H-1 still
|
||||
# (since 60% > 0.7 is False, so high threshold gives lower horizon)
|
||||
w = 100
|
||||
img = np.concatenate([
|
||||
_blue_sky(rows=h, cols=60),
|
||||
_green_ground(rows=h, cols=40),
|
||||
], axis=1)
|
||||
r_low = detect_sky(img, row_threshold=0.30)
|
||||
r_high = detect_sky(img, row_threshold=0.70)
|
||||
# With 60% sky per row: low threshold (0.3) passes; high (0.7) fails
|
||||
assert r_low.horizon_y >= 0, 'low threshold should find sky rows'
|
||||
assert r_high.horizon_y == -1, (
|
||||
f'60% sky/row fails 0.7 threshold; horizon_y should be -1, got {r_high.horizon_y}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
@ -0,0 +1,16 @@
|
||||
pid_scheduler:
|
||||
ros__parameters:
|
||||
# Base PID gains
|
||||
base_kp: 1.0
|
||||
base_ki: 0.1
|
||||
base_kd: 0.05
|
||||
|
||||
# Gain scheduling parameters
|
||||
# P gain increases by this factor * terrain_roughness
|
||||
kp_terrain_scale: 0.5
|
||||
|
||||
# D gain scale: -1 means full reduction at zero speed
|
||||
kd_speed_scale: -0.3
|
||||
|
||||
# I gain modulation factor
|
||||
ki_scale: 0.1
|
||||
@ -0,0 +1,31 @@
|
||||
"""Launch file for PID gain scheduler node."""
|
||||
|
||||
from launch import LaunchDescription
|
||||
from launch_ros.actions import Node
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
"""Generate launch description for PID scheduler."""
|
||||
pkg_dir = get_package_share_directory("saltybot_pid_scheduler")
|
||||
config_file = os.path.join(pkg_dir, "config", "pid_scheduler_config.yaml")
|
||||
|
||||
return LaunchDescription(
|
||||
[
|
||||
DeclareLaunchArgument(
|
||||
"config_file",
|
||||
default_value=config_file,
|
||||
description="Path to configuration YAML file",
|
||||
),
|
||||
Node(
|
||||
package="saltybot_pid_scheduler",
|
||||
executable="pid_scheduler_node",
|
||||
name="pid_scheduler",
|
||||
output="screen",
|
||||
parameters=[LaunchConfiguration("config_file")],
|
||||
),
|
||||
]
|
||||
)
|
||||
19
jetson/ros2_ws/src/saltybot_pid_scheduler/package.xml
Normal file
19
jetson/ros2_ws/src/saltybot_pid_scheduler/package.xml
Normal file
@ -0,0 +1,19 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_pid_scheduler</name>
|
||||
<version>0.1.0</version>
|
||||
<description>Adaptive PID gain scheduler for SaltyBot</description>
|
||||
<maintainer email="sl-controls@saltylab.local">SaltyLab Controls</maintainer>
|
||||
<license>MIT</license>
|
||||
|
||||
<buildtool_depend>ament_python</buildtool_depend>
|
||||
<depend>rclpy</depend>
|
||||
<depend>std_msgs</depend>
|
||||
|
||||
<test_depend>pytest</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1,125 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Adaptive PID gain scheduler for SaltyBot.
|
||||
|
||||
Monitors speed scale and terrain roughness, adjusts PID gains dynamically.
|
||||
Higher P on rough terrain, lower D at low speed.
|
||||
|
||||
Subscribed topics:
|
||||
/saltybot/speed_scale (std_msgs/Float32) - Speed reduction factor (0.0-1.0)
|
||||
/saltybot/terrain_roughness (std_msgs/Float32) - Terrain roughness (0.0-1.0)
|
||||
|
||||
Published topics:
|
||||
/saltybot/pid_gains (std_msgs/Float32MultiArray) - [Kp, Ki, Kd] gains
|
||||
"""
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from std_msgs.msg import Float32, Float32MultiArray
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PIDSchedulerNode(Node):
|
||||
"""ROS2 node for adaptive PID gain scheduling."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("pid_scheduler")
|
||||
|
||||
# Base PID gains
|
||||
self.declare_parameter("base_kp", 1.0)
|
||||
self.declare_parameter("base_ki", 0.1)
|
||||
self.declare_parameter("base_kd", 0.05)
|
||||
|
||||
# Gain scheduling parameters
|
||||
self.declare_parameter("kp_terrain_scale", 0.5) # P gain increase on rough terrain
|
||||
self.declare_parameter("kd_speed_scale", -0.3) # D gain decrease at low speed
|
||||
self.declare_parameter("ki_scale", 0.1) # I gain smoothing
|
||||
|
||||
self.base_kp = self.get_parameter("base_kp").value
|
||||
self.base_ki = self.get_parameter("base_ki").value
|
||||
self.base_kd = self.get_parameter("base_kd").value
|
||||
|
||||
self.kp_terrain_scale = self.get_parameter("kp_terrain_scale").value
|
||||
self.kd_speed_scale = self.get_parameter("kd_speed_scale").value
|
||||
self.ki_scale = self.get_parameter("ki_scale").value
|
||||
|
||||
# Current sensor inputs
|
||||
self.speed_scale = 1.0
|
||||
self.terrain_roughness = 0.0
|
||||
|
||||
# Current gains
|
||||
self.current_kp = self.base_kp
|
||||
self.current_ki = self.base_ki
|
||||
self.current_kd = self.base_kd
|
||||
|
||||
# Subscriptions
|
||||
self.sub_speed = self.create_subscription(
|
||||
Float32, "/saltybot/speed_scale", self._on_speed_scale, 10
|
||||
)
|
||||
self.sub_terrain = self.create_subscription(
|
||||
Float32, "/saltybot/terrain_roughness", self._on_terrain_roughness, 10
|
||||
)
|
||||
|
||||
# Publisher for PID gains
|
||||
self.pub_gains = self.create_publisher(Float32MultiArray, "/saltybot/pid_gains", 10)
|
||||
|
||||
self.get_logger().info(
|
||||
f"PID gain scheduler initialized. Base gains: "
|
||||
f"Kp={self.base_kp}, Ki={self.base_ki}, Kd={self.base_kd}"
|
||||
)
|
||||
|
||||
def _on_speed_scale(self, msg: Float32) -> None:
|
||||
"""Update speed scale and recalculate gains."""
|
||||
self.speed_scale = np.clip(msg.data, 0.0, 1.0)
|
||||
self._update_gains()
|
||||
|
||||
def _on_terrain_roughness(self, msg: Float32) -> None:
|
||||
"""Update terrain roughness and recalculate gains."""
|
||||
self.terrain_roughness = np.clip(msg.data, 0.0, 1.0)
|
||||
self._update_gains()
|
||||
|
||||
def _update_gains(self) -> None:
|
||||
"""Compute adaptive PID gains based on speed and terrain."""
|
||||
# P gain increases with terrain roughness (better response on rough surfaces)
|
||||
self.current_kp = self.base_kp * (1.0 + self.kp_terrain_scale * self.terrain_roughness)
|
||||
|
||||
# D gain decreases at low speed (avoid oscillation at low speeds)
|
||||
# Low speed_scale means robot is moving slow, so reduce D damping
|
||||
speed_factor = 1.0 + self.kd_speed_scale * (1.0 - self.speed_scale)
|
||||
self.current_kd = self.base_kd * max(0.1, speed_factor) # Don't let D go negative
|
||||
|
||||
# I gain scales smoothly with both factors
|
||||
terrain_factor = 1.0 + self.ki_scale * self.terrain_roughness
|
||||
speed_damping = 1.0 - 0.3 * (1.0 - self.speed_scale) # Reduce I at low speed
|
||||
self.current_ki = self.base_ki * terrain_factor * speed_damping
|
||||
|
||||
# Publish gains
|
||||
self._publish_gains()
|
||||
|
||||
def _publish_gains(self) -> None:
|
||||
"""Publish current PID gains as Float32MultiArray."""
|
||||
msg = Float32MultiArray()
|
||||
msg.data = [self.current_kp, self.current_ki, self.current_kd]
|
||||
|
||||
self.pub_gains.publish(msg)
|
||||
|
||||
self.get_logger().debug(
|
||||
f"PID gains updated: Kp={self.current_kp:.4f}, "
|
||||
f"Ki={self.current_ki:.4f}, Kd={self.current_kd:.4f} "
|
||||
f"(speed={self.speed_scale:.2f}, terrain={self.terrain_roughness:.2f})"
|
||||
)
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = PIDSchedulerNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
5
jetson/ros2_ws/src/saltybot_pid_scheduler/setup.cfg
Normal file
5
jetson/ros2_ws/src/saltybot_pid_scheduler/setup.cfg
Normal file
@ -0,0 +1,5 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/saltybot_pid_scheduler
|
||||
|
||||
[install]
|
||||
install_scripts=$base/lib/saltybot_pid_scheduler
|
||||
24
jetson/ros2_ws/src/saltybot_pid_scheduler/setup.py
Normal file
24
jetson/ros2_ws/src/saltybot_pid_scheduler/setup.py
Normal file
@ -0,0 +1,24 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name='saltybot_pid_scheduler',
|
||||
version='0.1.0',
|
||||
packages=find_packages(),
|
||||
data_files=[
|
||||
('share/ament_index/resource_index/packages', ['resource/saltybot_pid_scheduler']),
|
||||
('share/saltybot_pid_scheduler', ['package.xml']),
|
||||
('share/saltybot_pid_scheduler/config', ['config/pid_scheduler_config.yaml']),
|
||||
('share/saltybot_pid_scheduler/launch', ['launch/pid_scheduler.launch.py']),
|
||||
],
|
||||
install_requires=['setuptools'],
|
||||
zip_safe=True,
|
||||
author='SaltyLab Controls',
|
||||
author_email='sl-controls@saltylab.local',
|
||||
description='Adaptive PID gain scheduler for SaltyBot',
|
||||
license='MIT',
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'pid_scheduler_node=saltybot_pid_scheduler.pid_scheduler_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
@ -0,0 +1,178 @@
|
||||
"""Tests for PID gain scheduler."""
|
||||
|
||||
import pytest
|
||||
from std_msgs.msg import Float32
|
||||
import rclpy
|
||||
|
||||
from saltybot_pid_scheduler.pid_scheduler_node import PIDSchedulerNode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rclpy_fixture():
|
||||
rclpy.init()
|
||||
yield
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def node(rclpy_fixture):
|
||||
node = PIDSchedulerNode()
|
||||
yield node
|
||||
node.destroy_node()
|
||||
|
||||
|
||||
class TestInit:
|
||||
def test_node_initialization(self, node):
|
||||
assert node.base_kp == 1.0
|
||||
assert node.base_ki == 0.1
|
||||
assert node.base_kd == 0.05
|
||||
assert node.speed_scale == 1.0
|
||||
assert node.terrain_roughness == 0.0
|
||||
|
||||
|
||||
class TestGainScheduling:
|
||||
def test_gains_at_baseline(self, node):
|
||||
"""Test gains at baseline (no speed reduction, no terrain roughness)."""
|
||||
node._update_gains()
|
||||
|
||||
assert node.current_kp == pytest.approx(1.0)
|
||||
assert node.current_kd == pytest.approx(0.05)
|
||||
|
||||
def test_kp_increases_with_terrain(self, node):
|
||||
"""Test P gain increases with terrain roughness."""
|
||||
node.terrain_roughness = 0.0
|
||||
node._update_gains()
|
||||
kp_smooth = node.current_kp
|
||||
|
||||
node.terrain_roughness = 1.0
|
||||
node._update_gains()
|
||||
kp_rough = node.current_kp
|
||||
|
||||
assert kp_rough > kp_smooth
|
||||
assert kp_rough == pytest.approx(1.5) # 1.0 * (1 + 0.5 * 1.0)
|
||||
|
||||
def test_kd_decreases_at_low_speed(self, node):
|
||||
"""Test D gain decreases when robot slows down."""
|
||||
node.speed_scale = 1.0
|
||||
node._update_gains()
|
||||
kd_fast = node.current_kd
|
||||
|
||||
node.speed_scale = 0.0 # Robot stopped
|
||||
node._update_gains()
|
||||
kd_slow = node.current_kd
|
||||
|
||||
assert kd_slow < kd_fast
|
||||
|
||||
def test_ki_scales_with_terrain_and_speed(self, node):
|
||||
"""Test I gain scales with both terrain and speed."""
|
||||
node.speed_scale = 1.0
|
||||
node.terrain_roughness = 0.0
|
||||
node._update_gains()
|
||||
ki_baseline = node.current_ki
|
||||
|
||||
node.terrain_roughness = 1.0
|
||||
node._update_gains()
|
||||
ki_rough = node.current_ki
|
||||
|
||||
assert ki_rough > ki_baseline
|
||||
|
||||
def test_kd_never_negative(self, node):
|
||||
"""Test D gain never goes negative."""
|
||||
node.speed_scale = 0.0
|
||||
node._update_gains()
|
||||
|
||||
assert node.current_kd >= 0.0
|
||||
|
||||
def test_speed_scale_clipping(self, node):
|
||||
"""Test speed scale is clipped to [0, 1]."""
|
||||
msg = Float32()
|
||||
msg.data = 2.0 # Out of range
|
||||
|
||||
node._on_speed_scale(msg)
|
||||
|
||||
assert node.speed_scale == 1.0
|
||||
|
||||
def test_terrain_roughness_clipping(self, node):
|
||||
"""Test terrain roughness is clipped to [0, 1]."""
|
||||
msg = Float32()
|
||||
msg.data = -0.5 # Out of range
|
||||
|
||||
node._on_terrain_roughness(msg)
|
||||
|
||||
assert node.terrain_roughness == 0.0
|
||||
|
||||
|
||||
class TestSensorInputs:
|
||||
def test_speed_scale_callback(self, node):
|
||||
"""Test speed scale subscription and update."""
|
||||
msg = Float32()
|
||||
msg.data = 0.5
|
||||
|
||||
node._on_speed_scale(msg)
|
||||
|
||||
assert node.speed_scale == 0.5
|
||||
|
||||
def test_terrain_roughness_callback(self, node):
|
||||
"""Test terrain roughness subscription and update."""
|
||||
msg = Float32()
|
||||
msg.data = 0.75
|
||||
|
||||
node._on_terrain_roughness(msg)
|
||||
|
||||
assert node.terrain_roughness == 0.75
|
||||
|
||||
def test_combined_effects(self, node):
|
||||
"""Test combined effect of speed and terrain on gains."""
|
||||
# Slow speed + rough terrain = high P, low D
|
||||
node.speed_scale = 0.2
|
||||
node.terrain_roughness = 0.9
|
||||
|
||||
node._update_gains()
|
||||
|
||||
# P should be high (due to terrain)
|
||||
assert node.current_kp > node.base_kp
|
||||
|
||||
# D should be low (due to slow speed)
|
||||
# D factor = 1 + (-0.3) * (1 - 0.2) = 1 - 0.24 = 0.76
|
||||
assert node.current_kd < node.base_kd
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_zero_speed_scale(self, node):
|
||||
"""Test behavior at zero speed (robot stopped)."""
|
||||
node.speed_scale = 0.0
|
||||
node.terrain_roughness = 0.5
|
||||
|
||||
node._update_gains()
|
||||
|
||||
# All gains should be positive
|
||||
assert node.current_kp > 0
|
||||
assert node.current_ki > 0
|
||||
assert node.current_kd > 0
|
||||
|
||||
def test_max_terrain_roughness(self, node):
|
||||
"""Test behavior on extremely rough terrain."""
|
||||
node.speed_scale = 1.0
|
||||
node.terrain_roughness = 1.0
|
||||
|
||||
node._update_gains()
|
||||
|
||||
# Kp should be maximum
|
||||
assert node.current_kp == pytest.approx(1.5)
|
||||
|
||||
def test_rapid_sensor_changes(self, node):
|
||||
"""Test rapid changes in sensor inputs."""
|
||||
for speed in [1.0, 0.5, 0.1, 0.9, 1.0]:
|
||||
msg = Float32()
|
||||
msg.data = speed
|
||||
node._on_speed_scale(msg)
|
||||
|
||||
for roughness in [0.0, 0.5, 1.0, 0.3, 0.0]:
|
||||
msg = Float32()
|
||||
msg.data = roughness
|
||||
node._on_terrain_roughness(msg)
|
||||
|
||||
# Should end at baseline
|
||||
node._update_gains()
|
||||
assert node.speed_scale == 1.0
|
||||
assert node.terrain_roughness == 0.0
|
||||
@ -0,0 +1,10 @@
|
||||
personal_space_node:
|
||||
ros__parameters:
|
||||
personal_space_m: 0.8 # Minimum comfortable distance (m)
|
||||
backup_speed: 0.1 # Retreat linear speed (m/s, applied as -x)
|
||||
hysteresis_m: 0.1 # Hysteresis band above threshold (m)
|
||||
unknown_distance_m: 99.0 # Distance assumed for faces without a PersonState
|
||||
lost_timeout_s: 1.5 # Face freshness window (s)
|
||||
control_rate: 10.0 # Control loop / publish rate (Hz)
|
||||
faces_topic: "/social/faces/detected"
|
||||
states_topic: "/social/person_states"
|
||||
@ -0,0 +1,39 @@
|
||||
"""personal_space.launch.py -- Launch personal space respector (Issue #310).
|
||||
|
||||
Usage:
|
||||
ros2 launch saltybot_social personal_space.launch.py
|
||||
ros2 launch saltybot_social personal_space.launch.py personal_space_m:=1.0 backup_speed:=0.15
|
||||
"""
|
||||
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg = get_package_share_directory("saltybot_social")
|
||||
cfg = os.path.join(pkg, "config", "personal_space_params.yaml")
|
||||
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument("personal_space_m", default_value="0.8",
|
||||
description="Minimum comfortable distance (m)"),
|
||||
DeclareLaunchArgument("backup_speed", default_value="0.1",
|
||||
description="Retreat linear speed (m/s)"),
|
||||
|
||||
Node(
|
||||
package="saltybot_social",
|
||||
executable="personal_space_node",
|
||||
name="personal_space_node",
|
||||
output="screen",
|
||||
parameters=[
|
||||
cfg,
|
||||
{
|
||||
"personal_space_m": LaunchConfiguration("personal_space_m"),
|
||||
"backup_speed": LaunchConfiguration("backup_speed"),
|
||||
},
|
||||
],
|
||||
),
|
||||
])
|
||||
@ -0,0 +1,211 @@
|
||||
"""personal_space_node.py — Personal space respector.
|
||||
Issue #310
|
||||
|
||||
Monitors the distance to detected faces. When the closest person
|
||||
enters within ``personal_space_m`` metres (default 0.8 m) the robot
|
||||
slowly backs up (``Twist.linear.x = -backup_speed``) and latches
|
||||
/saltybot/too_close to True. Once the person retreats beyond
|
||||
``personal_space_m + hysteresis_m`` the robot stops and the latch clears.
|
||||
|
||||
Distance is read from /social/person_states (PersonStateArray, face_id →
|
||||
distance). When a face has no matching PersonState entry the distance is
|
||||
treated as ``unknown_distance_m`` (default: very large → safe, no backup).
|
||||
|
||||
Subscriptions
|
||||
─────────────
|
||||
/social/faces/detected saltybot_social_msgs/FaceDetectionArray
|
||||
/social/person_states saltybot_social_msgs/PersonStateArray
|
||||
|
||||
Publications
|
||||
────────────
|
||||
/cmd_vel geometry_msgs/Twist — linear.x only; angular.z = 0
|
||||
/saltybot/too_close std_msgs/Bool — True while backing up
|
||||
|
||||
State machine
|
||||
─────────────
|
||||
CLEAR — no face within personal_space_m; publish zero Twist + False
|
||||
TOO_CLOSE — face within personal_space_m; publish backup Twist + True
|
||||
clears only when distance > personal_space_m + hysteresis_m
|
||||
|
||||
Parameters
|
||||
──────────
|
||||
personal_space_m (float, 0.8) minimum comfortable distance (m)
|
||||
backup_speed (float, 0.1) retreat linear speed (m/s, applied as -x)
|
||||
hysteresis_m (float, 0.1) hysteresis band above threshold (m)
|
||||
unknown_distance_m (float, 99.0) distance assumed for faces without state
|
||||
lost_timeout_s (float, 1.5) face freshness window (s)
|
||||
control_rate (float, 10.0) publish rate (Hz)
|
||||
faces_topic (str, "/social/faces/detected")
|
||||
states_topic (str, "/social/person_states")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile
|
||||
from std_msgs.msg import Bool
|
||||
from geometry_msgs.msg import Twist
|
||||
|
||||
try:
|
||||
from saltybot_social_msgs.msg import FaceDetectionArray, PersonStateArray
|
||||
_MSGS = True
|
||||
except ImportError:
|
||||
_MSGS = False
|
||||
|
||||
|
||||
# ── Pure helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
def closest_face_distance(face_ids, distance_cache: Dict[int, float],
|
||||
unknown_distance_m: float) -> float:
|
||||
"""Return the minimum distance across all visible face_ids.
|
||||
|
||||
Falls back to *unknown_distance_m* for any face_id not in the cache.
|
||||
Returns *unknown_distance_m* when *face_ids* is empty.
|
||||
"""
|
||||
if not face_ids:
|
||||
return unknown_distance_m
|
||||
return min(distance_cache.get(fid, unknown_distance_m) for fid in face_ids)
|
||||
|
||||
|
||||
def should_backup(distance: float, personal_space_m: float) -> bool:
|
||||
return distance <= personal_space_m
|
||||
|
||||
|
||||
def should_clear(distance: float, personal_space_m: float,
|
||||
hysteresis_m: float) -> bool:
|
||||
return distance > personal_space_m + hysteresis_m
|
||||
|
||||
|
||||
# ── ROS2 node ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class PersonalSpaceNode(Node):
|
||||
"""Backs the robot away when a person enters personal space."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("personal_space_node")
|
||||
|
||||
self.declare_parameter("personal_space_m", 0.8)
|
||||
self.declare_parameter("backup_speed", 0.1)
|
||||
self.declare_parameter("hysteresis_m", 0.1)
|
||||
self.declare_parameter("unknown_distance_m", 99.0)
|
||||
self.declare_parameter("lost_timeout_s", 1.5)
|
||||
self.declare_parameter("control_rate", 10.0)
|
||||
self.declare_parameter("faces_topic", "/social/faces/detected")
|
||||
self.declare_parameter("states_topic", "/social/person_states")
|
||||
|
||||
self._space = self.get_parameter("personal_space_m").value
|
||||
self._backup_speed = self.get_parameter("backup_speed").value
|
||||
self._hysteresis = self.get_parameter("hysteresis_m").value
|
||||
self._unknown_dist = self.get_parameter("unknown_distance_m").value
|
||||
self._lost_t = self.get_parameter("lost_timeout_s").value
|
||||
rate = self.get_parameter("control_rate").value
|
||||
faces_topic = self.get_parameter("faces_topic").value
|
||||
states_topic = self.get_parameter("states_topic").value
|
||||
|
||||
# State
|
||||
self._too_close: bool = False
|
||||
self._distance_cache: Dict[int, float] = {} # face_id → distance (m)
|
||||
self._active_face_ids: list = [] # face_ids in last msg
|
||||
self._last_face_t: float = 0.0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
qos = QoSProfile(depth=10)
|
||||
self._cmd_pub = self.create_publisher(Twist, "/cmd_vel", qos)
|
||||
self._too_close_pub = self.create_publisher(Bool, "/saltybot/too_close", qos)
|
||||
|
||||
if _MSGS:
|
||||
self._states_sub = self.create_subscription(
|
||||
PersonStateArray, states_topic, self._on_person_states, qos
|
||||
)
|
||||
self._faces_sub = self.create_subscription(
|
||||
FaceDetectionArray, faces_topic, self._on_faces, qos
|
||||
)
|
||||
else:
|
||||
self.get_logger().warn(
|
||||
"saltybot_social_msgs not available — node passive"
|
||||
)
|
||||
|
||||
self._timer = self.create_timer(1.0 / rate, self._control_cb)
|
||||
|
||||
self.get_logger().info(
|
||||
f"PersonalSpaceNode ready "
|
||||
f"(space={self._space}m, backup={self._backup_speed}m/s, "
|
||||
f"hysteresis={self._hysteresis}m)"
|
||||
)
|
||||
|
||||
# ── Subscription callbacks ─────────────────────────────────────────────
|
||||
|
||||
def _on_person_states(self, msg) -> None:
|
||||
with self._lock:
|
||||
for ps in msg.persons:
|
||||
if ps.face_id >= 0:
|
||||
self._distance_cache[ps.face_id] = float(ps.distance)
|
||||
|
||||
def _on_faces(self, msg) -> None:
|
||||
face_ids = [int(f.face_id) for f in msg.faces]
|
||||
with self._lock:
|
||||
self._active_face_ids = face_ids
|
||||
if face_ids:
|
||||
self._last_face_t = time.monotonic()
|
||||
|
||||
# ── Control loop ───────────────────────────────────────────────────────
|
||||
|
||||
def _control_cb(self) -> None:
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
face_ids = list(self._active_face_ids)
|
||||
last_face_t = self._last_face_t
|
||||
cache = dict(self._distance_cache)
|
||||
|
||||
face_fresh = last_face_t > 0.0 and (now - last_face_t) < self._lost_t
|
||||
|
||||
if not face_fresh:
|
||||
# No fresh face data → clear state, stop
|
||||
if self._too_close:
|
||||
self._too_close = False
|
||||
self.get_logger().info("PersonalSpace: face lost — clearing")
|
||||
self._publish(backup=False)
|
||||
return
|
||||
|
||||
dist = closest_face_distance(face_ids, cache, self._unknown_dist)
|
||||
|
||||
if not self._too_close and should_backup(dist, self._space):
|
||||
self._too_close = True
|
||||
self.get_logger().warn(
|
||||
f"PersonalSpace: TOO CLOSE ({dist:.2f}m <= {self._space}m) — backing up"
|
||||
)
|
||||
elif self._too_close and should_clear(dist, self._space, self._hysteresis):
|
||||
self._too_close = False
|
||||
self.get_logger().info(
|
||||
f"PersonalSpace: CLEAR ({dist:.2f}m > {self._space + self._hysteresis:.2f}m)"
|
||||
)
|
||||
|
||||
self._publish(backup=self._too_close)
|
||||
|
||||
def _publish(self, backup: bool) -> None:
|
||||
twist = Twist()
|
||||
if backup:
|
||||
twist.linear.x = -self._backup_speed
|
||||
|
||||
bool_msg = Bool()
|
||||
bool_msg.data = backup
|
||||
|
||||
self._cmd_pub.publish(twist)
|
||||
self._too_close_pub.publish(bool_msg)
|
||||
|
||||
|
||||
def main(args=None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = PersonalSpaceNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
@ -55,6 +55,8 @@ setup(
|
||||
'volume_adjust_node = saltybot_social.volume_adjust_node:main',
|
||||
# Conversation topic memory (Issue #299)
|
||||
'topic_memory_node = saltybot_social.topic_memory_node:main',
|
||||
# Personal space respector (Issue #310)
|
||||
'personal_space_node = saltybot_social.personal_space_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
666
jetson/ros2_ws/src/saltybot_social/test/test_personal_space.py
Normal file
666
jetson/ros2_ws/src/saltybot_social/test/test_personal_space.py
Normal file
@ -0,0 +1,666 @@
|
||||
"""test_personal_space.py — Offline tests for personal_space_node (Issue #310).
|
||||
|
||||
Stubs out rclpy and ROS message types so tests run without a ROS install.
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
import time
|
||||
import types
|
||||
import unittest
|
||||
|
||||
|
||||
# ── ROS2 stubs ────────────────────────────────────────────────────────────────
|
||||
|
||||
def _make_ros_stubs():
|
||||
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
|
||||
"std_msgs", "std_msgs.msg",
|
||||
"geometry_msgs", "geometry_msgs.msg",
|
||||
"saltybot_social_msgs", "saltybot_social_msgs.msg"):
|
||||
if mod_name not in sys.modules:
|
||||
sys.modules[mod_name] = types.ModuleType(mod_name)
|
||||
|
||||
class _Node:
|
||||
def __init__(self, name="node"):
|
||||
self._name = name
|
||||
if not hasattr(self, "_params"):
|
||||
self._params = {}
|
||||
self._pubs = {}
|
||||
self._subs = {}
|
||||
self._logs = []
|
||||
self._timers = []
|
||||
|
||||
def declare_parameter(self, name, default):
|
||||
if name not in self._params:
|
||||
self._params[name] = default
|
||||
|
||||
def get_parameter(self, name):
|
||||
class _P:
|
||||
def __init__(self, v): self.value = v
|
||||
return _P(self._params.get(name))
|
||||
|
||||
def create_publisher(self, msg_type, topic, qos):
|
||||
pub = _FakePub()
|
||||
self._pubs[topic] = pub
|
||||
return pub
|
||||
|
||||
def create_subscription(self, msg_type, topic, cb, qos):
|
||||
self._subs[topic] = cb
|
||||
return object()
|
||||
|
||||
def create_timer(self, period, cb):
|
||||
self._timers.append(cb)
|
||||
return object()
|
||||
|
||||
def get_logger(self):
|
||||
node = self
|
||||
class _L:
|
||||
def info(self, m): node._logs.append(("INFO", m))
|
||||
def warn(self, m): node._logs.append(("WARN", m))
|
||||
def error(self, m): node._logs.append(("ERROR", m))
|
||||
return _L()
|
||||
|
||||
def destroy_node(self): pass
|
||||
|
||||
class _FakePub:
|
||||
def __init__(self):
|
||||
self.msgs = []
|
||||
def publish(self, msg):
|
||||
self.msgs.append(msg)
|
||||
|
||||
class _QoSProfile:
|
||||
def __init__(self, depth=10): self.depth = depth
|
||||
|
||||
class _Bool:
|
||||
def __init__(self): self.data = False
|
||||
|
||||
class _TwistLinear:
|
||||
def __init__(self): self.x = 0.0
|
||||
|
||||
class _TwistAngular:
|
||||
def __init__(self): self.z = 0.0
|
||||
|
||||
class _Twist:
|
||||
def __init__(self):
|
||||
self.linear = _TwistLinear()
|
||||
self.angular = _TwistAngular()
|
||||
|
||||
class _FaceDetection:
|
||||
def __init__(self, face_id=0):
|
||||
self.face_id = face_id
|
||||
|
||||
class _PersonState:
|
||||
def __init__(self, face_id=0, distance=1.0):
|
||||
self.face_id = face_id
|
||||
self.distance = distance
|
||||
|
||||
class _FaceDetectionArray:
|
||||
def __init__(self, faces=None):
|
||||
self.faces = faces or []
|
||||
|
||||
class _PersonStateArray:
|
||||
def __init__(self, persons=None):
|
||||
self.persons = persons or []
|
||||
|
||||
rclpy_mod = sys.modules["rclpy"]
|
||||
rclpy_mod.init = lambda args=None: None
|
||||
rclpy_mod.spin = lambda node: None
|
||||
rclpy_mod.shutdown = lambda: None
|
||||
|
||||
sys.modules["rclpy.node"].Node = _Node
|
||||
sys.modules["rclpy.qos"].QoSProfile = _QoSProfile
|
||||
sys.modules["std_msgs.msg"].Bool = _Bool
|
||||
sys.modules["geometry_msgs.msg"].Twist = _Twist
|
||||
sys.modules["saltybot_social_msgs.msg"].FaceDetectionArray = _FaceDetectionArray
|
||||
sys.modules["saltybot_social_msgs.msg"].PersonStateArray = _PersonStateArray
|
||||
|
||||
return (_Node, _FakePub, _Bool, _Twist,
|
||||
_FaceDetection, _PersonState,
|
||||
_FaceDetectionArray, _PersonStateArray)
|
||||
|
||||
|
||||
(_Node, _FakePub, _Bool, _Twist,
|
||||
_FaceDetection, _PersonState,
|
||||
_FaceDetectionArray, _PersonStateArray) = _make_ros_stubs()
|
||||
|
||||
|
||||
# ── Module loader ─────────────────────────────────────────────────────────────
|
||||
|
||||
_SRC = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/saltybot_social/personal_space_node.py"
|
||||
)
|
||||
|
||||
|
||||
def _load_mod():
|
||||
spec = importlib.util.spec_from_file_location("personal_space_testmod", _SRC)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def _make_node(mod, **kwargs):
|
||||
node = mod.PersonalSpaceNode.__new__(mod.PersonalSpaceNode)
|
||||
defaults = {
|
||||
"personal_space_m": 0.8,
|
||||
"backup_speed": 0.1,
|
||||
"hysteresis_m": 0.1,
|
||||
"unknown_distance_m": 99.0,
|
||||
"lost_timeout_s": 1.5,
|
||||
"control_rate": 10.0,
|
||||
"faces_topic": "/social/faces/detected",
|
||||
"states_topic": "/social/person_states",
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
node._params = dict(defaults)
|
||||
mod.PersonalSpaceNode.__init__(node)
|
||||
return node
|
||||
|
||||
|
||||
def _inject_faces(node, face_ids):
|
||||
"""Push a FaceDetectionArray message with the given face_ids."""
|
||||
msg = _FaceDetectionArray(faces=[_FaceDetection(fid) for fid in face_ids])
|
||||
node._subs["/social/faces/detected"](msg)
|
||||
|
||||
|
||||
def _inject_states(node, id_dist_pairs):
|
||||
"""Push a PersonStateArray message with (face_id, distance) pairs."""
|
||||
persons = [_PersonState(fid, dist) for fid, dist in id_dist_pairs]
|
||||
msg = _PersonStateArray(persons=persons)
|
||||
node._subs["/social/person_states"](msg)
|
||||
|
||||
|
||||
# ── Tests: pure helpers ────────────────────────────────────────────────────────
|
||||
|
||||
class TestClosestFaceDistance(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def _cfd(self, face_ids, cache, unknown=99.0):
|
||||
return self.mod.closest_face_distance(face_ids, cache, unknown)
|
||||
|
||||
def test_empty_face_ids_returns_unknown(self):
|
||||
self.assertEqual(self._cfd([], {}), 99.0)
|
||||
|
||||
def test_single_face_known(self):
|
||||
self.assertAlmostEqual(self._cfd([1], {1: 0.5}), 0.5)
|
||||
|
||||
def test_single_face_unknown(self):
|
||||
self.assertEqual(self._cfd([1], {}), 99.0)
|
||||
|
||||
def test_multiple_faces_min(self):
|
||||
dist = self._cfd([1, 2, 3], {1: 2.0, 2: 0.5, 3: 1.5})
|
||||
self.assertAlmostEqual(dist, 0.5)
|
||||
|
||||
def test_mixed_known_unknown(self):
|
||||
# unknown face treated as unknown_distance_m = 99, so min is the known one
|
||||
dist = self._cfd([1, 2], {1: 0.6}, 99.0)
|
||||
self.assertAlmostEqual(dist, 0.6)
|
||||
|
||||
def test_all_unknown_returns_unknown(self):
|
||||
dist = self._cfd([1, 2], {}, 99.0)
|
||||
self.assertEqual(dist, 99.0)
|
||||
|
||||
def test_custom_unknown_value(self):
|
||||
dist = self._cfd([], {}, unknown=5.0)
|
||||
self.assertEqual(dist, 5.0)
|
||||
|
||||
|
||||
class TestShouldBackup(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def test_at_threshold_triggers(self):
|
||||
self.assertTrue(self.mod.should_backup(0.8, 0.8))
|
||||
|
||||
def test_below_threshold_triggers(self):
|
||||
self.assertTrue(self.mod.should_backup(0.5, 0.8))
|
||||
|
||||
def test_above_threshold_no_trigger(self):
|
||||
self.assertFalse(self.mod.should_backup(1.0, 0.8))
|
||||
|
||||
def test_just_above_no_trigger(self):
|
||||
self.assertFalse(self.mod.should_backup(0.81, 0.8))
|
||||
|
||||
|
||||
class TestShouldClear(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def test_above_band_clears(self):
|
||||
# 0.8 + 0.1 = 0.9; 1.0 > 0.9 → clear
|
||||
self.assertTrue(self.mod.should_clear(1.0, 0.8, 0.1))
|
||||
|
||||
def test_at_band_edge_does_not_clear(self):
|
||||
self.assertFalse(self.mod.should_clear(0.9, 0.8, 0.1))
|
||||
|
||||
def test_within_band_does_not_clear(self):
|
||||
self.assertFalse(self.mod.should_clear(0.85, 0.8, 0.1))
|
||||
|
||||
def test_below_threshold_does_not_clear(self):
|
||||
self.assertFalse(self.mod.should_clear(0.5, 0.8, 0.1))
|
||||
|
||||
def test_just_above_band_clears(self):
|
||||
self.assertTrue(self.mod.should_clear(0.91, 0.8, 0.1))
|
||||
|
||||
|
||||
# ── Tests: node init ──────────────────────────────────────────────────────────
|
||||
|
||||
class TestNodeInit(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def test_instantiates(self):
|
||||
self.assertIsNotNone(_make_node(self.mod))
|
||||
|
||||
def test_cmd_vel_publisher(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIn("/cmd_vel", node._pubs)
|
||||
|
||||
def test_too_close_publisher(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIn("/saltybot/too_close", node._pubs)
|
||||
|
||||
def test_faces_subscriber(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIn("/social/faces/detected", node._subs)
|
||||
|
||||
def test_states_subscriber(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIn("/social/person_states", node._subs)
|
||||
|
||||
def test_timer_registered(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertGreater(len(node._timers), 0)
|
||||
|
||||
def test_too_close_initially_false(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertFalse(node._too_close)
|
||||
|
||||
def test_distance_cache_empty(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertEqual(node._distance_cache, {})
|
||||
|
||||
def test_active_face_ids_empty(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertEqual(node._active_face_ids, [])
|
||||
|
||||
def test_custom_topics(self):
|
||||
node = _make_node(self.mod,
|
||||
faces_topic="/my/faces",
|
||||
states_topic="/my/states")
|
||||
self.assertIn("/my/faces", node._subs)
|
||||
self.assertIn("/my/states", node._subs)
|
||||
|
||||
def test_parameters_stored(self):
|
||||
node = _make_node(self.mod, personal_space_m=1.2, backup_speed=0.2)
|
||||
self.assertAlmostEqual(node._space, 1.2)
|
||||
self.assertAlmostEqual(node._backup_speed, 0.2)
|
||||
|
||||
|
||||
# ── Tests: _on_person_states ──────────────────────────────────────────────────
|
||||
|
||||
class TestOnPersonStates(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def setUp(self):
|
||||
self.node = _make_node(self.mod)
|
||||
|
||||
def test_single_state_cached(self):
|
||||
_inject_states(self.node, [(1, 1.5)])
|
||||
self.assertAlmostEqual(self.node._distance_cache[1], 1.5)
|
||||
|
||||
def test_multiple_states_cached(self):
|
||||
_inject_states(self.node, [(1, 1.5), (2, 0.6), (3, 2.0)])
|
||||
self.assertAlmostEqual(self.node._distance_cache[1], 1.5)
|
||||
self.assertAlmostEqual(self.node._distance_cache[2], 0.6)
|
||||
self.assertAlmostEqual(self.node._distance_cache[3], 2.0)
|
||||
|
||||
def test_state_updated_on_second_msg(self):
|
||||
_inject_states(self.node, [(1, 1.5)])
|
||||
_inject_states(self.node, [(1, 0.4)])
|
||||
self.assertAlmostEqual(self.node._distance_cache[1], 0.4)
|
||||
|
||||
def test_negative_face_id_ignored(self):
|
||||
_inject_states(self.node, [(-1, 0.5)])
|
||||
self.assertNotIn(-1, self.node._distance_cache)
|
||||
|
||||
def test_zero_face_id_stored(self):
|
||||
_inject_states(self.node, [(0, 0.7)])
|
||||
self.assertIn(0, self.node._distance_cache)
|
||||
|
||||
def test_distance_cast_to_float(self):
|
||||
_inject_states(self.node, [(1, 1)]) # integer distance
|
||||
self.assertIsInstance(self.node._distance_cache[1], float)
|
||||
|
||||
|
||||
# ── Tests: _on_faces ──────────────────────────────────────────────────────────
|
||||
|
||||
class TestOnFaces(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def setUp(self):
|
||||
self.node = _make_node(self.mod)
|
||||
|
||||
def test_face_ids_stored(self):
|
||||
_inject_faces(self.node, [1, 2, 3])
|
||||
self.assertEqual(sorted(self.node._active_face_ids), [1, 2, 3])
|
||||
|
||||
def test_empty_faces_stored(self):
|
||||
_inject_faces(self.node, [])
|
||||
self.assertEqual(self.node._active_face_ids, [])
|
||||
|
||||
def test_timestamp_updated_on_faces(self):
|
||||
before = time.monotonic()
|
||||
_inject_faces(self.node, [1])
|
||||
self.assertGreaterEqual(self.node._last_face_t, before)
|
||||
|
||||
def test_timestamp_not_updated_on_empty(self):
|
||||
self.node._last_face_t = 0.0
|
||||
_inject_faces(self.node, [])
|
||||
self.assertEqual(self.node._last_face_t, 0.0)
|
||||
|
||||
def test_faces_replaced_not_merged(self):
|
||||
_inject_faces(self.node, [1, 2])
|
||||
_inject_faces(self.node, [3])
|
||||
self.assertEqual(sorted(self.node._active_face_ids), [3])
|
||||
|
||||
|
||||
# ── Tests: control loop ───────────────────────────────────────────────────────
|
||||
|
||||
class TestControlLoop(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def setUp(self):
|
||||
self.node = _make_node(self.mod, personal_space_m=0.8, backup_speed=0.1,
|
||||
hysteresis_m=0.1, lost_timeout_s=1.5)
|
||||
self.cmd_pub = self.node._pubs["/cmd_vel"]
|
||||
self.close_pub = self.node._pubs["/saltybot/too_close"]
|
||||
|
||||
def _fresh_face(self, face_id=1, distance=1.0):
|
||||
"""Inject a face + state so the node sees a fresh, close/far person."""
|
||||
_inject_states(self.node, [(face_id, distance)])
|
||||
_inject_faces(self.node, [face_id])
|
||||
|
||||
def test_no_face_publishes_zero_twist(self):
|
||||
self.node._control_cb()
|
||||
self.assertEqual(len(self.cmd_pub.msgs), 1)
|
||||
self.assertAlmostEqual(self.cmd_pub.msgs[-1].linear.x, 0.0)
|
||||
|
||||
def test_no_face_publishes_false(self):
|
||||
self.node._control_cb()
|
||||
self.assertFalse(self.close_pub.msgs[-1].data)
|
||||
|
||||
def test_person_far_no_backup(self):
|
||||
self._fresh_face(distance=1.5)
|
||||
self.node._control_cb()
|
||||
self.assertAlmostEqual(self.cmd_pub.msgs[-1].linear.x, 0.0)
|
||||
self.assertFalse(self.close_pub.msgs[-1].data)
|
||||
|
||||
def test_person_at_threshold_triggers_backup(self):
|
||||
self._fresh_face(distance=0.8)
|
||||
self.node._control_cb()
|
||||
self.assertAlmostEqual(self.cmd_pub.msgs[-1].linear.x, -0.1)
|
||||
self.assertTrue(self.close_pub.msgs[-1].data)
|
||||
|
||||
def test_person_too_close_triggers_backup(self):
|
||||
self._fresh_face(distance=0.5)
|
||||
self.node._control_cb()
|
||||
self.assertAlmostEqual(self.cmd_pub.msgs[-1].linear.x, -0.1)
|
||||
self.assertTrue(self.close_pub.msgs[-1].data)
|
||||
|
||||
def test_too_close_state_latches(self):
|
||||
"""State stays TOO_CLOSE even if distance jumps to mid-band."""
|
||||
self._fresh_face(distance=0.5)
|
||||
self.node._control_cb() # enter TOO_CLOSE
|
||||
# Now person is at 0.85 — inside hysteresis band (0.8 < 0.85 < 0.9)
|
||||
_inject_states(self.node, [(1, 0.85)])
|
||||
self.node._control_cb()
|
||||
self.assertTrue(self.node._too_close)
|
||||
self.assertTrue(self.close_pub.msgs[-1].data)
|
||||
|
||||
def test_hysteresis_clears_state(self):
|
||||
"""State clears when distance exceeds personal_space_m + hysteresis_m."""
|
||||
self._fresh_face(distance=0.5)
|
||||
self.node._control_cb() # enter TOO_CLOSE
|
||||
# Person retreats past hysteresis band (> 0.9 m)
|
||||
_inject_states(self.node, [(1, 1.0)])
|
||||
self.node._control_cb()
|
||||
self.assertFalse(self.node._too_close)
|
||||
self.assertFalse(self.close_pub.msgs[-1].data)
|
||||
|
||||
def test_backup_speed_magnitude(self):
|
||||
node = _make_node(self.mod, backup_speed=0.25)
|
||||
_inject_states(node, [(1, 0.3)])
|
||||
_inject_faces(node, [1])
|
||||
node._control_cb()
|
||||
self.assertAlmostEqual(node._pubs["/cmd_vel"].msgs[-1].linear.x, -0.25)
|
||||
|
||||
def test_state_machine_full_cycle(self):
|
||||
"""CLEAR → TOO_CLOSE → CLEAR."""
|
||||
# Start clear
|
||||
self._fresh_face(distance=1.5)
|
||||
self.node._control_cb()
|
||||
self.assertFalse(self.node._too_close)
|
||||
|
||||
# Enter TOO_CLOSE
|
||||
_inject_states(self.node, [(1, 0.5)])
|
||||
self.node._control_cb()
|
||||
self.assertTrue(self.node._too_close)
|
||||
|
||||
# Return to CLEAR
|
||||
_inject_states(self.node, [(1, 1.0)])
|
||||
self.node._control_cb()
|
||||
self.assertFalse(self.node._too_close)
|
||||
|
||||
def test_multiple_faces_closest_used(self):
|
||||
"""Closest face determines the backup decision."""
|
||||
_inject_states(self.node, [(1, 2.0), (2, 0.5)])
|
||||
_inject_faces(self.node, [1, 2])
|
||||
self.node._control_cb()
|
||||
self.assertTrue(self.node._too_close)
|
||||
|
||||
def test_both_far_no_backup(self):
|
||||
_inject_states(self.node, [(1, 1.5), (2, 2.0)])
|
||||
_inject_faces(self.node, [1, 2])
|
||||
self.node._control_cb()
|
||||
self.assertFalse(self.node._too_close)
|
||||
|
||||
def test_unknown_face_uses_unknown_distance(self):
|
||||
"""Face without PersonState gets unknown_distance_m (default=99 → safe)."""
|
||||
# No PersonState injected, so face_id=5 has no cache entry
|
||||
_inject_faces(self.node, [5])
|
||||
self.node._last_face_t = time.monotonic() # mark fresh
|
||||
self.node._control_cb()
|
||||
self.assertFalse(self.node._too_close)
|
||||
|
||||
def test_publishes_on_every_tick(self):
|
||||
self.node._control_cb()
|
||||
self.node._control_cb()
|
||||
self.node._control_cb()
|
||||
self.assertEqual(len(self.cmd_pub.msgs), 3)
|
||||
self.assertEqual(len(self.close_pub.msgs), 3)
|
||||
|
||||
|
||||
# ── Tests: face freshness / lost timeout ──────────────────────────────────────
|
||||
|
||||
class TestFaceFreshness(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def setUp(self):
|
||||
self.node = _make_node(self.mod, lost_timeout_s=1.5)
|
||||
self.cmd_pub = self.node._pubs["/cmd_vel"]
|
||||
self.close_pub = self.node._pubs["/saltybot/too_close"]
|
||||
|
||||
def test_stale_face_clears_state(self):
|
||||
"""If no fresh face data, state clears and robot stops."""
|
||||
# Manually set backed-up state
|
||||
self.node._too_close = True
|
||||
self.node._last_face_t = time.monotonic() - 10.0 # stale
|
||||
self.node._control_cb()
|
||||
self.assertFalse(self.node._too_close)
|
||||
self.assertAlmostEqual(self.cmd_pub.msgs[-1].linear.x, 0.0)
|
||||
self.assertFalse(self.close_pub.msgs[-1].data)
|
||||
|
||||
def test_fresh_face_active(self):
|
||||
"""Recently received face is treated as fresh."""
|
||||
_inject_states(self.node, [(1, 0.5)])
|
||||
_inject_faces(self.node, [1]) # updates _last_face_t
|
||||
self.node._control_cb()
|
||||
self.assertTrue(self.node._too_close)
|
||||
|
||||
def test_no_face_ever_is_stale(self):
|
||||
"""Never received a face → stale, no backup."""
|
||||
# _last_face_t = 0.0 (default) → not fresh
|
||||
self.node._control_cb()
|
||||
self.assertFalse(self.node._too_close)
|
||||
|
||||
def test_stale_logs_info(self):
|
||||
self.node._too_close = True
|
||||
self.node._last_face_t = time.monotonic() - 100.0
|
||||
self.node._control_cb()
|
||||
infos = [l for l in self.node._logs if l[0] == "INFO"]
|
||||
self.assertTrue(any("face lost" in m.lower() or "clearing" in m.lower()
|
||||
for _, m in infos))
|
||||
|
||||
def test_too_close_entry_logs_warn(self):
|
||||
_inject_states(self.node, [(1, 0.3)])
|
||||
_inject_faces(self.node, [1])
|
||||
self.node._control_cb()
|
||||
warns = [l for l in self.node._logs if l[0] == "WARN"]
|
||||
self.assertTrue(any("too close" in m.lower() or "backing" in m.lower()
|
||||
for _, m in warns))
|
||||
|
||||
def test_clear_logs_info(self):
|
||||
_inject_states(self.node, [(1, 0.3)])
|
||||
_inject_faces(self.node, [1])
|
||||
self.node._control_cb()
|
||||
_inject_states(self.node, [(1, 1.5)])
|
||||
self.node._control_cb()
|
||||
infos = [l for l in self.node._logs if l[0] == "INFO"]
|
||||
self.assertTrue(any("clear" in m.lower() for _, m in infos))
|
||||
|
||||
|
||||
# ── Tests: _publish ───────────────────────────────────────────────────────────
|
||||
|
||||
class TestPublish(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def setUp(self):
|
||||
self.node = _make_node(self.mod, backup_speed=0.15)
|
||||
self.cmd_pub = self.node._pubs["/cmd_vel"]
|
||||
self.bool_pub = self.node._pubs["/saltybot/too_close"]
|
||||
|
||||
def test_backup_false_zero_linear_x(self):
|
||||
self.node._publish(backup=False)
|
||||
self.assertAlmostEqual(self.cmd_pub.msgs[-1].linear.x, 0.0)
|
||||
|
||||
def test_backup_true_negative_linear_x(self):
|
||||
self.node._publish(backup=True)
|
||||
self.assertAlmostEqual(self.cmd_pub.msgs[-1].linear.x, -0.15)
|
||||
|
||||
def test_backup_false_bool_false(self):
|
||||
self.node._publish(backup=False)
|
||||
self.assertFalse(self.bool_pub.msgs[-1].data)
|
||||
|
||||
def test_backup_true_bool_true(self):
|
||||
self.node._publish(backup=True)
|
||||
self.assertTrue(self.bool_pub.msgs[-1].data)
|
||||
|
||||
def test_angular_z_always_zero(self):
|
||||
self.node._publish(backup=True)
|
||||
self.assertAlmostEqual(self.cmd_pub.msgs[-1].angular.z, 0.0)
|
||||
|
||||
def test_both_topics_published(self):
|
||||
self.node._publish(backup=False)
|
||||
self.assertEqual(len(self.cmd_pub.msgs), 1)
|
||||
self.assertEqual(len(self.bool_pub.msgs), 1)
|
||||
|
||||
|
||||
# ── Tests: source content ─────────────────────────────────────────────────────
|
||||
|
||||
class TestNodeSrc(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
with open(_SRC) as f: cls.src = f.read()
|
||||
|
||||
def test_issue_tag(self): self.assertIn("#310", self.src)
|
||||
def test_cmd_vel_topic(self): self.assertIn("/cmd_vel", self.src)
|
||||
def test_too_close_topic(self): self.assertIn("/saltybot/too_close", self.src)
|
||||
def test_faces_topic(self): self.assertIn("/social/faces/detected", self.src)
|
||||
def test_states_topic(self): self.assertIn("/social/person_states", self.src)
|
||||
def test_closest_face_fn(self): self.assertIn("closest_face_distance", self.src)
|
||||
def test_should_backup_fn(self): self.assertIn("should_backup", self.src)
|
||||
def test_should_clear_fn(self): self.assertIn("should_clear", self.src)
|
||||
def test_hysteresis_param(self): self.assertIn("hysteresis_m", self.src)
|
||||
def test_backup_speed_param(self): self.assertIn("backup_speed", self.src)
|
||||
def test_personal_space_param(self): self.assertIn("personal_space_m", self.src)
|
||||
def test_lost_timeout_param(self): self.assertIn("lost_timeout_s", self.src)
|
||||
def test_control_rate_param(self): self.assertIn("control_rate", self.src)
|
||||
def test_threading_lock(self): self.assertIn("threading.Lock", self.src)
|
||||
def test_linear_x(self): self.assertIn("linear.x", self.src)
|
||||
def test_twist_published(self): self.assertIn("Twist", self.src)
|
||||
def test_bool_published(self): self.assertIn("Bool", self.src)
|
||||
def test_main_defined(self): self.assertIn("def main", self.src)
|
||||
def test_face_detection_array(self): self.assertIn("FaceDetectionArray", self.src)
|
||||
def test_person_state_array(self): self.assertIn("PersonStateArray", self.src)
|
||||
|
||||
|
||||
# ── Tests: config / launch / setup ────────────────────────────────────────────
|
||||
|
||||
class TestConfig(unittest.TestCase):
|
||||
_CONFIG = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/config/personal_space_params.yaml"
|
||||
)
|
||||
_LAUNCH = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/launch/personal_space.launch.py"
|
||||
)
|
||||
_SETUP = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/setup.py"
|
||||
)
|
||||
|
||||
def test_config_exists(self):
|
||||
import os; self.assertTrue(os.path.exists(self._CONFIG))
|
||||
|
||||
def test_config_personal_space_m(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("personal_space_m", c)
|
||||
|
||||
def test_config_backup_speed(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("backup_speed", c)
|
||||
|
||||
def test_config_hysteresis(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("hysteresis_m", c)
|
||||
|
||||
def test_config_lost_timeout(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("lost_timeout_s", c)
|
||||
|
||||
def test_launch_exists(self):
|
||||
import os; self.assertTrue(os.path.exists(self._LAUNCH))
|
||||
|
||||
def test_launch_has_personal_space_arg(self):
|
||||
with open(self._LAUNCH) as f: c = f.read()
|
||||
self.assertIn("personal_space_m", c)
|
||||
|
||||
def test_launch_has_backup_speed_arg(self):
|
||||
with open(self._LAUNCH) as f: c = f.read()
|
||||
self.assertIn("backup_speed", c)
|
||||
|
||||
def test_entry_point_in_setup(self):
|
||||
with open(self._SETUP) as f: c = f.read()
|
||||
self.assertIn("personal_space_node", c)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -65,6 +65,9 @@ import { WaypointEditor } from './components/WaypointEditor.jsx';
|
||||
// Bandwidth monitor (issue #287)
|
||||
import { BandwidthMonitor } from './components/BandwidthMonitor.jsx';
|
||||
|
||||
// Temperature gauge (issue #308)
|
||||
import { TempGauge } from './components/TempGauge.jsx';
|
||||
|
||||
const TAB_GROUPS = [
|
||||
{
|
||||
label: 'SOCIAL',
|
||||
@ -87,6 +90,7 @@ const TAB_GROUPS = [
|
||||
{ id: 'battery', label: 'Battery', },
|
||||
{ id: 'battery-chart', label: 'Battery History', },
|
||||
{ id: 'motors', label: 'Motors', },
|
||||
{ id: 'thermal', label: 'Thermal', },
|
||||
{ id: 'map', label: 'Map', },
|
||||
{ id: 'control', label: 'Control', },
|
||||
{ id: 'health', label: 'Health', },
|
||||
@ -254,6 +258,7 @@ export default function App() {
|
||||
{activeTab === 'battery-chart' && <BatteryChart subscribe={subscribe} />}
|
||||
{activeTab === 'motors' && <MotorPanel subscribe={subscribe} />}
|
||||
{activeTab === 'motor-current-graph' && <MotorCurrentGraph subscribe={subscribe} />}
|
||||
{activeTab === 'thermal' && <TempGauge subscribe={subscribe} />}
|
||||
{activeTab === 'map' && <MapViewer subscribe={subscribe} />}
|
||||
{activeTab === 'control' && (
|
||||
<div className="flex flex-col h-full gap-4">
|
||||
|
||||
322
ui/social-bot/src/components/TempGauge.jsx
Normal file
322
ui/social-bot/src/components/TempGauge.jsx
Normal file
@ -0,0 +1,322 @@
|
||||
/**
|
||||
* TempGauge.jsx — CPU and GPU temperature circular gauge
|
||||
*
|
||||
* Features:
|
||||
* - Subscribes to /saltybot/thermal_status for CPU and GPU temperatures
|
||||
* - Circular gauge visualization with color zones
|
||||
* - Temperature zones: green <60°C, yellow 60-75°C, red >75°C
|
||||
* - Real-time temperature display with needle pointer
|
||||
* - Fan speed percentage indicator
|
||||
* - Peak temperature tracking
|
||||
* - Thermal alert indicators
|
||||
*/
|
||||
|
||||
import { useEffect, useRef, useState } from 'react';
|
||||
|
||||
const MIN_TEMP = 0;
|
||||
const MAX_TEMP = 100;
|
||||
const GAUGE_START_ANGLE = Math.PI * 0.7; // 126°
|
||||
const GAUGE_END_ANGLE = Math.PI * 2.3; // 414° (180° + 234°)
|
||||
const GAUGE_RANGE = GAUGE_END_ANGLE - GAUGE_START_ANGLE;
|
||||
|
||||
const TEMP_ZONES = {
|
||||
good: { max: 60, color: '#10b981', label: 'Good' }, // green
|
||||
caution: { max: 75, color: '#f59e0b', label: 'Caution' }, // yellow
|
||||
critical: { max: Infinity, color: '#ef4444', label: 'Critical' }, // red
|
||||
};
|
||||
|
||||
function CircularGauge({ temp, maxTemp, label, color }) {
|
||||
const canvasRef = useRef(null);
|
||||
|
||||
useEffect(() => {
|
||||
const canvas = canvasRef.current;
|
||||
if (!canvas) return;
|
||||
|
||||
const ctx = canvas.getContext('2d');
|
||||
const width = canvas.width;
|
||||
const height = canvas.height;
|
||||
const centerX = width / 2;
|
||||
const centerY = height * 0.65;
|
||||
const radius = width * 0.35;
|
||||
|
||||
// Clear canvas
|
||||
ctx.fillStyle = '#1f2937';
|
||||
ctx.fillRect(0, 0, width, height);
|
||||
|
||||
// Draw gauge background arcs (color zones)
|
||||
const zoneValues = [60, 75, 100];
|
||||
const zoneColors = ['#10b981', '#f59e0b', '#ef4444'];
|
||||
|
||||
for (let i = 0; i < zoneValues.length; i++) {
|
||||
const startVal = i === 0 ? 0 : zoneValues[i - 1];
|
||||
const endVal = zoneValues[i];
|
||||
const normalizedStart = startVal / MAX_TEMP;
|
||||
const normalizedEnd = endVal / MAX_TEMP;
|
||||
|
||||
const startAngle = GAUGE_START_ANGLE + GAUGE_RANGE * normalizedStart;
|
||||
const endAngle = GAUGE_START_ANGLE + GAUGE_RANGE * normalizedEnd;
|
||||
|
||||
ctx.strokeStyle = zoneColors[i];
|
||||
ctx.lineWidth = 12;
|
||||
ctx.lineCap = 'round';
|
||||
ctx.beginPath();
|
||||
ctx.arc(centerX, centerY, radius, startAngle, endAngle);
|
||||
ctx.stroke();
|
||||
}
|
||||
|
||||
// Draw outer ring
|
||||
ctx.strokeStyle = '#374151';
|
||||
ctx.lineWidth = 2;
|
||||
ctx.beginPath();
|
||||
ctx.arc(centerX, centerY, radius, GAUGE_START_ANGLE, GAUGE_END_ANGLE);
|
||||
ctx.stroke();
|
||||
|
||||
// Draw tick marks and labels
|
||||
ctx.fillStyle = '#9ca3af';
|
||||
ctx.strokeStyle = '#9ca3af';
|
||||
ctx.lineWidth = 1;
|
||||
ctx.font = 'bold 10px monospace';
|
||||
ctx.textAlign = 'center';
|
||||
ctx.textBaseline = 'middle';
|
||||
|
||||
for (let i = 0; i <= 10; i++) {
|
||||
const value = (i / 10) * MAX_TEMP;
|
||||
const angle = GAUGE_START_ANGLE + (i / 10) * GAUGE_RANGE;
|
||||
const tickLen = i % 5 === 0 ? 8 : 4;
|
||||
|
||||
const x1 = centerX + Math.cos(angle) * radius;
|
||||
const y1 = centerY + Math.sin(angle) * radius;
|
||||
const x2 = centerX + Math.cos(angle) * (radius + tickLen);
|
||||
const y2 = centerY + Math.sin(angle) * (radius + tickLen);
|
||||
|
||||
ctx.beginPath();
|
||||
ctx.moveTo(x1, y1);
|
||||
ctx.lineTo(x2, y2);
|
||||
ctx.stroke();
|
||||
|
||||
if (i % 2 === 0) {
|
||||
const labelX = centerX + Math.cos(angle) * (radius + 18);
|
||||
const labelY = centerY + Math.sin(angle) * (radius + 18);
|
||||
ctx.fillText(`${Math.round(value)}°`, labelX, labelY);
|
||||
}
|
||||
}
|
||||
|
||||
// Draw needle
|
||||
const normalizedTemp = Math.max(0, Math.min(1, temp / MAX_TEMP));
|
||||
const needleAngle = GAUGE_START_ANGLE + normalizedTemp * GAUGE_RANGE;
|
||||
|
||||
// Needle base circle
|
||||
ctx.fillStyle = '#1f2937';
|
||||
ctx.strokeStyle = '#9ca3af';
|
||||
ctx.lineWidth = 2;
|
||||
ctx.beginPath();
|
||||
ctx.arc(centerX, centerY, 8, 0, Math.PI * 2);
|
||||
ctx.fill();
|
||||
ctx.stroke();
|
||||
|
||||
// Needle line
|
||||
ctx.strokeStyle = color || '#06b6d4';
|
||||
ctx.lineWidth = 3;
|
||||
ctx.lineCap = 'round';
|
||||
ctx.beginPath();
|
||||
ctx.moveTo(centerX, centerY);
|
||||
const needleLen = radius * 0.85;
|
||||
ctx.lineTo(
|
||||
centerX + Math.cos(needleAngle) * needleLen,
|
||||
centerY + Math.sin(needleAngle) * needleLen
|
||||
);
|
||||
ctx.stroke();
|
||||
|
||||
// Needle tip circle
|
||||
ctx.fillStyle = color || '#06b6d4';
|
||||
ctx.beginPath();
|
||||
ctx.arc(
|
||||
centerX + Math.cos(needleAngle) * needleLen,
|
||||
centerY + Math.sin(needleAngle) * needleLen,
|
||||
4,
|
||||
0,
|
||||
Math.PI * 2
|
||||
);
|
||||
ctx.fill();
|
||||
|
||||
// Draw temperature display
|
||||
ctx.fillStyle = color || '#06b6d4';
|
||||
ctx.font = 'bold 32px monospace';
|
||||
ctx.textAlign = 'center';
|
||||
ctx.textBaseline = 'top';
|
||||
ctx.fillText(`${Math.round(temp)}°`, centerX, centerY - 20);
|
||||
|
||||
ctx.fillStyle = '#9ca3af';
|
||||
ctx.font = 'bold 12px monospace';
|
||||
ctx.fillText(label, centerX, centerY + 25);
|
||||
}, [temp, label, color]);
|
||||
|
||||
return <canvas ref={canvasRef} width={200} height={180} className="inline-block" />;
|
||||
}
|
||||
|
||||
function TemperatureRow({ label, temp, fanSpeed, maxTemp = MAX_TEMP }) {
|
||||
// Determine color zone
|
||||
let zoneColor = '#10b981'; // green
|
||||
let zoneLabel = 'Good';
|
||||
if (temp >= 60 && temp < 75) {
|
||||
zoneColor = '#f59e0b'; // yellow
|
||||
zoneLabel = 'Caution';
|
||||
} else if (temp >= 75) {
|
||||
zoneColor = '#ef4444'; // red
|
||||
zoneLabel = 'Critical';
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col items-center gap-3 p-4 bg-gray-900 rounded border border-gray-800">
|
||||
<CircularGauge temp={temp} maxTemp={maxTemp} label={label} color={zoneColor} />
|
||||
|
||||
<div className="w-full space-y-2">
|
||||
<div className="flex justify-between items-center text-xs">
|
||||
<span className="text-gray-600">STATUS</span>
|
||||
<span className={`font-bold ${zoneColor === '#10b981' ? 'text-green-400' : zoneColor === '#f59e0b' ? 'text-yellow-400' : 'text-red-400'}`}>
|
||||
{zoneLabel}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{fanSpeed !== undefined && (
|
||||
<div className="flex justify-between items-center text-xs">
|
||||
<span className="text-gray-600">FAN SPEED</span>
|
||||
<span className="text-cyan-300 font-mono">{Math.round(fanSpeed)}%</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex justify-between items-center text-xs">
|
||||
<span className="text-gray-600">MAX REACHED</span>
|
||||
<span className="text-amber-300 font-mono">{Math.round(maxTemp)}°</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function TempGauge({ subscribe }) {
|
||||
const [cpuTemp, setCpuTemp] = useState(0);
|
||||
const [gpuTemp, setGpuTemp] = useState(0);
|
||||
const [cpuFanSpeed, setCpuFanSpeed] = useState(0);
|
||||
const [gpuFanSpeed, setGpuFanSpeed] = useState(0);
|
||||
const [cpuMaxTemp, setCpuMaxTemp] = useState(0);
|
||||
const [gpuMaxTemp, setGpuMaxTemp] = useState(0);
|
||||
const maxTempRef = useRef({ cpu: 0, gpu: 0 });
|
||||
|
||||
// Subscribe to thermal status
|
||||
useEffect(() => {
|
||||
const unsubscribe = subscribe(
|
||||
'/saltybot/thermal_status',
|
||||
'saltybot_msgs/ThermalStatus',
|
||||
(msg) => {
|
||||
try {
|
||||
const cpu = msg.cpu_temp || 0;
|
||||
const gpu = msg.gpu_temp || 0;
|
||||
const cpuFan = msg.cpu_fan_speed || 0;
|
||||
const gpuFan = msg.gpu_fan_speed || 0;
|
||||
|
||||
setCpuTemp(cpu);
|
||||
setGpuTemp(gpu);
|
||||
setCpuFanSpeed(cpuFan);
|
||||
setGpuFanSpeed(gpuFan);
|
||||
|
||||
// Track max temperatures
|
||||
if (cpu > maxTempRef.current.cpu) {
|
||||
maxTempRef.current.cpu = cpu;
|
||||
setCpuMaxTemp(cpu);
|
||||
}
|
||||
if (gpu > maxTempRef.current.gpu) {
|
||||
maxTempRef.current.gpu = gpu;
|
||||
setGpuMaxTemp(gpu);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Error parsing thermal status:', e);
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
return unsubscribe;
|
||||
}, [subscribe]);
|
||||
|
||||
const overallCritical = cpuTemp >= 75 || gpuTemp >= 75;
|
||||
const overallCaution = cpuTemp >= 60 || gpuTemp >= 60;
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full space-y-3">
|
||||
{/* Summary Header */}
|
||||
<div className="bg-gray-950 rounded-lg border border-cyan-950 p-3 space-y-2">
|
||||
<div className="flex justify-between items-center">
|
||||
<div className="text-cyan-700 text-xs font-bold tracking-widest">
|
||||
THERMAL STATUS
|
||||
</div>
|
||||
<div className={`text-xs font-bold px-2 py-1 rounded ${
|
||||
overallCritical ? 'bg-red-950 text-red-400 border border-red-800' :
|
||||
overallCaution ? 'bg-yellow-950 text-yellow-400 border border-yellow-800' :
|
||||
'bg-green-950 text-green-400 border border-green-800'
|
||||
}`}>
|
||||
{overallCritical ? 'CRITICAL' : overallCaution ? 'CAUTION' : 'NORMAL'}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Quick stats */}
|
||||
<div className="grid grid-cols-2 gap-2 text-xs">
|
||||
<div className="bg-gray-900 rounded p-2">
|
||||
<div className="text-gray-600">CPU</div>
|
||||
<div className="text-lg font-mono text-cyan-300">{Math.round(cpuTemp)}°C</div>
|
||||
</div>
|
||||
<div className="bg-gray-900 rounded p-2">
|
||||
<div className="text-gray-600">GPU</div>
|
||||
<div className="text-lg font-mono text-cyan-300">{Math.round(gpuTemp)}°C</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Gauge Cards */}
|
||||
<div className="flex-1 grid grid-cols-2 gap-3 overflow-y-auto">
|
||||
<TemperatureRow
|
||||
label="CPU"
|
||||
temp={cpuTemp}
|
||||
fanSpeed={cpuFanSpeed}
|
||||
maxTemp={cpuMaxTemp}
|
||||
/>
|
||||
<TemperatureRow
|
||||
label="GPU"
|
||||
temp={gpuTemp}
|
||||
fanSpeed={gpuFanSpeed}
|
||||
maxTemp={gpuMaxTemp}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Temperature zones legend */}
|
||||
<div className="bg-gray-950 rounded border border-gray-800 p-2 space-y-2">
|
||||
<div className="text-xs text-gray-600 font-bold tracking-widest">TEMPERATURE ZONES</div>
|
||||
<div className="grid grid-cols-3 gap-2 text-xs">
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="w-3 h-3 rounded-full bg-green-500" />
|
||||
<span className="text-gray-400"><60°C</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="w-3 h-3 rounded-full bg-yellow-500" />
|
||||
<span className="text-gray-400">60-75°C</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="w-3 h-3 rounded-full bg-red-500" />
|
||||
<span className="text-gray-400">>75°C</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Topic info */}
|
||||
<div className="bg-gray-950 rounded border border-gray-800 p-2 text-xs text-gray-600 space-y-1">
|
||||
<div className="flex justify-between">
|
||||
<span>Topic:</span>
|
||||
<span className="text-gray-500">/saltybot/thermal_status</span>
|
||||
</div>
|
||||
<div className="flex justify-between">
|
||||
<span>Type:</span>
|
||||
<span className="text-gray-500">saltybot_msgs/ThermalStatus</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user