Merge pull request 'feat: Smooth velocity controller (Issue #652)' (#661) from sl-controls/issue-652-smooth-velocity into main

This commit is contained in:
sl-jetson 2026-03-18 07:56:26 -04:00
commit 4f3a30d871
16 changed files with 1008 additions and 1 deletions

View File

@ -0,0 +1,15 @@
# Diagnostics Aggregator — Issue #658
# Unified health dashboard aggregating telemetry from all SaltyBot subsystems.
diagnostics_aggregator:
ros__parameters:
# Publish rate for /saltybot/system_health (Hz)
heartbeat_hz: 1.0
# Seconds without a topic update before a subsystem is marked STALE
# Increase for sensors with lower publish rates (e.g. UWB at 5 Hz)
stale_timeout_s: 5.0
# Maximum number of state transitions kept in the in-memory ring buffer
# Last 10 transitions are included in each /saltybot/system_health publish
transition_log_size: 50

View File

@ -0,0 +1,44 @@
"""Launch file for diagnostics aggregator node."""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
pkg_dir = get_package_share_directory("saltybot_diagnostics_aggregator")
config_file = os.path.join(pkg_dir, "config", "aggregator_params.yaml")
return LaunchDescription([
DeclareLaunchArgument(
"config_file",
default_value=config_file,
description="Path to aggregator_params.yaml",
),
DeclareLaunchArgument(
"heartbeat_hz",
default_value="1.0",
description="Publish rate for /saltybot/system_health (Hz)",
),
DeclareLaunchArgument(
"stale_timeout_s",
default_value="5.0",
description="Seconds without update before subsystem goes STALE",
),
Node(
package="saltybot_diagnostics_aggregator",
executable="diagnostics_aggregator_node",
name="diagnostics_aggregator",
output="screen",
parameters=[
LaunchConfiguration("config_file"),
{
"heartbeat_hz": LaunchConfiguration("heartbeat_hz"),
"stale_timeout_s": LaunchConfiguration("stale_timeout_s"),
},
],
),
])

View File

@ -0,0 +1,30 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_diagnostics_aggregator</name>
<version>0.1.0</version>
<description>
Diagnostics aggregator for SaltyBot — unified health dashboard node (Issue #658).
Subscribes to /vesc/health, /diagnostics, /saltybot/safety_zone/status,
/saltybot/pose/status, /saltybot/uwb/status. Aggregates into
/saltybot/system_health JSON at 1 Hz. Tracks motors, battery, imu, uwb,
lidar, camera, can_bus, estop subsystems with state-transition logging.
</description>
<maintainer email="sl-firmware@saltylab.local">sl-firmware</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>std_msgs</depend>
<depend>diagnostic_msgs</depend>
<buildtool_depend>ament_python</buildtool_depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,312 @@
#!/usr/bin/env python3
"""Diagnostics aggregator — unified health dashboard node (Issue #658).
Subscribes to telemetry and diagnostics topics from all subsystems, aggregates
them into a single /saltybot/system_health JSON publish at 1 Hz.
Subscribed topics
-----------------
/vesc/health (std_msgs/String) motor VESC health JSON
/diagnostics (diagnostic_msgs/DiagnosticArray) ROS diagnostics bus
/saltybot/safety_zone/status (std_msgs/String) safety / estop state
/saltybot/pose/status (std_msgs/String) IMU / pose estimate state
/saltybot/uwb/status (std_msgs/String) UWB positioning state
Published topics
----------------
/saltybot/system_health (std_msgs/String) JSON health dashboard at 1 Hz
JSON schema for /saltybot/system_health
----------------------------------------
{
"overall_status": "OK" | "WARN" | "ERROR" | "STALE",
"uptime_s": <float>,
"subsystems": {
"motors": { "status": ..., "message": ..., "age_s": ..., "previous_status": ... },
"battery": { ... },
"imu": { ... },
"uwb": { ... },
"lidar": { ... },
"camera": { ... },
"can_bus": { ... },
"estop": { ... }
},
"last_error": { "subsystem": ..., "message": ..., "timestamp": ... } | null,
"recent_transitions": [ { "subsystem", "from_status", "to_status",
"message", "timestamp_iso" }, ... ],
"timestamp": "<ISO-8601>"
}
Parameters
----------
heartbeat_hz float 1.0 Publish rate (Hz)
stale_timeout_s float 5.0 Seconds without update STALE
transition_log_size int 50 Max recent transitions kept in memory
"""
import json
import time
from collections import deque
from datetime import datetime, timezone
from typing import Optional
import rclpy
from rclpy.node import Node
from std_msgs.msg import String
from diagnostic_msgs.msg import DiagnosticArray
from .subsystem import (
SubsystemState,
Transition,
STATUS_OK,
STATUS_WARN,
STATUS_ERROR,
STATUS_STALE,
STATUS_UNKNOWN,
worse,
ros_level_to_status,
SUBSYSTEM_NAMES as _SUBSYSTEM_NAMES,
keyword_to_subsystem as _keyword_to_subsystem,
)
class DiagnosticsAggregatorNode(Node):
"""Unified health dashboard node."""
def __init__(self):
super().__init__("diagnostics_aggregator")
# ----------------------------------------------------------------
# Parameters
# ----------------------------------------------------------------
self.declare_parameter("heartbeat_hz", 1.0)
self.declare_parameter("stale_timeout_s", 5.0)
self.declare_parameter("transition_log_size", 50)
hz = float(self.get_parameter("heartbeat_hz").value)
stale_timeout = float(self.get_parameter("stale_timeout_s").value)
log_size = int(self.get_parameter("transition_log_size").value)
# ----------------------------------------------------------------
# Subsystem state table
# ----------------------------------------------------------------
self._subsystems: dict[str, SubsystemState] = {
name: SubsystemState(name=name, stale_timeout_s=stale_timeout)
for name in _SUBSYSTEM_NAMES
}
self._transitions: deque[Transition] = deque(maxlen=log_size)
self._last_error: Optional[dict] = None
self._start_mono = time.monotonic()
# ----------------------------------------------------------------
# Subscriptions
# ----------------------------------------------------------------
self.create_subscription(String, "/vesc/health",
self._on_vesc_health, 10)
self.create_subscription(DiagnosticArray, "/diagnostics",
self._on_diagnostics, 10)
self.create_subscription(String, "/saltybot/safety_zone/status",
self._on_safety_zone, 10)
self.create_subscription(String, "/saltybot/pose/status",
self._on_pose_status, 10)
self.create_subscription(String, "/saltybot/uwb/status",
self._on_uwb_status, 10)
# ----------------------------------------------------------------
# Publisher
# ----------------------------------------------------------------
self._pub = self.create_publisher(String, "/saltybot/system_health", 1)
# ----------------------------------------------------------------
# 1 Hz heartbeat timer
# ----------------------------------------------------------------
self.create_timer(1.0 / max(0.1, hz), self._on_timer)
self.get_logger().info(
f"diagnostics_aggregator: {len(_SUBSYSTEM_NAMES)} subsystems, "
f"heartbeat={hz} Hz, stale_timeout={stale_timeout} s"
)
# ----------------------------------------------------------------
# Topic callbacks
# ----------------------------------------------------------------
def _on_vesc_health(self, msg: String) -> None:
"""Parse /vesc/health JSON → motors subsystem."""
now = time.monotonic()
try:
d = json.loads(msg.data)
fault = d.get("fault_code", 0)
alive = d.get("alive", True)
if not alive:
status, message = STATUS_STALE, "VESC offline"
elif fault != 0:
status = STATUS_ERROR
message = f"VESC fault {d.get('fault_name', fault)}"
else:
status = STATUS_OK
message = (
f"RPM={d.get('rpm', '?')} "
f"I={d.get('current_a', '?')} A "
f"V={d.get('voltage_v', '?')} V"
)
self._update("motors", status, message, now)
except Exception as exc:
self.get_logger().debug(f"/vesc/health parse error: {exc}")
def _on_diagnostics(self, msg: DiagnosticArray) -> None:
"""Fan /diagnostics entries out to per-subsystem states."""
now = time.monotonic()
# Accumulate worst status per subsystem across all entries in this array
pending: dict[str, tuple[str, str]] = {} # subsystem → (status, message)
for ds in msg.status:
subsystem = _keyword_to_subsystem(ds.name)
if subsystem is None:
continue
status = ros_level_to_status(ds.level)
message = ds.message or ""
if subsystem not in pending:
pending[subsystem] = (status, message)
else:
prev_s, prev_m = pending[subsystem]
new_worst = worse(status, prev_s)
pending[subsystem] = (
new_worst,
message if new_worst == status else prev_m,
)
for subsystem, (status, message) in pending.items():
self._update(subsystem, status, message, now)
def _on_safety_zone(self, msg: String) -> None:
"""Parse /saltybot/safety_zone/status JSON → estop subsystem."""
now = time.monotonic()
try:
d = json.loads(msg.data)
triggered = d.get("estop_triggered", d.get("triggered", False))
active = d.get("safety_zone_active", d.get("active", True))
if triggered:
status, message = STATUS_ERROR, "E-stop triggered"
elif not active:
status, message = STATUS_WARN, "Safety zone inactive"
else:
status, message = STATUS_OK, "Safety zone active"
self._update("estop", status, message, now)
except Exception as exc:
self.get_logger().debug(f"/saltybot/safety_zone/status parse error: {exc}")
def _on_pose_status(self, msg: String) -> None:
"""Parse /saltybot/pose/status JSON → imu subsystem."""
now = time.monotonic()
try:
d = json.loads(msg.data)
status_str = d.get("status", "OK").upper()
if status_str not in (STATUS_OK, STATUS_WARN, STATUS_ERROR):
status_str = STATUS_WARN
message = d.get("message", d.get("msg", ""))
self._update("imu", status_str, message, now)
except Exception as exc:
self.get_logger().debug(f"/saltybot/pose/status parse error: {exc}")
def _on_uwb_status(self, msg: String) -> None:
"""Parse /saltybot/uwb/status JSON → uwb subsystem."""
now = time.monotonic()
try:
d = json.loads(msg.data)
status_str = d.get("status", "OK").upper()
if status_str not in (STATUS_OK, STATUS_WARN, STATUS_ERROR):
status_str = STATUS_WARN
message = d.get("message", d.get("msg", ""))
self._update("uwb", status_str, message, now)
except Exception as exc:
self.get_logger().debug(f"/saltybot/uwb/status parse error: {exc}")
# ----------------------------------------------------------------
# Internal helpers
# ----------------------------------------------------------------
def _update(self, subsystem: str, status: str, message: str, now: float) -> None:
"""Update subsystem state and record any transition."""
s = self._subsystems.get(subsystem)
if s is None:
return
transition = s.update(status, message, now)
if transition is not None:
self._transitions.append(transition)
self.get_logger().info(
f"[{subsystem}] {transition.from_status}{transition.to_status}: {message}"
)
if transition.to_status == STATUS_ERROR:
self._last_error = {
"subsystem": subsystem,
"message": message,
"timestamp": transition.timestamp_iso,
}
def _apply_stale_checks(self, now: float) -> None:
for s in self._subsystems.values():
transition = s.apply_stale_check(now)
if transition is not None:
self._transitions.append(transition)
self.get_logger().warn(
f"[{transition.subsystem}] went STALE (no data)"
)
# ----------------------------------------------------------------
# Heartbeat timer
# ----------------------------------------------------------------
def _on_timer(self) -> None:
now = time.monotonic()
self._apply_stale_checks(now)
# Overall status = worst across all subsystems
overall = STATUS_OK
for s in self._subsystems.values():
overall = worse(overall, s.status)
# UNKNOWN does not degrade overall if at least one subsystem is known
known = [s for s in self._subsystems.values() if s.status != STATUS_UNKNOWN]
if not known:
overall = STATUS_UNKNOWN
uptime = now - self._start_mono
payload = {
"overall_status": overall,
"uptime_s": round(uptime, 1),
"subsystems": {
name: s.to_dict(now)
for name, s in self._subsystems.items()
},
"last_error": self._last_error,
"recent_transitions": [
{
"subsystem": t.subsystem,
"from_status": t.from_status,
"to_status": t.to_status,
"message": t.message,
"timestamp": t.timestamp_iso,
}
for t in list(self._transitions)[-10:] # last 10 in the JSON
],
"timestamp": datetime.now(timezone.utc).isoformat(),
}
self._pub.publish(String(data=json.dumps(payload)))
def main(args=None):
rclpy.init(args=args)
node = DiagnosticsAggregatorNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,148 @@
"""Subsystem state tracking for the diagnostics aggregator.
Each SubsystemState tracks one logical subsystem (motors, battery, etc.)
across potentially multiple source topics. Status priority:
ERROR > WARN > STALE > OK > UNKNOWN
"""
import time
from dataclasses import dataclass, field
from typing import Optional
# ---------------------------------------------------------------------------
# Status constants — ordered by severity (higher index = more severe)
# ---------------------------------------------------------------------------
STATUS_UNKNOWN = "UNKNOWN"
STATUS_OK = "OK"
STATUS_STALE = "STALE"
STATUS_WARN = "WARN"
STATUS_ERROR = "ERROR"
# ---------------------------------------------------------------------------
# Subsystem registry and diagnostic keyword routing
# ---------------------------------------------------------------------------
SUBSYSTEM_NAMES: list[str] = [
"motors", "battery", "imu", "uwb", "lidar", "camera", "can_bus", "estop"
]
# (keywords-tuple, subsystem-name) — first match wins, lower-cased comparison
DIAG_KEYWORD_MAP: list[tuple[tuple[str, ...], str]] = [
(("vesc", "motor", "esc", "fsesc"), "motors"),
(("battery", "ina219", "lvc", "coulomb", "vbat"), "battery"),
(("imu", "mpu6000", "bno055", "icm42688", "gyro", "accel"), "imu"),
(("uwb", "dw1000", "dw3000"), "uwb"),
(("lidar", "rplidar", "sick", "laser"), "lidar"),
(("camera", "realsense", "oak", "webcam"), "camera"),
(("can", "can_bus", "can_driver"), "can_bus"),
(("estop", "safety", "emergency", "e-stop"), "estop"),
]
def keyword_to_subsystem(name: str) -> Optional[str]:
"""Map a diagnostic status name to a subsystem key, or None."""
lower = name.lower()
for keywords, subsystem in DIAG_KEYWORD_MAP:
if any(kw in lower for kw in keywords):
return subsystem
return None
_SEVERITY: dict[str, int] = {
STATUS_UNKNOWN: 0,
STATUS_OK: 1,
STATUS_STALE: 2,
STATUS_WARN: 3,
STATUS_ERROR: 4,
}
def worse(a: str, b: str) -> str:
"""Return the more severe of two status strings."""
return a if _SEVERITY.get(a, 0) >= _SEVERITY.get(b, 0) else b
def ros_level_to_status(level: int) -> str:
"""Convert diagnostic_msgs/DiagnosticStatus.level to a status string."""
return {0: STATUS_OK, 1: STATUS_WARN, 2: STATUS_ERROR}.get(level, STATUS_UNKNOWN)
# ---------------------------------------------------------------------------
# Transition log entry
# ---------------------------------------------------------------------------
@dataclass
class Transition:
"""A logged status change for one subsystem."""
subsystem: str
from_status: str
to_status: str
message: str
timestamp_iso: str # ISO-8601 wall-clock
monotonic_s: float # time.monotonic() at the transition
# ---------------------------------------------------------------------------
# Per-subsystem state
# ---------------------------------------------------------------------------
@dataclass
class SubsystemState:
"""Live health state for one logical subsystem."""
name: str
stale_timeout_s: float = 5.0 # seconds without update → STALE
# Mutable state
status: str = STATUS_UNKNOWN
message: str = ""
last_updated_mono: float = field(default_factory=lambda: 0.0)
last_change_mono: float = field(default_factory=lambda: 0.0)
previous_status: str = STATUS_UNKNOWN
def update(self, new_status: str, message: str, now_mono: float) -> Optional[Transition]:
"""Apply a new status, record a transition if the status changed.
Returns a Transition if the status changed, else None.
"""
from datetime import datetime, timezone
self.last_updated_mono = now_mono
if new_status == self.status:
self.message = message
return None
transition = Transition(
subsystem=self.name,
from_status=self.status,
to_status=new_status,
message=message,
timestamp_iso=datetime.now(timezone.utc).isoformat(),
monotonic_s=now_mono,
)
self.previous_status = self.status
self.status = new_status
self.message = message
self.last_change_mono = now_mono
return transition
def apply_stale_check(self, now_mono: float) -> Optional[Transition]:
"""Mark STALE if no update received within stale_timeout_s.
Returns a Transition if newly staled, else None.
"""
if self.last_updated_mono == 0.0:
# Never received any data — leave as UNKNOWN
return None
if (now_mono - self.last_updated_mono) > self.stale_timeout_s:
if self.status not in (STATUS_STALE, STATUS_UNKNOWN):
return self.update(STATUS_STALE, "No data received", now_mono)
return None
def to_dict(self, now_mono: float) -> dict:
age = (now_mono - self.last_updated_mono) if self.last_updated_mono > 0 else None
return {
"status": self.status,
"message": self.message,
"age_s": round(age, 2) if age is not None else None,
"previous_status": self.previous_status,
}

View File

@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/saltybot_diagnostics_aggregator
[install]
install_scripts=$base/lib/saltybot_diagnostics_aggregator

View File

@ -0,0 +1,28 @@
from setuptools import setup
package_name = "saltybot_diagnostics_aggregator"
setup(
name=package_name,
version="0.1.0",
packages=[package_name],
data_files=[
("share/ament_index/resource_index/packages", [f"resource/{package_name}"]),
(f"share/{package_name}", ["package.xml"]),
(f"share/{package_name}/launch", ["launch/diagnostics_aggregator.launch.py"]),
(f"share/{package_name}/config", ["config/aggregator_params.yaml"]),
],
install_requires=["setuptools"],
zip_safe=True,
maintainer="sl-firmware",
maintainer_email="sl-firmware@saltylab.local",
description="Diagnostics aggregator — unified health dashboard for SaltyBot",
license="MIT",
tests_require=["pytest"],
entry_points={
"console_scripts": [
"diagnostics_aggregator_node = "
"saltybot_diagnostics_aggregator.aggregator_node:main",
],
},
)

View File

@ -8,3 +8,13 @@ vesc_can_driver:
wheel_radius: 0.1 wheel_radius: 0.1
max_speed: 5.0 max_speed: 5.0
command_timeout: 1.0 command_timeout: 1.0
velocity_smoother:
ros__parameters:
publish_rate: 50.0
max_linear_accel: 1.0
max_linear_decel: 2.0
max_angular_accel: 1.5
max_angular_decel: 3.0
max_linear_jerk: 0.0
max_angular_jerk: 0.0

View File

@ -20,6 +20,13 @@ def generate_launch_description():
default_value=config_file, default_value=config_file,
description="Path to configuration YAML file", description="Path to configuration YAML file",
), ),
Node(
package="saltybot_vesc_driver",
executable="velocity_smoother_node",
name="velocity_smoother",
output="screen",
parameters=[LaunchConfiguration("config_file")],
),
Node( Node(
package="saltybot_vesc_driver", package="saltybot_vesc_driver",
executable="vesc_driver_node", executable="vesc_driver_node",

View File

@ -0,0 +1,171 @@
#!/usr/bin/env python3
"""
Smooth velocity controller with accel/decel ramp.
Subscribes to /cmd_vel (geometry_msgs/Twist), applies acceleration
and deceleration ramps, and publishes smoothed commands to
/cmd_vel_smoothed at 50 Hz.
E-stop (std_msgs/Bool on /e_stop): immediate zero, bypasses ramp.
Optional jerk limiting via max_linear_jerk / max_angular_jerk params.
"""
import math
import rclpy
from geometry_msgs.msg import Twist
from rclpy.node import Node
from std_msgs.msg import Bool
class VelocitySmoother(Node):
def __init__(self):
super().__init__('velocity_smoother')
# Declare parameters
self.declare_parameter('publish_rate', 50.0) # Hz
self.declare_parameter('max_linear_accel', 1.0) # m/s²
self.declare_parameter('max_linear_decel', 2.0) # m/s²
self.declare_parameter('max_angular_accel', 1.5) # rad/s²
self.declare_parameter('max_angular_decel', 3.0) # rad/s²
self.declare_parameter('max_linear_jerk', 0.0) # m/s³, 0=disabled
self.declare_parameter('max_angular_jerk', 0.0) # rad/s³, 0=disabled
rate = self.get_parameter('publish_rate').value
self.max_lin_acc = self.get_parameter('max_linear_accel').value
self.max_lin_dec = self.get_parameter('max_linear_decel').value
self.max_ang_acc = self.get_parameter('max_angular_accel').value
self.max_ang_dec = self.get_parameter('max_angular_decel').value
self.max_lin_jerk = self.get_parameter('max_linear_jerk').value
self.max_ang_jerk = self.get_parameter('max_angular_jerk').value
self._dt = 1.0 / rate
# State
self._target_lin = 0.0
self._target_ang = 0.0
self._current_lin = 0.0
self._current_ang = 0.0
self._current_lin_acc = 0.0 # for jerk limiting
self._current_ang_acc = 0.0
self._e_stop = False
# Publisher
self._pub = self.create_publisher(Twist, '/cmd_vel_smoothed', 10)
# Subscriptions
self.create_subscription(Twist, '/cmd_vel', self._cmd_vel_cb, 10)
self.create_subscription(Bool, '/e_stop', self._e_stop_cb, 10)
# Timer
self.create_timer(self._dt, self._timer_cb)
self.get_logger().info(
f'VelocitySmoother ready at {rate:.0f} Hz — '
f'lin_acc={self.max_lin_acc} lin_dec={self.max_lin_dec} '
f'ang_acc={self.max_ang_acc} ang_dec={self.max_ang_dec}'
)
# ------------------------------------------------------------------
def _cmd_vel_cb(self, msg: Twist):
self._target_lin = msg.linear.x
self._target_ang = msg.angular.z
def _e_stop_cb(self, msg: Bool):
self._e_stop = msg.data
if self._e_stop:
self._target_lin = 0.0
self._target_ang = 0.0
self._current_lin = 0.0
self._current_ang = 0.0
self._current_lin_acc = 0.0
self._current_ang_acc = 0.0
self.get_logger().warn('E-STOP active — motors zeroed immediately')
# ------------------------------------------------------------------
def _ramp(
self,
current: float,
target: float,
max_acc: float,
max_dec: float,
current_acc: float,
max_jerk: float,
) -> tuple[float, float]:
"""
Advance `current` toward `target` with separate accel/decel limits.
Optionally apply jerk limiting to the acceleration.
Returns (new_value, new_acc).
"""
error = target - current
# Choose limit: decelerate if moving away from zero and target is
# closer to zero (or past it), else accelerate.
if current * error < 0 or (error != 0 and abs(target) < abs(current)):
limit = max_dec
else:
limit = max_acc
desired_acc = math.copysign(min(abs(error) / self._dt, limit), error)
# Clamp so we don't overshoot
desired_acc = max(-limit, min(limit, desired_acc))
if max_jerk > 0.0:
max_d_acc = max_jerk * self._dt
new_acc = current_acc + max(
-max_d_acc, min(max_d_acc, desired_acc - current_acc)
)
else:
new_acc = desired_acc
delta = new_acc * self._dt
# Clamp so we don't overshoot the target
if abs(delta) > abs(error):
delta = error
return current + delta, new_acc
def _timer_cb(self):
if self._e_stop:
msg = Twist()
self._pub.publish(msg)
return
self._current_lin, self._current_lin_acc = self._ramp(
self._current_lin,
self._target_lin,
self.max_lin_acc,
self.max_lin_dec,
self._current_lin_acc,
self.max_lin_jerk,
)
self._current_ang, self._current_ang_acc = self._ramp(
self._current_ang,
self._target_ang,
self.max_ang_acc,
self.max_ang_dec,
self._current_ang_acc,
self.max_lin_jerk, # intentional: use linear jerk scale for angular too if set
)
msg = Twist()
msg.linear.x = self._current_lin
msg.angular.z = self._current_ang
self._pub.publish(msg)
def main(args=None):
rclpy.init(args=args)
node = VelocitySmoother()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -142,7 +142,7 @@ class VescCanDriver(Node):
self._rx_thread.start() self._rx_thread.start()
# Subscriber # Subscriber
self.create_subscription(Twist, '/cmd_vel', self._cmd_vel_cb, 10) self.create_subscription(Twist, '/cmd_vel_smoothed', self._cmd_vel_cb, 10)
# Watchdog + telemetry publish timer # Watchdog + telemetry publish timer
self.create_timer(1.0 / max(1, tel_hz), self._watchdog_and_publish_cb) self.create_timer(1.0 / max(1, tel_hz), self._watchdog_and_publish_cb)

View File

@ -22,6 +22,7 @@ setup(
entry_points={ entry_points={
"console_scripts": [ "console_scripts": [
"vesc_driver_node = saltybot_vesc_driver.vesc_driver_node:main", "vesc_driver_node = saltybot_vesc_driver.vesc_driver_node:main",
"velocity_smoother_node = saltybot_vesc_driver.velocity_smoother_node:main",
], ],
}, },
) )

View File

@ -0,0 +1,237 @@
#!/usr/bin/env python3
"""Unit tests for VelocitySmoother ramp logic."""
import math
import unittest
from unittest.mock import MagicMock, patch
# ---------------------------------------------------------------------------
# Minimal ROS2 stubs so tests run without a ROS2 installation
# ---------------------------------------------------------------------------
class _FakeParam:
def __init__(self, val):
self._val = val
@property
def value(self):
return self._val
class _FakeNode:
def __init__(self):
self._params = {}
def declare_parameter(self, name, default):
self._params[name] = _FakeParam(default)
def get_parameter(self, name):
return self._params[name]
def create_publisher(self, *a, **kw):
return MagicMock()
def create_subscription(self, *a, **kw):
return MagicMock()
def create_timer(self, *a, **kw):
return MagicMock()
def get_logger(self):
log = MagicMock()
log.info = MagicMock()
log.warn = MagicMock()
return log
# Patch rclpy and geometry_msgs before importing the module.
# rclpy.node.Node must be a *real* class so that VelocitySmoother can
# inherit from it without becoming a MagicMock itself.
import sys
class _RealNodeBase:
"""Minimal real base class that stands in for rclpy.node.Node."""
def __init__(self, *args, **kwargs):
pass
def declare_parameter(self, name, default):
if not hasattr(self, '_params'):
self._params = {}
self._params[name] = _FakeParam(default)
def get_parameter(self, name):
return self._params[name]
def create_publisher(self, *a, **kw):
return MagicMock()
def create_subscription(self, *a, **kw):
return MagicMock()
def create_timer(self, *a, **kw):
return MagicMock()
def get_logger(self):
log = MagicMock()
log.info = MagicMock()
log.warn = MagicMock()
return log
rclpy_node_mod = MagicMock()
rclpy_node_mod.Node = _RealNodeBase
rclpy_mock = MagicMock()
rclpy_mock.node = rclpy_node_mod
sys.modules['rclpy'] = rclpy_mock
sys.modules['rclpy.node'] = rclpy_node_mod
sys.modules.setdefault('geometry_msgs', MagicMock())
sys.modules.setdefault('geometry_msgs.msg', MagicMock())
sys.modules.setdefault('std_msgs', MagicMock())
sys.modules.setdefault('std_msgs.msg', MagicMock())
# Provide a real Twist-like object for tests
class _Twist:
class _Vec:
x = 0.0
y = 0.0
z = 0.0
def __init__(self):
self.linear = self._Vec()
self.angular = self._Vec()
sys.modules['geometry_msgs.msg'].Twist = _Twist
sys.modules['std_msgs.msg'].Bool = MagicMock()
import importlib, os, pathlib
sys.path.insert(0, str(pathlib.Path(__file__).parents[1] / 'saltybot_vesc_driver'))
# Now we can import the smoother logic directly
from velocity_smoother_node import VelocitySmoother
def _make_smoother(**params):
"""Create a VelocitySmoother backed by _RealNodeBase with custom params."""
defaults = dict(
publish_rate=50.0,
max_linear_accel=1.0,
max_linear_decel=2.0,
max_angular_accel=1.5,
max_angular_decel=3.0,
max_linear_jerk=0.0,
max_angular_jerk=0.0,
)
defaults.update(params)
# Pre-seed _params before __init__ so declare_parameter picks up overrides.
node = VelocitySmoother.__new__(VelocitySmoother)
node._params = {k: _FakeParam(v) for k, v in defaults.items()}
# Monkey-patch declare_parameter so it doesn't overwrite our pre-seeded values
original_declare = _RealNodeBase.declare_parameter
def _noop_declare(self, name, default):
pass # params already seeded
_RealNodeBase.declare_parameter = _noop_declare
try:
VelocitySmoother.__init__(node)
finally:
_RealNodeBase.declare_parameter = original_declare
return node
# ---------------------------------------------------------------------------
class TestRampLogic(unittest.TestCase):
def _make(self, **kw):
return _make_smoother(**kw)
def test_ramp_reaches_target_within_expected_steps(self):
"""From 0 to 1 m/s with accel=1 m/s² at 50 Hz → ~50 steps."""
node = self._make(max_linear_accel=1.0, publish_rate=50.0)
node._target_lin = 1.0
steps = 0
while abs(node._current_lin - 1.0) > 0.01 and steps < 200:
node._timer_cb()
steps += 1
self.assertLessEqual(steps, 55, "Should reach 1 m/s within ~55 steps at 50 Hz")
self.assertAlmostEqual(node._current_lin, 1.0, places=2)
def test_decel_faster_than_accel(self):
"""Deceleration should complete faster than acceleration."""
node = self._make(max_linear_accel=1.0, max_linear_decel=2.0, publish_rate=50.0)
# Ramp up
node._target_lin = 1.0
for _ in range(100):
node._timer_cb()
# Now decelerate
node._target_lin = 0.0
decel_steps = 0
while abs(node._current_lin) > 0.01 and decel_steps < 200:
node._timer_cb()
decel_steps += 1
# Ramp up again to measure accel steps
node._current_lin = 0.0
node._target_lin = 1.0
accel_steps = 0
while abs(node._current_lin - 1.0) > 0.01 and accel_steps < 200:
node._timer_cb()
accel_steps += 1
self.assertLess(decel_steps, accel_steps,
"Decel (2 m/s²) should finish in fewer steps than accel (1 m/s²)")
def test_e_stop_zeros_immediately(self):
"""E-stop must zero velocity in the same callback, bypassing ramp."""
node = self._make()
node._current_lin = 2.0
node._current_ang = 1.0
node._target_lin = 2.0
node._target_ang = 1.0
msg = MagicMock()
msg.data = True
node._e_stop_cb(msg)
self.assertEqual(node._current_lin, 0.0)
self.assertEqual(node._current_ang, 0.0)
self.assertTrue(node._e_stop)
def test_no_overshoot(self):
"""Current velocity must never exceed target during ramp-up."""
node = self._make(max_linear_accel=1.0, publish_rate=50.0)
node._target_lin = 0.5
for _ in range(100):
node._timer_cb()
self.assertLessEqual(node._current_lin, 0.5 + 1e-9)
def test_negative_velocity_ramp(self):
"""Ramp works symmetrically for negative targets."""
node = self._make(max_linear_accel=1.0, publish_rate=50.0)
node._target_lin = -1.0
for _ in range(200):
node._timer_cb()
self.assertAlmostEqual(node._current_lin, -1.0, places=2)
def test_angular_ramp(self):
"""Angular velocity ramps correctly."""
node = self._make(max_angular_accel=1.5, publish_rate=50.0)
node._target_ang = 1.0
for _ in range(200):
node._timer_cb()
self.assertAlmostEqual(node._current_ang, 1.0, places=2)
def test_ramp_timing_linear(self):
"""Time to ramp from 0 to v_max ≈ v_max / accel (±10%)."""
accel = 1.0 # m/s²
v_max = 1.0 # m/s
rate = 50.0
expected_time = v_max / accel # 1.0 s
node = self._make(max_linear_accel=accel, publish_rate=rate)
node._target_lin = v_max
steps = 0
while abs(node._current_lin - v_max) > 0.01 and steps < 500:
node._timer_cb()
steps += 1
actual_time = steps / rate
self.assertAlmostEqual(actual_time, expected_time, delta=expected_time * 0.12,
msg=f"Ramp time {actual_time:.2f}s expected ~{expected_time:.2f}s")
if __name__ == '__main__':
unittest.main()