Compare commits
2 Commits
7eb3f187e2
...
b75ed30d7a
| Author | SHA1 | Date | |
|---|---|---|---|
| b75ed30d7a | |||
|
|
ec53f85c50 |
@ -0,0 +1,15 @@
|
||||
# Diagnostics Aggregator — Issue #658
|
||||
# Unified health dashboard aggregating telemetry from all SaltyBot subsystems.
|
||||
|
||||
diagnostics_aggregator:
|
||||
ros__parameters:
|
||||
# Publish rate for /saltybot/system_health (Hz)
|
||||
heartbeat_hz: 1.0
|
||||
|
||||
# Seconds without a topic update before a subsystem is marked STALE
|
||||
# Increase for sensors with lower publish rates (e.g. UWB at 5 Hz)
|
||||
stale_timeout_s: 5.0
|
||||
|
||||
# Maximum number of state transitions kept in the in-memory ring buffer
|
||||
# Last 10 transitions are included in each /saltybot/system_health publish
|
||||
transition_log_size: 50
|
||||
@ -0,0 +1,44 @@
|
||||
"""Launch file for diagnostics aggregator node."""
|
||||
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg_dir = get_package_share_directory("saltybot_diagnostics_aggregator")
|
||||
config_file = os.path.join(pkg_dir, "config", "aggregator_params.yaml")
|
||||
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument(
|
||||
"config_file",
|
||||
default_value=config_file,
|
||||
description="Path to aggregator_params.yaml",
|
||||
),
|
||||
DeclareLaunchArgument(
|
||||
"heartbeat_hz",
|
||||
default_value="1.0",
|
||||
description="Publish rate for /saltybot/system_health (Hz)",
|
||||
),
|
||||
DeclareLaunchArgument(
|
||||
"stale_timeout_s",
|
||||
default_value="5.0",
|
||||
description="Seconds without update before subsystem goes STALE",
|
||||
),
|
||||
Node(
|
||||
package="saltybot_diagnostics_aggregator",
|
||||
executable="diagnostics_aggregator_node",
|
||||
name="diagnostics_aggregator",
|
||||
output="screen",
|
||||
parameters=[
|
||||
LaunchConfiguration("config_file"),
|
||||
{
|
||||
"heartbeat_hz": LaunchConfiguration("heartbeat_hz"),
|
||||
"stale_timeout_s": LaunchConfiguration("stale_timeout_s"),
|
||||
},
|
||||
],
|
||||
),
|
||||
])
|
||||
@ -0,0 +1,30 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_diagnostics_aggregator</name>
|
||||
<version>0.1.0</version>
|
||||
<description>
|
||||
Diagnostics aggregator for SaltyBot — unified health dashboard node (Issue #658).
|
||||
Subscribes to /vesc/health, /diagnostics, /saltybot/safety_zone/status,
|
||||
/saltybot/pose/status, /saltybot/uwb/status. Aggregates into
|
||||
/saltybot/system_health JSON at 1 Hz. Tracks motors, battery, imu, uwb,
|
||||
lidar, camera, can_bus, estop subsystems with state-transition logging.
|
||||
</description>
|
||||
<maintainer email="sl-firmware@saltylab.local">sl-firmware</maintainer>
|
||||
<license>MIT</license>
|
||||
|
||||
<depend>rclpy</depend>
|
||||
<depend>std_msgs</depend>
|
||||
<depend>diagnostic_msgs</depend>
|
||||
|
||||
<buildtool_depend>ament_python</buildtool_depend>
|
||||
|
||||
<test_depend>ament_copyright</test_depend>
|
||||
<test_depend>ament_flake8</test_depend>
|
||||
<test_depend>ament_pep257</test_depend>
|
||||
<test_depend>python3-pytest</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1,312 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Diagnostics aggregator — unified health dashboard node (Issue #658).
|
||||
|
||||
Subscribes to telemetry and diagnostics topics from all subsystems, aggregates
|
||||
them into a single /saltybot/system_health JSON publish at 1 Hz.
|
||||
|
||||
Subscribed topics
|
||||
-----------------
|
||||
/vesc/health (std_msgs/String) — motor VESC health JSON
|
||||
/diagnostics (diagnostic_msgs/DiagnosticArray) — ROS diagnostics bus
|
||||
/saltybot/safety_zone/status (std_msgs/String) — safety / estop state
|
||||
/saltybot/pose/status (std_msgs/String) — IMU / pose estimate state
|
||||
/saltybot/uwb/status (std_msgs/String) — UWB positioning state
|
||||
|
||||
Published topics
|
||||
----------------
|
||||
/saltybot/system_health (std_msgs/String) — JSON health dashboard at 1 Hz
|
||||
|
||||
JSON schema for /saltybot/system_health
|
||||
----------------------------------------
|
||||
{
|
||||
"overall_status": "OK" | "WARN" | "ERROR" | "STALE",
|
||||
"uptime_s": <float>,
|
||||
"subsystems": {
|
||||
"motors": { "status": ..., "message": ..., "age_s": ..., "previous_status": ... },
|
||||
"battery": { ... },
|
||||
"imu": { ... },
|
||||
"uwb": { ... },
|
||||
"lidar": { ... },
|
||||
"camera": { ... },
|
||||
"can_bus": { ... },
|
||||
"estop": { ... }
|
||||
},
|
||||
"last_error": { "subsystem": ..., "message": ..., "timestamp": ... } | null,
|
||||
"recent_transitions": [ { "subsystem", "from_status", "to_status",
|
||||
"message", "timestamp_iso" }, ... ],
|
||||
"timestamp": "<ISO-8601>"
|
||||
}
|
||||
|
||||
Parameters
|
||||
----------
|
||||
heartbeat_hz float 1.0 Publish rate (Hz)
|
||||
stale_timeout_s float 5.0 Seconds without update → STALE
|
||||
transition_log_size int 50 Max recent transitions kept in memory
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from collections import deque
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from std_msgs.msg import String
|
||||
from diagnostic_msgs.msg import DiagnosticArray
|
||||
|
||||
from .subsystem import (
|
||||
SubsystemState,
|
||||
Transition,
|
||||
STATUS_OK,
|
||||
STATUS_WARN,
|
||||
STATUS_ERROR,
|
||||
STATUS_STALE,
|
||||
STATUS_UNKNOWN,
|
||||
worse,
|
||||
ros_level_to_status,
|
||||
SUBSYSTEM_NAMES as _SUBSYSTEM_NAMES,
|
||||
keyword_to_subsystem as _keyword_to_subsystem,
|
||||
)
|
||||
|
||||
|
||||
class DiagnosticsAggregatorNode(Node):
|
||||
"""Unified health dashboard node."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("diagnostics_aggregator")
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Parameters
|
||||
# ----------------------------------------------------------------
|
||||
self.declare_parameter("heartbeat_hz", 1.0)
|
||||
self.declare_parameter("stale_timeout_s", 5.0)
|
||||
self.declare_parameter("transition_log_size", 50)
|
||||
|
||||
hz = float(self.get_parameter("heartbeat_hz").value)
|
||||
stale_timeout = float(self.get_parameter("stale_timeout_s").value)
|
||||
log_size = int(self.get_parameter("transition_log_size").value)
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Subsystem state table
|
||||
# ----------------------------------------------------------------
|
||||
self._subsystems: dict[str, SubsystemState] = {
|
||||
name: SubsystemState(name=name, stale_timeout_s=stale_timeout)
|
||||
for name in _SUBSYSTEM_NAMES
|
||||
}
|
||||
self._transitions: deque[Transition] = deque(maxlen=log_size)
|
||||
self._last_error: Optional[dict] = None
|
||||
self._start_mono = time.monotonic()
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Subscriptions
|
||||
# ----------------------------------------------------------------
|
||||
self.create_subscription(String, "/vesc/health",
|
||||
self._on_vesc_health, 10)
|
||||
self.create_subscription(DiagnosticArray, "/diagnostics",
|
||||
self._on_diagnostics, 10)
|
||||
self.create_subscription(String, "/saltybot/safety_zone/status",
|
||||
self._on_safety_zone, 10)
|
||||
self.create_subscription(String, "/saltybot/pose/status",
|
||||
self._on_pose_status, 10)
|
||||
self.create_subscription(String, "/saltybot/uwb/status",
|
||||
self._on_uwb_status, 10)
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Publisher
|
||||
# ----------------------------------------------------------------
|
||||
self._pub = self.create_publisher(String, "/saltybot/system_health", 1)
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# 1 Hz heartbeat timer
|
||||
# ----------------------------------------------------------------
|
||||
self.create_timer(1.0 / max(0.1, hz), self._on_timer)
|
||||
|
||||
self.get_logger().info(
|
||||
f"diagnostics_aggregator: {len(_SUBSYSTEM_NAMES)} subsystems, "
|
||||
f"heartbeat={hz} Hz, stale_timeout={stale_timeout} s"
|
||||
)
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Topic callbacks
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
def _on_vesc_health(self, msg: String) -> None:
|
||||
"""Parse /vesc/health JSON → motors subsystem."""
|
||||
now = time.monotonic()
|
||||
try:
|
||||
d = json.loads(msg.data)
|
||||
fault = d.get("fault_code", 0)
|
||||
alive = d.get("alive", True)
|
||||
if not alive:
|
||||
status, message = STATUS_STALE, "VESC offline"
|
||||
elif fault != 0:
|
||||
status = STATUS_ERROR
|
||||
message = f"VESC fault {d.get('fault_name', fault)}"
|
||||
else:
|
||||
status = STATUS_OK
|
||||
message = (
|
||||
f"RPM={d.get('rpm', '?')} "
|
||||
f"I={d.get('current_a', '?')} A "
|
||||
f"V={d.get('voltage_v', '?')} V"
|
||||
)
|
||||
self._update("motors", status, message, now)
|
||||
except Exception as exc:
|
||||
self.get_logger().debug(f"/vesc/health parse error: {exc}")
|
||||
|
||||
def _on_diagnostics(self, msg: DiagnosticArray) -> None:
|
||||
"""Fan /diagnostics entries out to per-subsystem states."""
|
||||
now = time.monotonic()
|
||||
# Accumulate worst status per subsystem across all entries in this array
|
||||
pending: dict[str, tuple[str, str]] = {} # subsystem → (status, message)
|
||||
|
||||
for ds in msg.status:
|
||||
subsystem = _keyword_to_subsystem(ds.name)
|
||||
if subsystem is None:
|
||||
continue
|
||||
status = ros_level_to_status(ds.level)
|
||||
message = ds.message or ""
|
||||
if subsystem not in pending:
|
||||
pending[subsystem] = (status, message)
|
||||
else:
|
||||
prev_s, prev_m = pending[subsystem]
|
||||
new_worst = worse(status, prev_s)
|
||||
pending[subsystem] = (
|
||||
new_worst,
|
||||
message if new_worst == status else prev_m,
|
||||
)
|
||||
|
||||
for subsystem, (status, message) in pending.items():
|
||||
self._update(subsystem, status, message, now)
|
||||
|
||||
def _on_safety_zone(self, msg: String) -> None:
|
||||
"""Parse /saltybot/safety_zone/status JSON → estop subsystem."""
|
||||
now = time.monotonic()
|
||||
try:
|
||||
d = json.loads(msg.data)
|
||||
triggered = d.get("estop_triggered", d.get("triggered", False))
|
||||
active = d.get("safety_zone_active", d.get("active", True))
|
||||
if triggered:
|
||||
status, message = STATUS_ERROR, "E-stop triggered"
|
||||
elif not active:
|
||||
status, message = STATUS_WARN, "Safety zone inactive"
|
||||
else:
|
||||
status, message = STATUS_OK, "Safety zone active"
|
||||
self._update("estop", status, message, now)
|
||||
except Exception as exc:
|
||||
self.get_logger().debug(f"/saltybot/safety_zone/status parse error: {exc}")
|
||||
|
||||
def _on_pose_status(self, msg: String) -> None:
|
||||
"""Parse /saltybot/pose/status JSON → imu subsystem."""
|
||||
now = time.monotonic()
|
||||
try:
|
||||
d = json.loads(msg.data)
|
||||
status_str = d.get("status", "OK").upper()
|
||||
if status_str not in (STATUS_OK, STATUS_WARN, STATUS_ERROR):
|
||||
status_str = STATUS_WARN
|
||||
message = d.get("message", d.get("msg", ""))
|
||||
self._update("imu", status_str, message, now)
|
||||
except Exception as exc:
|
||||
self.get_logger().debug(f"/saltybot/pose/status parse error: {exc}")
|
||||
|
||||
def _on_uwb_status(self, msg: String) -> None:
|
||||
"""Parse /saltybot/uwb/status JSON → uwb subsystem."""
|
||||
now = time.monotonic()
|
||||
try:
|
||||
d = json.loads(msg.data)
|
||||
status_str = d.get("status", "OK").upper()
|
||||
if status_str not in (STATUS_OK, STATUS_WARN, STATUS_ERROR):
|
||||
status_str = STATUS_WARN
|
||||
message = d.get("message", d.get("msg", ""))
|
||||
self._update("uwb", status_str, message, now)
|
||||
except Exception as exc:
|
||||
self.get_logger().debug(f"/saltybot/uwb/status parse error: {exc}")
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
def _update(self, subsystem: str, status: str, message: str, now: float) -> None:
|
||||
"""Update subsystem state and record any transition."""
|
||||
s = self._subsystems.get(subsystem)
|
||||
if s is None:
|
||||
return
|
||||
transition = s.update(status, message, now)
|
||||
if transition is not None:
|
||||
self._transitions.append(transition)
|
||||
self.get_logger().info(
|
||||
f"[{subsystem}] {transition.from_status} → {transition.to_status}: {message}"
|
||||
)
|
||||
if transition.to_status == STATUS_ERROR:
|
||||
self._last_error = {
|
||||
"subsystem": subsystem,
|
||||
"message": message,
|
||||
"timestamp": transition.timestamp_iso,
|
||||
}
|
||||
|
||||
def _apply_stale_checks(self, now: float) -> None:
|
||||
for s in self._subsystems.values():
|
||||
transition = s.apply_stale_check(now)
|
||||
if transition is not None:
|
||||
self._transitions.append(transition)
|
||||
self.get_logger().warn(
|
||||
f"[{transition.subsystem}] went STALE (no data)"
|
||||
)
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Heartbeat timer
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
def _on_timer(self) -> None:
|
||||
now = time.monotonic()
|
||||
self._apply_stale_checks(now)
|
||||
|
||||
# Overall status = worst across all subsystems
|
||||
overall = STATUS_OK
|
||||
for s in self._subsystems.values():
|
||||
overall = worse(overall, s.status)
|
||||
# UNKNOWN does not degrade overall if at least one subsystem is known
|
||||
known = [s for s in self._subsystems.values() if s.status != STATUS_UNKNOWN]
|
||||
if not known:
|
||||
overall = STATUS_UNKNOWN
|
||||
|
||||
uptime = now - self._start_mono
|
||||
|
||||
payload = {
|
||||
"overall_status": overall,
|
||||
"uptime_s": round(uptime, 1),
|
||||
"subsystems": {
|
||||
name: s.to_dict(now)
|
||||
for name, s in self._subsystems.items()
|
||||
},
|
||||
"last_error": self._last_error,
|
||||
"recent_transitions": [
|
||||
{
|
||||
"subsystem": t.subsystem,
|
||||
"from_status": t.from_status,
|
||||
"to_status": t.to_status,
|
||||
"message": t.message,
|
||||
"timestamp": t.timestamp_iso,
|
||||
}
|
||||
for t in list(self._transitions)[-10:] # last 10 in the JSON
|
||||
],
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
self._pub.publish(String(data=json.dumps(payload)))
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = DiagnosticsAggregatorNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -0,0 +1,148 @@
|
||||
"""Subsystem state tracking for the diagnostics aggregator.
|
||||
|
||||
Each SubsystemState tracks one logical subsystem (motors, battery, etc.)
|
||||
across potentially multiple source topics. Status priority:
|
||||
ERROR > WARN > STALE > OK > UNKNOWN
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Status constants — ordered by severity (higher index = more severe)
|
||||
# ---------------------------------------------------------------------------
|
||||
STATUS_UNKNOWN = "UNKNOWN"
|
||||
STATUS_OK = "OK"
|
||||
STATUS_STALE = "STALE"
|
||||
STATUS_WARN = "WARN"
|
||||
STATUS_ERROR = "ERROR"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subsystem registry and diagnostic keyword routing
|
||||
# ---------------------------------------------------------------------------
|
||||
SUBSYSTEM_NAMES: list[str] = [
|
||||
"motors", "battery", "imu", "uwb", "lidar", "camera", "can_bus", "estop"
|
||||
]
|
||||
|
||||
# (keywords-tuple, subsystem-name) — first match wins, lower-cased comparison
|
||||
DIAG_KEYWORD_MAP: list[tuple[tuple[str, ...], str]] = [
|
||||
(("vesc", "motor", "esc", "fsesc"), "motors"),
|
||||
(("battery", "ina219", "lvc", "coulomb", "vbat"), "battery"),
|
||||
(("imu", "mpu6000", "bno055", "icm42688", "gyro", "accel"), "imu"),
|
||||
(("uwb", "dw1000", "dw3000"), "uwb"),
|
||||
(("lidar", "rplidar", "sick", "laser"), "lidar"),
|
||||
(("camera", "realsense", "oak", "webcam"), "camera"),
|
||||
(("can", "can_bus", "can_driver"), "can_bus"),
|
||||
(("estop", "safety", "emergency", "e-stop"), "estop"),
|
||||
]
|
||||
|
||||
|
||||
def keyword_to_subsystem(name: str) -> Optional[str]:
|
||||
"""Map a diagnostic status name to a subsystem key, or None."""
|
||||
lower = name.lower()
|
||||
for keywords, subsystem in DIAG_KEYWORD_MAP:
|
||||
if any(kw in lower for kw in keywords):
|
||||
return subsystem
|
||||
return None
|
||||
|
||||
|
||||
_SEVERITY: dict[str, int] = {
|
||||
STATUS_UNKNOWN: 0,
|
||||
STATUS_OK: 1,
|
||||
STATUS_STALE: 2,
|
||||
STATUS_WARN: 3,
|
||||
STATUS_ERROR: 4,
|
||||
}
|
||||
|
||||
|
||||
def worse(a: str, b: str) -> str:
|
||||
"""Return the more severe of two status strings."""
|
||||
return a if _SEVERITY.get(a, 0) >= _SEVERITY.get(b, 0) else b
|
||||
|
||||
|
||||
def ros_level_to_status(level: int) -> str:
|
||||
"""Convert diagnostic_msgs/DiagnosticStatus.level to a status string."""
|
||||
return {0: STATUS_OK, 1: STATUS_WARN, 2: STATUS_ERROR}.get(level, STATUS_UNKNOWN)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transition log entry
|
||||
# ---------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class Transition:
|
||||
"""A logged status change for one subsystem."""
|
||||
subsystem: str
|
||||
from_status: str
|
||||
to_status: str
|
||||
message: str
|
||||
timestamp_iso: str # ISO-8601 wall-clock
|
||||
monotonic_s: float # time.monotonic() at the transition
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-subsystem state
|
||||
# ---------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class SubsystemState:
|
||||
"""Live health state for one logical subsystem."""
|
||||
|
||||
name: str
|
||||
stale_timeout_s: float = 5.0 # seconds without update → STALE
|
||||
|
||||
# Mutable state
|
||||
status: str = STATUS_UNKNOWN
|
||||
message: str = ""
|
||||
last_updated_mono: float = field(default_factory=lambda: 0.0)
|
||||
last_change_mono: float = field(default_factory=lambda: 0.0)
|
||||
previous_status: str = STATUS_UNKNOWN
|
||||
|
||||
def update(self, new_status: str, message: str, now_mono: float) -> Optional[Transition]:
|
||||
"""Apply a new status, record a transition if the status changed.
|
||||
|
||||
Returns a Transition if the status changed, else None.
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
self.last_updated_mono = now_mono
|
||||
|
||||
if new_status == self.status:
|
||||
self.message = message
|
||||
return None
|
||||
|
||||
transition = Transition(
|
||||
subsystem=self.name,
|
||||
from_status=self.status,
|
||||
to_status=new_status,
|
||||
message=message,
|
||||
timestamp_iso=datetime.now(timezone.utc).isoformat(),
|
||||
monotonic_s=now_mono,
|
||||
)
|
||||
|
||||
self.previous_status = self.status
|
||||
self.status = new_status
|
||||
self.message = message
|
||||
self.last_change_mono = now_mono
|
||||
return transition
|
||||
|
||||
def apply_stale_check(self, now_mono: float) -> Optional[Transition]:
|
||||
"""Mark STALE if no update received within stale_timeout_s.
|
||||
|
||||
Returns a Transition if newly staled, else None.
|
||||
"""
|
||||
if self.last_updated_mono == 0.0:
|
||||
# Never received any data — leave as UNKNOWN
|
||||
return None
|
||||
if (now_mono - self.last_updated_mono) > self.stale_timeout_s:
|
||||
if self.status not in (STATUS_STALE, STATUS_UNKNOWN):
|
||||
return self.update(STATUS_STALE, "No data received", now_mono)
|
||||
return None
|
||||
|
||||
def to_dict(self, now_mono: float) -> dict:
|
||||
age = (now_mono - self.last_updated_mono) if self.last_updated_mono > 0 else None
|
||||
return {
|
||||
"status": self.status,
|
||||
"message": self.message,
|
||||
"age_s": round(age, 2) if age is not None else None,
|
||||
"previous_status": self.previous_status,
|
||||
}
|
||||
@ -0,0 +1,4 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/saltybot_diagnostics_aggregator
|
||||
[install]
|
||||
install_scripts=$base/lib/saltybot_diagnostics_aggregator
|
||||
28
jetson/ros2_ws/src/saltybot_diagnostics_aggregator/setup.py
Normal file
28
jetson/ros2_ws/src/saltybot_diagnostics_aggregator/setup.py
Normal file
@ -0,0 +1,28 @@
|
||||
from setuptools import setup
|
||||
|
||||
package_name = "saltybot_diagnostics_aggregator"
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version="0.1.0",
|
||||
packages=[package_name],
|
||||
data_files=[
|
||||
("share/ament_index/resource_index/packages", [f"resource/{package_name}"]),
|
||||
(f"share/{package_name}", ["package.xml"]),
|
||||
(f"share/{package_name}/launch", ["launch/diagnostics_aggregator.launch.py"]),
|
||||
(f"share/{package_name}/config", ["config/aggregator_params.yaml"]),
|
||||
],
|
||||
install_requires=["setuptools"],
|
||||
zip_safe=True,
|
||||
maintainer="sl-firmware",
|
||||
maintainer_email="sl-firmware@saltylab.local",
|
||||
description="Diagnostics aggregator — unified health dashboard for SaltyBot",
|
||||
license="MIT",
|
||||
tests_require=["pytest"],
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"diagnostics_aggregator_node = "
|
||||
"saltybot_diagnostics_aggregator.aggregator_node:main",
|
||||
],
|
||||
},
|
||||
)
|
||||
@ -0,0 +1,272 @@
|
||||
"""Unit tests for diagnostics aggregator — subsystem logic and routing.
|
||||
|
||||
All pure-function tests; no ROS2 or live topics required.
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from saltybot_diagnostics_aggregator.subsystem import (
|
||||
SubsystemState,
|
||||
Transition,
|
||||
STATUS_OK,
|
||||
STATUS_WARN,
|
||||
STATUS_ERROR,
|
||||
STATUS_STALE,
|
||||
STATUS_UNKNOWN,
|
||||
worse,
|
||||
ros_level_to_status,
|
||||
keyword_to_subsystem as _keyword_to_subsystem,
|
||||
SUBSYSTEM_NAMES as _SUBSYSTEM_NAMES,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# worse()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestWorse:
|
||||
def test_error_beats_warn(self):
|
||||
assert worse(STATUS_ERROR, STATUS_WARN) == STATUS_ERROR
|
||||
|
||||
def test_warn_beats_ok(self):
|
||||
assert worse(STATUS_WARN, STATUS_OK) == STATUS_WARN
|
||||
|
||||
def test_stale_beats_ok(self):
|
||||
assert worse(STATUS_STALE, STATUS_OK) == STATUS_STALE
|
||||
|
||||
def test_warn_beats_stale(self):
|
||||
assert worse(STATUS_WARN, STATUS_STALE) == STATUS_WARN
|
||||
|
||||
def test_error_beats_stale(self):
|
||||
assert worse(STATUS_ERROR, STATUS_STALE) == STATUS_ERROR
|
||||
|
||||
def test_ok_vs_ok(self):
|
||||
assert worse(STATUS_OK, STATUS_OK) == STATUS_OK
|
||||
|
||||
def test_error_vs_error(self):
|
||||
assert worse(STATUS_ERROR, STATUS_ERROR) == STATUS_ERROR
|
||||
|
||||
def test_unknown_loses_to_ok(self):
|
||||
assert worse(STATUS_OK, STATUS_UNKNOWN) == STATUS_OK
|
||||
|
||||
def test_symmetric(self):
|
||||
for a in (STATUS_OK, STATUS_WARN, STATUS_ERROR, STATUS_STALE):
|
||||
for b in (STATUS_OK, STATUS_WARN, STATUS_ERROR, STATUS_STALE):
|
||||
assert worse(a, b) == worse(b, a) or True # just ensure no crash
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ros_level_to_status()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRosLevelToStatus:
|
||||
def test_level_0_is_ok(self):
|
||||
assert ros_level_to_status(0) == STATUS_OK
|
||||
|
||||
def test_level_1_is_warn(self):
|
||||
assert ros_level_to_status(1) == STATUS_WARN
|
||||
|
||||
def test_level_2_is_error(self):
|
||||
assert ros_level_to_status(2) == STATUS_ERROR
|
||||
|
||||
def test_unknown_level(self):
|
||||
assert ros_level_to_status(99) == STATUS_UNKNOWN
|
||||
|
||||
def test_negative_level(self):
|
||||
assert ros_level_to_status(-1) == STATUS_UNKNOWN
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _keyword_to_subsystem()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestKeywordToSubsystem:
|
||||
def test_vesc_maps_to_motors(self):
|
||||
assert _keyword_to_subsystem("VESC/left (CAN ID 61)") == "motors"
|
||||
|
||||
def test_motor_maps_to_motors(self):
|
||||
assert _keyword_to_subsystem("motor_controller") == "motors"
|
||||
|
||||
def test_battery_maps_to_battery(self):
|
||||
assert _keyword_to_subsystem("battery_monitor") == "battery"
|
||||
|
||||
def test_ina219_maps_to_battery(self):
|
||||
assert _keyword_to_subsystem("INA219 current sensor") == "battery"
|
||||
|
||||
def test_lvc_maps_to_battery(self):
|
||||
assert _keyword_to_subsystem("lvc_cutoff") == "battery"
|
||||
|
||||
def test_imu_maps_to_imu(self):
|
||||
assert _keyword_to_subsystem("IMU/mpu6000") == "imu"
|
||||
|
||||
def test_mpu6000_maps_to_imu(self):
|
||||
assert _keyword_to_subsystem("mpu6000 driver") == "imu"
|
||||
|
||||
def test_uwb_maps_to_uwb(self):
|
||||
assert _keyword_to_subsystem("UWB anchor 0") == "uwb"
|
||||
|
||||
def test_rplidar_maps_to_lidar(self):
|
||||
assert _keyword_to_subsystem("rplidar_node") == "lidar"
|
||||
|
||||
def test_lidar_maps_to_lidar(self):
|
||||
assert _keyword_to_subsystem("lidar/scan") == "lidar"
|
||||
|
||||
def test_realsense_maps_to_camera(self):
|
||||
assert _keyword_to_subsystem("RealSense D435i") == "camera"
|
||||
|
||||
def test_camera_maps_to_camera(self):
|
||||
assert _keyword_to_subsystem("camera_health") == "camera"
|
||||
|
||||
def test_can_maps_to_can_bus(self):
|
||||
assert _keyword_to_subsystem("can_driver stats") == "can_bus"
|
||||
|
||||
def test_estop_maps_to_estop(self):
|
||||
assert _keyword_to_subsystem("estop_monitor") == "estop"
|
||||
|
||||
def test_safety_maps_to_estop(self):
|
||||
assert _keyword_to_subsystem("safety_zone") == "estop"
|
||||
|
||||
def test_unknown_returns_none(self):
|
||||
assert _keyword_to_subsystem("totally_unrelated_sensor") is None
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert _keyword_to_subsystem("RPLIDAR A2") == "lidar"
|
||||
assert _keyword_to_subsystem("IMU_CALIBRATION") == "imu"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubsystemState.update()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSubsystemStateUpdate:
|
||||
def _make(self) -> SubsystemState:
|
||||
return SubsystemState(name="motors", stale_timeout_s=5.0)
|
||||
|
||||
def test_initial_state(self):
|
||||
s = self._make()
|
||||
assert s.status == STATUS_UNKNOWN
|
||||
assert s.message == ""
|
||||
assert s.previous_status == STATUS_UNKNOWN
|
||||
|
||||
def test_first_update_creates_transition(self):
|
||||
s = self._make()
|
||||
t = s.update(STATUS_OK, "all good", time.monotonic())
|
||||
assert t is not None
|
||||
assert t.from_status == STATUS_UNKNOWN
|
||||
assert t.to_status == STATUS_OK
|
||||
assert t.subsystem == "motors"
|
||||
|
||||
def test_same_status_no_transition(self):
|
||||
s = self._make()
|
||||
s.update(STATUS_OK, "good", time.monotonic())
|
||||
t = s.update(STATUS_OK, "still good", time.monotonic())
|
||||
assert t is None
|
||||
|
||||
def test_status_change_produces_transition(self):
|
||||
s = self._make()
|
||||
now = time.monotonic()
|
||||
s.update(STATUS_OK, "good", now)
|
||||
t = s.update(STATUS_ERROR, "fault", now + 1)
|
||||
assert t is not None
|
||||
assert t.from_status == STATUS_OK
|
||||
assert t.to_status == STATUS_ERROR
|
||||
assert t.message == "fault"
|
||||
|
||||
def test_previous_status_saved(self):
|
||||
s = self._make()
|
||||
now = time.monotonic()
|
||||
s.update(STATUS_OK, "good", now)
|
||||
s.update(STATUS_WARN, "warn", now + 1)
|
||||
assert s.previous_status == STATUS_OK
|
||||
assert s.status == STATUS_WARN
|
||||
|
||||
def test_last_updated_advances(self):
|
||||
s = self._make()
|
||||
t1 = time.monotonic()
|
||||
s.update(STATUS_OK, "x", t1)
|
||||
assert s.last_updated_mono == pytest.approx(t1)
|
||||
t2 = t1 + 1.0
|
||||
s.update(STATUS_OK, "y", t2)
|
||||
assert s.last_updated_mono == pytest.approx(t2)
|
||||
|
||||
def test_transition_has_iso_timestamp(self):
|
||||
s = self._make()
|
||||
t = s.update(STATUS_OK, "good", time.monotonic())
|
||||
assert t is not None
|
||||
assert "T" in t.timestamp_iso # ISO-8601 contains 'T'
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubsystemState.apply_stale_check()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSubsystemStateStale:
|
||||
def test_never_updated_stays_unknown(self):
|
||||
s = SubsystemState(name="imu", stale_timeout_s=2.0)
|
||||
t = s.apply_stale_check(time.monotonic() + 100)
|
||||
assert t is None
|
||||
assert s.status == STATUS_UNKNOWN
|
||||
|
||||
def test_fresh_data_not_stale(self):
|
||||
s = SubsystemState(name="imu", stale_timeout_s=5.0)
|
||||
now = time.monotonic()
|
||||
s.update(STATUS_OK, "good", now)
|
||||
t = s.apply_stale_check(now + 3.0) # 3s < 5s timeout
|
||||
assert t is None
|
||||
assert s.status == STATUS_OK
|
||||
|
||||
def test_old_data_goes_stale(self):
|
||||
s = SubsystemState(name="imu", stale_timeout_s=5.0)
|
||||
now = time.monotonic()
|
||||
s.update(STATUS_OK, "good", now)
|
||||
t = s.apply_stale_check(now + 6.0) # 6s > 5s timeout
|
||||
assert t is not None
|
||||
assert t.to_status == STATUS_STALE
|
||||
|
||||
def test_already_stale_no_duplicate_transition(self):
|
||||
s = SubsystemState(name="imu", stale_timeout_s=5.0)
|
||||
now = time.monotonic()
|
||||
s.update(STATUS_OK, "good", now)
|
||||
s.apply_stale_check(now + 6.0) # → STALE
|
||||
t2 = s.apply_stale_check(now + 7.0) # already STALE
|
||||
assert t2 is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubsystemState.to_dict()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSubsystemStateToDict:
|
||||
def test_unknown_state(self):
|
||||
s = SubsystemState(name="uwb")
|
||||
d = s.to_dict(time.monotonic())
|
||||
assert d["status"] == STATUS_UNKNOWN
|
||||
assert d["age_s"] is None
|
||||
|
||||
def test_known_state_has_age(self):
|
||||
s = SubsystemState(name="uwb", stale_timeout_s=5.0)
|
||||
now = time.monotonic()
|
||||
s.update(STATUS_OK, "ok", now)
|
||||
d = s.to_dict(now + 1.5)
|
||||
assert d["status"] == STATUS_OK
|
||||
assert d["age_s"] == pytest.approx(1.5, abs=0.01)
|
||||
assert d["message"] == "ok"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _SUBSYSTEM_NAMES completeness
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSubsystemNames:
|
||||
def test_all_required_subsystems_present(self):
|
||||
required = {"motors", "battery", "imu", "uwb", "lidar", "camera", "can_bus", "estop"}
|
||||
assert required.issubset(set(_SUBSYSTEM_NAMES))
|
||||
|
||||
def test_no_duplicates(self):
|
||||
assert len(_SUBSYSTEM_NAMES) == len(set(_SUBSYSTEM_NAMES))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@ -0,0 +1,305 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
vesc_mqtt_relay_node.py — ROS2 → MQTT relay for VESC CAN telemetry (Issue #656)
|
||||
|
||||
Subscribes to VESC telemetry ROS2 topics and republishes as MQTT JSON payloads
|
||||
for the sensor dashboard. Each per-motor topic is rate-limited to 5 Hz.
|
||||
|
||||
ROS2 topics consumed
|
||||
─────────────────────
|
||||
/vesc/left/state std_msgs/String JSON from vesc_telemetry_node
|
||||
/vesc/right/state std_msgs/String JSON from vesc_telemetry_node
|
||||
/vesc/combined std_msgs/String JSON from vesc_telemetry_node
|
||||
|
||||
MQTT topics published
|
||||
──────────────────────
|
||||
saltybot/phone/vesc_left — per-motor telemetry (left)
|
||||
saltybot/phone/vesc_right — per-motor telemetry (right)
|
||||
saltybot/phone/vesc_combined — combined voltage + total current + RPMs
|
||||
|
||||
MQTT payload (per-motor)
|
||||
────────────────────────
|
||||
{
|
||||
"rpm": int,
|
||||
"current_a": float, # phase current
|
||||
"voltage_v": float, # bus voltage
|
||||
"temperature_c": float, # FET temperature
|
||||
"duty_cycle": float, # -1.0 … 1.0
|
||||
"fault_code": int,
|
||||
"ts": float # epoch seconds
|
||||
}
|
||||
|
||||
Parameters
|
||||
──────────
|
||||
mqtt_host str Broker IP/hostname (default: localhost)
|
||||
mqtt_port int Broker port (default: 1883)
|
||||
mqtt_keepalive int Keepalive seconds (default: 60)
|
||||
reconnect_delay_s float Delay between reconnects (default: 5.0)
|
||||
motor_rate_hz float Max publish rate per motor (default: 5.0)
|
||||
|
||||
Dependencies
|
||||
────────────
|
||||
pip install paho-mqtt
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from std_msgs.msg import String
|
||||
|
||||
try:
|
||||
import paho.mqtt.client as mqtt
|
||||
_MQTT_OK = True
|
||||
except ImportError:
|
||||
_MQTT_OK = False
|
||||
|
||||
# ── MQTT topic constants ──────────────────────────────────────────────────────
|
||||
|
||||
_MQTT_VESC_LEFT = "saltybot/phone/vesc_left"
|
||||
_MQTT_VESC_RIGHT = "saltybot/phone/vesc_right"
|
||||
_MQTT_VESC_COMBINED = "saltybot/phone/vesc_combined"
|
||||
|
||||
# ── ROS2 topic constants ──────────────────────────────────────────────────────
|
||||
|
||||
_ROS_VESC_LEFT = "/vesc/left/state"
|
||||
_ROS_VESC_RIGHT = "/vesc/right/state"
|
||||
_ROS_VESC_COMBINED = "/vesc/combined"
|
||||
|
||||
|
||||
def _extract_motor_payload(data: dict) -> dict:
|
||||
"""
|
||||
Extract the required fields from a vesc_telemetry_node per-motor JSON dict.
|
||||
|
||||
Upstream JSON keys (from vesc_telemetry_node._state_to_dict):
|
||||
rpm, current_a, voltage_v, temp_fet_c, duty_cycle, fault_code, ...
|
||||
|
||||
Returns a dashboard-friendly payload with a stable key set.
|
||||
"""
|
||||
return {
|
||||
"rpm": int(data["rpm"]),
|
||||
"current_a": float(data["current_a"]),
|
||||
"voltage_v": float(data["voltage_v"]),
|
||||
"temperature_c": float(data["temp_fet_c"]),
|
||||
"duty_cycle": float(data["duty_cycle"]),
|
||||
"fault_code": int(data["fault_code"]),
|
||||
"ts": float(data.get("stamp", time.time())),
|
||||
}
|
||||
|
||||
|
||||
def _extract_combined_payload(data: dict) -> dict:
|
||||
"""
|
||||
Extract fields from the /vesc/combined JSON dict.
|
||||
|
||||
Upstream keys: voltage_v, total_current_a, left_rpm, right_rpm,
|
||||
left_alive, right_alive, stamp
|
||||
"""
|
||||
return {
|
||||
"voltage_v": float(data["voltage_v"]),
|
||||
"total_current_a": float(data["total_current_a"]),
|
||||
"left_rpm": int(data["left_rpm"]),
|
||||
"right_rpm": int(data["right_rpm"]),
|
||||
"left_alive": bool(data["left_alive"]),
|
||||
"right_alive": bool(data["right_alive"]),
|
||||
"ts": float(data.get("stamp", time.time())),
|
||||
}
|
||||
|
||||
|
||||
class VescMqttRelayNode(Node):
|
||||
"""
|
||||
Subscribes to VESC ROS2 topics and relays telemetry to MQTT.
|
||||
|
||||
Rate limiting: each per-motor topic maintains a last-publish timestamp;
|
||||
messages arriving faster than motor_rate_hz are silently dropped.
|
||||
The /vesc/combined topic is also rate-limited at the same rate.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("vesc_mqtt_relay")
|
||||
|
||||
# ── Parameters ────────────────────────────────────────────────────────
|
||||
self.declare_parameter("mqtt_host", "localhost")
|
||||
self.declare_parameter("mqtt_port", 1883)
|
||||
self.declare_parameter("mqtt_keepalive", 60)
|
||||
self.declare_parameter("reconnect_delay_s", 5.0)
|
||||
self.declare_parameter("motor_rate_hz", 5.0)
|
||||
|
||||
self._mqtt_host = self.get_parameter("mqtt_host").value
|
||||
self._mqtt_port = self.get_parameter("mqtt_port").value
|
||||
self._mqtt_keepalive = self.get_parameter("mqtt_keepalive").value
|
||||
reconnect_delay = self.get_parameter("reconnect_delay_s").value
|
||||
rate_hz = max(0.1, float(self.get_parameter("motor_rate_hz").value))
|
||||
self._min_interval = 1.0 / rate_hz
|
||||
|
||||
if not _MQTT_OK:
|
||||
self.get_logger().error(
|
||||
"paho-mqtt not installed — run: pip install paho-mqtt"
|
||||
)
|
||||
|
||||
# ── Rate-limit state (last publish epoch per MQTT topic) ──────────────
|
||||
self._last_pub: dict[str, float] = {
|
||||
_MQTT_VESC_LEFT: 0.0,
|
||||
_MQTT_VESC_RIGHT: 0.0,
|
||||
_MQTT_VESC_COMBINED: 0.0,
|
||||
}
|
||||
|
||||
# ── Stats ─────────────────────────────────────────────────────────────
|
||||
self._rx_count: dict[str, int] = {
|
||||
_MQTT_VESC_LEFT: 0, _MQTT_VESC_RIGHT: 0, _MQTT_VESC_COMBINED: 0
|
||||
}
|
||||
self._pub_count: dict[str, int] = {
|
||||
_MQTT_VESC_LEFT: 0, _MQTT_VESC_RIGHT: 0, _MQTT_VESC_COMBINED: 0
|
||||
}
|
||||
self._drop_count: dict[str, int] = {
|
||||
_MQTT_VESC_LEFT: 0, _MQTT_VESC_RIGHT: 0, _MQTT_VESC_COMBINED: 0
|
||||
}
|
||||
self._err_count: dict[str, int] = {
|
||||
_MQTT_VESC_LEFT: 0, _MQTT_VESC_RIGHT: 0, _MQTT_VESC_COMBINED: 0
|
||||
}
|
||||
self._mqtt_connected = False
|
||||
|
||||
# ── ROS2 subscriptions ────────────────────────────────────────────────
|
||||
self.create_subscription(
|
||||
String, _ROS_VESC_LEFT,
|
||||
lambda msg: self._on_vesc(msg, _MQTT_VESC_LEFT, _extract_motor_payload),
|
||||
10,
|
||||
)
|
||||
self.create_subscription(
|
||||
String, _ROS_VESC_RIGHT,
|
||||
lambda msg: self._on_vesc(msg, _MQTT_VESC_RIGHT, _extract_motor_payload),
|
||||
10,
|
||||
)
|
||||
self.create_subscription(
|
||||
String, _ROS_VESC_COMBINED,
|
||||
lambda msg: self._on_vesc(msg, _MQTT_VESC_COMBINED, _extract_combined_payload),
|
||||
10,
|
||||
)
|
||||
|
||||
# ── MQTT client ───────────────────────────────────────────────────────
|
||||
self._mqtt_client: "mqtt.Client | None" = None
|
||||
if _MQTT_OK:
|
||||
self._mqtt_client = mqtt.Client(
|
||||
client_id="saltybot-vesc-mqtt-relay", clean_session=True
|
||||
)
|
||||
self._mqtt_client.on_connect = self._on_mqtt_connect
|
||||
self._mqtt_client.on_disconnect = self._on_mqtt_disconnect
|
||||
self._mqtt_client.reconnect_delay_set(
|
||||
min_delay=int(reconnect_delay),
|
||||
max_delay=int(reconnect_delay) * 4,
|
||||
)
|
||||
self._mqtt_connect()
|
||||
|
||||
self.get_logger().info(
|
||||
"VESC→MQTT relay started — broker=%s:%d rate=%.1f Hz",
|
||||
self._mqtt_host, self._mqtt_port, rate_hz,
|
||||
)
|
||||
|
||||
# ── MQTT connection ───────────────────────────────────────────────────────
|
||||
|
||||
def _mqtt_connect(self) -> None:
|
||||
try:
|
||||
self._mqtt_client.connect_async(
|
||||
self._mqtt_host, self._mqtt_port,
|
||||
keepalive=self._mqtt_keepalive,
|
||||
)
|
||||
self._mqtt_client.loop_start()
|
||||
except Exception as exc:
|
||||
self.get_logger().warning("MQTT initial connect error: %s", str(exc))
|
||||
|
||||
def _on_mqtt_connect(self, client, userdata, flags, rc) -> None:
|
||||
if rc == 0:
|
||||
self._mqtt_connected = True
|
||||
self.get_logger().info(
|
||||
"MQTT connected to %s:%d", self._mqtt_host, self._mqtt_port
|
||||
)
|
||||
else:
|
||||
self.get_logger().warning("MQTT connect failed rc=%d", rc)
|
||||
|
||||
def _on_mqtt_disconnect(self, client, userdata, rc) -> None:
|
||||
self._mqtt_connected = False
|
||||
if rc != 0:
|
||||
self.get_logger().warning(
|
||||
"MQTT disconnected (rc=%d) — paho will reconnect", rc
|
||||
)
|
||||
|
||||
# ── ROS2 subscriber callback ──────────────────────────────────────────────
|
||||
|
||||
def _on_vesc(self, ros_msg: String, mqtt_topic: str, extractor) -> None:
|
||||
"""
|
||||
Handle an incoming VESC ROS2 message.
|
||||
|
||||
1. Parse JSON from the String payload.
|
||||
2. Check rate limit — drop if too soon.
|
||||
3. Extract dashboard fields.
|
||||
4. Publish to MQTT.
|
||||
"""
|
||||
self._rx_count[mqtt_topic] += 1
|
||||
|
||||
# Rate limit
|
||||
now = time.monotonic()
|
||||
if now - self._last_pub[mqtt_topic] < self._min_interval:
|
||||
self._drop_count[mqtt_topic] += 1
|
||||
return
|
||||
|
||||
# Parse
|
||||
try:
|
||||
data = json.loads(ros_msg.data)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
|
||||
self._err_count[mqtt_topic] += 1
|
||||
self.get_logger().debug("JSON error on %s: %s", mqtt_topic, exc)
|
||||
return
|
||||
|
||||
# Extract
|
||||
try:
|
||||
payload = extractor(data)
|
||||
except (KeyError, TypeError, ValueError) as exc:
|
||||
self._err_count[mqtt_topic] += 1
|
||||
self.get_logger().debug("Payload error on %s: %s — %s", mqtt_topic, exc, data)
|
||||
return
|
||||
|
||||
# Publish
|
||||
if self._mqtt_client is not None and self._mqtt_connected:
|
||||
try:
|
||||
self._mqtt_client.publish(
|
||||
mqtt_topic,
|
||||
json.dumps(payload, separators=(",", ":")),
|
||||
qos=0,
|
||||
retain=False,
|
||||
)
|
||||
self._last_pub[mqtt_topic] = now
|
||||
self._pub_count[mqtt_topic] += 1
|
||||
except Exception as exc:
|
||||
self._err_count[mqtt_topic] += 1
|
||||
self.get_logger().debug("MQTT publish error on %s: %s", mqtt_topic, exc)
|
||||
else:
|
||||
# Still update last_pub to avoid log spam when broker is down
|
||||
self._last_pub[mqtt_topic] = now
|
||||
self.get_logger().debug("MQTT not connected — dropped %s", mqtt_topic)
|
||||
|
||||
# ── Cleanup ───────────────────────────────────────────────────────────────
|
||||
|
||||
def destroy_node(self) -> None:
|
||||
if self._mqtt_client is not None:
|
||||
self._mqtt_client.loop_stop()
|
||||
self._mqtt_client.disconnect()
|
||||
super().destroy_node()
|
||||
|
||||
|
||||
# ── Entry point ───────────────────────────────────────────────────────────────
|
||||
|
||||
def main(args: Any = None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = VescMqttRelayNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.try_shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -30,6 +30,7 @@ setup(
|
||||
'phone_bridge = saltybot_phone.ws_bridge:main',
|
||||
'phone_camera_node = saltybot_phone.phone_camera_node:main',
|
||||
'mqtt_ros2_bridge = saltybot_phone.mqtt_ros2_bridge_node:main',
|
||||
'vesc_mqtt_relay = saltybot_phone.vesc_mqtt_relay_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
0
jetson/ros2_ws/src/saltybot_phone/test/__init__.py
Normal file
0
jetson/ros2_ws/src/saltybot_phone/test/__init__.py
Normal file
343
jetson/ros2_ws/src/saltybot_phone/test/test_vesc_mqtt_relay.py
Normal file
343
jetson/ros2_ws/src/saltybot_phone/test/test_vesc_mqtt_relay.py
Normal file
@ -0,0 +1,343 @@
|
||||
"""Unit tests for vesc_mqtt_relay_node — pure-logic helpers.
|
||||
|
||||
Does not require ROS2, paho-mqtt, or a live VESC/broker.
|
||||
Tests cover the two payload extractors and the rate-limiting logic.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import types
|
||||
|
||||
# ── Stub out ROS2 / paho so the module can be imported without them ───────────
|
||||
|
||||
def _make_rclpy_stub():
|
||||
rclpy = types.ModuleType("rclpy")
|
||||
node_mod = types.ModuleType("rclpy.node")
|
||||
|
||||
class _Node:
|
||||
def __init__(self, *a, **kw): pass
|
||||
def declare_parameter(self, *a, **kw): pass
|
||||
def get_parameter(self, name):
|
||||
class _P:
|
||||
value = None
|
||||
return _P()
|
||||
def create_subscription(self, *a, **kw): pass
|
||||
def create_publisher(self, *a, **kw): pass
|
||||
def create_timer(self, *a, **kw): pass
|
||||
def get_clock(self): return None
|
||||
def get_logger(self):
|
||||
class _L:
|
||||
def info(self, *a, **kw): pass
|
||||
def warning(self, *a, **kw): pass
|
||||
def error(self, *a, **kw): pass
|
||||
def debug(self, *a, **kw): pass
|
||||
return _L()
|
||||
def destroy_node(self): pass
|
||||
|
||||
node_mod.Node = _Node
|
||||
rclpy.node = node_mod
|
||||
rclpy.init = lambda *a, **kw: None
|
||||
rclpy.spin = lambda *a, **kw: None
|
||||
rclpy.try_shutdown = lambda *a, **kw: None
|
||||
|
||||
std_msgs_mod = types.ModuleType("std_msgs")
|
||||
std_msgs_msg = types.ModuleType("std_msgs.msg")
|
||||
class _String:
|
||||
data: str = ""
|
||||
std_msgs_msg.String = _String
|
||||
std_msgs_mod.msg = std_msgs_msg
|
||||
|
||||
paho_mod = types.ModuleType("paho")
|
||||
paho_mqtt = types.ModuleType("paho.mqtt")
|
||||
paho_client = types.ModuleType("paho.mqtt.client")
|
||||
class _Client:
|
||||
def __init__(self, *a, **kw): pass
|
||||
def connect_async(self, *a, **kw): pass
|
||||
def loop_start(self): pass
|
||||
def loop_stop(self): pass
|
||||
def disconnect(self): pass
|
||||
def publish(self, *a, **kw): pass
|
||||
def reconnect_delay_set(self, *a, **kw): pass
|
||||
on_connect = None
|
||||
on_disconnect = None
|
||||
paho_client.Client = _Client
|
||||
paho_mqtt.client = paho_client
|
||||
paho_mod.mqtt = paho_mqtt
|
||||
|
||||
for name, mod in [
|
||||
("rclpy", rclpy),
|
||||
("rclpy.node", node_mod),
|
||||
("std_msgs", std_msgs_mod),
|
||||
("std_msgs.msg", std_msgs_msg),
|
||||
("paho", paho_mod),
|
||||
("paho.mqtt", paho_mqtt),
|
||||
("paho.mqtt.client", paho_client),
|
||||
]:
|
||||
sys.modules.setdefault(name, mod)
|
||||
|
||||
|
||||
_make_rclpy_stub()
|
||||
|
||||
from saltybot_phone.vesc_mqtt_relay_node import (
|
||||
_MQTT_VESC_LEFT,
|
||||
_MQTT_VESC_RIGHT,
|
||||
_MQTT_VESC_COMBINED,
|
||||
_extract_combined_payload,
|
||||
_extract_motor_payload,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _extract_motor_payload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractMotorPayload:
|
||||
"""Covers field extraction and type coercion for per-motor JSON."""
|
||||
|
||||
def _sample(self, **overrides):
|
||||
base = {
|
||||
"can_id": 61,
|
||||
"rpm": 1500,
|
||||
"current_a": 12.34,
|
||||
"current_in_a": 10.0,
|
||||
"duty_cycle": 0.4500,
|
||||
"voltage_v": 25.20,
|
||||
"temp_fet_c": 45.5,
|
||||
"temp_motor_c": 62.0,
|
||||
"fault_code": 0,
|
||||
"fault_name": "NONE",
|
||||
"alive": True,
|
||||
"stamp": 1000.0,
|
||||
}
|
||||
base.update(overrides)
|
||||
return base
|
||||
|
||||
def test_required_keys_present(self):
|
||||
p = _extract_motor_payload(self._sample())
|
||||
for key in ("rpm", "current_a", "voltage_v", "temperature_c",
|
||||
"duty_cycle", "fault_code", "ts"):
|
||||
assert key in p, f"missing key: {key}"
|
||||
|
||||
def test_rpm_is_int(self):
|
||||
p = _extract_motor_payload(self._sample(rpm=1500))
|
||||
assert isinstance(p["rpm"], int)
|
||||
assert p["rpm"] == 1500
|
||||
|
||||
def test_temperature_maps_to_temp_fet_c(self):
|
||||
p = _extract_motor_payload(self._sample(temp_fet_c=55.5))
|
||||
assert p["temperature_c"] == 55.5
|
||||
|
||||
def test_voltage_v(self):
|
||||
p = _extract_motor_payload(self._sample(voltage_v=24.8))
|
||||
assert p["voltage_v"] == 24.8
|
||||
|
||||
def test_duty_cycle(self):
|
||||
p = _extract_motor_payload(self._sample(duty_cycle=0.75))
|
||||
assert p["duty_cycle"] == 0.75
|
||||
|
||||
def test_fault_code_zero(self):
|
||||
p = _extract_motor_payload(self._sample(fault_code=0))
|
||||
assert p["fault_code"] == 0
|
||||
|
||||
def test_fault_code_nonzero(self):
|
||||
p = _extract_motor_payload(self._sample(fault_code=3))
|
||||
assert p["fault_code"] == 3
|
||||
|
||||
def test_ts_from_stamp(self):
|
||||
p = _extract_motor_payload(self._sample(stamp=12345.678))
|
||||
assert p["ts"] == 12345.678
|
||||
|
||||
def test_negative_rpm(self):
|
||||
p = _extract_motor_payload(self._sample(rpm=-3000))
|
||||
assert p["rpm"] == -3000
|
||||
|
||||
def test_negative_current(self):
|
||||
p = _extract_motor_payload(self._sample(current_a=-5.0))
|
||||
assert p["current_a"] == -5.0
|
||||
|
||||
def test_negative_duty_cycle(self):
|
||||
p = _extract_motor_payload(self._sample(duty_cycle=-0.5))
|
||||
assert p["duty_cycle"] == -0.5
|
||||
|
||||
def test_missing_required_key_raises(self):
|
||||
bad = self._sample()
|
||||
del bad["rpm"]
|
||||
try:
|
||||
_extract_motor_payload(bad)
|
||||
assert False, "expected KeyError"
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def test_missing_stamp_uses_current_time(self):
|
||||
data = self._sample()
|
||||
del data["stamp"]
|
||||
before = time.time()
|
||||
p = _extract_motor_payload(data)
|
||||
after = time.time()
|
||||
assert before <= p["ts"] <= after
|
||||
|
||||
def test_json_roundtrip(self):
|
||||
p = _extract_motor_payload(self._sample())
|
||||
raw = json.dumps(p)
|
||||
restored = json.loads(raw)
|
||||
assert restored["rpm"] == p["rpm"]
|
||||
assert restored["fault_code"] == p["fault_code"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _extract_combined_payload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractCombinedPayload:
|
||||
def _sample(self, **overrides):
|
||||
base = {
|
||||
"voltage_v": 25.2,
|
||||
"total_current_a": 18.5,
|
||||
"left_rpm": 1400,
|
||||
"right_rpm": 1420,
|
||||
"left_alive": True,
|
||||
"right_alive": True,
|
||||
"stamp": 2000.0,
|
||||
}
|
||||
base.update(overrides)
|
||||
return base
|
||||
|
||||
def test_required_keys_present(self):
|
||||
p = _extract_combined_payload(self._sample())
|
||||
for key in ("voltage_v", "total_current_a", "left_rpm", "right_rpm",
|
||||
"left_alive", "right_alive", "ts"):
|
||||
assert key in p, f"missing key: {key}"
|
||||
|
||||
def test_voltage_v(self):
|
||||
p = _extract_combined_payload(self._sample(voltage_v=24.0))
|
||||
assert p["voltage_v"] == 24.0
|
||||
|
||||
def test_total_current_a(self):
|
||||
p = _extract_combined_payload(self._sample(total_current_a=30.5))
|
||||
assert p["total_current_a"] == 30.5
|
||||
|
||||
def test_rpms_are_int(self):
|
||||
p = _extract_combined_payload(self._sample(left_rpm=1000, right_rpm=1050))
|
||||
assert isinstance(p["left_rpm"], int)
|
||||
assert isinstance(p["right_rpm"], int)
|
||||
|
||||
def test_alive_flags(self):
|
||||
p = _extract_combined_payload(self._sample(left_alive=True, right_alive=False))
|
||||
assert p["left_alive"] is True
|
||||
assert p["right_alive"] is False
|
||||
|
||||
def test_ts_from_stamp(self):
|
||||
p = _extract_combined_payload(self._sample(stamp=9999.1))
|
||||
assert p["ts"] == 9999.1
|
||||
|
||||
def test_missing_stamp_uses_current_time(self):
|
||||
data = self._sample()
|
||||
del data["stamp"]
|
||||
before = time.time()
|
||||
p = _extract_combined_payload(data)
|
||||
after = time.time()
|
||||
assert before <= p["ts"] <= after
|
||||
|
||||
def test_json_roundtrip(self):
|
||||
p = _extract_combined_payload(self._sample())
|
||||
raw = json.dumps(p)
|
||||
restored = json.loads(raw)
|
||||
assert restored["voltage_v"] == p["voltage_v"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rate-limit logic (isolated, no ROS2 / paho)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRateLimit:
|
||||
"""
|
||||
Exercise the rate-limiting gate that lives inside VescMqttRelayNode._on_vesc.
|
||||
We test the guard condition directly without instantiating the ROS2 node.
|
||||
"""
|
||||
|
||||
def _make_gate(self, rate_hz: float):
|
||||
"""
|
||||
Return a stateful callable that mirrors the rate-limit check in the node.
|
||||
Returns True if a publish should proceed, False if the message is dropped.
|
||||
"""
|
||||
min_interval = 1.0 / rate_hz
|
||||
state = {"last": 0.0}
|
||||
|
||||
def gate(now: float) -> bool:
|
||||
if now - state["last"] < min_interval:
|
||||
return False
|
||||
state["last"] = now
|
||||
return True
|
||||
|
||||
return gate
|
||||
|
||||
def test_first_message_always_passes(self):
|
||||
gate = self._make_gate(5.0)
|
||||
assert gate(time.monotonic()) is True
|
||||
|
||||
def test_immediate_second_message_dropped(self):
|
||||
gate = self._make_gate(5.0)
|
||||
t = time.monotonic()
|
||||
gate(t)
|
||||
assert gate(t + 0.001) is False # 1 ms < 200 ms interval
|
||||
|
||||
def test_message_after_interval_passes(self):
|
||||
gate = self._make_gate(5.0)
|
||||
t = time.monotonic()
|
||||
gate(t)
|
||||
assert gate(t + 0.201) is True # 201 ms > 200 ms interval
|
||||
|
||||
def test_exactly_at_interval_dropped(self):
|
||||
gate = self._make_gate(5.0)
|
||||
t = time.monotonic()
|
||||
gate(t)
|
||||
# Exactly at the boundary is strictly less-than, so it's dropped
|
||||
assert gate(t + 0.2) is False
|
||||
|
||||
def test_10hz_rate(self):
|
||||
gate = self._make_gate(10.0)
|
||||
t = time.monotonic()
|
||||
gate(t)
|
||||
assert gate(t + 0.09) is False # 90 ms < 100 ms
|
||||
assert gate(t + 0.101) is True # 101 ms > 100 ms
|
||||
|
||||
def test_1hz_rate(self):
|
||||
gate = self._make_gate(1.0)
|
||||
t = time.monotonic()
|
||||
gate(t)
|
||||
assert gate(t + 0.999) is False
|
||||
assert gate(t + 1.001) is True
|
||||
|
||||
def test_multiple_topics_independent(self):
|
||||
gate_left = self._make_gate(5.0)
|
||||
gate_right = self._make_gate(5.0)
|
||||
t = time.monotonic()
|
||||
gate_left(t)
|
||||
gate_right(t)
|
||||
# left: drop, right: drop
|
||||
assert gate_left(t + 0.05) is False
|
||||
assert gate_right(t + 0.05) is False
|
||||
# advance only left past interval
|
||||
assert gate_left(t + 0.21) is True
|
||||
assert gate_right(t + 0.21) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MQTT topic constant checks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTopicConstants:
|
||||
def test_left_topic(self):
|
||||
assert _MQTT_VESC_LEFT == "saltybot/phone/vesc_left"
|
||||
|
||||
def test_right_topic(self):
|
||||
assert _MQTT_VESC_RIGHT == "saltybot/phone/vesc_right"
|
||||
|
||||
def test_combined_topic(self):
|
||||
assert _MQTT_VESC_COMBINED == "saltybot/phone/vesc_combined"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
pytest.main([__file__, "-v"])
|
||||
@ -8,3 +8,13 @@ vesc_can_driver:
|
||||
wheel_radius: 0.1
|
||||
max_speed: 5.0
|
||||
command_timeout: 1.0
|
||||
|
||||
velocity_smoother:
|
||||
ros__parameters:
|
||||
publish_rate: 50.0
|
||||
max_linear_accel: 1.0
|
||||
max_linear_decel: 2.0
|
||||
max_angular_accel: 1.5
|
||||
max_angular_decel: 3.0
|
||||
max_linear_jerk: 0.0
|
||||
max_angular_jerk: 0.0
|
||||
|
||||
@ -20,6 +20,13 @@ def generate_launch_description():
|
||||
default_value=config_file,
|
||||
description="Path to configuration YAML file",
|
||||
),
|
||||
Node(
|
||||
package="saltybot_vesc_driver",
|
||||
executable="velocity_smoother_node",
|
||||
name="velocity_smoother",
|
||||
output="screen",
|
||||
parameters=[LaunchConfiguration("config_file")],
|
||||
),
|
||||
Node(
|
||||
package="saltybot_vesc_driver",
|
||||
executable="vesc_driver_node",
|
||||
|
||||
@ -0,0 +1,171 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Smooth velocity controller with accel/decel ramp.
|
||||
|
||||
Subscribes to /cmd_vel (geometry_msgs/Twist), applies acceleration
|
||||
and deceleration ramps, and publishes smoothed commands to
|
||||
/cmd_vel_smoothed at 50 Hz.
|
||||
|
||||
E-stop (std_msgs/Bool on /e_stop): immediate zero, bypasses ramp.
|
||||
Optional jerk limiting via max_linear_jerk / max_angular_jerk params.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import rclpy
|
||||
from geometry_msgs.msg import Twist
|
||||
from rclpy.node import Node
|
||||
from std_msgs.msg import Bool
|
||||
|
||||
|
||||
class VelocitySmoother(Node):
|
||||
def __init__(self):
|
||||
super().__init__('velocity_smoother')
|
||||
|
||||
# Declare parameters
|
||||
self.declare_parameter('publish_rate', 50.0) # Hz
|
||||
self.declare_parameter('max_linear_accel', 1.0) # m/s²
|
||||
self.declare_parameter('max_linear_decel', 2.0) # m/s²
|
||||
self.declare_parameter('max_angular_accel', 1.5) # rad/s²
|
||||
self.declare_parameter('max_angular_decel', 3.0) # rad/s²
|
||||
self.declare_parameter('max_linear_jerk', 0.0) # m/s³, 0=disabled
|
||||
self.declare_parameter('max_angular_jerk', 0.0) # rad/s³, 0=disabled
|
||||
|
||||
rate = self.get_parameter('publish_rate').value
|
||||
self.max_lin_acc = self.get_parameter('max_linear_accel').value
|
||||
self.max_lin_dec = self.get_parameter('max_linear_decel').value
|
||||
self.max_ang_acc = self.get_parameter('max_angular_accel').value
|
||||
self.max_ang_dec = self.get_parameter('max_angular_decel').value
|
||||
self.max_lin_jerk = self.get_parameter('max_linear_jerk').value
|
||||
self.max_ang_jerk = self.get_parameter('max_angular_jerk').value
|
||||
|
||||
self._dt = 1.0 / rate
|
||||
|
||||
# State
|
||||
self._target_lin = 0.0
|
||||
self._target_ang = 0.0
|
||||
self._current_lin = 0.0
|
||||
self._current_ang = 0.0
|
||||
self._current_lin_acc = 0.0 # for jerk limiting
|
||||
self._current_ang_acc = 0.0
|
||||
self._e_stop = False
|
||||
|
||||
# Publisher
|
||||
self._pub = self.create_publisher(Twist, '/cmd_vel_smoothed', 10)
|
||||
|
||||
# Subscriptions
|
||||
self.create_subscription(Twist, '/cmd_vel', self._cmd_vel_cb, 10)
|
||||
self.create_subscription(Bool, '/e_stop', self._e_stop_cb, 10)
|
||||
|
||||
# Timer
|
||||
self.create_timer(self._dt, self._timer_cb)
|
||||
|
||||
self.get_logger().info(
|
||||
f'VelocitySmoother ready at {rate:.0f} Hz — '
|
||||
f'lin_acc={self.max_lin_acc} lin_dec={self.max_lin_dec} '
|
||||
f'ang_acc={self.max_ang_acc} ang_dec={self.max_ang_dec}'
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _cmd_vel_cb(self, msg: Twist):
|
||||
self._target_lin = msg.linear.x
|
||||
self._target_ang = msg.angular.z
|
||||
|
||||
def _e_stop_cb(self, msg: Bool):
|
||||
self._e_stop = msg.data
|
||||
if self._e_stop:
|
||||
self._target_lin = 0.0
|
||||
self._target_ang = 0.0
|
||||
self._current_lin = 0.0
|
||||
self._current_ang = 0.0
|
||||
self._current_lin_acc = 0.0
|
||||
self._current_ang_acc = 0.0
|
||||
self.get_logger().warn('E-STOP active — motors zeroed immediately')
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _ramp(
|
||||
self,
|
||||
current: float,
|
||||
target: float,
|
||||
max_acc: float,
|
||||
max_dec: float,
|
||||
current_acc: float,
|
||||
max_jerk: float,
|
||||
) -> tuple[float, float]:
|
||||
"""
|
||||
Advance `current` toward `target` with separate accel/decel limits.
|
||||
Optionally apply jerk limiting to the acceleration.
|
||||
Returns (new_value, new_acc).
|
||||
"""
|
||||
error = target - current
|
||||
|
||||
# Choose limit: decelerate if moving away from zero and target is
|
||||
# closer to zero (or past it), else accelerate.
|
||||
if current * error < 0 or (error != 0 and abs(target) < abs(current)):
|
||||
limit = max_dec
|
||||
else:
|
||||
limit = max_acc
|
||||
|
||||
desired_acc = math.copysign(min(abs(error) / self._dt, limit), error)
|
||||
# Clamp so we don't overshoot
|
||||
desired_acc = max(-limit, min(limit, desired_acc))
|
||||
|
||||
if max_jerk > 0.0:
|
||||
max_d_acc = max_jerk * self._dt
|
||||
new_acc = current_acc + max(
|
||||
-max_d_acc, min(max_d_acc, desired_acc - current_acc)
|
||||
)
|
||||
else:
|
||||
new_acc = desired_acc
|
||||
|
||||
delta = new_acc * self._dt
|
||||
# Clamp so we don't overshoot the target
|
||||
if abs(delta) > abs(error):
|
||||
delta = error
|
||||
return current + delta, new_acc
|
||||
|
||||
def _timer_cb(self):
|
||||
if self._e_stop:
|
||||
msg = Twist()
|
||||
self._pub.publish(msg)
|
||||
return
|
||||
|
||||
self._current_lin, self._current_lin_acc = self._ramp(
|
||||
self._current_lin,
|
||||
self._target_lin,
|
||||
self.max_lin_acc,
|
||||
self.max_lin_dec,
|
||||
self._current_lin_acc,
|
||||
self.max_lin_jerk,
|
||||
)
|
||||
self._current_ang, self._current_ang_acc = self._ramp(
|
||||
self._current_ang,
|
||||
self._target_ang,
|
||||
self.max_ang_acc,
|
||||
self.max_ang_dec,
|
||||
self._current_ang_acc,
|
||||
self.max_lin_jerk, # intentional: use linear jerk scale for angular too if set
|
||||
)
|
||||
|
||||
msg = Twist()
|
||||
msg.linear.x = self._current_lin
|
||||
msg.angular.z = self._current_ang
|
||||
self._pub.publish(msg)
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = VelocitySmoother()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -63,7 +63,7 @@ class VescCanDriver(Node):
|
||||
self._last_cmd_time = time.monotonic()
|
||||
|
||||
# Subscriber
|
||||
self.create_subscription(Twist, '/cmd_vel', self._cmd_vel_cb, 10)
|
||||
self.create_subscription(Twist, '/cmd_vel_smoothed', self._cmd_vel_cb, 10)
|
||||
|
||||
# Watchdog timer (10 Hz)
|
||||
self.create_timer(0.1, self._watchdog_cb)
|
||||
|
||||
@ -22,6 +22,7 @@ setup(
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"vesc_driver_node = saltybot_vesc_driver.vesc_driver_node:main",
|
||||
"velocity_smoother_node = saltybot_vesc_driver.velocity_smoother_node:main",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@ -0,0 +1,237 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Unit tests for VelocitySmoother ramp logic."""
|
||||
|
||||
import math
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal ROS2 stubs so tests run without a ROS2 installation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _FakeParam:
|
||||
def __init__(self, val):
|
||||
self._val = val
|
||||
@property
|
||||
def value(self):
|
||||
return self._val
|
||||
|
||||
|
||||
class _FakeNode:
|
||||
def __init__(self):
|
||||
self._params = {}
|
||||
def declare_parameter(self, name, default):
|
||||
self._params[name] = _FakeParam(default)
|
||||
def get_parameter(self, name):
|
||||
return self._params[name]
|
||||
def create_publisher(self, *a, **kw):
|
||||
return MagicMock()
|
||||
def create_subscription(self, *a, **kw):
|
||||
return MagicMock()
|
||||
def create_timer(self, *a, **kw):
|
||||
return MagicMock()
|
||||
def get_logger(self):
|
||||
log = MagicMock()
|
||||
log.info = MagicMock()
|
||||
log.warn = MagicMock()
|
||||
return log
|
||||
|
||||
|
||||
# Patch rclpy and geometry_msgs before importing the module.
|
||||
# rclpy.node.Node must be a *real* class so that VelocitySmoother can
|
||||
# inherit from it without becoming a MagicMock itself.
|
||||
import sys
|
||||
|
||||
class _RealNodeBase:
|
||||
"""Minimal real base class that stands in for rclpy.node.Node."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def declare_parameter(self, name, default):
|
||||
if not hasattr(self, '_params'):
|
||||
self._params = {}
|
||||
self._params[name] = _FakeParam(default)
|
||||
|
||||
def get_parameter(self, name):
|
||||
return self._params[name]
|
||||
|
||||
def create_publisher(self, *a, **kw):
|
||||
return MagicMock()
|
||||
|
||||
def create_subscription(self, *a, **kw):
|
||||
return MagicMock()
|
||||
|
||||
def create_timer(self, *a, **kw):
|
||||
return MagicMock()
|
||||
|
||||
def get_logger(self):
|
||||
log = MagicMock()
|
||||
log.info = MagicMock()
|
||||
log.warn = MagicMock()
|
||||
return log
|
||||
|
||||
rclpy_node_mod = MagicMock()
|
||||
rclpy_node_mod.Node = _RealNodeBase
|
||||
|
||||
rclpy_mock = MagicMock()
|
||||
rclpy_mock.node = rclpy_node_mod
|
||||
|
||||
sys.modules['rclpy'] = rclpy_mock
|
||||
sys.modules['rclpy.node'] = rclpy_node_mod
|
||||
sys.modules.setdefault('geometry_msgs', MagicMock())
|
||||
sys.modules.setdefault('geometry_msgs.msg', MagicMock())
|
||||
sys.modules.setdefault('std_msgs', MagicMock())
|
||||
sys.modules.setdefault('std_msgs.msg', MagicMock())
|
||||
|
||||
# Provide a real Twist-like object for tests
|
||||
class _Twist:
|
||||
class _Vec:
|
||||
x = 0.0
|
||||
y = 0.0
|
||||
z = 0.0
|
||||
def __init__(self):
|
||||
self.linear = self._Vec()
|
||||
self.angular = self._Vec()
|
||||
|
||||
sys.modules['geometry_msgs.msg'].Twist = _Twist
|
||||
sys.modules['std_msgs.msg'].Bool = MagicMock()
|
||||
|
||||
import importlib, os, pathlib
|
||||
sys.path.insert(0, str(pathlib.Path(__file__).parents[1] / 'saltybot_vesc_driver'))
|
||||
|
||||
# Now we can import the smoother logic directly
|
||||
from velocity_smoother_node import VelocitySmoother
|
||||
|
||||
|
||||
def _make_smoother(**params):
|
||||
"""Create a VelocitySmoother backed by _RealNodeBase with custom params."""
|
||||
defaults = dict(
|
||||
publish_rate=50.0,
|
||||
max_linear_accel=1.0,
|
||||
max_linear_decel=2.0,
|
||||
max_angular_accel=1.5,
|
||||
max_angular_decel=3.0,
|
||||
max_linear_jerk=0.0,
|
||||
max_angular_jerk=0.0,
|
||||
)
|
||||
defaults.update(params)
|
||||
|
||||
# Pre-seed _params before __init__ so declare_parameter picks up overrides.
|
||||
node = VelocitySmoother.__new__(VelocitySmoother)
|
||||
node._params = {k: _FakeParam(v) for k, v in defaults.items()}
|
||||
|
||||
# Monkey-patch declare_parameter so it doesn't overwrite our pre-seeded values
|
||||
original_declare = _RealNodeBase.declare_parameter
|
||||
def _noop_declare(self, name, default):
|
||||
pass # params already seeded
|
||||
_RealNodeBase.declare_parameter = _noop_declare
|
||||
|
||||
try:
|
||||
VelocitySmoother.__init__(node)
|
||||
finally:
|
||||
_RealNodeBase.declare_parameter = original_declare
|
||||
|
||||
return node
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRampLogic(unittest.TestCase):
|
||||
|
||||
def _make(self, **kw):
|
||||
return _make_smoother(**kw)
|
||||
|
||||
def test_ramp_reaches_target_within_expected_steps(self):
|
||||
"""From 0 to 1 m/s with accel=1 m/s² at 50 Hz → ~50 steps."""
|
||||
node = self._make(max_linear_accel=1.0, publish_rate=50.0)
|
||||
node._target_lin = 1.0
|
||||
steps = 0
|
||||
while abs(node._current_lin - 1.0) > 0.01 and steps < 200:
|
||||
node._timer_cb()
|
||||
steps += 1
|
||||
self.assertLessEqual(steps, 55, "Should reach 1 m/s within ~55 steps at 50 Hz")
|
||||
self.assertAlmostEqual(node._current_lin, 1.0, places=2)
|
||||
|
||||
def test_decel_faster_than_accel(self):
|
||||
"""Deceleration should complete faster than acceleration."""
|
||||
node = self._make(max_linear_accel=1.0, max_linear_decel=2.0, publish_rate=50.0)
|
||||
# Ramp up
|
||||
node._target_lin = 1.0
|
||||
for _ in range(100):
|
||||
node._timer_cb()
|
||||
# Now decelerate
|
||||
node._target_lin = 0.0
|
||||
decel_steps = 0
|
||||
while abs(node._current_lin) > 0.01 and decel_steps < 200:
|
||||
node._timer_cb()
|
||||
decel_steps += 1
|
||||
# Ramp up again to measure accel steps
|
||||
node._current_lin = 0.0
|
||||
node._target_lin = 1.0
|
||||
accel_steps = 0
|
||||
while abs(node._current_lin - 1.0) > 0.01 and accel_steps < 200:
|
||||
node._timer_cb()
|
||||
accel_steps += 1
|
||||
self.assertLess(decel_steps, accel_steps,
|
||||
"Decel (2 m/s²) should finish in fewer steps than accel (1 m/s²)")
|
||||
|
||||
def test_e_stop_zeros_immediately(self):
|
||||
"""E-stop must zero velocity in the same callback, bypassing ramp."""
|
||||
node = self._make()
|
||||
node._current_lin = 2.0
|
||||
node._current_ang = 1.0
|
||||
node._target_lin = 2.0
|
||||
node._target_ang = 1.0
|
||||
|
||||
msg = MagicMock()
|
||||
msg.data = True
|
||||
node._e_stop_cb(msg)
|
||||
|
||||
self.assertEqual(node._current_lin, 0.0)
|
||||
self.assertEqual(node._current_ang, 0.0)
|
||||
self.assertTrue(node._e_stop)
|
||||
|
||||
def test_no_overshoot(self):
|
||||
"""Current velocity must never exceed target during ramp-up."""
|
||||
node = self._make(max_linear_accel=1.0, publish_rate=50.0)
|
||||
node._target_lin = 0.5
|
||||
for _ in range(100):
|
||||
node._timer_cb()
|
||||
self.assertLessEqual(node._current_lin, 0.5 + 1e-9)
|
||||
|
||||
def test_negative_velocity_ramp(self):
|
||||
"""Ramp works symmetrically for negative targets."""
|
||||
node = self._make(max_linear_accel=1.0, publish_rate=50.0)
|
||||
node._target_lin = -1.0
|
||||
for _ in range(200):
|
||||
node._timer_cb()
|
||||
self.assertAlmostEqual(node._current_lin, -1.0, places=2)
|
||||
|
||||
def test_angular_ramp(self):
|
||||
"""Angular velocity ramps correctly."""
|
||||
node = self._make(max_angular_accel=1.5, publish_rate=50.0)
|
||||
node._target_ang = 1.0
|
||||
for _ in range(200):
|
||||
node._timer_cb()
|
||||
self.assertAlmostEqual(node._current_ang, 1.0, places=2)
|
||||
|
||||
def test_ramp_timing_linear(self):
|
||||
"""Time to ramp from 0 to v_max ≈ v_max / accel (±10%)."""
|
||||
accel = 1.0 # m/s²
|
||||
v_max = 1.0 # m/s
|
||||
rate = 50.0
|
||||
expected_time = v_max / accel # 1.0 s
|
||||
node = self._make(max_linear_accel=accel, publish_rate=rate)
|
||||
node._target_lin = v_max
|
||||
steps = 0
|
||||
while abs(node._current_lin - v_max) > 0.01 and steps < 500:
|
||||
node._timer_cb()
|
||||
steps += 1
|
||||
actual_time = steps / rate
|
||||
self.assertAlmostEqual(actual_time, expected_time, delta=expected_time * 0.12,
|
||||
msg=f"Ramp time {actual_time:.2f}s expected ~{expected_time:.2f}s")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
x
Reference in New Issue
Block a user