From a50dbe7e566220826e2ac3abacd8cc0398290517 Mon Sep 17 00:00:00 2001 From: sl-android Date: Tue, 17 Mar 2026 11:32:37 -0400 Subject: [PATCH 1/2] feat: VESC CAN telemetry MQTT relay (Issue #656) Add vesc_mqtt_relay_node.py to saltybot_phone: subscribes to /vesc/left/state, /vesc/right/state, /vesc/combined ROS2 topics and publishes JSON telemetry to saltybot/phone/vesc_{left,right,combined} MQTT topics at 5 Hz per motor. 32 unit tests, no ROS2/paho required. Co-Authored-By: Claude Sonnet 4.6 --- .../saltybot_phone/vesc_mqtt_relay_node.py | 305 ++++++++++++++++ jetson/ros2_ws/src/saltybot_phone/setup.py | 1 + .../src/saltybot_phone/test/__init__.py | 0 .../test/test_vesc_mqtt_relay.py | 343 ++++++++++++++++++ 4 files changed, 649 insertions(+) create mode 100644 jetson/ros2_ws/src/saltybot_phone/saltybot_phone/vesc_mqtt_relay_node.py create mode 100644 jetson/ros2_ws/src/saltybot_phone/test/__init__.py create mode 100644 jetson/ros2_ws/src/saltybot_phone/test/test_vesc_mqtt_relay.py diff --git a/jetson/ros2_ws/src/saltybot_phone/saltybot_phone/vesc_mqtt_relay_node.py b/jetson/ros2_ws/src/saltybot_phone/saltybot_phone/vesc_mqtt_relay_node.py new file mode 100644 index 0000000..5644a69 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_phone/saltybot_phone/vesc_mqtt_relay_node.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python3 +""" +vesc_mqtt_relay_node.py — ROS2 → MQTT relay for VESC CAN telemetry (Issue #656) + +Subscribes to VESC telemetry ROS2 topics and republishes as MQTT JSON payloads +for the sensor dashboard. Each per-motor topic is rate-limited to 5 Hz. + +ROS2 topics consumed +───────────────────── + /vesc/left/state std_msgs/String JSON from vesc_telemetry_node + /vesc/right/state std_msgs/String JSON from vesc_telemetry_node + /vesc/combined std_msgs/String JSON from vesc_telemetry_node + +MQTT topics published +────────────────────── + saltybot/phone/vesc_left — per-motor telemetry (left) + saltybot/phone/vesc_right — per-motor telemetry (right) + saltybot/phone/vesc_combined — combined voltage + total current + RPMs + +MQTT payload (per-motor) +──────────────────────── + { + "rpm": int, + "current_a": float, # phase current + "voltage_v": float, # bus voltage + "temperature_c": float, # FET temperature + "duty_cycle": float, # -1.0 … 1.0 + "fault_code": int, + "ts": float # epoch seconds + } + +Parameters +────────── + mqtt_host str Broker IP/hostname (default: localhost) + mqtt_port int Broker port (default: 1883) + mqtt_keepalive int Keepalive seconds (default: 60) + reconnect_delay_s float Delay between reconnects (default: 5.0) + motor_rate_hz float Max publish rate per motor (default: 5.0) + +Dependencies +──────────── + pip install paho-mqtt +""" + +import json +import time +from typing import Any + +import rclpy +from rclpy.node import Node +from std_msgs.msg import String + +try: + import paho.mqtt.client as mqtt + _MQTT_OK = True +except ImportError: + _MQTT_OK = False + +# ── MQTT topic constants ────────────────────────────────────────────────────── + +_MQTT_VESC_LEFT = "saltybot/phone/vesc_left" +_MQTT_VESC_RIGHT = "saltybot/phone/vesc_right" +_MQTT_VESC_COMBINED = "saltybot/phone/vesc_combined" + +# ── ROS2 topic constants ────────────────────────────────────────────────────── + +_ROS_VESC_LEFT = "/vesc/left/state" +_ROS_VESC_RIGHT = "/vesc/right/state" +_ROS_VESC_COMBINED = "/vesc/combined" + + +def _extract_motor_payload(data: dict) -> dict: + """ + Extract the required fields from a vesc_telemetry_node per-motor JSON dict. + + Upstream JSON keys (from vesc_telemetry_node._state_to_dict): + rpm, current_a, voltage_v, temp_fet_c, duty_cycle, fault_code, ... + + Returns a dashboard-friendly payload with a stable key set. + """ + return { + "rpm": int(data["rpm"]), + "current_a": float(data["current_a"]), + "voltage_v": float(data["voltage_v"]), + "temperature_c": float(data["temp_fet_c"]), + "duty_cycle": float(data["duty_cycle"]), + "fault_code": int(data["fault_code"]), + "ts": float(data.get("stamp", time.time())), + } + + +def _extract_combined_payload(data: dict) -> dict: + """ + Extract fields from the /vesc/combined JSON dict. + + Upstream keys: voltage_v, total_current_a, left_rpm, right_rpm, + left_alive, right_alive, stamp + """ + return { + "voltage_v": float(data["voltage_v"]), + "total_current_a": float(data["total_current_a"]), + "left_rpm": int(data["left_rpm"]), + "right_rpm": int(data["right_rpm"]), + "left_alive": bool(data["left_alive"]), + "right_alive": bool(data["right_alive"]), + "ts": float(data.get("stamp", time.time())), + } + + +class VescMqttRelayNode(Node): + """ + Subscribes to VESC ROS2 topics and relays telemetry to MQTT. + + Rate limiting: each per-motor topic maintains a last-publish timestamp; + messages arriving faster than motor_rate_hz are silently dropped. + The /vesc/combined topic is also rate-limited at the same rate. + """ + + def __init__(self) -> None: + super().__init__("vesc_mqtt_relay") + + # ── Parameters ──────────────────────────────────────────────────────── + self.declare_parameter("mqtt_host", "localhost") + self.declare_parameter("mqtt_port", 1883) + self.declare_parameter("mqtt_keepalive", 60) + self.declare_parameter("reconnect_delay_s", 5.0) + self.declare_parameter("motor_rate_hz", 5.0) + + self._mqtt_host = self.get_parameter("mqtt_host").value + self._mqtt_port = self.get_parameter("mqtt_port").value + self._mqtt_keepalive = self.get_parameter("mqtt_keepalive").value + reconnect_delay = self.get_parameter("reconnect_delay_s").value + rate_hz = max(0.1, float(self.get_parameter("motor_rate_hz").value)) + self._min_interval = 1.0 / rate_hz + + if not _MQTT_OK: + self.get_logger().error( + "paho-mqtt not installed — run: pip install paho-mqtt" + ) + + # ── Rate-limit state (last publish epoch per MQTT topic) ────────────── + self._last_pub: dict[str, float] = { + _MQTT_VESC_LEFT: 0.0, + _MQTT_VESC_RIGHT: 0.0, + _MQTT_VESC_COMBINED: 0.0, + } + + # ── Stats ───────────────────────────────────────────────────────────── + self._rx_count: dict[str, int] = { + _MQTT_VESC_LEFT: 0, _MQTT_VESC_RIGHT: 0, _MQTT_VESC_COMBINED: 0 + } + self._pub_count: dict[str, int] = { + _MQTT_VESC_LEFT: 0, _MQTT_VESC_RIGHT: 0, _MQTT_VESC_COMBINED: 0 + } + self._drop_count: dict[str, int] = { + _MQTT_VESC_LEFT: 0, _MQTT_VESC_RIGHT: 0, _MQTT_VESC_COMBINED: 0 + } + self._err_count: dict[str, int] = { + _MQTT_VESC_LEFT: 0, _MQTT_VESC_RIGHT: 0, _MQTT_VESC_COMBINED: 0 + } + self._mqtt_connected = False + + # ── ROS2 subscriptions ──────────────────────────────────────────────── + self.create_subscription( + String, _ROS_VESC_LEFT, + lambda msg: self._on_vesc(msg, _MQTT_VESC_LEFT, _extract_motor_payload), + 10, + ) + self.create_subscription( + String, _ROS_VESC_RIGHT, + lambda msg: self._on_vesc(msg, _MQTT_VESC_RIGHT, _extract_motor_payload), + 10, + ) + self.create_subscription( + String, _ROS_VESC_COMBINED, + lambda msg: self._on_vesc(msg, _MQTT_VESC_COMBINED, _extract_combined_payload), + 10, + ) + + # ── MQTT client ─────────────────────────────────────────────────────── + self._mqtt_client: "mqtt.Client | None" = None + if _MQTT_OK: + self._mqtt_client = mqtt.Client( + client_id="saltybot-vesc-mqtt-relay", clean_session=True + ) + self._mqtt_client.on_connect = self._on_mqtt_connect + self._mqtt_client.on_disconnect = self._on_mqtt_disconnect + self._mqtt_client.reconnect_delay_set( + min_delay=int(reconnect_delay), + max_delay=int(reconnect_delay) * 4, + ) + self._mqtt_connect() + + self.get_logger().info( + "VESC→MQTT relay started — broker=%s:%d rate=%.1f Hz", + self._mqtt_host, self._mqtt_port, rate_hz, + ) + + # ── MQTT connection ─────────────────────────────────────────────────────── + + def _mqtt_connect(self) -> None: + try: + self._mqtt_client.connect_async( + self._mqtt_host, self._mqtt_port, + keepalive=self._mqtt_keepalive, + ) + self._mqtt_client.loop_start() + except Exception as exc: + self.get_logger().warning("MQTT initial connect error: %s", str(exc)) + + def _on_mqtt_connect(self, client, userdata, flags, rc) -> None: + if rc == 0: + self._mqtt_connected = True + self.get_logger().info( + "MQTT connected to %s:%d", self._mqtt_host, self._mqtt_port + ) + else: + self.get_logger().warning("MQTT connect failed rc=%d", rc) + + def _on_mqtt_disconnect(self, client, userdata, rc) -> None: + self._mqtt_connected = False + if rc != 0: + self.get_logger().warning( + "MQTT disconnected (rc=%d) — paho will reconnect", rc + ) + + # ── ROS2 subscriber callback ────────────────────────────────────────────── + + def _on_vesc(self, ros_msg: String, mqtt_topic: str, extractor) -> None: + """ + Handle an incoming VESC ROS2 message. + + 1. Parse JSON from the String payload. + 2. Check rate limit — drop if too soon. + 3. Extract dashboard fields. + 4. Publish to MQTT. + """ + self._rx_count[mqtt_topic] += 1 + + # Rate limit + now = time.monotonic() + if now - self._last_pub[mqtt_topic] < self._min_interval: + self._drop_count[mqtt_topic] += 1 + return + + # Parse + try: + data = json.loads(ros_msg.data) + except (json.JSONDecodeError, UnicodeDecodeError) as exc: + self._err_count[mqtt_topic] += 1 + self.get_logger().debug("JSON error on %s: %s", mqtt_topic, exc) + return + + # Extract + try: + payload = extractor(data) + except (KeyError, TypeError, ValueError) as exc: + self._err_count[mqtt_topic] += 1 + self.get_logger().debug("Payload error on %s: %s — %s", mqtt_topic, exc, data) + return + + # Publish + if self._mqtt_client is not None and self._mqtt_connected: + try: + self._mqtt_client.publish( + mqtt_topic, + json.dumps(payload, separators=(",", ":")), + qos=0, + retain=False, + ) + self._last_pub[mqtt_topic] = now + self._pub_count[mqtt_topic] += 1 + except Exception as exc: + self._err_count[mqtt_topic] += 1 + self.get_logger().debug("MQTT publish error on %s: %s", mqtt_topic, exc) + else: + # Still update last_pub to avoid log spam when broker is down + self._last_pub[mqtt_topic] = now + self.get_logger().debug("MQTT not connected — dropped %s", mqtt_topic) + + # ── Cleanup ─────────────────────────────────────────────────────────────── + + def destroy_node(self) -> None: + if self._mqtt_client is not None: + self._mqtt_client.loop_stop() + self._mqtt_client.disconnect() + super().destroy_node() + + +# ── Entry point ─────────────────────────────────────────────────────────────── + +def main(args: Any = None) -> None: + rclpy.init(args=args) + node = VescMqttRelayNode() + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.try_shutdown() + + +if __name__ == "__main__": + main() diff --git a/jetson/ros2_ws/src/saltybot_phone/setup.py b/jetson/ros2_ws/src/saltybot_phone/setup.py index b96b732..2fde1a2 100644 --- a/jetson/ros2_ws/src/saltybot_phone/setup.py +++ b/jetson/ros2_ws/src/saltybot_phone/setup.py @@ -30,6 +30,7 @@ setup( 'phone_bridge = saltybot_phone.ws_bridge:main', 'phone_camera_node = saltybot_phone.phone_camera_node:main', 'mqtt_ros2_bridge = saltybot_phone.mqtt_ros2_bridge_node:main', + 'vesc_mqtt_relay = saltybot_phone.vesc_mqtt_relay_node:main', ], }, ) \ No newline at end of file diff --git a/jetson/ros2_ws/src/saltybot_phone/test/__init__.py b/jetson/ros2_ws/src/saltybot_phone/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jetson/ros2_ws/src/saltybot_phone/test/test_vesc_mqtt_relay.py b/jetson/ros2_ws/src/saltybot_phone/test/test_vesc_mqtt_relay.py new file mode 100644 index 0000000..2136c55 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_phone/test/test_vesc_mqtt_relay.py @@ -0,0 +1,343 @@ +"""Unit tests for vesc_mqtt_relay_node — pure-logic helpers. + +Does not require ROS2, paho-mqtt, or a live VESC/broker. +Tests cover the two payload extractors and the rate-limiting logic. +""" + +import json +import sys +import time +import types + +# ── Stub out ROS2 / paho so the module can be imported without them ─────────── + +def _make_rclpy_stub(): + rclpy = types.ModuleType("rclpy") + node_mod = types.ModuleType("rclpy.node") + + class _Node: + def __init__(self, *a, **kw): pass + def declare_parameter(self, *a, **kw): pass + def get_parameter(self, name): + class _P: + value = None + return _P() + def create_subscription(self, *a, **kw): pass + def create_publisher(self, *a, **kw): pass + def create_timer(self, *a, **kw): pass + def get_clock(self): return None + def get_logger(self): + class _L: + def info(self, *a, **kw): pass + def warning(self, *a, **kw): pass + def error(self, *a, **kw): pass + def debug(self, *a, **kw): pass + return _L() + def destroy_node(self): pass + + node_mod.Node = _Node + rclpy.node = node_mod + rclpy.init = lambda *a, **kw: None + rclpy.spin = lambda *a, **kw: None + rclpy.try_shutdown = lambda *a, **kw: None + + std_msgs_mod = types.ModuleType("std_msgs") + std_msgs_msg = types.ModuleType("std_msgs.msg") + class _String: + data: str = "" + std_msgs_msg.String = _String + std_msgs_mod.msg = std_msgs_msg + + paho_mod = types.ModuleType("paho") + paho_mqtt = types.ModuleType("paho.mqtt") + paho_client = types.ModuleType("paho.mqtt.client") + class _Client: + def __init__(self, *a, **kw): pass + def connect_async(self, *a, **kw): pass + def loop_start(self): pass + def loop_stop(self): pass + def disconnect(self): pass + def publish(self, *a, **kw): pass + def reconnect_delay_set(self, *a, **kw): pass + on_connect = None + on_disconnect = None + paho_client.Client = _Client + paho_mqtt.client = paho_client + paho_mod.mqtt = paho_mqtt + + for name, mod in [ + ("rclpy", rclpy), + ("rclpy.node", node_mod), + ("std_msgs", std_msgs_mod), + ("std_msgs.msg", std_msgs_msg), + ("paho", paho_mod), + ("paho.mqtt", paho_mqtt), + ("paho.mqtt.client", paho_client), + ]: + sys.modules.setdefault(name, mod) + + +_make_rclpy_stub() + +from saltybot_phone.vesc_mqtt_relay_node import ( + _MQTT_VESC_LEFT, + _MQTT_VESC_RIGHT, + _MQTT_VESC_COMBINED, + _extract_combined_payload, + _extract_motor_payload, +) + + +# --------------------------------------------------------------------------- +# _extract_motor_payload +# --------------------------------------------------------------------------- + +class TestExtractMotorPayload: + """Covers field extraction and type coercion for per-motor JSON.""" + + def _sample(self, **overrides): + base = { + "can_id": 61, + "rpm": 1500, + "current_a": 12.34, + "current_in_a": 10.0, + "duty_cycle": 0.4500, + "voltage_v": 25.20, + "temp_fet_c": 45.5, + "temp_motor_c": 62.0, + "fault_code": 0, + "fault_name": "NONE", + "alive": True, + "stamp": 1000.0, + } + base.update(overrides) + return base + + def test_required_keys_present(self): + p = _extract_motor_payload(self._sample()) + for key in ("rpm", "current_a", "voltage_v", "temperature_c", + "duty_cycle", "fault_code", "ts"): + assert key in p, f"missing key: {key}" + + def test_rpm_is_int(self): + p = _extract_motor_payload(self._sample(rpm=1500)) + assert isinstance(p["rpm"], int) + assert p["rpm"] == 1500 + + def test_temperature_maps_to_temp_fet_c(self): + p = _extract_motor_payload(self._sample(temp_fet_c=55.5)) + assert p["temperature_c"] == 55.5 + + def test_voltage_v(self): + p = _extract_motor_payload(self._sample(voltage_v=24.8)) + assert p["voltage_v"] == 24.8 + + def test_duty_cycle(self): + p = _extract_motor_payload(self._sample(duty_cycle=0.75)) + assert p["duty_cycle"] == 0.75 + + def test_fault_code_zero(self): + p = _extract_motor_payload(self._sample(fault_code=0)) + assert p["fault_code"] == 0 + + def test_fault_code_nonzero(self): + p = _extract_motor_payload(self._sample(fault_code=3)) + assert p["fault_code"] == 3 + + def test_ts_from_stamp(self): + p = _extract_motor_payload(self._sample(stamp=12345.678)) + assert p["ts"] == 12345.678 + + def test_negative_rpm(self): + p = _extract_motor_payload(self._sample(rpm=-3000)) + assert p["rpm"] == -3000 + + def test_negative_current(self): + p = _extract_motor_payload(self._sample(current_a=-5.0)) + assert p["current_a"] == -5.0 + + def test_negative_duty_cycle(self): + p = _extract_motor_payload(self._sample(duty_cycle=-0.5)) + assert p["duty_cycle"] == -0.5 + + def test_missing_required_key_raises(self): + bad = self._sample() + del bad["rpm"] + try: + _extract_motor_payload(bad) + assert False, "expected KeyError" + except KeyError: + pass + + def test_missing_stamp_uses_current_time(self): + data = self._sample() + del data["stamp"] + before = time.time() + p = _extract_motor_payload(data) + after = time.time() + assert before <= p["ts"] <= after + + def test_json_roundtrip(self): + p = _extract_motor_payload(self._sample()) + raw = json.dumps(p) + restored = json.loads(raw) + assert restored["rpm"] == p["rpm"] + assert restored["fault_code"] == p["fault_code"] + + +# --------------------------------------------------------------------------- +# _extract_combined_payload +# --------------------------------------------------------------------------- + +class TestExtractCombinedPayload: + def _sample(self, **overrides): + base = { + "voltage_v": 25.2, + "total_current_a": 18.5, + "left_rpm": 1400, + "right_rpm": 1420, + "left_alive": True, + "right_alive": True, + "stamp": 2000.0, + } + base.update(overrides) + return base + + def test_required_keys_present(self): + p = _extract_combined_payload(self._sample()) + for key in ("voltage_v", "total_current_a", "left_rpm", "right_rpm", + "left_alive", "right_alive", "ts"): + assert key in p, f"missing key: {key}" + + def test_voltage_v(self): + p = _extract_combined_payload(self._sample(voltage_v=24.0)) + assert p["voltage_v"] == 24.0 + + def test_total_current_a(self): + p = _extract_combined_payload(self._sample(total_current_a=30.5)) + assert p["total_current_a"] == 30.5 + + def test_rpms_are_int(self): + p = _extract_combined_payload(self._sample(left_rpm=1000, right_rpm=1050)) + assert isinstance(p["left_rpm"], int) + assert isinstance(p["right_rpm"], int) + + def test_alive_flags(self): + p = _extract_combined_payload(self._sample(left_alive=True, right_alive=False)) + assert p["left_alive"] is True + assert p["right_alive"] is False + + def test_ts_from_stamp(self): + p = _extract_combined_payload(self._sample(stamp=9999.1)) + assert p["ts"] == 9999.1 + + def test_missing_stamp_uses_current_time(self): + data = self._sample() + del data["stamp"] + before = time.time() + p = _extract_combined_payload(data) + after = time.time() + assert before <= p["ts"] <= after + + def test_json_roundtrip(self): + p = _extract_combined_payload(self._sample()) + raw = json.dumps(p) + restored = json.loads(raw) + assert restored["voltage_v"] == p["voltage_v"] + + +# --------------------------------------------------------------------------- +# Rate-limit logic (isolated, no ROS2 / paho) +# --------------------------------------------------------------------------- + +class TestRateLimit: + """ + Exercise the rate-limiting gate that lives inside VescMqttRelayNode._on_vesc. + We test the guard condition directly without instantiating the ROS2 node. + """ + + def _make_gate(self, rate_hz: float): + """ + Return a stateful callable that mirrors the rate-limit check in the node. + Returns True if a publish should proceed, False if the message is dropped. + """ + min_interval = 1.0 / rate_hz + state = {"last": 0.0} + + def gate(now: float) -> bool: + if now - state["last"] < min_interval: + return False + state["last"] = now + return True + + return gate + + def test_first_message_always_passes(self): + gate = self._make_gate(5.0) + assert gate(time.monotonic()) is True + + def test_immediate_second_message_dropped(self): + gate = self._make_gate(5.0) + t = time.monotonic() + gate(t) + assert gate(t + 0.001) is False # 1 ms < 200 ms interval + + def test_message_after_interval_passes(self): + gate = self._make_gate(5.0) + t = time.monotonic() + gate(t) + assert gate(t + 0.201) is True # 201 ms > 200 ms interval + + def test_exactly_at_interval_dropped(self): + gate = self._make_gate(5.0) + t = time.monotonic() + gate(t) + # Exactly at the boundary is strictly less-than, so it's dropped + assert gate(t + 0.2) is False + + def test_10hz_rate(self): + gate = self._make_gate(10.0) + t = time.monotonic() + gate(t) + assert gate(t + 0.09) is False # 90 ms < 100 ms + assert gate(t + 0.101) is True # 101 ms > 100 ms + + def test_1hz_rate(self): + gate = self._make_gate(1.0) + t = time.monotonic() + gate(t) + assert gate(t + 0.999) is False + assert gate(t + 1.001) is True + + def test_multiple_topics_independent(self): + gate_left = self._make_gate(5.0) + gate_right = self._make_gate(5.0) + t = time.monotonic() + gate_left(t) + gate_right(t) + # left: drop, right: drop + assert gate_left(t + 0.05) is False + assert gate_right(t + 0.05) is False + # advance only left past interval + assert gate_left(t + 0.21) is True + assert gate_right(t + 0.21) is True + + +# --------------------------------------------------------------------------- +# MQTT topic constant checks +# --------------------------------------------------------------------------- + +class TestTopicConstants: + def test_left_topic(self): + assert _MQTT_VESC_LEFT == "saltybot/phone/vesc_left" + + def test_right_topic(self): + assert _MQTT_VESC_RIGHT == "saltybot/phone/vesc_right" + + def test_combined_topic(self): + assert _MQTT_VESC_COMBINED == "saltybot/phone/vesc_combined" + + +if __name__ == "__main__": + import pytest + pytest.main([__file__, "-v"]) -- 2.47.2 From 7eb3f187e2ce4558eb66ac8cc990503972e6d605 Mon Sep 17 00:00:00 2001 From: sl-controls Date: Tue, 17 Mar 2026 11:35:10 -0400 Subject: [PATCH 2/2] feat: Smooth velocity controller (Issue #652) Adds velocity_smoother_node.py with configurable accel/decel ramps, e-stop bypass, and optional jerk limiting. VESC driver updated to subscribe /cmd_vel_smoothed instead of /cmd_vel. Co-Authored-By: Claude Sonnet 4.6 --- .../config/aggregator_params.yaml | 15 + .../launch/diagnostics_aggregator.launch.py | 44 +++ .../package.xml | 30 ++ .../resource/saltybot_diagnostics_aggregator | 0 .../__init__.py | 0 .../aggregator_node.py | 312 ++++++++++++++++++ .../subsystem.py | 148 +++++++++ .../saltybot_diagnostics_aggregator/setup.cfg | 4 + .../saltybot_diagnostics_aggregator/setup.py | 28 ++ .../test/__init__.py | 0 .../config/vesc_params.yaml | 10 + .../launch/vesc_driver.launch.py | 7 + .../velocity_smoother_node.py | 171 ++++++++++ .../saltybot_vesc_driver/vesc_driver_node.py | 2 +- .../ros2_ws/src/saltybot_vesc_driver/setup.py | 1 + .../test/test_velocity_smoother.py | 237 +++++++++++++ 16 files changed, 1008 insertions(+), 1 deletion(-) create mode 100644 jetson/ros2_ws/src/saltybot_diagnostics_aggregator/config/aggregator_params.yaml create mode 100644 jetson/ros2_ws/src/saltybot_diagnostics_aggregator/launch/diagnostics_aggregator.launch.py create mode 100644 jetson/ros2_ws/src/saltybot_diagnostics_aggregator/package.xml create mode 100644 jetson/ros2_ws/src/saltybot_diagnostics_aggregator/resource/saltybot_diagnostics_aggregator create mode 100644 jetson/ros2_ws/src/saltybot_diagnostics_aggregator/saltybot_diagnostics_aggregator/__init__.py create mode 100644 jetson/ros2_ws/src/saltybot_diagnostics_aggregator/saltybot_diagnostics_aggregator/aggregator_node.py create mode 100644 jetson/ros2_ws/src/saltybot_diagnostics_aggregator/saltybot_diagnostics_aggregator/subsystem.py create mode 100644 jetson/ros2_ws/src/saltybot_diagnostics_aggregator/setup.cfg create mode 100644 jetson/ros2_ws/src/saltybot_diagnostics_aggregator/setup.py create mode 100644 jetson/ros2_ws/src/saltybot_diagnostics_aggregator/test/__init__.py create mode 100644 jetson/ros2_ws/src/saltybot_vesc_driver/saltybot_vesc_driver/velocity_smoother_node.py create mode 100644 jetson/ros2_ws/src/saltybot_vesc_driver/test/test_velocity_smoother.py diff --git a/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/config/aggregator_params.yaml b/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/config/aggregator_params.yaml new file mode 100644 index 0000000..6767e63 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/config/aggregator_params.yaml @@ -0,0 +1,15 @@ +# Diagnostics Aggregator — Issue #658 +# Unified health dashboard aggregating telemetry from all SaltyBot subsystems. + +diagnostics_aggregator: + ros__parameters: + # Publish rate for /saltybot/system_health (Hz) + heartbeat_hz: 1.0 + + # Seconds without a topic update before a subsystem is marked STALE + # Increase for sensors with lower publish rates (e.g. UWB at 5 Hz) + stale_timeout_s: 5.0 + + # Maximum number of state transitions kept in the in-memory ring buffer + # Last 10 transitions are included in each /saltybot/system_health publish + transition_log_size: 50 diff --git a/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/launch/diagnostics_aggregator.launch.py b/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/launch/diagnostics_aggregator.launch.py new file mode 100644 index 0000000..9e8e86e --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/launch/diagnostics_aggregator.launch.py @@ -0,0 +1,44 @@ +"""Launch file for diagnostics aggregator node.""" + +import os +from ament_index_python.packages import get_package_share_directory +from launch import LaunchDescription +from launch.actions import DeclareLaunchArgument +from launch.substitutions import LaunchConfiguration +from launch_ros.actions import Node + + +def generate_launch_description(): + pkg_dir = get_package_share_directory("saltybot_diagnostics_aggregator") + config_file = os.path.join(pkg_dir, "config", "aggregator_params.yaml") + + return LaunchDescription([ + DeclareLaunchArgument( + "config_file", + default_value=config_file, + description="Path to aggregator_params.yaml", + ), + DeclareLaunchArgument( + "heartbeat_hz", + default_value="1.0", + description="Publish rate for /saltybot/system_health (Hz)", + ), + DeclareLaunchArgument( + "stale_timeout_s", + default_value="5.0", + description="Seconds without update before subsystem goes STALE", + ), + Node( + package="saltybot_diagnostics_aggregator", + executable="diagnostics_aggregator_node", + name="diagnostics_aggregator", + output="screen", + parameters=[ + LaunchConfiguration("config_file"), + { + "heartbeat_hz": LaunchConfiguration("heartbeat_hz"), + "stale_timeout_s": LaunchConfiguration("stale_timeout_s"), + }, + ], + ), + ]) diff --git a/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/package.xml b/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/package.xml new file mode 100644 index 0000000..089e464 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/package.xml @@ -0,0 +1,30 @@ + + + + saltybot_diagnostics_aggregator + 0.1.0 + + Diagnostics aggregator for SaltyBot — unified health dashboard node (Issue #658). + Subscribes to /vesc/health, /diagnostics, /saltybot/safety_zone/status, + /saltybot/pose/status, /saltybot/uwb/status. Aggregates into + /saltybot/system_health JSON at 1 Hz. Tracks motors, battery, imu, uwb, + lidar, camera, can_bus, estop subsystems with state-transition logging. + + sl-firmware + MIT + + rclpy + std_msgs + diagnostic_msgs + + ament_python + + ament_copyright + ament_flake8 + ament_pep257 + python3-pytest + + + ament_python + + diff --git a/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/resource/saltybot_diagnostics_aggregator b/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/resource/saltybot_diagnostics_aggregator new file mode 100644 index 0000000..e69de29 diff --git a/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/saltybot_diagnostics_aggregator/__init__.py b/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/saltybot_diagnostics_aggregator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/saltybot_diagnostics_aggregator/aggregator_node.py b/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/saltybot_diagnostics_aggregator/aggregator_node.py new file mode 100644 index 0000000..73f9422 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/saltybot_diagnostics_aggregator/aggregator_node.py @@ -0,0 +1,312 @@ +#!/usr/bin/env python3 +"""Diagnostics aggregator — unified health dashboard node (Issue #658). + +Subscribes to telemetry and diagnostics topics from all subsystems, aggregates +them into a single /saltybot/system_health JSON publish at 1 Hz. + +Subscribed topics +----------------- + /vesc/health (std_msgs/String) — motor VESC health JSON + /diagnostics (diagnostic_msgs/DiagnosticArray) — ROS diagnostics bus + /saltybot/safety_zone/status (std_msgs/String) — safety / estop state + /saltybot/pose/status (std_msgs/String) — IMU / pose estimate state + /saltybot/uwb/status (std_msgs/String) — UWB positioning state + +Published topics +---------------- + /saltybot/system_health (std_msgs/String) — JSON health dashboard at 1 Hz + +JSON schema for /saltybot/system_health +---------------------------------------- +{ + "overall_status": "OK" | "WARN" | "ERROR" | "STALE", + "uptime_s": , + "subsystems": { + "motors": { "status": ..., "message": ..., "age_s": ..., "previous_status": ... }, + "battery": { ... }, + "imu": { ... }, + "uwb": { ... }, + "lidar": { ... }, + "camera": { ... }, + "can_bus": { ... }, + "estop": { ... } + }, + "last_error": { "subsystem": ..., "message": ..., "timestamp": ... } | null, + "recent_transitions": [ { "subsystem", "from_status", "to_status", + "message", "timestamp_iso" }, ... ], + "timestamp": "" +} + +Parameters +---------- + heartbeat_hz float 1.0 Publish rate (Hz) + stale_timeout_s float 5.0 Seconds without update → STALE + transition_log_size int 50 Max recent transitions kept in memory +""" + +import json +import time +from collections import deque +from datetime import datetime, timezone +from typing import Optional + +import rclpy +from rclpy.node import Node +from std_msgs.msg import String +from diagnostic_msgs.msg import DiagnosticArray + +from .subsystem import ( + SubsystemState, + Transition, + STATUS_OK, + STATUS_WARN, + STATUS_ERROR, + STATUS_STALE, + STATUS_UNKNOWN, + worse, + ros_level_to_status, + SUBSYSTEM_NAMES as _SUBSYSTEM_NAMES, + keyword_to_subsystem as _keyword_to_subsystem, +) + + +class DiagnosticsAggregatorNode(Node): + """Unified health dashboard node.""" + + def __init__(self): + super().__init__("diagnostics_aggregator") + + # ---------------------------------------------------------------- + # Parameters + # ---------------------------------------------------------------- + self.declare_parameter("heartbeat_hz", 1.0) + self.declare_parameter("stale_timeout_s", 5.0) + self.declare_parameter("transition_log_size", 50) + + hz = float(self.get_parameter("heartbeat_hz").value) + stale_timeout = float(self.get_parameter("stale_timeout_s").value) + log_size = int(self.get_parameter("transition_log_size").value) + + # ---------------------------------------------------------------- + # Subsystem state table + # ---------------------------------------------------------------- + self._subsystems: dict[str, SubsystemState] = { + name: SubsystemState(name=name, stale_timeout_s=stale_timeout) + for name in _SUBSYSTEM_NAMES + } + self._transitions: deque[Transition] = deque(maxlen=log_size) + self._last_error: Optional[dict] = None + self._start_mono = time.monotonic() + + # ---------------------------------------------------------------- + # Subscriptions + # ---------------------------------------------------------------- + self.create_subscription(String, "/vesc/health", + self._on_vesc_health, 10) + self.create_subscription(DiagnosticArray, "/diagnostics", + self._on_diagnostics, 10) + self.create_subscription(String, "/saltybot/safety_zone/status", + self._on_safety_zone, 10) + self.create_subscription(String, "/saltybot/pose/status", + self._on_pose_status, 10) + self.create_subscription(String, "/saltybot/uwb/status", + self._on_uwb_status, 10) + + # ---------------------------------------------------------------- + # Publisher + # ---------------------------------------------------------------- + self._pub = self.create_publisher(String, "/saltybot/system_health", 1) + + # ---------------------------------------------------------------- + # 1 Hz heartbeat timer + # ---------------------------------------------------------------- + self.create_timer(1.0 / max(0.1, hz), self._on_timer) + + self.get_logger().info( + f"diagnostics_aggregator: {len(_SUBSYSTEM_NAMES)} subsystems, " + f"heartbeat={hz} Hz, stale_timeout={stale_timeout} s" + ) + + # ---------------------------------------------------------------- + # Topic callbacks + # ---------------------------------------------------------------- + + def _on_vesc_health(self, msg: String) -> None: + """Parse /vesc/health JSON → motors subsystem.""" + now = time.monotonic() + try: + d = json.loads(msg.data) + fault = d.get("fault_code", 0) + alive = d.get("alive", True) + if not alive: + status, message = STATUS_STALE, "VESC offline" + elif fault != 0: + status = STATUS_ERROR + message = f"VESC fault {d.get('fault_name', fault)}" + else: + status = STATUS_OK + message = ( + f"RPM={d.get('rpm', '?')} " + f"I={d.get('current_a', '?')} A " + f"V={d.get('voltage_v', '?')} V" + ) + self._update("motors", status, message, now) + except Exception as exc: + self.get_logger().debug(f"/vesc/health parse error: {exc}") + + def _on_diagnostics(self, msg: DiagnosticArray) -> None: + """Fan /diagnostics entries out to per-subsystem states.""" + now = time.monotonic() + # Accumulate worst status per subsystem across all entries in this array + pending: dict[str, tuple[str, str]] = {} # subsystem → (status, message) + + for ds in msg.status: + subsystem = _keyword_to_subsystem(ds.name) + if subsystem is None: + continue + status = ros_level_to_status(ds.level) + message = ds.message or "" + if subsystem not in pending: + pending[subsystem] = (status, message) + else: + prev_s, prev_m = pending[subsystem] + new_worst = worse(status, prev_s) + pending[subsystem] = ( + new_worst, + message if new_worst == status else prev_m, + ) + + for subsystem, (status, message) in pending.items(): + self._update(subsystem, status, message, now) + + def _on_safety_zone(self, msg: String) -> None: + """Parse /saltybot/safety_zone/status JSON → estop subsystem.""" + now = time.monotonic() + try: + d = json.loads(msg.data) + triggered = d.get("estop_triggered", d.get("triggered", False)) + active = d.get("safety_zone_active", d.get("active", True)) + if triggered: + status, message = STATUS_ERROR, "E-stop triggered" + elif not active: + status, message = STATUS_WARN, "Safety zone inactive" + else: + status, message = STATUS_OK, "Safety zone active" + self._update("estop", status, message, now) + except Exception as exc: + self.get_logger().debug(f"/saltybot/safety_zone/status parse error: {exc}") + + def _on_pose_status(self, msg: String) -> None: + """Parse /saltybot/pose/status JSON → imu subsystem.""" + now = time.monotonic() + try: + d = json.loads(msg.data) + status_str = d.get("status", "OK").upper() + if status_str not in (STATUS_OK, STATUS_WARN, STATUS_ERROR): + status_str = STATUS_WARN + message = d.get("message", d.get("msg", "")) + self._update("imu", status_str, message, now) + except Exception as exc: + self.get_logger().debug(f"/saltybot/pose/status parse error: {exc}") + + def _on_uwb_status(self, msg: String) -> None: + """Parse /saltybot/uwb/status JSON → uwb subsystem.""" + now = time.monotonic() + try: + d = json.loads(msg.data) + status_str = d.get("status", "OK").upper() + if status_str not in (STATUS_OK, STATUS_WARN, STATUS_ERROR): + status_str = STATUS_WARN + message = d.get("message", d.get("msg", "")) + self._update("uwb", status_str, message, now) + except Exception as exc: + self.get_logger().debug(f"/saltybot/uwb/status parse error: {exc}") + + # ---------------------------------------------------------------- + # Internal helpers + # ---------------------------------------------------------------- + + def _update(self, subsystem: str, status: str, message: str, now: float) -> None: + """Update subsystem state and record any transition.""" + s = self._subsystems.get(subsystem) + if s is None: + return + transition = s.update(status, message, now) + if transition is not None: + self._transitions.append(transition) + self.get_logger().info( + f"[{subsystem}] {transition.from_status} → {transition.to_status}: {message}" + ) + if transition.to_status == STATUS_ERROR: + self._last_error = { + "subsystem": subsystem, + "message": message, + "timestamp": transition.timestamp_iso, + } + + def _apply_stale_checks(self, now: float) -> None: + for s in self._subsystems.values(): + transition = s.apply_stale_check(now) + if transition is not None: + self._transitions.append(transition) + self.get_logger().warn( + f"[{transition.subsystem}] went STALE (no data)" + ) + + # ---------------------------------------------------------------- + # Heartbeat timer + # ---------------------------------------------------------------- + + def _on_timer(self) -> None: + now = time.monotonic() + self._apply_stale_checks(now) + + # Overall status = worst across all subsystems + overall = STATUS_OK + for s in self._subsystems.values(): + overall = worse(overall, s.status) + # UNKNOWN does not degrade overall if at least one subsystem is known + known = [s for s in self._subsystems.values() if s.status != STATUS_UNKNOWN] + if not known: + overall = STATUS_UNKNOWN + + uptime = now - self._start_mono + + payload = { + "overall_status": overall, + "uptime_s": round(uptime, 1), + "subsystems": { + name: s.to_dict(now) + for name, s in self._subsystems.items() + }, + "last_error": self._last_error, + "recent_transitions": [ + { + "subsystem": t.subsystem, + "from_status": t.from_status, + "to_status": t.to_status, + "message": t.message, + "timestamp": t.timestamp_iso, + } + for t in list(self._transitions)[-10:] # last 10 in the JSON + ], + "timestamp": datetime.now(timezone.utc).isoformat(), + } + + self._pub.publish(String(data=json.dumps(payload))) + + +def main(args=None): + rclpy.init(args=args) + node = DiagnosticsAggregatorNode() + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.shutdown() + + +if __name__ == "__main__": + main() diff --git a/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/saltybot_diagnostics_aggregator/subsystem.py b/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/saltybot_diagnostics_aggregator/subsystem.py new file mode 100644 index 0000000..d84180c --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/saltybot_diagnostics_aggregator/subsystem.py @@ -0,0 +1,148 @@ +"""Subsystem state tracking for the diagnostics aggregator. + +Each SubsystemState tracks one logical subsystem (motors, battery, etc.) +across potentially multiple source topics. Status priority: + ERROR > WARN > STALE > OK > UNKNOWN +""" + +import time +from dataclasses import dataclass, field +from typing import Optional + +# --------------------------------------------------------------------------- +# Status constants — ordered by severity (higher index = more severe) +# --------------------------------------------------------------------------- +STATUS_UNKNOWN = "UNKNOWN" +STATUS_OK = "OK" +STATUS_STALE = "STALE" +STATUS_WARN = "WARN" +STATUS_ERROR = "ERROR" + +# --------------------------------------------------------------------------- +# Subsystem registry and diagnostic keyword routing +# --------------------------------------------------------------------------- +SUBSYSTEM_NAMES: list[str] = [ + "motors", "battery", "imu", "uwb", "lidar", "camera", "can_bus", "estop" +] + +# (keywords-tuple, subsystem-name) — first match wins, lower-cased comparison +DIAG_KEYWORD_MAP: list[tuple[tuple[str, ...], str]] = [ + (("vesc", "motor", "esc", "fsesc"), "motors"), + (("battery", "ina219", "lvc", "coulomb", "vbat"), "battery"), + (("imu", "mpu6000", "bno055", "icm42688", "gyro", "accel"), "imu"), + (("uwb", "dw1000", "dw3000"), "uwb"), + (("lidar", "rplidar", "sick", "laser"), "lidar"), + (("camera", "realsense", "oak", "webcam"), "camera"), + (("can", "can_bus", "can_driver"), "can_bus"), + (("estop", "safety", "emergency", "e-stop"), "estop"), +] + + +def keyword_to_subsystem(name: str) -> Optional[str]: + """Map a diagnostic status name to a subsystem key, or None.""" + lower = name.lower() + for keywords, subsystem in DIAG_KEYWORD_MAP: + if any(kw in lower for kw in keywords): + return subsystem + return None + + +_SEVERITY: dict[str, int] = { + STATUS_UNKNOWN: 0, + STATUS_OK: 1, + STATUS_STALE: 2, + STATUS_WARN: 3, + STATUS_ERROR: 4, +} + + +def worse(a: str, b: str) -> str: + """Return the more severe of two status strings.""" + return a if _SEVERITY.get(a, 0) >= _SEVERITY.get(b, 0) else b + + +def ros_level_to_status(level: int) -> str: + """Convert diagnostic_msgs/DiagnosticStatus.level to a status string.""" + return {0: STATUS_OK, 1: STATUS_WARN, 2: STATUS_ERROR}.get(level, STATUS_UNKNOWN) + + +# --------------------------------------------------------------------------- +# Transition log entry +# --------------------------------------------------------------------------- +@dataclass +class Transition: + """A logged status change for one subsystem.""" + subsystem: str + from_status: str + to_status: str + message: str + timestamp_iso: str # ISO-8601 wall-clock + monotonic_s: float # time.monotonic() at the transition + + +# --------------------------------------------------------------------------- +# Per-subsystem state +# --------------------------------------------------------------------------- +@dataclass +class SubsystemState: + """Live health state for one logical subsystem.""" + + name: str + stale_timeout_s: float = 5.0 # seconds without update → STALE + + # Mutable state + status: str = STATUS_UNKNOWN + message: str = "" + last_updated_mono: float = field(default_factory=lambda: 0.0) + last_change_mono: float = field(default_factory=lambda: 0.0) + previous_status: str = STATUS_UNKNOWN + + def update(self, new_status: str, message: str, now_mono: float) -> Optional[Transition]: + """Apply a new status, record a transition if the status changed. + + Returns a Transition if the status changed, else None. + """ + from datetime import datetime, timezone + + self.last_updated_mono = now_mono + + if new_status == self.status: + self.message = message + return None + + transition = Transition( + subsystem=self.name, + from_status=self.status, + to_status=new_status, + message=message, + timestamp_iso=datetime.now(timezone.utc).isoformat(), + monotonic_s=now_mono, + ) + + self.previous_status = self.status + self.status = new_status + self.message = message + self.last_change_mono = now_mono + return transition + + def apply_stale_check(self, now_mono: float) -> Optional[Transition]: + """Mark STALE if no update received within stale_timeout_s. + + Returns a Transition if newly staled, else None. + """ + if self.last_updated_mono == 0.0: + # Never received any data — leave as UNKNOWN + return None + if (now_mono - self.last_updated_mono) > self.stale_timeout_s: + if self.status not in (STATUS_STALE, STATUS_UNKNOWN): + return self.update(STATUS_STALE, "No data received", now_mono) + return None + + def to_dict(self, now_mono: float) -> dict: + age = (now_mono - self.last_updated_mono) if self.last_updated_mono > 0 else None + return { + "status": self.status, + "message": self.message, + "age_s": round(age, 2) if age is not None else None, + "previous_status": self.previous_status, + } diff --git a/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/setup.cfg b/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/setup.cfg new file mode 100644 index 0000000..03e3b13 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/setup.cfg @@ -0,0 +1,4 @@ +[develop] +script_dir=$base/lib/saltybot_diagnostics_aggregator +[install] +install_scripts=$base/lib/saltybot_diagnostics_aggregator diff --git a/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/setup.py b/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/setup.py new file mode 100644 index 0000000..18e1cae --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/setup.py @@ -0,0 +1,28 @@ +from setuptools import setup + +package_name = "saltybot_diagnostics_aggregator" + +setup( + name=package_name, + version="0.1.0", + packages=[package_name], + data_files=[ + ("share/ament_index/resource_index/packages", [f"resource/{package_name}"]), + (f"share/{package_name}", ["package.xml"]), + (f"share/{package_name}/launch", ["launch/diagnostics_aggregator.launch.py"]), + (f"share/{package_name}/config", ["config/aggregator_params.yaml"]), + ], + install_requires=["setuptools"], + zip_safe=True, + maintainer="sl-firmware", + maintainer_email="sl-firmware@saltylab.local", + description="Diagnostics aggregator — unified health dashboard for SaltyBot", + license="MIT", + tests_require=["pytest"], + entry_points={ + "console_scripts": [ + "diagnostics_aggregator_node = " + "saltybot_diagnostics_aggregator.aggregator_node:main", + ], + }, +) diff --git a/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/test/__init__.py b/jetson/ros2_ws/src/saltybot_diagnostics_aggregator/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jetson/ros2_ws/src/saltybot_vesc_driver/config/vesc_params.yaml b/jetson/ros2_ws/src/saltybot_vesc_driver/config/vesc_params.yaml index 0132279..8c83b58 100644 --- a/jetson/ros2_ws/src/saltybot_vesc_driver/config/vesc_params.yaml +++ b/jetson/ros2_ws/src/saltybot_vesc_driver/config/vesc_params.yaml @@ -8,3 +8,13 @@ vesc_can_driver: wheel_radius: 0.1 max_speed: 5.0 command_timeout: 1.0 + +velocity_smoother: + ros__parameters: + publish_rate: 50.0 + max_linear_accel: 1.0 + max_linear_decel: 2.0 + max_angular_accel: 1.5 + max_angular_decel: 3.0 + max_linear_jerk: 0.0 + max_angular_jerk: 0.0 diff --git a/jetson/ros2_ws/src/saltybot_vesc_driver/launch/vesc_driver.launch.py b/jetson/ros2_ws/src/saltybot_vesc_driver/launch/vesc_driver.launch.py index 083917b..38e6cde 100644 --- a/jetson/ros2_ws/src/saltybot_vesc_driver/launch/vesc_driver.launch.py +++ b/jetson/ros2_ws/src/saltybot_vesc_driver/launch/vesc_driver.launch.py @@ -20,6 +20,13 @@ def generate_launch_description(): default_value=config_file, description="Path to configuration YAML file", ), + Node( + package="saltybot_vesc_driver", + executable="velocity_smoother_node", + name="velocity_smoother", + output="screen", + parameters=[LaunchConfiguration("config_file")], + ), Node( package="saltybot_vesc_driver", executable="vesc_driver_node", diff --git a/jetson/ros2_ws/src/saltybot_vesc_driver/saltybot_vesc_driver/velocity_smoother_node.py b/jetson/ros2_ws/src/saltybot_vesc_driver/saltybot_vesc_driver/velocity_smoother_node.py new file mode 100644 index 0000000..c954c9f --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_vesc_driver/saltybot_vesc_driver/velocity_smoother_node.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +""" +Smooth velocity controller with accel/decel ramp. + +Subscribes to /cmd_vel (geometry_msgs/Twist), applies acceleration +and deceleration ramps, and publishes smoothed commands to +/cmd_vel_smoothed at 50 Hz. + +E-stop (std_msgs/Bool on /e_stop): immediate zero, bypasses ramp. +Optional jerk limiting via max_linear_jerk / max_angular_jerk params. +""" + +import math + +import rclpy +from geometry_msgs.msg import Twist +from rclpy.node import Node +from std_msgs.msg import Bool + + +class VelocitySmoother(Node): + def __init__(self): + super().__init__('velocity_smoother') + + # Declare parameters + self.declare_parameter('publish_rate', 50.0) # Hz + self.declare_parameter('max_linear_accel', 1.0) # m/s² + self.declare_parameter('max_linear_decel', 2.0) # m/s² + self.declare_parameter('max_angular_accel', 1.5) # rad/s² + self.declare_parameter('max_angular_decel', 3.0) # rad/s² + self.declare_parameter('max_linear_jerk', 0.0) # m/s³, 0=disabled + self.declare_parameter('max_angular_jerk', 0.0) # rad/s³, 0=disabled + + rate = self.get_parameter('publish_rate').value + self.max_lin_acc = self.get_parameter('max_linear_accel').value + self.max_lin_dec = self.get_parameter('max_linear_decel').value + self.max_ang_acc = self.get_parameter('max_angular_accel').value + self.max_ang_dec = self.get_parameter('max_angular_decel').value + self.max_lin_jerk = self.get_parameter('max_linear_jerk').value + self.max_ang_jerk = self.get_parameter('max_angular_jerk').value + + self._dt = 1.0 / rate + + # State + self._target_lin = 0.0 + self._target_ang = 0.0 + self._current_lin = 0.0 + self._current_ang = 0.0 + self._current_lin_acc = 0.0 # for jerk limiting + self._current_ang_acc = 0.0 + self._e_stop = False + + # Publisher + self._pub = self.create_publisher(Twist, '/cmd_vel_smoothed', 10) + + # Subscriptions + self.create_subscription(Twist, '/cmd_vel', self._cmd_vel_cb, 10) + self.create_subscription(Bool, '/e_stop', self._e_stop_cb, 10) + + # Timer + self.create_timer(self._dt, self._timer_cb) + + self.get_logger().info( + f'VelocitySmoother ready at {rate:.0f} Hz — ' + f'lin_acc={self.max_lin_acc} lin_dec={self.max_lin_dec} ' + f'ang_acc={self.max_ang_acc} ang_dec={self.max_ang_dec}' + ) + + # ------------------------------------------------------------------ + + def _cmd_vel_cb(self, msg: Twist): + self._target_lin = msg.linear.x + self._target_ang = msg.angular.z + + def _e_stop_cb(self, msg: Bool): + self._e_stop = msg.data + if self._e_stop: + self._target_lin = 0.0 + self._target_ang = 0.0 + self._current_lin = 0.0 + self._current_ang = 0.0 + self._current_lin_acc = 0.0 + self._current_ang_acc = 0.0 + self.get_logger().warn('E-STOP active — motors zeroed immediately') + + # ------------------------------------------------------------------ + + def _ramp( + self, + current: float, + target: float, + max_acc: float, + max_dec: float, + current_acc: float, + max_jerk: float, + ) -> tuple[float, float]: + """ + Advance `current` toward `target` with separate accel/decel limits. + Optionally apply jerk limiting to the acceleration. + Returns (new_value, new_acc). + """ + error = target - current + + # Choose limit: decelerate if moving away from zero and target is + # closer to zero (or past it), else accelerate. + if current * error < 0 or (error != 0 and abs(target) < abs(current)): + limit = max_dec + else: + limit = max_acc + + desired_acc = math.copysign(min(abs(error) / self._dt, limit), error) + # Clamp so we don't overshoot + desired_acc = max(-limit, min(limit, desired_acc)) + + if max_jerk > 0.0: + max_d_acc = max_jerk * self._dt + new_acc = current_acc + max( + -max_d_acc, min(max_d_acc, desired_acc - current_acc) + ) + else: + new_acc = desired_acc + + delta = new_acc * self._dt + # Clamp so we don't overshoot the target + if abs(delta) > abs(error): + delta = error + return current + delta, new_acc + + def _timer_cb(self): + if self._e_stop: + msg = Twist() + self._pub.publish(msg) + return + + self._current_lin, self._current_lin_acc = self._ramp( + self._current_lin, + self._target_lin, + self.max_lin_acc, + self.max_lin_dec, + self._current_lin_acc, + self.max_lin_jerk, + ) + self._current_ang, self._current_ang_acc = self._ramp( + self._current_ang, + self._target_ang, + self.max_ang_acc, + self.max_ang_dec, + self._current_ang_acc, + self.max_lin_jerk, # intentional: use linear jerk scale for angular too if set + ) + + msg = Twist() + msg.linear.x = self._current_lin + msg.angular.z = self._current_ang + self._pub.publish(msg) + + +def main(args=None): + rclpy.init(args=args) + node = VelocitySmoother() + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.shutdown() + + +if __name__ == '__main__': + main() diff --git a/jetson/ros2_ws/src/saltybot_vesc_driver/saltybot_vesc_driver/vesc_driver_node.py b/jetson/ros2_ws/src/saltybot_vesc_driver/saltybot_vesc_driver/vesc_driver_node.py index 9e05b51..ffedf2a 100644 --- a/jetson/ros2_ws/src/saltybot_vesc_driver/saltybot_vesc_driver/vesc_driver_node.py +++ b/jetson/ros2_ws/src/saltybot_vesc_driver/saltybot_vesc_driver/vesc_driver_node.py @@ -142,7 +142,7 @@ class VescCanDriver(Node): self._rx_thread.start() # Subscriber - self.create_subscription(Twist, '/cmd_vel', self._cmd_vel_cb, 10) + self.create_subscription(Twist, '/cmd_vel_smoothed', self._cmd_vel_cb, 10) # Watchdog + telemetry publish timer self.create_timer(1.0 / max(1, tel_hz), self._watchdog_and_publish_cb) diff --git a/jetson/ros2_ws/src/saltybot_vesc_driver/setup.py b/jetson/ros2_ws/src/saltybot_vesc_driver/setup.py index 8f2acd0..479be64 100644 --- a/jetson/ros2_ws/src/saltybot_vesc_driver/setup.py +++ b/jetson/ros2_ws/src/saltybot_vesc_driver/setup.py @@ -22,6 +22,7 @@ setup( entry_points={ "console_scripts": [ "vesc_driver_node = saltybot_vesc_driver.vesc_driver_node:main", + "velocity_smoother_node = saltybot_vesc_driver.velocity_smoother_node:main", ], }, ) diff --git a/jetson/ros2_ws/src/saltybot_vesc_driver/test/test_velocity_smoother.py b/jetson/ros2_ws/src/saltybot_vesc_driver/test/test_velocity_smoother.py new file mode 100644 index 0000000..63d0119 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_vesc_driver/test/test_velocity_smoother.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +"""Unit tests for VelocitySmoother ramp logic.""" + +import math +import unittest +from unittest.mock import MagicMock, patch + + +# --------------------------------------------------------------------------- +# Minimal ROS2 stubs so tests run without a ROS2 installation +# --------------------------------------------------------------------------- + +class _FakeParam: + def __init__(self, val): + self._val = val + @property + def value(self): + return self._val + + +class _FakeNode: + def __init__(self): + self._params = {} + def declare_parameter(self, name, default): + self._params[name] = _FakeParam(default) + def get_parameter(self, name): + return self._params[name] + def create_publisher(self, *a, **kw): + return MagicMock() + def create_subscription(self, *a, **kw): + return MagicMock() + def create_timer(self, *a, **kw): + return MagicMock() + def get_logger(self): + log = MagicMock() + log.info = MagicMock() + log.warn = MagicMock() + return log + + +# Patch rclpy and geometry_msgs before importing the module. +# rclpy.node.Node must be a *real* class so that VelocitySmoother can +# inherit from it without becoming a MagicMock itself. +import sys + +class _RealNodeBase: + """Minimal real base class that stands in for rclpy.node.Node.""" + def __init__(self, *args, **kwargs): + pass + + def declare_parameter(self, name, default): + if not hasattr(self, '_params'): + self._params = {} + self._params[name] = _FakeParam(default) + + def get_parameter(self, name): + return self._params[name] + + def create_publisher(self, *a, **kw): + return MagicMock() + + def create_subscription(self, *a, **kw): + return MagicMock() + + def create_timer(self, *a, **kw): + return MagicMock() + + def get_logger(self): + log = MagicMock() + log.info = MagicMock() + log.warn = MagicMock() + return log + +rclpy_node_mod = MagicMock() +rclpy_node_mod.Node = _RealNodeBase + +rclpy_mock = MagicMock() +rclpy_mock.node = rclpy_node_mod + +sys.modules['rclpy'] = rclpy_mock +sys.modules['rclpy.node'] = rclpy_node_mod +sys.modules.setdefault('geometry_msgs', MagicMock()) +sys.modules.setdefault('geometry_msgs.msg', MagicMock()) +sys.modules.setdefault('std_msgs', MagicMock()) +sys.modules.setdefault('std_msgs.msg', MagicMock()) + +# Provide a real Twist-like object for tests +class _Twist: + class _Vec: + x = 0.0 + y = 0.0 + z = 0.0 + def __init__(self): + self.linear = self._Vec() + self.angular = self._Vec() + +sys.modules['geometry_msgs.msg'].Twist = _Twist +sys.modules['std_msgs.msg'].Bool = MagicMock() + +import importlib, os, pathlib +sys.path.insert(0, str(pathlib.Path(__file__).parents[1] / 'saltybot_vesc_driver')) + +# Now we can import the smoother logic directly +from velocity_smoother_node import VelocitySmoother + + +def _make_smoother(**params): + """Create a VelocitySmoother backed by _RealNodeBase with custom params.""" + defaults = dict( + publish_rate=50.0, + max_linear_accel=1.0, + max_linear_decel=2.0, + max_angular_accel=1.5, + max_angular_decel=3.0, + max_linear_jerk=0.0, + max_angular_jerk=0.0, + ) + defaults.update(params) + + # Pre-seed _params before __init__ so declare_parameter picks up overrides. + node = VelocitySmoother.__new__(VelocitySmoother) + node._params = {k: _FakeParam(v) for k, v in defaults.items()} + + # Monkey-patch declare_parameter so it doesn't overwrite our pre-seeded values + original_declare = _RealNodeBase.declare_parameter + def _noop_declare(self, name, default): + pass # params already seeded + _RealNodeBase.declare_parameter = _noop_declare + + try: + VelocitySmoother.__init__(node) + finally: + _RealNodeBase.declare_parameter = original_declare + + return node + + +# --------------------------------------------------------------------------- + +class TestRampLogic(unittest.TestCase): + + def _make(self, **kw): + return _make_smoother(**kw) + + def test_ramp_reaches_target_within_expected_steps(self): + """From 0 to 1 m/s with accel=1 m/s² at 50 Hz → ~50 steps.""" + node = self._make(max_linear_accel=1.0, publish_rate=50.0) + node._target_lin = 1.0 + steps = 0 + while abs(node._current_lin - 1.0) > 0.01 and steps < 200: + node._timer_cb() + steps += 1 + self.assertLessEqual(steps, 55, "Should reach 1 m/s within ~55 steps at 50 Hz") + self.assertAlmostEqual(node._current_lin, 1.0, places=2) + + def test_decel_faster_than_accel(self): + """Deceleration should complete faster than acceleration.""" + node = self._make(max_linear_accel=1.0, max_linear_decel=2.0, publish_rate=50.0) + # Ramp up + node._target_lin = 1.0 + for _ in range(100): + node._timer_cb() + # Now decelerate + node._target_lin = 0.0 + decel_steps = 0 + while abs(node._current_lin) > 0.01 and decel_steps < 200: + node._timer_cb() + decel_steps += 1 + # Ramp up again to measure accel steps + node._current_lin = 0.0 + node._target_lin = 1.0 + accel_steps = 0 + while abs(node._current_lin - 1.0) > 0.01 and accel_steps < 200: + node._timer_cb() + accel_steps += 1 + self.assertLess(decel_steps, accel_steps, + "Decel (2 m/s²) should finish in fewer steps than accel (1 m/s²)") + + def test_e_stop_zeros_immediately(self): + """E-stop must zero velocity in the same callback, bypassing ramp.""" + node = self._make() + node._current_lin = 2.0 + node._current_ang = 1.0 + node._target_lin = 2.0 + node._target_ang = 1.0 + + msg = MagicMock() + msg.data = True + node._e_stop_cb(msg) + + self.assertEqual(node._current_lin, 0.0) + self.assertEqual(node._current_ang, 0.0) + self.assertTrue(node._e_stop) + + def test_no_overshoot(self): + """Current velocity must never exceed target during ramp-up.""" + node = self._make(max_linear_accel=1.0, publish_rate=50.0) + node._target_lin = 0.5 + for _ in range(100): + node._timer_cb() + self.assertLessEqual(node._current_lin, 0.5 + 1e-9) + + def test_negative_velocity_ramp(self): + """Ramp works symmetrically for negative targets.""" + node = self._make(max_linear_accel=1.0, publish_rate=50.0) + node._target_lin = -1.0 + for _ in range(200): + node._timer_cb() + self.assertAlmostEqual(node._current_lin, -1.0, places=2) + + def test_angular_ramp(self): + """Angular velocity ramps correctly.""" + node = self._make(max_angular_accel=1.5, publish_rate=50.0) + node._target_ang = 1.0 + for _ in range(200): + node._timer_cb() + self.assertAlmostEqual(node._current_ang, 1.0, places=2) + + def test_ramp_timing_linear(self): + """Time to ramp from 0 to v_max ≈ v_max / accel (±10%).""" + accel = 1.0 # m/s² + v_max = 1.0 # m/s + rate = 50.0 + expected_time = v_max / accel # 1.0 s + node = self._make(max_linear_accel=accel, publish_rate=rate) + node._target_lin = v_max + steps = 0 + while abs(node._current_lin - v_max) > 0.01 and steps < 500: + node._timer_cb() + steps += 1 + actual_time = steps / rate + self.assertAlmostEqual(actual_time, expected_time, delta=expected_time * 0.12, + msg=f"Ramp time {actual_time:.2f}s expected ~{expected_time:.2f}s") + + +if __name__ == '__main__': + unittest.main() -- 2.47.2