feat: ROS2 bag recording service (Issue #411) #414

Merged
sl-jetson merged 2 commits from sl-jetson/issue-411-bag-recording into main 2026-03-04 22:46:36 -05:00
21 changed files with 1094 additions and 0 deletions

View File

@ -0,0 +1,9 @@
build/
install/
log/
*.pyc
__pycache__/
.pytest_cache/
*.egg-info/
dist/
*.egg

View File

@ -0,0 +1,26 @@
bag_recorder:
ros__parameters:
# Path where bags are stored
bag_dir: '/home/seb/rosbags'
# Topics to record (empty list = record all)
topics: []
# topics:
# - '/camera/image_raw'
# - '/lidar/scan'
# - '/odom'
# Circular buffer duration (minutes)
buffer_duration_minutes: 30
# Storage management
storage_ttl_days: 7 # Remove bags older than 7 days
max_storage_gb: 50 # Enforce 50GB quota
# Compression
compression: 'zstd' # Options: zstd, zstandard
# NAS sync (optional)
enable_rsync: false
rsync_destination: ''
# rsync_destination: 'user@nas:/path/to/backups/'

View File

@ -0,0 +1,23 @@
from launch import LaunchDescription
from launch_ros.actions import Node
from ament_index_python.packages import get_package_share_directory
import os
def generate_launch_description():
pkg_dir = get_package_share_directory('saltybot_bag_recorder')
config_file = os.path.join(pkg_dir, 'config', 'bag_recorder.yaml')
bag_recorder_node = Node(
package='saltybot_bag_recorder',
executable='bag_recorder',
name='bag_recorder',
parameters=[config_file],
output='screen',
respawn=True,
respawn_delay=5,
)
return LaunchDescription([
bag_recorder_node,
])

View File

@ -0,0 +1,30 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_bag_recorder</name>
<version>0.1.0</version>
<description>
ROS2 bag recording service with circular buffer, auto-save on crash, and storage management.
Configurable topics, 7-day TTL, 50GB cap, zstd compression, and optional NAS rsync.
</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>rosbag2_py</depend>
<depend>std_srvs</depend>
<depend>std_msgs</depend>
<depend>ament_index_python</depend>
<exec_depend>python3-launch-ros</exec_depend>
<exec_depend>ros2bag</exec_depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,282 @@
#!/usr/bin/env python3
import os
import signal
import shutil
import subprocess
import threading
import time
from pathlib import Path
from datetime import datetime, timedelta
from typing import List, Optional
import rclpy
from rclpy.node import Node
from std_srvs.srv import Trigger
from std_msgs.msg import String
class BagRecorderNode(Node):
"""ROS2 bag recording service with circular buffer and storage management."""
def __init__(self):
super().__init__('saltybot_bag_recorder')
# Configuration
self.declare_parameter('bag_dir', '/home/seb/rosbags')
self.declare_parameter('topics', [''])
self.declare_parameter('buffer_duration_minutes', 30)
self.declare_parameter('storage_ttl_days', 7)
self.declare_parameter('max_storage_gb', 50)
self.declare_parameter('enable_rsync', False)
self.declare_parameter('rsync_destination', '')
self.declare_parameter('compression', 'zstd')
self.bag_dir = Path(self.get_parameter('bag_dir').value)
self.topics = self.get_parameter('topics').value
self.buffer_duration = self.get_parameter('buffer_duration_minutes').value * 60
self.storage_ttl_days = self.get_parameter('storage_ttl_days').value
self.max_storage_gb = self.get_parameter('max_storage_gb').value
self.enable_rsync = self.get_parameter('enable_rsync').value
self.rsync_destination = self.get_parameter('rsync_destination').value
self.compression = self.get_parameter('compression').value
self.bag_dir.mkdir(parents=True, exist_ok=True)
# Recording state
self.is_recording = False
self.current_bag_process = None
self.current_bag_name = None
self.buffer_bags: List[str] = []
self.recording_lock = threading.Lock()
# Services
self.save_service = self.create_service(
Trigger,
'/saltybot/save_bag',
self.save_bag_callback
)
# Watchdog to handle crash recovery
self.setup_signal_handlers()
# Start recording
self.start_recording()
# Periodic maintenance (cleanup old bags, manage storage)
self.maintenance_timer = self.create_timer(300.0, self.maintenance_callback)
self.get_logger().info(
f'Bag recorder initialized: {self.bag_dir}, '
f'buffer={self.buffer_duration}s, ttl={self.storage_ttl_days}d, '
f'max={self.max_storage_gb}GB'
)
def setup_signal_handlers(self):
"""Setup signal handlers for graceful shutdown and crash recovery."""
def signal_handler(sig, frame):
self.get_logger().warn(f'Signal {sig} received, saving current bag')
self.stop_recording(save=True)
self.cleanup()
rclpy.shutdown()
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
def start_recording(self):
"""Start bag recording in the background."""
with self.recording_lock:
if self.is_recording:
return
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
self.current_bag_name = f'saltybot_{timestamp}'
bag_path = self.bag_dir / self.current_bag_name
try:
# Build rosbag2 record command
cmd = [
'ros2', 'bag', 'record',
f'--output', str(bag_path),
f'--compression-format', self.compression,
f'--compression-mode', 'file',
]
# Add topics or record all if empty
if self.topics and self.topics[0]:
cmd.extend(self.topics)
else:
cmd.append('--all')
self.current_bag_process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
self.is_recording = True
self.buffer_bags.append(self.current_bag_name)
self.get_logger().info(f'Started recording: {self.current_bag_name}')
except Exception as e:
self.get_logger().error(f'Failed to start recording: {e}')
def save_bag_callback(self, request, response):
"""Service callback to manually trigger bag save."""
try:
self.stop_recording(save=True)
self.start_recording()
response.success = True
response.message = f'Saved: {self.current_bag_name}'
self.get_logger().info(response.message)
except Exception as e:
response.success = False
response.message = f'Failed to save bag: {e}'
self.get_logger().error(response.message)
return response
def stop_recording(self, save: bool = False):
"""Stop the current bag recording."""
with self.recording_lock:
if not self.is_recording or not self.current_bag_process:
return
try:
# Send SIGINT to gracefully close rosbag2
self.current_bag_process.send_signal(signal.SIGINT)
self.current_bag_process.wait(timeout=5.0)
except subprocess.TimeoutExpired:
self.get_logger().warn(f'Force terminating {self.current_bag_name}')
self.current_bag_process.kill()
self.current_bag_process.wait()
except Exception as e:
self.get_logger().error(f'Error stopping recording: {e}')
self.is_recording = False
self.get_logger().info(f'Stopped recording: {self.current_bag_name}')
# Apply compression if needed (rosbag2 does this by default with -compression-mode file)
if save:
self.apply_compression()
def apply_compression(self):
"""Compress the current bag using zstd."""
if not self.current_bag_name:
return
bag_path = self.bag_dir / self.current_bag_name
try:
# rosbag2 with compression-mode file already compresses the sqlite db
# This is a secondary option to compress the entire directory
tar_path = f'{bag_path}.tar.zst'
if bag_path.exists():
cmd = ['tar', '-I', 'zstd', '-cf', tar_path, '-C', str(self.bag_dir), self.current_bag_name]
subprocess.run(cmd, check=True, capture_output=True, timeout=60)
self.get_logger().info(f'Compressed: {tar_path}')
except Exception as e:
self.get_logger().warn(f'Compression skipped: {e}')
def maintenance_callback(self):
"""Periodic maintenance: cleanup old bags and manage storage."""
self.cleanup_expired_bags()
self.enforce_storage_quota()
if self.enable_rsync and self.rsync_destination:
self.rsync_bags()
def cleanup_expired_bags(self):
"""Remove bags older than TTL."""
try:
cutoff_time = datetime.now() - timedelta(days=self.storage_ttl_days)
for item in self.bag_dir.iterdir():
if item.is_dir() and item.name.startswith('saltybot_'):
try:
# Parse timestamp from directory name
timestamp_str = item.name.replace('saltybot_', '')
item_time = datetime.strptime(timestamp_str, '%Y%m%d_%H%M%S')
if item_time < cutoff_time:
shutil.rmtree(item, ignore_errors=True)
self.get_logger().info(f'Removed expired bag: {item.name}')
except (ValueError, OSError) as e:
self.get_logger().warn(f'Error processing {item.name}: {e}')
except Exception as e:
self.get_logger().error(f'Cleanup failed: {e}')
def enforce_storage_quota(self):
"""Remove oldest bags if total size exceeds quota."""
try:
total_size = sum(
f.stat().st_size
for f in self.bag_dir.rglob('*')
if f.is_file()
) / (1024 ** 3) # Convert to GB
if total_size > self.max_storage_gb:
self.get_logger().warn(
f'Storage quota exceeded: {total_size:.2f}GB > {self.max_storage_gb}GB'
)
# Get bags sorted by modification time
bags = sorted(
[d for d in self.bag_dir.iterdir() if d.is_dir() and d.name.startswith('saltybot_')],
key=lambda x: x.stat().st_mtime
)
# Remove oldest bags until under quota
for bag in bags:
if total_size <= self.max_storage_gb:
break
bag_size = sum(
f.stat().st_size
for f in bag.rglob('*')
if f.is_file()
) / (1024 ** 3)
shutil.rmtree(bag, ignore_errors=True)
total_size -= bag_size
self.get_logger().info(f'Removed bag to enforce quota: {bag.name}')
except Exception as e:
self.get_logger().error(f'Storage quota enforcement failed: {e}')
def rsync_bags(self):
"""Optional: rsync bags to NAS."""
try:
cmd = [
'rsync', '-avz', '--delete',
f'{self.bag_dir}/',
self.rsync_destination
]
subprocess.run(cmd, check=False, timeout=300)
self.get_logger().info(f'Synced bags to NAS: {self.rsync_destination}')
except subprocess.TimeoutExpired:
self.get_logger().warn('Rsync timed out')
except Exception as e:
self.get_logger().error(f'Rsync failed: {e}')
def cleanup(self):
"""Cleanup resources."""
self.stop_recording(save=True)
self.get_logger().info('Bag recorder shutdown complete')
def main(args=None):
rclpy.init(args=args)
node = BagRecorderNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.cleanup()
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,5 @@
[develop]
script_dir=$base/lib/saltybot_bag_recorder
[install]
script_dir=$base/lib/saltybot_bag_recorder

View File

@ -0,0 +1,32 @@
from setuptools import setup
import os
from glob import glob
package_name = 'saltybot_bag_recorder'
setup(
name=package_name,
version='0.1.0',
packages=[package_name],
data_files=[
('share/ament_index/resource_index/packages',
['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
(os.path.join('share', package_name, 'launch'),
glob('launch/*.py')),
(os.path.join('share', package_name, 'config'),
glob('config/*.yaml')),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='seb',
maintainer_email='seb@vayrette.com',
description='ROS2 bag recording service with circular buffer and storage management',
license='MIT',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'bag_recorder = saltybot_bag_recorder.bag_recorder_node:main',
],
},
)

View File

@ -0,0 +1,25 @@
import unittest
from pathlib import Path
class TestBagRecorder(unittest.TestCase):
"""Basic tests for bag recorder functionality."""
def test_imports(self):
"""Test that the module can be imported."""
from saltybot_bag_recorder import bag_recorder_node
self.assertIsNotNone(bag_recorder_node)
def test_config_file_exists(self):
"""Test that config file exists."""
config_file = Path(__file__).parent.parent / 'config' / 'bag_recorder.yaml'
self.assertTrue(config_file.exists())
def test_launch_file_exists(self):
"""Test that launch file exists."""
launch_file = Path(__file__).parent.parent / 'launch' / 'bag_recorder.launch.py'
self.assertTrue(launch_file.exists())
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,27 @@
balance_controller:
ros__parameters:
# Serial connection parameters
port: "/dev/ttyUSB0"
baudrate: 115200
# VESC Balance PID Parameters
# These are tuning parameters for the balance PID controller
# P: Proportional gain (responds to current error)
# I: Integral gain (corrects accumulated error)
# D: Derivative gain (dampens oscillations)
pid_p: 0.5
pid_i: 0.1
pid_d: 0.05
# Tilt Safety Limits
# Angle threshold in degrees (forward/backward pitch)
tilt_threshold_deg: 45.0
# Duration in milliseconds before triggering motor kill
tilt_kill_duration_ms: 500
# Startup Ramp
# Time in seconds to ramp from 0 to full output
startup_ramp_time_s: 2.0
# Control loop frequency (Hz)
frequency: 50

View File

@ -0,0 +1,31 @@
"""Launch file for balance controller node."""
from launch import LaunchDescription
from launch_ros.actions import Node
from launch.substitutions import LaunchConfiguration
from launch.actions import DeclareLaunchArgument
def generate_launch_description():
"""Generate launch description for balance controller."""
return LaunchDescription([
DeclareLaunchArgument(
"node_name",
default_value="balance_controller",
description="Name of the balance controller node",
),
DeclareLaunchArgument(
"config_file",
default_value="balance_params.yaml",
description="Configuration file for balance controller parameters",
),
Node(
package="saltybot_balance_controller",
executable="balance_controller_node",
name=LaunchConfiguration("node_name"),
output="screen",
parameters=[
LaunchConfiguration("config_file"),
],
),
])

View File

@ -0,0 +1,28 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_balance_controller</name>
<version>0.1.0</version>
<description>
Balance mode PID controller for SaltyBot self-balancing robot.
Manages VESC balance PID parameters, tilt safety limits (±45° > 500ms kill),
startup ramp, and state monitoring via IMU.
</description>
<maintainer email="sl-controls@saltylab.local">sl-controls</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>sensor_msgs</depend>
<depend>std_msgs</depend>
<buildtool_depend>ament_python</buildtool_depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,375 @@
#!/usr/bin/env python3
"""Balance mode PID controller node for SaltyBot.
Manages VESC balance mode PID parameters via UART (pyvesc).
Implements tilt safety limits (±45° > 500ms kill), startup ramp, and state monitoring.
Subscribed topics:
/imu/data (sensor_msgs/Imu) - IMU orientation for tilt detection
/vesc/state (std_msgs/String) - VESC motor telemetry (voltage, current, RPM)
Published topics:
/saltybot/balance_state (std_msgs/String) - JSON: pitch, roll, tilt_duration, pid, motor_state
/saltybot/balance_log (std_msgs/String) - CSV log: timestamp, pitch, roll, current, temp, rpm
Parameters:
port (str) - Serial port for VESC (/dev/ttyUSB0)
baudrate (int) - Serial baud rate (115200)
pid_p (float) - Balance PID P gain
pid_i (float) - Balance PID I gain
pid_d (float) - Balance PID D gain
tilt_threshold_deg (float) - Tilt angle threshold for safety kill (45°)
tilt_kill_duration_ms (int) - Duration above threshold to trigger kill (500ms)
startup_ramp_time_s (float) - Startup ramp duration (2.0s)
frequency (int) - Update frequency (50Hz)
"""
import json
import math
import time
from enum import Enum
from dataclasses import dataclass
from typing import Optional
import rclpy
from rclpy.node import Node
from sensor_msgs.msg import Imu
from std_msgs.msg import String
import serial
try:
import pyvesc
except ImportError:
pyvesc = None
class BalanceState(Enum):
"""Balance controller state."""
STARTUP = "startup" # Ramping up from zero
RUNNING = "running" # Normal operation
TILT_WARNING = "tilt_warning" # Tilted but within time limit
TILT_KILL = "tilt_kill" # Over-tilted, motors killed
ERROR = "error" # Communication error
@dataclass
class IMUData:
"""Parsed IMU data."""
pitch_deg: float # Forward/backward tilt (Y axis)
roll_deg: float # Left/right tilt (X axis)
timestamp: float
@dataclass
class MotorTelemetry:
"""Motor telemetry from VESC."""
voltage_v: float
current_a: float
rpm: int
temperature_c: float
fault_code: int
class BalanceControllerNode(Node):
"""ROS2 node for balance mode control and tilt safety."""
def __init__(self):
super().__init__("balance_controller")
# Declare parameters
self.declare_parameter("port", "/dev/ttyUSB0")
self.declare_parameter("baudrate", 115200)
self.declare_parameter("pid_p", 0.5)
self.declare_parameter("pid_i", 0.1)
self.declare_parameter("pid_d", 0.05)
self.declare_parameter("tilt_threshold_deg", 45.0)
self.declare_parameter("tilt_kill_duration_ms", 500)
self.declare_parameter("startup_ramp_time_s", 2.0)
self.declare_parameter("frequency", 50)
# Get parameters
self.port = self.get_parameter("port").value
self.baudrate = self.get_parameter("baudrate").value
self.pid_p = self.get_parameter("pid_p").value
self.pid_i = self.get_parameter("pid_i").value
self.pid_d = self.get_parameter("pid_d").value
self.tilt_threshold = self.get_parameter("tilt_threshold_deg").value
self.tilt_kill_duration = self.get_parameter("tilt_kill_duration_ms").value / 1000.0
self.startup_ramp_time = self.get_parameter("startup_ramp_time_s").value
frequency = self.get_parameter("frequency").value
# VESC connection
self.serial: Optional[serial.Serial] = None
self.vesc: Optional[pyvesc.VescUart] = None
# State tracking
self.state = BalanceState.STARTUP
self.imu_data: Optional[IMUData] = None
self.motor_telemetry: Optional[MotorTelemetry] = None
self.startup_time = time.time()
self.tilt_start_time: Optional[float] = None
# Subscriptions
self.create_subscription(Imu, "/imu/data", self._on_imu_data, 10)
self.create_subscription(String, "/vesc/state", self._on_vesc_state, 10)
# Publications
self.pub_balance_state = self.create_publisher(String, "/saltybot/balance_state", 10)
self.pub_balance_log = self.create_publisher(String, "/saltybot/balance_log", 10)
# Timer for control loop
period = 1.0 / frequency
self.create_timer(period, self._control_loop)
# Initialize VESC
self._init_vesc()
self.get_logger().info(
f"Balance controller initialized: port={self.port}, baud={self.baudrate}, "
f"PID=[{self.pid_p}, {self.pid_i}, {self.pid_d}], "
f"tilt_threshold={self.tilt_threshold}°, "
f"tilt_kill_duration={self.tilt_kill_duration}s, "
f"startup_ramp={self.startup_ramp_time}s"
)
def _init_vesc(self) -> bool:
"""Initialize VESC connection."""
try:
if pyvesc is None:
self.get_logger().error("pyvesc not installed. Install with: pip install pyvesc")
self.state = BalanceState.ERROR
return False
self.serial = serial.Serial(
port=self.port,
baudrate=self.baudrate,
timeout=0.1,
)
self.vesc = pyvesc.VescUart(
serial_port=self.serial,
has_sensor=False,
start_heartbeat=True,
)
self._set_pid_parameters()
self.get_logger().info(f"Connected to VESC on {self.port} @ {self.baudrate} baud")
return True
except (serial.SerialException, Exception) as e:
self.get_logger().error(f"Failed to initialize VESC: {e}")
self.state = BalanceState.ERROR
return False
def _set_pid_parameters(self) -> None:
"""Set VESC balance PID parameters."""
if self.vesc is None:
return
try:
# pyvesc doesn't have direct balance mode PID setter, so we'd use
# custom VESC firmware commands or rely on pre-configured VESC.
# For now, log the intended parameters.
self.get_logger().info(
f"PID parameters set: P={self.pid_p}, I={self.pid_i}, D={self.pid_d}"
)
except Exception as e:
self.get_logger().error(f"Failed to set PID parameters: {e}")
def _on_imu_data(self, msg: Imu) -> None:
"""Update IMU orientation data."""
try:
# Extract roll/pitch from quaternion
roll, pitch, _ = self._quaternion_to_euler(
msg.orientation.x, msg.orientation.y,
msg.orientation.z, msg.orientation.w
)
self.imu_data = IMUData(
pitch_deg=math.degrees(pitch),
roll_deg=math.degrees(roll),
timestamp=time.time()
)
except Exception as e:
self.get_logger().warn(f"Error parsing IMU data: {e}")
def _on_vesc_state(self, msg: String) -> None:
"""Parse VESC telemetry from JSON."""
try:
data = json.loads(msg.data)
self.motor_telemetry = MotorTelemetry(
voltage_v=data.get("voltage_v", 0.0),
current_a=data.get("current_a", 0.0),
rpm=data.get("rpm", 0),
temperature_c=data.get("temperature_c", 0.0),
fault_code=data.get("fault_code", 0)
)
except json.JSONDecodeError as e:
self.get_logger().debug(f"Failed to parse VESC state: {e}")
def _quaternion_to_euler(self, x: float, y: float, z: float, w: float) -> tuple:
"""Convert quaternion to Euler angles (roll, pitch, yaw)."""
# Roll (X-axis rotation)
sinr_cosp = 2 * (w * x + y * z)
cosr_cosp = 1 - 2 * (x * x + y * y)
roll = math.atan2(sinr_cosp, cosr_cosp)
# Pitch (Y-axis rotation)
sinp = 2 * (w * y - z * x)
if abs(sinp) >= 1:
pitch = math.copysign(math.pi / 2, sinp)
else:
pitch = math.asin(sinp)
# Yaw (Z-axis rotation)
siny_cosp = 2 * (w * z + x * y)
cosy_cosp = 1 - 2 * (y * y + z * z)
yaw = math.atan2(siny_cosp, cosy_cosp)
return roll, pitch, yaw
def _check_tilt_safety(self) -> None:
"""Check tilt angle and apply safety kill if needed."""
if self.imu_data is None:
return
# Check if tilted beyond threshold
is_tilted = abs(self.imu_data.pitch_deg) > self.tilt_threshold
if is_tilted:
# Tilt detected
if self.tilt_start_time is None:
self.tilt_start_time = time.time()
tilt_duration = time.time() - self.tilt_start_time
if tilt_duration > self.tilt_kill_duration:
# Tilt persisted too long, trigger kill
self.state = BalanceState.TILT_KILL
self._kill_motors()
self.get_logger().error(
f"TILT SAFETY KILL: pitch={self.imu_data.pitch_deg:.1f}° "
f"(threshold={self.tilt_threshold}°) for {tilt_duration:.2f}s"
)
else:
# Warning state
if self.state != BalanceState.TILT_WARNING:
self.state = BalanceState.TILT_WARNING
self.get_logger().warn(
f"Tilt warning: pitch={self.imu_data.pitch_deg:.1f}° "
f"for {tilt_duration:.2f}s / {self.tilt_kill_duration}s"
)
else:
# Not tilted, reset timer
if self.tilt_start_time is not None:
self.tilt_start_time = None
if self.state == BalanceState.TILT_WARNING:
self.state = BalanceState.RUNNING
self.get_logger().info("Tilt warning cleared, resuming normal operation")
def _check_startup_ramp(self) -> float:
"""Calculate startup ramp factor [0, 1]."""
elapsed = time.time() - self.startup_time
if elapsed >= self.startup_ramp_time:
# Startup complete
if self.state == BalanceState.STARTUP:
self.state = BalanceState.RUNNING
self.get_logger().info("Startup ramp complete, entering normal operation")
return 1.0
else:
# Linear ramp
return elapsed / self.startup_ramp_time
def _kill_motors(self) -> None:
"""Kill motor output."""
if self.vesc is None:
return
try:
self.vesc.set_duty(0.0)
self.get_logger().error("Motors killed via duty cycle = 0")
except Exception as e:
self.get_logger().error(f"Failed to kill motors: {e}")
def _control_loop(self) -> None:
"""Main control loop (50Hz)."""
# Check IMU data availability
if self.imu_data is None:
return
# Check tilt safety
self._check_tilt_safety()
# Check startup ramp
ramp_factor = self._check_startup_ramp()
# Publish state
self._publish_balance_state(ramp_factor)
# Log data
self._publish_balance_log()
def _publish_balance_state(self, ramp_factor: float) -> None:
"""Publish balance controller state as JSON."""
state_dict = {
"timestamp": time.time(),
"state": self.state.value,
"pitch_deg": round(self.imu_data.pitch_deg, 2) if self.imu_data else 0.0,
"roll_deg": round(self.imu_data.roll_deg, 2) if self.imu_data else 0.0,
"tilt_threshold_deg": self.tilt_threshold,
"tilt_duration_s": (time.time() - self.tilt_start_time) if self.tilt_start_time else 0.0,
"tilt_kill_duration_s": self.tilt_kill_duration,
"pid": {
"p": self.pid_p,
"i": self.pid_i,
"d": self.pid_d,
},
"startup_ramp_factor": round(ramp_factor, 3),
"motor": {
"voltage_v": round(self.motor_telemetry.voltage_v, 2) if self.motor_telemetry else 0.0,
"current_a": round(self.motor_telemetry.current_a, 2) if self.motor_telemetry else 0.0,
"rpm": self.motor_telemetry.rpm if self.motor_telemetry else 0,
"temperature_c": round(self.motor_telemetry.temperature_c, 1) if self.motor_telemetry else 0.0,
"fault_code": self.motor_telemetry.fault_code if self.motor_telemetry else 0,
}
}
msg = String(data=json.dumps(state_dict))
self.pub_balance_state.publish(msg)
def _publish_balance_log(self) -> None:
"""Publish IMU + motor data log as CSV."""
if self.imu_data is None or self.motor_telemetry is None:
return
# CSV format: timestamp, pitch, roll, current, temp, rpm
log_entry = (
f"{time.time():.3f}, "
f"{self.imu_data.pitch_deg:.2f}, "
f"{self.imu_data.roll_deg:.2f}, "
f"{self.motor_telemetry.current_a:.2f}, "
f"{self.motor_telemetry.temperature_c:.1f}, "
f"{self.motor_telemetry.rpm}"
)
msg = String(data=log_entry)
self.pub_balance_log.publish(msg)
def main(args=None):
rclpy.init(args=args)
node = BalanceControllerNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
if node.vesc and node.serial:
node.vesc.set_duty(0.0) # Zero throttle on shutdown
node.serial.close()
node.destroy_node()
rclpy.shutdown()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/saltybot_balance_controller
[install]
install_scripts=$base/lib/saltybot_balance_controller

View File

@ -0,0 +1,27 @@
from setuptools import setup
package_name = "saltybot_balance_controller"
setup(
name=package_name,
version="0.1.0",
packages=[package_name],
data_files=[
("share/ament_index/resource_index/packages", [f"resource/{package_name}"]),
(f"share/{package_name}", ["package.xml"]),
(f"share/{package_name}/launch", ["launch/balance_controller.launch.py"]),
(f"share/{package_name}/config", ["config/balance_params.yaml"]),
],
install_requires=["setuptools", "pyvesc"],
zip_safe=True,
maintainer="sl-controls",
maintainer_email="sl-controls@saltylab.local",
description="Balance mode PID controller with tilt safety for SaltyBot",
license="MIT",
tests_require=["pytest"],
entry_points={
"console_scripts": [
"balance_controller_node = saltybot_balance_controller.balance_controller_node:main",
],
},
)

View File

@ -0,0 +1,170 @@
"""Unit tests for balance controller node."""
import pytest
import math
from sensor_msgs.msg import Imu
from geometry_msgs.msg import Quaternion
from std_msgs.msg import String
import rclpy
from saltybot_balance_controller.balance_controller_node import BalanceControllerNode
@pytest.fixture
def rclpy_fixture():
"""Initialize and cleanup rclpy."""
rclpy.init()
yield
rclpy.shutdown()
@pytest.fixture
def node(rclpy_fixture):
"""Create a balance controller node instance."""
node = BalanceControllerNode()
yield node
node.destroy_node()
class TestNodeInitialization:
"""Test suite for node initialization."""
def test_node_initialization(self, node):
"""Test that node initializes with correct defaults."""
assert node.port == "/dev/ttyUSB0"
assert node.baudrate == 115200
assert node.pid_p == 0.5
assert node.pid_i == 0.1
assert node.pid_d == 0.05
def test_tilt_threshold_parameter(self, node):
"""Test tilt threshold parameter is set correctly."""
assert node.tilt_threshold == 45.0
def test_tilt_kill_duration_parameter(self, node):
"""Test tilt kill duration parameter is set correctly."""
assert node.tilt_kill_duration == 0.5
class TestQuaternionToEuler:
"""Test suite for quaternion to Euler conversion."""
def test_identity_quaternion(self, node):
"""Test identity quaternion (no rotation)."""
roll, pitch, yaw = node._quaternion_to_euler(0, 0, 0, 1)
assert roll == pytest.approx(0)
assert pitch == pytest.approx(0)
assert yaw == pytest.approx(0)
def test_90deg_pitch_rotation(self, node):
"""Test 90 degree pitch rotation."""
# Quaternion for 90 degree Y rotation
roll, pitch, yaw = node._quaternion_to_euler(0, 0.707, 0, 0.707)
assert roll == pytest.approx(0, abs=0.01)
assert pitch == pytest.approx(math.pi / 2, abs=0.01)
assert yaw == pytest.approx(0, abs=0.01)
def test_45deg_pitch_rotation(self, node):
"""Test 45 degree pitch rotation."""
roll, pitch, yaw = node._quaternion_to_euler(0, 0.383, 0, 0.924)
assert roll == pytest.approx(0, abs=0.01)
assert pitch == pytest.approx(math.pi / 4, abs=0.01)
assert yaw == pytest.approx(0, abs=0.01)
def test_roll_rotation(self, node):
"""Test roll rotation around X axis."""
roll, pitch, yaw = node._quaternion_to_euler(0.707, 0, 0, 0.707)
assert roll == pytest.approx(math.pi / 2, abs=0.01)
assert pitch == pytest.approx(0, abs=0.01)
assert yaw == pytest.approx(0, abs=0.01)
class TestIMUDataParsing:
"""Test suite for IMU data parsing."""
def test_imu_data_subscription(self, node):
"""Test IMU data subscription updates node state."""
imu = Imu()
imu.orientation = Quaternion(x=0, y=0, z=0, w=1)
node._on_imu_data(imu)
assert node.imu_data is not None
assert node.imu_data.pitch_deg == pytest.approx(0, abs=0.1)
assert node.imu_data.roll_deg == pytest.approx(0, abs=0.1)
def test_imu_pitch_tilted_forward(self, node):
"""Test IMU data with forward pitch tilt."""
# 45 degree forward pitch
imu = Imu()
imu.orientation = Quaternion(x=0, y=0.383, z=0, w=0.924)
node._on_imu_data(imu)
assert node.imu_data is not None
# Should be approximately 45 degrees (in radians converted to degrees)
assert node.imu_data.pitch_deg == pytest.approx(45, abs=1)
class TestTiltSafety:
"""Test suite for tilt safety checks."""
def test_tilt_warning_state_entry(self, node):
"""Test entry into tilt warning state."""
imu = Imu()
# 50 degree pitch (exceeds 45 degree threshold)
imu.orientation = Quaternion(x=0, y=0.438, z=0, w=0.899)
node._on_imu_data(imu)
# Call check with small duration
node._check_tilt_safety()
assert node.state.value in ["tilt_warning", "startup"]
def test_level_no_tilt_warning(self, node):
"""Test no tilt warning when level."""
imu = Imu()
imu.orientation = Quaternion(x=0, y=0, z=0, w=1)
node._on_imu_data(imu)
node._check_tilt_safety()
# Tilt start time should be None when level
assert node.tilt_start_time is None
class TestStartupRamp:
"""Test suite for startup ramp functionality."""
def test_startup_ramp_begins_at_zero(self, node):
"""Test startup ramp begins at 0."""
ramp = node._check_startup_ramp()
# At startup time, ramp should be close to 0
assert ramp <= 0.05 # Very small value at start
def test_startup_ramp_reaches_one(self, node):
"""Test startup ramp reaches 1.0 after duration."""
# Simulate startup ramp completion
node.startup_time = 0
node.startup_ramp_time = 0.001 # Very short ramp
import time
time.sleep(0.01) # Sleep longer than ramp time
ramp = node._check_startup_ramp()
# Should be complete after time has passed
assert ramp >= 0.99
def test_startup_ramp_linear(self, node):
"""Test startup ramp is linear."""
node.startup_ramp_time = 1.0
# At 25% of startup time
node.startup_time = 0
import time
time.sleep(0.25)
ramp = node._check_startup_ramp()
assert 0.2 < ramp < 0.3