Compare commits

..

2 Commits

Author SHA1 Message Date
41040f8bbd feat(firmware): Pan-tilt servo driver for camera head (Issue #206)
Implements TIM4 PWM driver for 2-servo camera mount with:
- 50 Hz PWM frequency (standard servo control)
- CH1 (PB6) pan servo, CH2 (PB7) tilt servo
- 0-180° angle range → 500-2500 µs pulse width mapping
- Non-blocking servo_set_angle() for immediate positioning
- servo_sweep() for smooth pan-tilt animation (linear interpolation)
- Independent sweep control per servo (pan and tilt move simultaneously)
- 15 comprehensive unit tests covering all scenarios

Integration:
- servo_init() called at startup after power_mgmt_init()
- servo_tick(now_ms) called every 1ms in main loop
- Ready for camera/gimbal control automation

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
2026-03-02 11:21:39 -05:00
a4371f2a1d feat: Add Issue #203 - Node watchdog monitor (20Hz heartbeat detection)
Implements node watchdog ROS2 node that monitors heartbeats from critical
systems and triggers safety fallback when motor driver is lost >2s.

Features:
  - Monitor heartbeats from: balance, motor_driver, emergency, docking
  - Alert on /saltybot/node_watchdog (JSON) if heartbeat lost >1s
  - Safety fallback: zero cmd_vel if motor driver lost >2s
  - Republish cmd_vel on /saltybot/cmd_vel_safe with safety checks
  - 20Hz monitoring and publishing frequency
  - Configurable heartbeat timeout thresholds

Heartbeat Topics:
  - /saltybot/balance_heartbeat (std_msgs/UInt32)
  - /saltybot/motor_driver_heartbeat (std_msgs/UInt32)
  - /saltybot/emergency_heartbeat (std_msgs/UInt32)
  - /saltybot/docking_heartbeat (std_msgs/UInt32)
  - /cmd_vel (geometry_msgs/Twist)

Published Topics:
  - /saltybot/node_watchdog (std_msgs/String) - JSON status
  - /saltybot/cmd_vel_safe (geometry_msgs/Twist) - Safety-checked velocity

Package: saltybot_node_watchdog
Entry point: node_watchdog_node
Launch file: node_watchdog.launch.py

Tests: 20+ unit tests covering:
  - Heartbeat reception and timeout detection
  - Motor driver critical timeout (>2s)
  - Safety fallback logic
  - cmd_vel zeroing on motor driver loss
  - Health status JSON serialization
  - Multi-node failure scenarios

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
2026-03-02 11:18:50 -05:00
31 changed files with 653 additions and 1176 deletions

View File

@ -0,0 +1,14 @@
# Node watchdog configuration
node_watchdog:
ros__parameters:
# Publishing frequency in Hz
frequency: 20 # 20 Hz = 50ms cycle
# General heartbeat timeout (seconds)
# Alert if any heartbeat lost for this duration
heartbeat_timeout: 1.0
# Motor driver critical timeout (seconds)
# Trigger safety fallback (zero cmd_vel) if motor driver down this long
motor_driver_critical_timeout: 2.0

View File

@ -0,0 +1,36 @@
"""Launch file for node_watchdog_node."""
from launch import LaunchDescription
from launch_ros.actions import Node
from launch.substitutions import LaunchConfiguration
from launch.actions import DeclareLaunchArgument
import os
from ament_index_python.packages import get_package_share_directory
def generate_launch_description():
"""Generate launch description for node watchdog."""
# Package directory
pkg_dir = get_package_share_directory("saltybot_node_watchdog")
# Parameters
config_file = os.path.join(pkg_dir, "config", "watchdog_config.yaml")
# Declare launch arguments
return LaunchDescription(
[
DeclareLaunchArgument(
"config_file",
default_value=config_file,
description="Path to configuration YAML file",
),
# Node watchdog node
Node(
package="saltybot_node_watchdog",
executable="node_watchdog_node",
name="node_watchdog",
output="screen",
parameters=[LaunchConfiguration("config_file")],
),
]
)

View File

@ -1,16 +1,18 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_thermal</name>
<name>saltybot_node_watchdog</name>
<version>0.1.0</version>
<description>
Jetson thermal monitor (Issue #205). Reads /sys/class/thermal/thermal_zone*,
publishes /saltybot/thermal JSON at 1 Hz, warns at 75 °C, throttles at 85 °C.
Node watchdog monitor for SaltyBot critical systems. Monitors heartbeats from balance,
motor driver, emergency, and docking nodes. Publishes alerts on heartbeat loss >1s.
Implements safety fallback: zeros cmd_vel if motor driver lost >2s. Runs at 20Hz.
</description>
<maintainer email="sl-jetson@saltylab.local">sl-jetson</maintainer>
<maintainer email="sl-controls@saltylab.local">sl-controls</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>geometry_msgs</depend>
<depend>std_msgs</depend>
<buildtool_depend>ament_python</buildtool_depend>

View File

@ -0,0 +1,235 @@
#!/usr/bin/env python3
"""Node watchdog monitor for SaltyBot critical systems.
Monitors heartbeats from balance, motor driver, emergency, and docking nodes.
Publishes alerts on heartbeat loss >1s. Implements safety fallback: zeros cmd_vel
if motor driver lost >2s.
Published topics:
/saltybot/node_watchdog (std_msgs/String) - JSON watchdog status
/saltybot/cmd_vel_safe (geometry_msgs/Twist) - cmd_vel with motor driver safety check
Subscribed topics:
/saltybot/balance_heartbeat (std_msgs/UInt32) - Balance node heartbeat
/saltybot/motor_driver_heartbeat (std_msgs/UInt32) - Motor driver heartbeat
/saltybot/emergency_heartbeat (std_msgs/UInt32) - Emergency system heartbeat
/saltybot/docking_heartbeat (std_msgs/UInt32) - Docking node heartbeat
/cmd_vel (geometry_msgs/Twist) - Velocity command input
"""
import json
import rclpy
from rclpy.node import Node
from rclpy.timer import Timer
from geometry_msgs.msg import Twist
from std_msgs.msg import UInt32, String
class NodeWatchdogNode(Node):
"""ROS2 watchdog monitor for critical system nodes."""
def __init__(self):
super().__init__("node_watchdog")
# Parameters
self.declare_parameter("frequency", 20) # Hz
self.declare_parameter("heartbeat_timeout", 1.0) # seconds, general timeout
self.declare_parameter("motor_driver_critical_timeout", 2.0) # seconds
frequency = self.get_parameter("frequency").value
self.heartbeat_timeout = self.get_parameter("heartbeat_timeout").value
self.motor_driver_critical_timeout = self.get_parameter(
"motor_driver_critical_timeout"
).value
# Heartbeat tracking
self.critical_nodes = {
"balance": None,
"motor_driver": None,
"emergency": None,
"docking": None,
}
self.last_heartbeat_time = {
"balance": None,
"motor_driver": None,
"emergency": None,
"docking": None,
}
self.last_cmd_vel = None
self.motor_driver_down = False
# Subscriptions for heartbeats
self.create_subscription(
UInt32, "/saltybot/balance_heartbeat", self._on_balance_heartbeat, 10
)
self.create_subscription(
UInt32,
"/saltybot/motor_driver_heartbeat",
self._on_motor_driver_heartbeat,
10,
)
self.create_subscription(
UInt32, "/saltybot/emergency_heartbeat", self._on_emergency_heartbeat, 10
)
self.create_subscription(
UInt32, "/saltybot/docking_heartbeat", self._on_docking_heartbeat, 10
)
# cmd_vel subscription and safe republishing
self.create_subscription(Twist, "/cmd_vel", self._on_cmd_vel, 10)
# Publications
self.pub_watchdog = self.create_publisher(String, "/saltybot/node_watchdog", 10)
self.pub_cmd_vel_safe = self.create_publisher(
Twist, "/saltybot/cmd_vel_safe", 10
)
# Timer for periodic monitoring at 20Hz
period = 1.0 / frequency
self.timer: Timer = self.create_timer(period, self._timer_callback)
self.get_logger().info(
f"Node watchdog initialized at {frequency}Hz. "
f"Heartbeat timeout: {self.heartbeat_timeout}s, "
f"Motor driver critical: {self.motor_driver_critical_timeout}s"
)
def _on_balance_heartbeat(self, msg: UInt32) -> None:
"""Update balance node heartbeat timestamp."""
self.last_heartbeat_time["balance"] = self.get_clock().now()
def _on_motor_driver_heartbeat(self, msg: UInt32) -> None:
"""Update motor driver heartbeat timestamp."""
self.last_heartbeat_time["motor_driver"] = self.get_clock().now()
self.motor_driver_down = False
def _on_emergency_heartbeat(self, msg: UInt32) -> None:
"""Update emergency system heartbeat timestamp."""
self.last_heartbeat_time["emergency"] = self.get_clock().now()
def _on_docking_heartbeat(self, msg: UInt32) -> None:
"""Update docking node heartbeat timestamp."""
self.last_heartbeat_time["docking"] = self.get_clock().now()
def _on_cmd_vel(self, msg: Twist) -> None:
"""Cache the last received cmd_vel."""
self.last_cmd_vel = msg
def _check_node_health(self) -> dict:
"""Check health of all monitored nodes.
Returns:
dict: Health status of each node with timeout and elapsed time.
"""
now = self.get_clock().now()
health = {}
for node_name in self.critical_nodes:
last_time = self.last_heartbeat_time[node_name]
if last_time is None:
# No heartbeat received yet
health[node_name] = {
"status": "unknown",
"elapsed_s": None,
"timeout_s": self.heartbeat_timeout,
}
else:
# Calculate elapsed time since last heartbeat
elapsed = (now - last_time).nanoseconds / 1e9
is_timeout = elapsed > self.heartbeat_timeout
# Special case: motor driver has longer critical timeout
if node_name == "motor_driver":
is_critical = elapsed > self.motor_driver_critical_timeout
else:
is_critical = False
health[node_name] = {
"status": "down" if is_timeout else "up",
"elapsed_s": elapsed,
"timeout_s": (
self.motor_driver_critical_timeout
if node_name == "motor_driver"
else self.heartbeat_timeout
),
"critical": is_critical,
}
return health
def _timer_callback(self) -> None:
"""Monitor node health and publish alerts at 20Hz."""
health = self._check_node_health()
# Detect if motor driver is in critical state (down for >2s)
motor_driver_health = health.get("motor_driver", {})
if motor_driver_health.get("critical", False):
self.motor_driver_down = True
self.get_logger().warn(
f"MOTOR DRIVER DOWN >2s ({motor_driver_health['elapsed_s']:.1f}s). "
"Applying safety fallback: zeroing cmd_vel."
)
# Determine any nodes down for >1s
nodes_with_timeout = {
name: status
for name, status in health.items()
if status.get("status") == "down"
}
# Publish watchdog status
watchdog_status = {
"timestamp": self.get_clock().now().nanoseconds / 1e9,
"all_healthy": len(nodes_with_timeout) == 0
and not self.motor_driver_down,
"health": health,
"motor_driver_critical": self.motor_driver_down,
}
watchdog_msg = String(data=json.dumps(watchdog_status))
self.pub_watchdog.publish(watchdog_msg)
# Publish cmd_vel with safety checks
if self.last_cmd_vel is not None:
cmd_vel_safe = self._apply_safety_checks(self.last_cmd_vel)
self.pub_cmd_vel_safe.publish(cmd_vel_safe)
def _apply_safety_checks(self, cmd_vel: Twist) -> Twist:
"""Apply safety checks to cmd_vel based on system state.
Args:
cmd_vel: Original velocity command
Returns:
Twist: Potentially modified velocity command for safe operation.
"""
safe_cmd = Twist()
# If motor driver is critically down, zero all velocities
if self.motor_driver_down:
return safe_cmd
# Otherwise, pass through unchanged
safe_cmd.linear.x = cmd_vel.linear.x
safe_cmd.linear.y = cmd_vel.linear.y
safe_cmd.linear.z = cmd_vel.linear.z
safe_cmd.angular.x = cmd_vel.angular.x
safe_cmd.angular.y = cmd_vel.angular.y
safe_cmd.angular.z = cmd_vel.angular.z
return safe_cmd
def main(args=None):
rclpy.init(args=args)
node = NodeWatchdogNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/saltybot_node_watchdog
[egg_info]
tag_date = 0

View File

@ -0,0 +1,29 @@
from setuptools import setup
package_name = "saltybot_node_watchdog"
setup(
name=package_name,
version="0.1.0",
packages=[package_name],
data_files=[
("share/ament_index/resource_index/packages", [f"resource/{package_name}"]),
(f"share/{package_name}", ["package.xml"]),
(f"share/{package_name}/launch", ["launch/node_watchdog.launch.py"]),
(f"share/{package_name}/config", ["config/watchdog_config.yaml"]),
],
install_requires=["setuptools"],
zip_safe=True,
maintainer="sl-controls",
maintainer_email="sl-controls@saltylab.local",
description=(
"Node watchdog: heartbeat monitoring with safety fallback for critical systems"
),
license="MIT",
tests_require=["pytest"],
entry_points={
"console_scripts": [
"node_watchdog_node = saltybot_node_watchdog.node_watchdog_node:main",
],
},
)

View File

@ -0,0 +1,329 @@
"""Unit tests for node_watchdog_node."""
import pytest
import json
import time
from geometry_msgs.msg import Twist
from std_msgs.msg import UInt32, String
import rclpy
from rclpy.time import Time
# Import the node under test
from saltybot_node_watchdog.node_watchdog_node import NodeWatchdogNode
@pytest.fixture
def rclpy_fixture():
"""Initialize and cleanup rclpy."""
rclpy.init()
yield
rclpy.shutdown()
@pytest.fixture
def node(rclpy_fixture):
"""Create a watchdog node instance."""
node = NodeWatchdogNode()
yield node
node.destroy_node()
class TestNodeWatchdogNode:
"""Test suite for NodeWatchdogNode."""
def test_node_initialization(self, node):
"""Test that node initializes with correct defaults."""
assert node.heartbeat_timeout == 1.0
assert node.motor_driver_critical_timeout == 2.0
assert node.last_cmd_vel is None
assert node.motor_driver_down is False
# All heartbeat times should be None initially
for node_name in node.critical_nodes:
assert node.last_heartbeat_time[node_name] is None
def test_balance_heartbeat_received(self, node):
"""Test balance node heartbeat recording."""
msg = UInt32(data=1)
node._on_balance_heartbeat(msg)
assert node.last_heartbeat_time["balance"] is not None
def test_motor_driver_heartbeat_received(self, node):
"""Test motor driver heartbeat recording."""
msg = UInt32(data=1)
node._on_motor_driver_heartbeat(msg)
assert node.last_heartbeat_time["motor_driver"] is not None
# Motor driver heartbeat should clear the down flag
node.motor_driver_down = True
node._on_motor_driver_heartbeat(msg)
assert node.motor_driver_down is False
def test_emergency_heartbeat_received(self, node):
"""Test emergency system heartbeat recording."""
msg = UInt32(data=1)
node._on_emergency_heartbeat(msg)
assert node.last_heartbeat_time["emergency"] is not None
def test_docking_heartbeat_received(self, node):
"""Test docking node heartbeat recording."""
msg = UInt32(data=1)
node._on_docking_heartbeat(msg)
assert node.last_heartbeat_time["docking"] is not None
def test_cmd_vel_caching(self, node):
"""Test that cmd_vel messages are cached."""
msg = Twist()
msg.linear.x = 1.0
node._on_cmd_vel(msg)
assert node.last_cmd_vel is not None
assert node.last_cmd_vel.linear.x == 1.0
def test_health_check_all_unknown(self, node):
"""Test health check when no heartbeats received."""
health = node._check_node_health()
assert len(health) == 4
for node_name in node.critical_nodes:
assert health[node_name]["status"] == "unknown"
assert health[node_name]["elapsed_s"] is None
assert health[node_name]["timeout_s"] == 1.0
def test_health_check_just_received(self, node):
"""Test health check just after heartbeat received."""
# Record a heartbeat for balance node
node.last_heartbeat_time["balance"] = node.get_clock().now()
health = node._check_node_health()
# Balance should be up, others unknown
assert health["balance"]["status"] == "up"
assert health["balance"]["elapsed_s"] < 0.1
assert health["emergency"]["status"] == "unknown"
def test_health_check_timeout_general(self, node):
"""Test that heartbeat timeout is detected (>1s)."""
# Simulate a heartbeat that arrived >1s ago
now = node.get_clock().now()
old_time = Time(
nanoseconds=now.nanoseconds - int(1.5 * 1e9)
) # 1.5 seconds ago
node.last_heartbeat_time["balance"] = old_time
health = node._check_node_health()
assert health["balance"]["status"] == "down"
assert health["balance"]["elapsed_s"] > 1.4
assert health["balance"]["elapsed_s"] < 2.0
def test_health_check_motor_driver_critical(self, node):
"""Test motor driver critical timeout (>2s)."""
# Simulate motor driver heartbeat >2s ago
now = node.get_clock().now()
old_time = Time(nanoseconds=now.nanoseconds - int(2.5 * 1e9)) # 2.5 seconds
node.last_heartbeat_time["motor_driver"] = old_time
health = node._check_node_health()
motor_health = health["motor_driver"]
assert motor_health["status"] == "down"
assert motor_health.get("critical", False) is True
assert motor_health["elapsed_s"] > 2.4
def test_safety_check_normal_operation(self, node):
"""Test safety check passes through cmd_vel normally."""
node.motor_driver_down = False
cmd = Twist()
cmd.linear.x = 1.5
cmd.angular.z = 0.3
safe_cmd = node._apply_safety_checks(cmd)
assert abs(safe_cmd.linear.x - 1.5) < 1e-6
assert abs(safe_cmd.angular.z - 0.3) < 1e-6
def test_safety_check_motor_driver_down(self, node):
"""Test safety check zeros cmd_vel when motor driver is down."""
node.motor_driver_down = True
cmd = Twist()
cmd.linear.x = 1.5
cmd.linear.y = 0.2
cmd.angular.z = 0.3
safe_cmd = node._apply_safety_checks(cmd)
# All velocities should be zero
assert safe_cmd.linear.x == 0.0
assert safe_cmd.linear.y == 0.0
assert safe_cmd.linear.z == 0.0
assert safe_cmd.angular.x == 0.0
assert safe_cmd.angular.y == 0.0
assert safe_cmd.angular.z == 0.0
def test_timer_callback_publishes(self, node):
"""Test that timer callback publishes watchdog status."""
# Record a heartbeat
node.last_heartbeat_time["balance"] = node.get_clock().now()
node.last_cmd_vel = Twist()
node.last_cmd_vel.linear.x = 1.0
# Call timer callback
node._timer_callback()
# Just verify it doesn't crash; actual publishing is tested via integration
def test_watchdog_status_json_all_healthy(self, node):
"""Test watchdog status JSON when all nodes healthy."""
# Record all heartbeats
now = node.get_clock().now()
for node_name in node.critical_nodes:
node.last_heartbeat_time[node_name] = now
health = node._check_node_health()
watchdog_status = {
"timestamp": now.nanoseconds / 1e9,
"all_healthy": all(
s["status"] == "up" for s in health.values()
),
"health": health,
"motor_driver_critical": False,
}
# Verify it's valid JSON
json_str = json.dumps(watchdog_status)
parsed = json.loads(json_str)
assert parsed["all_healthy"] is True
assert parsed["motor_driver_critical"] is False
def test_watchdog_status_json_with_timeout(self, node):
"""Test watchdog status JSON when node has timed out."""
# Balance heartbeat >1s ago
now = node.get_clock().now()
old_time = Time(nanoseconds=now.nanoseconds - int(1.5 * 1e9))
node.last_heartbeat_time["balance"] = old_time
# Others are current
for name in ["motor_driver", "emergency", "docking"]:
node.last_heartbeat_time[name] = now
health = node._check_node_health()
watchdog_status = {
"timestamp": now.nanoseconds / 1e9,
"all_healthy": all(s["status"] == "up" for s in health.values()),
"health": health,
"motor_driver_critical": False,
}
json_str = json.dumps(watchdog_status)
parsed = json.loads(json_str)
assert parsed["all_healthy"] is False
assert parsed["health"]["balance"]["status"] == "down"
class TestNodeWatchdogScenarios:
"""Integration-style tests for realistic scenarios."""
def test_scenario_all_nodes_healthy(self, node):
"""Scenario: all critical nodes sending heartbeats."""
now = node.get_clock().now()
# All nodes sending heartbeats
for name in node.critical_nodes:
node.last_heartbeat_time[name] = now
health = node._check_node_health()
all_up = all(h["status"] == "up" for h in health.values())
assert all_up is True
def test_scenario_motor_driver_loss_below_critical(self, node):
"""Scenario: motor driver offline 1.5s (below 2s critical)."""
now = node.get_clock().now()
old_time = Time(nanoseconds=now.nanoseconds - int(1.5 * 1e9))
# Motor driver down 1.5s, others healthy
node.last_heartbeat_time["motor_driver"] = old_time
for name in ["balance", "emergency", "docking"]:
node.last_heartbeat_time[name] = now
health = node._check_node_health()
motor = health["motor_driver"]
assert motor["status"] == "down"
assert motor.get("critical", False) is False
# Safety fallback should NOT be triggered yet
assert node.motor_driver_down is False
def test_scenario_motor_driver_critical_loss(self, node):
"""Scenario: motor driver offline >2s (triggers critical)."""
now = node.get_clock().now()
old_time = Time(nanoseconds=now.nanoseconds - int(2.5 * 1e9))
node.last_heartbeat_time["motor_driver"] = old_time
node.last_heartbeat_time["balance"] = now
node.last_heartbeat_time["emergency"] = now
node.last_heartbeat_time["docking"] = now
health = node._check_node_health()
motor = health["motor_driver"]
assert motor["status"] == "down"
assert motor.get("critical", False) is True
def test_scenario_cascading_node_failures(self, node):
"""Scenario: multiple nodes failing in sequence."""
now = node.get_clock().now()
old1s = Time(nanoseconds=now.nanoseconds - int(1.2 * 1e9))
old2s = Time(nanoseconds=now.nanoseconds - int(2.5 * 1e9))
# Balance down 1.2s, motor driver down 2.5s, others healthy
node.last_heartbeat_time["balance"] = old1s
node.last_heartbeat_time["motor_driver"] = old2s
node.last_heartbeat_time["emergency"] = now
node.last_heartbeat_time["docking"] = now
health = node._check_node_health()
assert health["balance"]["status"] == "down"
assert health["balance"].get("critical", False) is False
assert health["motor_driver"]["status"] == "down"
assert health["motor_driver"].get("critical", False) is True
def test_scenario_cmd_vel_safety_fallback(self, node):
"""Scenario: motor driver down triggers safety zeroing of cmd_vel."""
# Motor driver is critically down
node.motor_driver_down = True
cmd = Twist()
cmd.linear.x = 2.0
cmd.angular.z = 0.5
safe_cmd = node._apply_safety_checks(cmd)
# All should be zeroed
assert safe_cmd.linear.x == 0.0
assert safe_cmd.linear.y == 0.0
assert safe_cmd.linear.z == 0.0
assert safe_cmd.angular.x == 0.0
assert safe_cmd.angular.y == 0.0
assert safe_cmd.angular.z == 0.0
def test_scenario_motor_driver_recovery(self, node):
"""Scenario: motor driver comes back online after being down."""
now = node.get_clock().now()
# Motor driver was down
node.motor_driver_down = True
# Motor driver sends heartbeat
node._on_motor_driver_heartbeat(UInt32(data=1))
# Should clear the down flag
assert node.motor_driver_down is False
# cmd_vel should pass through
cmd = Twist()
cmd.linear.x = 1.0
safe_cmd = node._apply_safety_checks(cmd)
assert safe_cmd.linear.x == 1.0

View File

@ -1,6 +0,0 @@
person_reid:
ros__parameters:
model_path: '' # path to MobileNetV2+projection ONNX file (empty = histogram fallback)
match_threshold: 0.75 # cosine similarity threshold for re-ID match
max_identity_age_s: 300.0 # seconds before unseen identity is pruned
publish_hz: 5.0 # publication rate (Hz)

View File

@ -1,28 +0,0 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_person_reid</name>
<version>0.1.0</version>
<description>
Person re-identification node — cross-camera appearance matching using
MobileNetV2 ONNX embeddings (128-dim, cosine similarity gallery).
</description>
<maintainer email="robot@saltylab.local">SaltyLab</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>sensor_msgs</depend>
<depend>vision_msgs</depend>
<depend>cv_bridge</depend>
<depend>message_filters</depend>
<depend>saltybot_person_reid_msgs</depend>
<exec_depend>python3-numpy</exec_depend>
<exec_depend>python3-opencv</exec_depend>
<test_depend>pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -1,95 +0,0 @@
"""
_embedding_model.py Appearance embedding extractor (no ROS2 deps).
Primary: MobileNetV2 ONNX torso crop 128-dim L2-normalised embedding.
Fallback: 128-bin HSV histogram (H:16 × S:8) when no model file is available.
"""
from __future__ import annotations
import numpy as np
import cv2
# Top fraction of the bounding box height used as torso crop
_INPUT_SIZE = (128, 256) # (W, H) fed to MobileNetV2
class EmbeddingModel:
"""
Extract a 128-dim L2-normalised appearance embedding from a BGR crop.
Parameters
----------
model_path : str or None
Path to a MobileNetV2+projection ONNX file. When None (or file
not found), falls back to a 128-bin HSV colour histogram.
"""
def __init__(self, model_path: str | None = None):
self._net = None
if model_path:
try:
self._net = cv2.dnn.readNetFromONNX(model_path)
except Exception:
pass # histogram fallback
def embed(self, bgr_crop: np.ndarray) -> np.ndarray:
"""
Parameters
----------
bgr_crop : np.ndarray shape (H, W, 3) uint8
Returns
-------
np.ndarray shape (128,) float32, L2-normalised
"""
if bgr_crop.size == 0:
return np.zeros(128, dtype=np.float32)
if self._net is not None:
return self._mobilenet_embed(bgr_crop)
return self._histogram_embed(bgr_crop)
# ── MobileNetV2 path ──────────────────────────────────────────────────────
def _mobilenet_embed(self, bgr: np.ndarray) -> np.ndarray:
resized = cv2.resize(bgr, _INPUT_SIZE)
blob = cv2.dnn.blobFromImage(
resized,
scalefactor=1.0 / 255.0,
size=_INPUT_SIZE,
mean=(0.485 * 255, 0.456 * 255, 0.406 * 255),
swapRB=True,
crop=False,
)
# Std normalisation: divide channel-wise
blob[:, 0] /= 0.229
blob[:, 1] /= 0.224
blob[:, 2] /= 0.225
self._net.setInput(blob)
feat = self._net.forward().flatten().astype(np.float32)
# Ensure 128-dim — average-pool if model output differs
if feat.shape[0] != 128:
n = feat.shape[0]
block = max(1, n // 128)
feat = feat[: block * 128].reshape(128, block).mean(axis=1)
return _l2_norm(feat)
# ── HSV histogram fallback ────────────────────────────────────────────────
def _histogram_embed(self, bgr: np.ndarray) -> np.ndarray:
"""128-bin HSV histogram: 16 H-bins × 8 S-bins, concatenated."""
hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
hist = cv2.calcHist(
[hsv], [0, 1], None,
[16, 8], [0, 180, 0, 256],
).flatten().astype(np.float32)
return _l2_norm(hist)
def _l2_norm(v: np.ndarray) -> np.ndarray:
n = float(np.linalg.norm(v))
return v / n if n > 1e-6 else v

View File

@ -1,105 +0,0 @@
"""
_reid_gallery.py Appearance gallery for person re-identification (no ROS2 deps).
Matches an incoming embedding against stored identities using cosine similarity.
New identities are created when the best match falls below the threshold.
"""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from typing import List, Tuple
import numpy as np
@dataclass
class Identity:
identity_id: int
embedding: np.ndarray # shape (D,) L2-normalised
last_seen: float = field(default_factory=time.monotonic)
hit_count: int = 1
def update(self, new_embedding: np.ndarray, alpha: float = 0.1) -> None:
"""EMA update of the stored embedding, re-normalised after blending."""
merged = (1.0 - alpha) * self.embedding + alpha * new_embedding
n = float(np.linalg.norm(merged))
self.embedding = merged / n if n > 1e-6 else merged
self.last_seen = time.monotonic()
self.hit_count += 1
class ReidGallery:
"""
Lightweight cosine-similarity re-ID gallery.
Parameters
----------
match_threshold : float
Cosine similarity (dot product of unit vectors) required to accept a
match. Range [0, 1]; 0 = always new identity, 1 = perfect match only.
max_age_s : float
Identities not seen for this many seconds are pruned.
"""
def __init__(
self,
match_threshold: float = 0.75,
max_age_s: float = 300.0,
):
self._threshold = match_threshold
self._max_age_s = max_age_s
self._identities: List[Identity] = []
self._next_id = 1
def match(self, embedding: np.ndarray) -> Tuple[int, float, bool]:
"""
Match embedding against the gallery.
Returns
-------
(identity_id, match_score, is_new)
identity_id : assigned ID (new or existing)
match_score : cosine similarity to best match (0.0 if new)
is_new : True if a new identity was created
"""
self._prune()
if not self._identities:
return self._add_identity(embedding)
scores = np.array(
[float(np.dot(embedding, ident.embedding)) for ident in self._identities]
)
best_idx = int(np.argmax(scores))
best_score = float(scores[best_idx])
if best_score >= self._threshold:
ident = self._identities[best_idx]
ident.update(embedding)
return ident.identity_id, best_score, False
return self._add_identity(embedding)
# ── Internal helpers ──────────────────────────────────────────────────────
def _add_identity(self, embedding: np.ndarray) -> Tuple[int, float, bool]:
new_id = self._next_id
self._next_id += 1
self._identities.append(
Identity(identity_id=new_id, embedding=embedding.copy())
)
return new_id, 0.0, True
def _prune(self) -> None:
now = time.monotonic()
self._identities = [
ident
for ident in self._identities
if now - ident.last_seen < self._max_age_s
]
@property
def size(self) -> int:
return len(self._identities)

View File

@ -1,174 +0,0 @@
"""
person_reid_node.py Person re-identification for cross-camera tracking.
Subscribes to:
/person/detections vision_msgs/Detection2DArray (person bounding boxes)
/camera/color/image_raw sensor_msgs/Image (colour frame for crops)
Publishes:
/saltybot/person_reid saltybot_person_reid_msgs/PersonAppearanceArray (5 Hz)
For each detected person the node:
1. Crops the torso region (top 65 % of the bounding box height).
2. Extracts a 128-dim L2-normalised embedding via MobileNetV2 ONNX (if the
model file is provided) or a 128-bin HSV colour histogram (fallback).
3. Matches against a cosine-similarity gallery.
4. Assigns a persistent identity_id (new or existing).
Parameters:
model_path str '' Path to MobileNetV2+projection ONNX file
match_threshold float 0.75 Cosine similarity threshold for matching
max_identity_age_s float 300.0 Seconds before an unseen identity is pruned
publish_hz float 5.0 Publication rate (Hz)
"""
from __future__ import annotations
from typing import List
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
import message_filters
import cv2
import numpy as np
from cv_bridge import CvBridge
from sensor_msgs.msg import Image
from vision_msgs.msg import Detection2DArray
from saltybot_person_reid_msgs.msg import PersonAppearance, PersonAppearanceArray
from ._embedding_model import EmbeddingModel
from ._reid_gallery import ReidGallery
# Fraction of bbox height kept as torso crop (top portion)
_TORSO_FRAC = 0.65
_BEST_EFFORT_QOS = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=4,
)
class PersonReidNode(Node):
def __init__(self):
super().__init__('person_reid')
self.declare_parameter('model_path', '')
self.declare_parameter('match_threshold', 0.75)
self.declare_parameter('max_identity_age_s', 300.0)
self.declare_parameter('publish_hz', 5.0)
model_path = self.get_parameter('model_path').value
match_thr = self.get_parameter('match_threshold').value
max_age = self.get_parameter('max_identity_age_s').value
publish_hz = self.get_parameter('publish_hz').value
self._bridge = CvBridge()
self._embedder = EmbeddingModel(model_path or None)
self._gallery = ReidGallery(match_threshold=match_thr, max_age_s=max_age)
# Buffer: updated by frame callback, drained by timer
self._pending: List[PersonAppearance] = []
self._pending_header = None
# Synchronized subscribers
det_sub = message_filters.Subscriber(
self, Detection2DArray, '/person/detections',
qos_profile=_BEST_EFFORT_QOS)
img_sub = message_filters.Subscriber(
self, Image, '/camera/color/image_raw',
qos_profile=_BEST_EFFORT_QOS)
self._sync = message_filters.ApproximateTimeSynchronizer(
[det_sub, img_sub], queue_size=4, slop=0.1)
self._sync.registerCallback(self._on_frame)
self._pub = self.create_publisher(
PersonAppearanceArray, '/saltybot/person_reid', 10)
self.create_timer(1.0 / publish_hz, self._tick)
backend = 'ONNX' if self._embedder._net else 'histogram'
self.get_logger().info(
f'person_reid ready — backend={backend} '
f'threshold={match_thr} max_age={max_age}s'
)
# ── Frame callback ─────────────────────────────────────────────────────────
def _on_frame(self, det_msg: Detection2DArray, img_msg: Image) -> None:
if not det_msg.detections:
self._pending = []
self._pending_header = det_msg.header
return
try:
bgr = self._bridge.imgmsg_to_cv2(img_msg, desired_encoding='bgr8')
except Exception as exc:
self.get_logger().error(
f'imgmsg_to_cv2 failed: {exc}', throttle_duration_sec=5.0)
return
h_img, w_img = bgr.shape[:2]
appearances: List[PersonAppearance] = []
for det in det_msg.detections:
cx = det.bbox.center.position.x
cy = det.bbox.center.position.y
bw = det.bbox.size_x
bh = det.bbox.size_y
conf = det.results[0].hypothesis.score if det.results else 0.0
# Torso crop: top TORSO_FRAC of bounding box
x1 = max(0, int(cx - bw / 2.0))
y1 = max(0, int(cy - bh / 2.0))
x2 = min(w_img, int(cx + bw / 2.0))
y2 = min(h_img, int(cy - bh / 2.0 + bh * _TORSO_FRAC))
if x2 - x1 < 8 or y2 - y1 < 8:
continue
crop = bgr[y1:y2, x1:x2]
emb = self._embedder.embed(crop)
identity_id, match_score, is_new = self._gallery.match(emb)
app = PersonAppearance()
app.header = det_msg.header
app.track_id = identity_id
app.embedding = emb.tolist()
app.bbox = det.bbox
app.confidence = float(conf)
app.match_score = float(match_score)
app.is_new_identity = is_new
appearances.append(app)
self._pending = appearances
self._pending_header = det_msg.header
# ── 5 Hz publish tick ─────────────────────────────────────────────────────
def _tick(self) -> None:
if self._pending_header is None:
return
msg = PersonAppearanceArray()
msg.header = self._pending_header
msg.appearances = self._pending
self._pub.publish(msg)
def main(args=None):
rclpy.init(args=args)
node = PersonReidNode()
try:
rclpy.spin(node)
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -1,4 +0,0 @@
[develop]
script_dir=$base/lib/saltybot_person_reid
[install]
install_scripts=$base/lib/saltybot_person_reid

View File

@ -1,29 +0,0 @@
from setuptools import setup, find_packages
from glob import glob
package_name = 'saltybot_person_reid'
setup(
name=package_name,
version='0.1.0',
packages=find_packages(exclude=['test']),
data_files=[
('share/ament_index/resource_index/packages',
['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
('share/' + package_name + '/config',
glob('config/*.yaml')),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='SaltyLab',
maintainer_email='robot@saltylab.local',
description='Person re-identification — cross-camera appearance matching',
license='MIT',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'person_reid = saltybot_person_reid.person_reid_node:main',
],
},
)

View File

@ -1,163 +0,0 @@
"""
test_person_reid.py Unit tests for person re-ID helpers (no ROS2 required).
Covers:
- _l2_norm helper
- EmbeddingModel (histogram fallback no model file needed)
- ReidGallery cosine-similarity matching
"""
import sys
import os
import time
import numpy as np
import pytest
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from saltybot_person_reid._embedding_model import EmbeddingModel, _l2_norm
from saltybot_person_reid._reid_gallery import ReidGallery
# ── _l2_norm ──────────────────────────────────────────────────────────────────
class TestL2Norm:
def test_unit_vector_unchanged(self):
v = np.array([1.0, 0.0, 0.0], dtype=np.float32)
assert np.allclose(_l2_norm(v), v)
def test_normalised_to_unit_norm(self):
v = np.array([3.0, 4.0], dtype=np.float32)
assert abs(np.linalg.norm(_l2_norm(v)) - 1.0) < 1e-6
def test_zero_vector_does_not_crash(self):
v = np.zeros(4, dtype=np.float32)
result = _l2_norm(v)
assert result.shape == (4,)
# ── EmbeddingModel ────────────────────────────────────────────────────────────
class TestEmbeddingModel:
def test_histogram_fallback_shape(self):
m = EmbeddingModel(model_path=None)
bgr = np.random.randint(0, 255, (100, 50, 3), dtype=np.uint8)
emb = m.embed(bgr)
assert emb.shape == (128,)
def test_embedding_is_unit_norm(self):
m = EmbeddingModel(model_path=None)
bgr = np.random.randint(0, 255, (80, 40, 3), dtype=np.uint8)
emb = m.embed(bgr)
assert abs(np.linalg.norm(emb) - 1.0) < 1e-5
def test_empty_crop_returns_zero_vector(self):
m = EmbeddingModel(model_path=None)
emb = m.embed(np.zeros((0, 0, 3), dtype=np.uint8))
assert emb.shape == (128,)
assert np.all(emb == 0.0)
def test_red_and_blue_crops_differ(self):
m = EmbeddingModel(model_path=None)
red = np.full((80, 40, 3), (0, 0, 200), dtype=np.uint8)
blue = np.full((80, 40, 3), (200, 0, 0), dtype=np.uint8)
sim = float(np.dot(m.embed(red), m.embed(blue)))
assert sim < 0.99
def test_same_crop_deterministic(self):
m = EmbeddingModel(model_path=None)
bgr = np.random.randint(0, 255, (80, 40, 3), dtype=np.uint8)
assert np.allclose(m.embed(bgr), m.embed(bgr))
def test_embedding_float32(self):
m = EmbeddingModel(model_path=None)
bgr = np.random.randint(0, 255, (60, 30, 3), dtype=np.uint8)
emb = m.embed(bgr)
assert emb.dtype == np.float32
# ── ReidGallery ───────────────────────────────────────────────────────────────
def _unit(dim: int = 128, seed: int | None = None) -> np.ndarray:
rng = np.random.default_rng(seed)
v = rng.standard_normal(dim).astype(np.float32)
return v / np.linalg.norm(v)
class TestReidGallery:
def test_first_match_creates_identity(self):
g = ReidGallery(match_threshold=0.75)
uid, score, is_new = g.match(_unit(seed=0))
assert uid == 1
assert is_new is True
assert score == pytest.approx(0.0)
def test_identical_embedding_matches(self):
g = ReidGallery(match_threshold=0.75)
emb = _unit(seed=1)
g.match(emb)
uid2, score2, is_new2 = g.match(emb)
assert uid2 == 1
assert is_new2 is False
assert score2 > 0.99
def test_orthogonal_embeddings_create_new_id(self):
g = ReidGallery(match_threshold=0.75)
e1 = np.zeros(128, dtype=np.float32); e1[0] = 1.0
e2 = np.zeros(128, dtype=np.float32); e2[64] = 1.0
uid1, _, new1 = g.match(e1)
uid2, _, new2 = g.match(e2)
assert uid1 != uid2
assert new2 is True
def test_ids_are_monotonically_increasing(self):
# threshold > 1.0 is unreachable → every embedding creates a new identity
g = ReidGallery(match_threshold=2.0)
ids = [g.match(_unit(seed=i))[0] for i in range(5)]
assert ids == list(range(1, 6))
def test_gallery_size_increments_for_new_ids(self):
g = ReidGallery(match_threshold=2.0)
for i in range(4):
g.match(_unit(seed=i))
assert g.size == 4
def test_prune_removes_stale_identities(self):
g = ReidGallery(match_threshold=0.75, max_age_s=0.01)
g.match(_unit(seed=0))
time.sleep(0.05)
g._prune()
assert g.size == 0
def test_empty_gallery_prune_is_safe(self):
g = ReidGallery()
g._prune()
assert g.size == 0
def test_match_below_threshold_increments_id(self):
g = ReidGallery(match_threshold=0.99)
# Two random unit vectors are almost certainly < 0.99 similar
e1, e2 = _unit(seed=10), _unit(seed=20)
uid1, _, _ = g.match(e1)
uid2, _, _ = g.match(e2)
# uid2 may or may not equal uid1 depending on random similarity,
# but both must be valid positive integers
assert uid1 >= 1
assert uid2 >= 1
def test_identity_update_does_not_change_id(self):
g = ReidGallery(match_threshold=0.5)
emb = _unit(seed=5)
uid_first, _, _ = g.match(emb)
for _ in range(10):
g.match(emb)
uid_last, _, _ = g.match(emb)
assert uid_last == uid_first
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@ -1,16 +0,0 @@
cmake_minimum_required(VERSION 3.8)
project(saltybot_person_reid_msgs)
find_package(ament_cmake REQUIRED)
find_package(rosidl_default_generators REQUIRED)
find_package(std_msgs REQUIRED)
find_package(vision_msgs REQUIRED)
rosidl_generate_interfaces(${PROJECT_NAME}
"msg/PersonAppearance.msg"
"msg/PersonAppearanceArray.msg"
DEPENDENCIES std_msgs vision_msgs
)
ament_export_dependencies(rosidl_default_runtime)
ament_package()

View File

@ -1,7 +0,0 @@
std_msgs/Header header
uint32 track_id
float32[] embedding
vision_msgs/BoundingBox2D bbox
float32 confidence
float32 match_score
bool is_new_identity

View File

@ -1,2 +0,0 @@
std_msgs/Header header
PersonAppearance[] appearances

View File

@ -1,22 +0,0 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_person_reid_msgs</name>
<version>0.1.0</version>
<description>Message types for person re-identification.</description>
<maintainer email="robot@saltylab.local">SaltyLab</maintainer>
<license>MIT</license>
<buildtool_depend>ament_cmake</buildtool_depend>
<buildtool_depend>rosidl_default_generators</buildtool_depend>
<depend>std_msgs</depend>
<depend>vision_msgs</depend>
<exec_depend>rosidl_default_runtime</exec_depend>
<member_of_group>rosidl_interface_packages</member_of_group>
<export>
<build_type>ament_cmake</build_type>
</export>
</package>

View File

@ -1,6 +0,0 @@
thermal_node:
ros__parameters:
publish_rate_hz: 1.0 # Hz — publish rate for /saltybot/thermal
warn_temp_c: 75.0 # Log WARN above this temperature (°C)
throttle_temp_c: 85.0 # Log ERROR + set throttled=true above this (°C)
thermal_root: "/sys/class/thermal" # Sysfs thermal root; override for tests

View File

@ -1,42 +0,0 @@
"""thermal.launch.py — Launch the Jetson thermal monitor (Issue #205).
Usage:
ros2 launch saltybot_thermal thermal.launch.py
ros2 launch saltybot_thermal thermal.launch.py warn_temp_c:=70.0
"""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
pkg = get_package_share_directory("saltybot_thermal")
cfg = os.path.join(pkg, "config", "thermal_params.yaml")
return LaunchDescription([
DeclareLaunchArgument("publish_rate_hz", default_value="1.0",
description="Publish rate (Hz)"),
DeclareLaunchArgument("warn_temp_c", default_value="75.0",
description="WARN threshold (°C)"),
DeclareLaunchArgument("throttle_temp_c", default_value="85.0",
description="THROTTLE threshold (°C)"),
Node(
package="saltybot_thermal",
executable="thermal_node",
name="thermal_node",
output="screen",
parameters=[
cfg,
{
"publish_rate_hz": LaunchConfiguration("publish_rate_hz"),
"warn_temp_c": LaunchConfiguration("warn_temp_c"),
"throttle_temp_c": LaunchConfiguration("throttle_temp_c"),
},
],
),
])

View File

@ -1,139 +0,0 @@
"""thermal_node.py — Jetson CPU/GPU thermal monitor.
Issue #205
Reads every /sys/class/thermal/thermal_zone* sysfs entry, publishes a JSON
blob on /saltybot/thermal at a configurable rate (default 1 Hz), and logs
ROS2 WARN / ERROR when zone temperatures exceed configurable thresholds.
Published topic:
/saltybot/thermal (std_msgs/String, JSON)
JSON schema:
{
"ts": <float unix seconds>,
"zones": [
{"zone": "CPU-therm", "index": 0, "temp_c": 42.5},
...
],
"max_temp_c": 55.0,
"throttled": false,
"warn": false
}
Parameters:
publish_rate_hz (float, 1.0) publish rate
warn_temp_c (float, 75.0) log WARN above this temperature
throttle_temp_c (float, 85.0) log ERROR and set throttled=true above this
thermal_root (str, "/sys/class/thermal") sysfs thermal root (override for tests)
"""
from __future__ import annotations
import json
import os
import time
from typing import List, Optional
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile
from std_msgs.msg import String
def read_thermal_zones(root: str) -> List[dict]:
"""Return a list of {zone, index, temp_c} dicts from sysfs."""
zones = []
try:
entries = sorted(os.listdir(root))
except OSError:
return zones
for entry in entries:
if not entry.startswith("thermal_zone"):
continue
try:
idx = int(entry[len("thermal_zone"):])
except ValueError:
continue
zone_dir = os.path.join(root, entry)
try:
with open(os.path.join(zone_dir, "type")) as f:
zone_type = f.read().strip()
except OSError:
zone_type = entry
try:
with open(os.path.join(zone_dir, "temp")) as f:
temp_mc = int(f.read().strip()) # millidegrees Celsius
temp_c = round(temp_mc / 1000.0, 1)
except (OSError, ValueError):
continue
zones.append({"zone": zone_type, "index": idx, "temp_c": temp_c})
return zones
class ThermalNode(Node):
"""Reads Jetson thermal zones and publishes /saltybot/thermal at 1 Hz."""
def __init__(self) -> None:
super().__init__("thermal_node")
self.declare_parameter("publish_rate_hz", 1.0)
self.declare_parameter("warn_temp_c", 75.0)
self.declare_parameter("throttle_temp_c", 85.0)
self.declare_parameter("thermal_root", "/sys/class/thermal")
self._rate = self.get_parameter("publish_rate_hz").value
self._warn_t = self.get_parameter("warn_temp_c").value
self._throttle_t = self.get_parameter("throttle_temp_c").value
self._root = self.get_parameter("thermal_root").value
qos = QoSProfile(depth=10)
self._pub = self.create_publisher(String, "/saltybot/thermal", qos)
self._timer = self.create_timer(1.0 / self._rate, self._publish)
self.get_logger().info(
f"ThermalNode ready (rate={self._rate} Hz, "
f"warn={self._warn_t}°C, throttle={self._throttle_t}°C, "
f"root={self._root})"
)
def _publish(self) -> None:
zones = read_thermal_zones(self._root)
if not zones:
self.get_logger().warn("No thermal zones found — check thermal_root param")
return
max_temp = max(z["temp_c"] for z in zones)
throttled = max_temp >= self._throttle_t
warn = max_temp >= self._warn_t
payload = {
"ts": time.time(),
"zones": zones,
"max_temp_c": max_temp,
"throttled": throttled,
"warn": warn,
}
msg = String()
msg.data = json.dumps(payload)
self._pub.publish(msg)
if throttled:
self.get_logger().error(
f"THERMAL THROTTLE: {max_temp}°C >= {self._throttle_t}°C"
)
elif warn:
self.get_logger().warn(
f"Thermal warning: {max_temp}°C >= {self._warn_t}°C"
)
def main(args: Optional[list] = None) -> None:
rclpy.init(args=args)
node = ThermalNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@ -1,4 +0,0 @@
[develop]
script_dir=$base/lib/saltybot_thermal
[egg_info]
tag_date = 0

View File

@ -1,27 +0,0 @@
from setuptools import setup
package_name = "saltybot_thermal"
setup(
name=package_name,
version="0.1.0",
packages=[package_name],
data_files=[
("share/ament_index/resource_index/packages", [f"resource/{package_name}"]),
(f"share/{package_name}", ["package.xml"]),
(f"share/{package_name}/launch", ["launch/thermal.launch.py"]),
(f"share/{package_name}/config", ["config/thermal_params.yaml"]),
],
install_requires=["setuptools"],
zip_safe=True,
maintainer="sl-jetson",
maintainer_email="sl-jetson@saltylab.local",
description="Jetson thermal monitor — /saltybot/thermal JSON at 1 Hz",
license="MIT",
tests_require=["pytest"],
entry_points={
"console_scripts": [
"thermal_node = saltybot_thermal.thermal_node:main",
],
},
)

View File

@ -1,303 +0,0 @@
"""test_thermal.py -- Unit tests for Issue #205 Jetson thermal monitor."""
from __future__ import annotations
import json, os, time
import pytest
def _pkg_root():
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def _read_src(rel_path):
with open(os.path.join(_pkg_root(), rel_path)) as f:
return f.read()
# ── Import the sysfs reader (no ROS required) ─────────────────────────────────
def _import_reader():
import importlib.util, sys, types
# Build minimal ROS2 stubs so thermal_node.py imports without a ROS install
def _stub(name):
m = types.ModuleType(name)
sys.modules[name] = m
return m
rclpy_mod = _stub("rclpy")
rclpy_node_mod = _stub("rclpy.node")
rclpy_qos_mod = _stub("rclpy.qos")
std_msgs_mod = _stub("std_msgs")
std_msg_mod = _stub("std_msgs.msg")
class _Node:
def __init__(self, *a, **kw): pass
def declare_parameter(self, *a, **kw): pass
def get_parameter(self, name):
class _P:
value = None
return _P()
def create_publisher(self, *a, **kw): return None
def create_timer(self, *a, **kw): return None
def get_logger(self):
class _L:
def info(self, *a): pass
def warn(self, *a): pass
def error(self, *a): pass
return _L()
def destroy_node(self): pass
class _QoSProfile:
def __init__(self, **kw): pass
class _String:
data = ""
rclpy_node_mod.Node = _Node
rclpy_qos_mod.QoSProfile = _QoSProfile
std_msg_mod.String = _String
rclpy_mod.init = lambda *a, **kw: None
rclpy_mod.spin = lambda node: None
rclpy_mod.ok = lambda: True
rclpy_mod.shutdown = lambda: None
spec = importlib.util.spec_from_file_location(
"thermal_node_testmod",
os.path.join(_pkg_root(), "saltybot_thermal", "thermal_node.py"),
)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
# ── Sysfs fixture helpers ─────────────────────────────────────────────────────
def _make_zone(root, idx, zone_type, temp_mc):
"""Create a fake thermal_zone<idx> directory under root."""
zdir = os.path.join(str(root), "thermal_zone{}".format(idx))
os.makedirs(zdir, exist_ok=True)
with open(os.path.join(zdir, "type"), "w") as f:
f.write(zone_type + "\n")
with open(os.path.join(zdir, "temp"), "w") as f:
f.write(str(temp_mc) + "\n")
# ── read_thermal_zones ────────────────────────────────────────────────────────
class TestReadThermalZones:
@pytest.fixture(scope="class")
def mod(self):
return _import_reader()
def test_empty_dir(self, mod, tmp_path):
assert mod.read_thermal_zones(str(tmp_path)) == []
def test_missing_dir(self, mod):
assert mod.read_thermal_zones("/nonexistent/path/xyz") == []
def test_single_zone(self, mod, tmp_path):
_make_zone(tmp_path, 0, "CPU-therm", 45000)
zones = mod.read_thermal_zones(str(tmp_path))
assert len(zones) == 1
assert zones[0]["zone"] == "CPU-therm"
assert zones[0]["temp_c"] == 45.0
assert zones[0]["index"] == 0
def test_temp_millidegrees_conversion(self, mod, tmp_path):
_make_zone(tmp_path, 0, "GPU-therm", 72500)
zones = mod.read_thermal_zones(str(tmp_path))
assert zones[0]["temp_c"] == 72.5
def test_multiple_zones(self, mod, tmp_path):
_make_zone(tmp_path, 0, "CPU-therm", 40000)
_make_zone(tmp_path, 1, "GPU-therm", 55000)
_make_zone(tmp_path, 2, "PMIC-Die", 38000)
zones = mod.read_thermal_zones(str(tmp_path))
assert len(zones) == 3
def test_sorted_by_index(self, mod, tmp_path):
_make_zone(tmp_path, 2, "Z2", 20000)
_make_zone(tmp_path, 0, "Z0", 10000)
_make_zone(tmp_path, 1, "Z1", 15000)
zones = mod.read_thermal_zones(str(tmp_path))
indices = [z["index"] for z in zones]
assert indices == sorted(indices)
def test_skips_non_zone_entries(self, mod, tmp_path):
os.makedirs(os.path.join(str(tmp_path), "cooling_device0"))
_make_zone(tmp_path, 0, "CPU-therm", 40000)
zones = mod.read_thermal_zones(str(tmp_path))
assert len(zones) == 1
def test_skips_zone_without_temp(self, mod, tmp_path):
zdir = os.path.join(str(tmp_path), "thermal_zone0")
os.makedirs(zdir)
with open(os.path.join(zdir, "type"), "w") as f:
f.write("CPU-therm\n")
# No temp file — should be skipped
zones = mod.read_thermal_zones(str(tmp_path))
assert zones == []
def test_zone_type_fallback(self, mod, tmp_path):
"""Zone without type file falls back to directory name."""
zdir = os.path.join(str(tmp_path), "thermal_zone0")
os.makedirs(zdir)
with open(os.path.join(zdir, "temp"), "w") as f:
f.write("40000\n")
zones = mod.read_thermal_zones(str(tmp_path))
assert len(zones) == 1
assert zones[0]["zone"] == "thermal_zone0"
def test_temp_rounding(self, mod, tmp_path):
_make_zone(tmp_path, 0, "CPU-therm", 72333)
zones = mod.read_thermal_zones(str(tmp_path))
assert zones[0]["temp_c"] == 72.3
# ── Threshold logic (pure Python) ────────────────────────────────────────────
class TestThresholds:
def _classify(self, temp_c, warn_t=75.0, throttle_t=85.0):
throttled = temp_c >= throttle_t
warn = temp_c >= warn_t
return throttled, warn
def test_normal(self):
t, w = self._classify(50.0)
assert not t and not w
def test_warn_boundary(self):
t, w = self._classify(75.0)
assert not t and w
def test_below_warn(self):
t, w = self._classify(74.9)
assert not t and not w
def test_throttle_boundary(self):
t, w = self._classify(85.0)
assert t and w
def test_above_throttle(self):
t, w = self._classify(90.0)
assert t and w
def test_custom_thresholds(self):
t, w = self._classify(70.0, warn_t=70.0, throttle_t=80.0)
assert not t and w
def test_max_temp_drives_status(self):
zones = [{"temp_c": 40.0}, {"temp_c": 86.0}, {"temp_c": 55.0}]
max_t = max(z["temp_c"] for z in zones)
assert max_t == 86.0
t, w = self._classify(max_t)
assert t and w
# ── JSON payload schema ───────────────────────────────────────────────────────
class TestJsonPayload:
def _make_payload(self, zones, warn_t=75.0, throttle_t=85.0):
max_temp = max(z["temp_c"] for z in zones) if zones else 0.0
return {
"ts": time.time(),
"zones": zones,
"max_temp_c": max_temp,
"throttled": max_temp >= throttle_t,
"warn": max_temp >= warn_t,
}
def test_has_ts(self):
p = self._make_payload([{"zone": "CPU", "index": 0, "temp_c": 40.0}])
assert "ts" in p and isinstance(p["ts"], float)
def test_has_zones(self):
p = self._make_payload([{"zone": "CPU", "index": 0, "temp_c": 40.0}])
assert "zones" in p and len(p["zones"]) == 1
def test_has_max_temp(self):
p = self._make_payload([{"zone": "CPU", "index": 0, "temp_c": 55.0}])
assert p["max_temp_c"] == 55.0
def test_throttled_false_below(self):
p = self._make_payload([{"zone": "CPU", "index": 0, "temp_c": 60.0}])
assert p["throttled"] is False
def test_warn_true_at_threshold(self):
p = self._make_payload([{"zone": "CPU", "index": 0, "temp_c": 75.0}])
assert p["warn"] is True and p["throttled"] is False
def test_throttled_true_above(self):
p = self._make_payload([{"zone": "CPU", "index": 0, "temp_c": 90.0}])
assert p["throttled"] is True
def test_json_serializable(self):
zones = [{"zone": "CPU", "index": 0, "temp_c": 50.0}]
p = self._make_payload(zones)
blob = json.dumps(p)
parsed = json.loads(blob)
assert parsed["max_temp_c"] == 50.0
def test_multi_zone_max(self):
zones = [
{"zone": "CPU-therm", "index": 0, "temp_c": 55.0},
{"zone": "GPU-therm", "index": 1, "temp_c": 78.0},
{"zone": "PMIC-Die", "index": 2, "temp_c": 38.0},
]
p = self._make_payload(zones)
assert p["max_temp_c"] == 78.0
assert p["warn"] is True
assert p["throttled"] is False
# ── Node source checks ────────────────────────────────────────────────────────
class TestNodeSrc:
@pytest.fixture(scope="class")
def src(self):
return _read_src("saltybot_thermal/thermal_node.py")
def test_class_defined(self, src): assert "class ThermalNode" in src
def test_publish_rate_param(self, src): assert '"publish_rate_hz"' in src
def test_warn_param(self, src): assert '"warn_temp_c"' in src
def test_throttle_param(self, src): assert '"throttle_temp_c"' in src
def test_thermal_root_param(self, src): assert '"thermal_root"' in src
def test_topic(self, src): assert '"/saltybot/thermal"' in src
def test_read_fn(self, src): assert "read_thermal_zones" in src
def test_warn_log(self, src): assert "warn" in src.lower()
def test_error_log(self, src): assert "error" in src.lower()
def test_throttled_flag(self, src): assert '"throttled"' in src
def test_warn_flag(self, src): assert '"warn"' in src
def test_max_temp(self, src): assert '"max_temp_c"' in src
def test_millidegrees(self, src): assert "1000" in src
def test_json_dumps(self, src): assert "json.dumps" in src
def test_issue_tag(self, src): assert "205" in src
def test_main(self, src): assert "def main" in src
def test_sysfs_path(self, src): assert "/sys/class/thermal" in src
# ── Package metadata ──────────────────────────────────────────────────────────
class TestPackageMeta:
@pytest.fixture(scope="class")
def pkg_xml(self):
return _read_src("package.xml")
@pytest.fixture(scope="class")
def setup_py(self):
return _read_src("setup.py")
@pytest.fixture(scope="class")
def cfg(self):
return _read_src("config/thermal_params.yaml")
def test_pkg_name(self, pkg_xml): assert "saltybot_thermal" in pkg_xml
def test_issue_tag(self, pkg_xml): assert "205" in pkg_xml
def test_entry_point(self, setup_py): assert "thermal_node = saltybot_thermal.thermal_node:main" in setup_py
def test_cfg_node_name(self, cfg): assert "thermal_node:" in cfg
def test_cfg_warn(self, cfg): assert "warn_temp_c" in cfg
def test_cfg_throttle(self, cfg): assert "throttle_temp_c" in cfg
def test_cfg_rate(self, cfg): assert "publish_rate_hz" in cfg
def test_cfg_defaults(self, cfg):
assert "75.0" in cfg and "85.0" in cfg and "1.0" in cfg