Compare commits
No commits in common. "f69c02880ea8b97e51134d35f0e0b61f093cf6f1" and "9257f4c7de1f1881cc3edb46315b0ec2762e32fb" have entirely different histories.
f69c02880e
...
9257f4c7de
@ -1,9 +0,0 @@
|
|||||||
build/
|
|
||||||
install/
|
|
||||||
log/
|
|
||||||
*.pyc
|
|
||||||
__pycache__/
|
|
||||||
.pytest_cache/
|
|
||||||
*.egg-info/
|
|
||||||
dist/
|
|
||||||
*.egg
|
|
||||||
@ -1,26 +0,0 @@
|
|||||||
bag_recorder:
|
|
||||||
ros__parameters:
|
|
||||||
# Path where bags are stored
|
|
||||||
bag_dir: '/home/seb/rosbags'
|
|
||||||
|
|
||||||
# Topics to record (empty list = record all)
|
|
||||||
topics: []
|
|
||||||
# topics:
|
|
||||||
# - '/camera/image_raw'
|
|
||||||
# - '/lidar/scan'
|
|
||||||
# - '/odom'
|
|
||||||
|
|
||||||
# Circular buffer duration (minutes)
|
|
||||||
buffer_duration_minutes: 30
|
|
||||||
|
|
||||||
# Storage management
|
|
||||||
storage_ttl_days: 7 # Remove bags older than 7 days
|
|
||||||
max_storage_gb: 50 # Enforce 50GB quota
|
|
||||||
|
|
||||||
# Compression
|
|
||||||
compression: 'zstd' # Options: zstd, zstandard
|
|
||||||
|
|
||||||
# NAS sync (optional)
|
|
||||||
enable_rsync: false
|
|
||||||
rsync_destination: ''
|
|
||||||
# rsync_destination: 'user@nas:/path/to/backups/'
|
|
||||||
@ -1,23 +0,0 @@
|
|||||||
from launch import LaunchDescription
|
|
||||||
from launch_ros.actions import Node
|
|
||||||
from ament_index_python.packages import get_package_share_directory
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def generate_launch_description():
|
|
||||||
pkg_dir = get_package_share_directory('saltybot_bag_recorder')
|
|
||||||
config_file = os.path.join(pkg_dir, 'config', 'bag_recorder.yaml')
|
|
||||||
|
|
||||||
bag_recorder_node = Node(
|
|
||||||
package='saltybot_bag_recorder',
|
|
||||||
executable='bag_recorder',
|
|
||||||
name='bag_recorder',
|
|
||||||
parameters=[config_file],
|
|
||||||
output='screen',
|
|
||||||
respawn=True,
|
|
||||||
respawn_delay=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
return LaunchDescription([
|
|
||||||
bag_recorder_node,
|
|
||||||
])
|
|
||||||
@ -1,30 +0,0 @@
|
|||||||
<?xml version="1.0"?>
|
|
||||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
|
||||||
<package format="3">
|
|
||||||
<name>saltybot_bag_recorder</name>
|
|
||||||
<version>0.1.0</version>
|
|
||||||
<description>
|
|
||||||
ROS2 bag recording service with circular buffer, auto-save on crash, and storage management.
|
|
||||||
Configurable topics, 7-day TTL, 50GB cap, zstd compression, and optional NAS rsync.
|
|
||||||
</description>
|
|
||||||
<maintainer email="seb@vayrette.com">seb</maintainer>
|
|
||||||
<license>MIT</license>
|
|
||||||
|
|
||||||
<depend>rclpy</depend>
|
|
||||||
<depend>rosbag2_py</depend>
|
|
||||||
<depend>std_srvs</depend>
|
|
||||||
<depend>std_msgs</depend>
|
|
||||||
<depend>ament_index_python</depend>
|
|
||||||
|
|
||||||
<exec_depend>python3-launch-ros</exec_depend>
|
|
||||||
<exec_depend>ros2bag</exec_depend>
|
|
||||||
|
|
||||||
<test_depend>ament_copyright</test_depend>
|
|
||||||
<test_depend>ament_flake8</test_depend>
|
|
||||||
<test_depend>ament_pep257</test_depend>
|
|
||||||
<test_depend>python3-pytest</test_depend>
|
|
||||||
|
|
||||||
<export>
|
|
||||||
<build_type>ament_python</build_type>
|
|
||||||
</export>
|
|
||||||
</package>
|
|
||||||
@ -1,282 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
import os
|
|
||||||
import signal
|
|
||||||
import shutil
|
|
||||||
import subprocess
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import rclpy
|
|
||||||
from rclpy.node import Node
|
|
||||||
from std_srvs.srv import Trigger
|
|
||||||
from std_msgs.msg import String
|
|
||||||
|
|
||||||
|
|
||||||
class BagRecorderNode(Node):
|
|
||||||
"""ROS2 bag recording service with circular buffer and storage management."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__('saltybot_bag_recorder')
|
|
||||||
|
|
||||||
# Configuration
|
|
||||||
self.declare_parameter('bag_dir', '/home/seb/rosbags')
|
|
||||||
self.declare_parameter('topics', [''])
|
|
||||||
self.declare_parameter('buffer_duration_minutes', 30)
|
|
||||||
self.declare_parameter('storage_ttl_days', 7)
|
|
||||||
self.declare_parameter('max_storage_gb', 50)
|
|
||||||
self.declare_parameter('enable_rsync', False)
|
|
||||||
self.declare_parameter('rsync_destination', '')
|
|
||||||
self.declare_parameter('compression', 'zstd')
|
|
||||||
|
|
||||||
self.bag_dir = Path(self.get_parameter('bag_dir').value)
|
|
||||||
self.topics = self.get_parameter('topics').value
|
|
||||||
self.buffer_duration = self.get_parameter('buffer_duration_minutes').value * 60
|
|
||||||
self.storage_ttl_days = self.get_parameter('storage_ttl_days').value
|
|
||||||
self.max_storage_gb = self.get_parameter('max_storage_gb').value
|
|
||||||
self.enable_rsync = self.get_parameter('enable_rsync').value
|
|
||||||
self.rsync_destination = self.get_parameter('rsync_destination').value
|
|
||||||
self.compression = self.get_parameter('compression').value
|
|
||||||
|
|
||||||
self.bag_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Recording state
|
|
||||||
self.is_recording = False
|
|
||||||
self.current_bag_process = None
|
|
||||||
self.current_bag_name = None
|
|
||||||
self.buffer_bags: List[str] = []
|
|
||||||
self.recording_lock = threading.Lock()
|
|
||||||
|
|
||||||
# Services
|
|
||||||
self.save_service = self.create_service(
|
|
||||||
Trigger,
|
|
||||||
'/saltybot/save_bag',
|
|
||||||
self.save_bag_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
# Watchdog to handle crash recovery
|
|
||||||
self.setup_signal_handlers()
|
|
||||||
|
|
||||||
# Start recording
|
|
||||||
self.start_recording()
|
|
||||||
|
|
||||||
# Periodic maintenance (cleanup old bags, manage storage)
|
|
||||||
self.maintenance_timer = self.create_timer(300.0, self.maintenance_callback)
|
|
||||||
|
|
||||||
self.get_logger().info(
|
|
||||||
f'Bag recorder initialized: {self.bag_dir}, '
|
|
||||||
f'buffer={self.buffer_duration}s, ttl={self.storage_ttl_days}d, '
|
|
||||||
f'max={self.max_storage_gb}GB'
|
|
||||||
)
|
|
||||||
|
|
||||||
def setup_signal_handlers(self):
|
|
||||||
"""Setup signal handlers for graceful shutdown and crash recovery."""
|
|
||||||
def signal_handler(sig, frame):
|
|
||||||
self.get_logger().warn(f'Signal {sig} received, saving current bag')
|
|
||||||
self.stop_recording(save=True)
|
|
||||||
self.cleanup()
|
|
||||||
rclpy.shutdown()
|
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
|
|
||||||
def start_recording(self):
|
|
||||||
"""Start bag recording in the background."""
|
|
||||||
with self.recording_lock:
|
|
||||||
if self.is_recording:
|
|
||||||
return
|
|
||||||
|
|
||||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
||||||
self.current_bag_name = f'saltybot_{timestamp}'
|
|
||||||
bag_path = self.bag_dir / self.current_bag_name
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Build rosbag2 record command
|
|
||||||
cmd = [
|
|
||||||
'ros2', 'bag', 'record',
|
|
||||||
f'--output', str(bag_path),
|
|
||||||
f'--compression-format', self.compression,
|
|
||||||
f'--compression-mode', 'file',
|
|
||||||
]
|
|
||||||
|
|
||||||
# Add topics or record all if empty
|
|
||||||
if self.topics and self.topics[0]:
|
|
||||||
cmd.extend(self.topics)
|
|
||||||
else:
|
|
||||||
cmd.append('--all')
|
|
||||||
|
|
||||||
self.current_bag_process = subprocess.Popen(
|
|
||||||
cmd,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.PIPE
|
|
||||||
)
|
|
||||||
self.is_recording = True
|
|
||||||
self.buffer_bags.append(self.current_bag_name)
|
|
||||||
|
|
||||||
self.get_logger().info(f'Started recording: {self.current_bag_name}')
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.get_logger().error(f'Failed to start recording: {e}')
|
|
||||||
|
|
||||||
def save_bag_callback(self, request, response):
|
|
||||||
"""Service callback to manually trigger bag save."""
|
|
||||||
try:
|
|
||||||
self.stop_recording(save=True)
|
|
||||||
self.start_recording()
|
|
||||||
response.success = True
|
|
||||||
response.message = f'Saved: {self.current_bag_name}'
|
|
||||||
self.get_logger().info(response.message)
|
|
||||||
except Exception as e:
|
|
||||||
response.success = False
|
|
||||||
response.message = f'Failed to save bag: {e}'
|
|
||||||
self.get_logger().error(response.message)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
def stop_recording(self, save: bool = False):
|
|
||||||
"""Stop the current bag recording."""
|
|
||||||
with self.recording_lock:
|
|
||||||
if not self.is_recording or not self.current_bag_process:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Send SIGINT to gracefully close rosbag2
|
|
||||||
self.current_bag_process.send_signal(signal.SIGINT)
|
|
||||||
self.current_bag_process.wait(timeout=5.0)
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
self.get_logger().warn(f'Force terminating {self.current_bag_name}')
|
|
||||||
self.current_bag_process.kill()
|
|
||||||
self.current_bag_process.wait()
|
|
||||||
except Exception as e:
|
|
||||||
self.get_logger().error(f'Error stopping recording: {e}')
|
|
||||||
|
|
||||||
self.is_recording = False
|
|
||||||
self.get_logger().info(f'Stopped recording: {self.current_bag_name}')
|
|
||||||
|
|
||||||
# Apply compression if needed (rosbag2 does this by default with -compression-mode file)
|
|
||||||
if save:
|
|
||||||
self.apply_compression()
|
|
||||||
|
|
||||||
def apply_compression(self):
|
|
||||||
"""Compress the current bag using zstd."""
|
|
||||||
if not self.current_bag_name:
|
|
||||||
return
|
|
||||||
|
|
||||||
bag_path = self.bag_dir / self.current_bag_name
|
|
||||||
try:
|
|
||||||
# rosbag2 with compression-mode file already compresses the sqlite db
|
|
||||||
# This is a secondary option to compress the entire directory
|
|
||||||
tar_path = f'{bag_path}.tar.zst'
|
|
||||||
|
|
||||||
if bag_path.exists():
|
|
||||||
cmd = ['tar', '-I', 'zstd', '-cf', tar_path, '-C', str(self.bag_dir), self.current_bag_name]
|
|
||||||
subprocess.run(cmd, check=True, capture_output=True, timeout=60)
|
|
||||||
self.get_logger().info(f'Compressed: {tar_path}')
|
|
||||||
except Exception as e:
|
|
||||||
self.get_logger().warn(f'Compression skipped: {e}')
|
|
||||||
|
|
||||||
def maintenance_callback(self):
|
|
||||||
"""Periodic maintenance: cleanup old bags and manage storage."""
|
|
||||||
self.cleanup_expired_bags()
|
|
||||||
self.enforce_storage_quota()
|
|
||||||
if self.enable_rsync and self.rsync_destination:
|
|
||||||
self.rsync_bags()
|
|
||||||
|
|
||||||
def cleanup_expired_bags(self):
|
|
||||||
"""Remove bags older than TTL."""
|
|
||||||
try:
|
|
||||||
cutoff_time = datetime.now() - timedelta(days=self.storage_ttl_days)
|
|
||||||
|
|
||||||
for item in self.bag_dir.iterdir():
|
|
||||||
if item.is_dir() and item.name.startswith('saltybot_'):
|
|
||||||
try:
|
|
||||||
# Parse timestamp from directory name
|
|
||||||
timestamp_str = item.name.replace('saltybot_', '')
|
|
||||||
item_time = datetime.strptime(timestamp_str, '%Y%m%d_%H%M%S')
|
|
||||||
|
|
||||||
if item_time < cutoff_time:
|
|
||||||
shutil.rmtree(item, ignore_errors=True)
|
|
||||||
self.get_logger().info(f'Removed expired bag: {item.name}')
|
|
||||||
except (ValueError, OSError) as e:
|
|
||||||
self.get_logger().warn(f'Error processing {item.name}: {e}')
|
|
||||||
except Exception as e:
|
|
||||||
self.get_logger().error(f'Cleanup failed: {e}')
|
|
||||||
|
|
||||||
def enforce_storage_quota(self):
|
|
||||||
"""Remove oldest bags if total size exceeds quota."""
|
|
||||||
try:
|
|
||||||
total_size = sum(
|
|
||||||
f.stat().st_size
|
|
||||||
for f in self.bag_dir.rglob('*')
|
|
||||||
if f.is_file()
|
|
||||||
) / (1024 ** 3) # Convert to GB
|
|
||||||
|
|
||||||
if total_size > self.max_storage_gb:
|
|
||||||
self.get_logger().warn(
|
|
||||||
f'Storage quota exceeded: {total_size:.2f}GB > {self.max_storage_gb}GB'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get bags sorted by modification time
|
|
||||||
bags = sorted(
|
|
||||||
[d for d in self.bag_dir.iterdir() if d.is_dir() and d.name.startswith('saltybot_')],
|
|
||||||
key=lambda x: x.stat().st_mtime
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remove oldest bags until under quota
|
|
||||||
for bag in bags:
|
|
||||||
if total_size <= self.max_storage_gb:
|
|
||||||
break
|
|
||||||
|
|
||||||
bag_size = sum(
|
|
||||||
f.stat().st_size
|
|
||||||
for f in bag.rglob('*')
|
|
||||||
if f.is_file()
|
|
||||||
) / (1024 ** 3)
|
|
||||||
|
|
||||||
shutil.rmtree(bag, ignore_errors=True)
|
|
||||||
total_size -= bag_size
|
|
||||||
self.get_logger().info(f'Removed bag to enforce quota: {bag.name}')
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.get_logger().error(f'Storage quota enforcement failed: {e}')
|
|
||||||
|
|
||||||
def rsync_bags(self):
|
|
||||||
"""Optional: rsync bags to NAS."""
|
|
||||||
try:
|
|
||||||
cmd = [
|
|
||||||
'rsync', '-avz', '--delete',
|
|
||||||
f'{self.bag_dir}/',
|
|
||||||
self.rsync_destination
|
|
||||||
]
|
|
||||||
subprocess.run(cmd, check=False, timeout=300)
|
|
||||||
self.get_logger().info(f'Synced bags to NAS: {self.rsync_destination}')
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
self.get_logger().warn('Rsync timed out')
|
|
||||||
except Exception as e:
|
|
||||||
self.get_logger().error(f'Rsync failed: {e}')
|
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
"""Cleanup resources."""
|
|
||||||
self.stop_recording(save=True)
|
|
||||||
self.get_logger().info('Bag recorder shutdown complete')
|
|
||||||
|
|
||||||
|
|
||||||
def main(args=None):
|
|
||||||
rclpy.init(args=args)
|
|
||||||
node = BagRecorderNode()
|
|
||||||
|
|
||||||
try:
|
|
||||||
rclpy.spin(node)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
node.cleanup()
|
|
||||||
node.destroy_node()
|
|
||||||
rclpy.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
@ -1,5 +0,0 @@
|
|||||||
[develop]
|
|
||||||
script_dir=$base/lib/saltybot_bag_recorder
|
|
||||||
|
|
||||||
[install]
|
|
||||||
script_dir=$base/lib/saltybot_bag_recorder
|
|
||||||
@ -1,32 +0,0 @@
|
|||||||
from setuptools import setup
|
|
||||||
import os
|
|
||||||
from glob import glob
|
|
||||||
|
|
||||||
package_name = 'saltybot_bag_recorder'
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name=package_name,
|
|
||||||
version='0.1.0',
|
|
||||||
packages=[package_name],
|
|
||||||
data_files=[
|
|
||||||
('share/ament_index/resource_index/packages',
|
|
||||||
['resource/' + package_name]),
|
|
||||||
('share/' + package_name, ['package.xml']),
|
|
||||||
(os.path.join('share', package_name, 'launch'),
|
|
||||||
glob('launch/*.py')),
|
|
||||||
(os.path.join('share', package_name, 'config'),
|
|
||||||
glob('config/*.yaml')),
|
|
||||||
],
|
|
||||||
install_requires=['setuptools'],
|
|
||||||
zip_safe=True,
|
|
||||||
maintainer='seb',
|
|
||||||
maintainer_email='seb@vayrette.com',
|
|
||||||
description='ROS2 bag recording service with circular buffer and storage management',
|
|
||||||
license='MIT',
|
|
||||||
tests_require=['pytest'],
|
|
||||||
entry_points={
|
|
||||||
'console_scripts': [
|
|
||||||
'bag_recorder = saltybot_bag_recorder.bag_recorder_node:main',
|
|
||||||
],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@ -1,25 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
class TestBagRecorder(unittest.TestCase):
|
|
||||||
"""Basic tests for bag recorder functionality."""
|
|
||||||
|
|
||||||
def test_imports(self):
|
|
||||||
"""Test that the module can be imported."""
|
|
||||||
from saltybot_bag_recorder import bag_recorder_node
|
|
||||||
self.assertIsNotNone(bag_recorder_node)
|
|
||||||
|
|
||||||
def test_config_file_exists(self):
|
|
||||||
"""Test that config file exists."""
|
|
||||||
config_file = Path(__file__).parent.parent / 'config' / 'bag_recorder.yaml'
|
|
||||||
self.assertTrue(config_file.exists())
|
|
||||||
|
|
||||||
def test_launch_file_exists(self):
|
|
||||||
"""Test that launch file exists."""
|
|
||||||
launch_file = Path(__file__).parent.parent / 'launch' / 'bag_recorder.launch.py'
|
|
||||||
self.assertTrue(launch_file.exists())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
||||||
@ -1,27 +0,0 @@
|
|||||||
balance_controller:
|
|
||||||
ros__parameters:
|
|
||||||
# Serial connection parameters
|
|
||||||
port: "/dev/ttyUSB0"
|
|
||||||
baudrate: 115200
|
|
||||||
|
|
||||||
# VESC Balance PID Parameters
|
|
||||||
# These are tuning parameters for the balance PID controller
|
|
||||||
# P: Proportional gain (responds to current error)
|
|
||||||
# I: Integral gain (corrects accumulated error)
|
|
||||||
# D: Derivative gain (dampens oscillations)
|
|
||||||
pid_p: 0.5
|
|
||||||
pid_i: 0.1
|
|
||||||
pid_d: 0.05
|
|
||||||
|
|
||||||
# Tilt Safety Limits
|
|
||||||
# Angle threshold in degrees (forward/backward pitch)
|
|
||||||
tilt_threshold_deg: 45.0
|
|
||||||
# Duration in milliseconds before triggering motor kill
|
|
||||||
tilt_kill_duration_ms: 500
|
|
||||||
|
|
||||||
# Startup Ramp
|
|
||||||
# Time in seconds to ramp from 0 to full output
|
|
||||||
startup_ramp_time_s: 2.0
|
|
||||||
|
|
||||||
# Control loop frequency (Hz)
|
|
||||||
frequency: 50
|
|
||||||
@ -1,31 +0,0 @@
|
|||||||
"""Launch file for balance controller node."""
|
|
||||||
|
|
||||||
from launch import LaunchDescription
|
|
||||||
from launch_ros.actions import Node
|
|
||||||
from launch.substitutions import LaunchConfiguration
|
|
||||||
from launch.actions import DeclareLaunchArgument
|
|
||||||
|
|
||||||
|
|
||||||
def generate_launch_description():
|
|
||||||
"""Generate launch description for balance controller."""
|
|
||||||
return LaunchDescription([
|
|
||||||
DeclareLaunchArgument(
|
|
||||||
"node_name",
|
|
||||||
default_value="balance_controller",
|
|
||||||
description="Name of the balance controller node",
|
|
||||||
),
|
|
||||||
DeclareLaunchArgument(
|
|
||||||
"config_file",
|
|
||||||
default_value="balance_params.yaml",
|
|
||||||
description="Configuration file for balance controller parameters",
|
|
||||||
),
|
|
||||||
Node(
|
|
||||||
package="saltybot_balance_controller",
|
|
||||||
executable="balance_controller_node",
|
|
||||||
name=LaunchConfiguration("node_name"),
|
|
||||||
output="screen",
|
|
||||||
parameters=[
|
|
||||||
LaunchConfiguration("config_file"),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
])
|
|
||||||
@ -1,28 +0,0 @@
|
|||||||
<?xml version="1.0"?>
|
|
||||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
|
||||||
<package format="3">
|
|
||||||
<name>saltybot_balance_controller</name>
|
|
||||||
<version>0.1.0</version>
|
|
||||||
<description>
|
|
||||||
Balance mode PID controller for SaltyBot self-balancing robot.
|
|
||||||
Manages VESC balance PID parameters, tilt safety limits (±45° > 500ms kill),
|
|
||||||
startup ramp, and state monitoring via IMU.
|
|
||||||
</description>
|
|
||||||
<maintainer email="sl-controls@saltylab.local">sl-controls</maintainer>
|
|
||||||
<license>MIT</license>
|
|
||||||
|
|
||||||
<depend>rclpy</depend>
|
|
||||||
<depend>sensor_msgs</depend>
|
|
||||||
<depend>std_msgs</depend>
|
|
||||||
|
|
||||||
<buildtool_depend>ament_python</buildtool_depend>
|
|
||||||
|
|
||||||
<test_depend>ament_copyright</test_depend>
|
|
||||||
<test_depend>ament_flake8</test_depend>
|
|
||||||
<test_depend>ament_pep257</test_depend>
|
|
||||||
<test_depend>python3-pytest</test_depend>
|
|
||||||
|
|
||||||
<export>
|
|
||||||
<build_type>ament_python</build_type>
|
|
||||||
</export>
|
|
||||||
</package>
|
|
||||||
@ -1,375 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""Balance mode PID controller node for SaltyBot.
|
|
||||||
|
|
||||||
Manages VESC balance mode PID parameters via UART (pyvesc).
|
|
||||||
Implements tilt safety limits (±45° > 500ms kill), startup ramp, and state monitoring.
|
|
||||||
|
|
||||||
Subscribed topics:
|
|
||||||
/imu/data (sensor_msgs/Imu) - IMU orientation for tilt detection
|
|
||||||
/vesc/state (std_msgs/String) - VESC motor telemetry (voltage, current, RPM)
|
|
||||||
|
|
||||||
Published topics:
|
|
||||||
/saltybot/balance_state (std_msgs/String) - JSON: pitch, roll, tilt_duration, pid, motor_state
|
|
||||||
/saltybot/balance_log (std_msgs/String) - CSV log: timestamp, pitch, roll, current, temp, rpm
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
port (str) - Serial port for VESC (/dev/ttyUSB0)
|
|
||||||
baudrate (int) - Serial baud rate (115200)
|
|
||||||
pid_p (float) - Balance PID P gain
|
|
||||||
pid_i (float) - Balance PID I gain
|
|
||||||
pid_d (float) - Balance PID D gain
|
|
||||||
tilt_threshold_deg (float) - Tilt angle threshold for safety kill (45°)
|
|
||||||
tilt_kill_duration_ms (int) - Duration above threshold to trigger kill (500ms)
|
|
||||||
startup_ramp_time_s (float) - Startup ramp duration (2.0s)
|
|
||||||
frequency (int) - Update frequency (50Hz)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import time
|
|
||||||
from enum import Enum
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import rclpy
|
|
||||||
from rclpy.node import Node
|
|
||||||
from sensor_msgs.msg import Imu
|
|
||||||
from std_msgs.msg import String
|
|
||||||
import serial
|
|
||||||
|
|
||||||
try:
|
|
||||||
import pyvesc
|
|
||||||
except ImportError:
|
|
||||||
pyvesc = None
|
|
||||||
|
|
||||||
|
|
||||||
class BalanceState(Enum):
|
|
||||||
"""Balance controller state."""
|
|
||||||
STARTUP = "startup" # Ramping up from zero
|
|
||||||
RUNNING = "running" # Normal operation
|
|
||||||
TILT_WARNING = "tilt_warning" # Tilted but within time limit
|
|
||||||
TILT_KILL = "tilt_kill" # Over-tilted, motors killed
|
|
||||||
ERROR = "error" # Communication error
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class IMUData:
|
|
||||||
"""Parsed IMU data."""
|
|
||||||
pitch_deg: float # Forward/backward tilt (Y axis)
|
|
||||||
roll_deg: float # Left/right tilt (X axis)
|
|
||||||
timestamp: float
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MotorTelemetry:
|
|
||||||
"""Motor telemetry from VESC."""
|
|
||||||
voltage_v: float
|
|
||||||
current_a: float
|
|
||||||
rpm: int
|
|
||||||
temperature_c: float
|
|
||||||
fault_code: int
|
|
||||||
|
|
||||||
|
|
||||||
class BalanceControllerNode(Node):
|
|
||||||
"""ROS2 node for balance mode control and tilt safety."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__("balance_controller")
|
|
||||||
|
|
||||||
# Declare parameters
|
|
||||||
self.declare_parameter("port", "/dev/ttyUSB0")
|
|
||||||
self.declare_parameter("baudrate", 115200)
|
|
||||||
self.declare_parameter("pid_p", 0.5)
|
|
||||||
self.declare_parameter("pid_i", 0.1)
|
|
||||||
self.declare_parameter("pid_d", 0.05)
|
|
||||||
self.declare_parameter("tilt_threshold_deg", 45.0)
|
|
||||||
self.declare_parameter("tilt_kill_duration_ms", 500)
|
|
||||||
self.declare_parameter("startup_ramp_time_s", 2.0)
|
|
||||||
self.declare_parameter("frequency", 50)
|
|
||||||
|
|
||||||
# Get parameters
|
|
||||||
self.port = self.get_parameter("port").value
|
|
||||||
self.baudrate = self.get_parameter("baudrate").value
|
|
||||||
self.pid_p = self.get_parameter("pid_p").value
|
|
||||||
self.pid_i = self.get_parameter("pid_i").value
|
|
||||||
self.pid_d = self.get_parameter("pid_d").value
|
|
||||||
self.tilt_threshold = self.get_parameter("tilt_threshold_deg").value
|
|
||||||
self.tilt_kill_duration = self.get_parameter("tilt_kill_duration_ms").value / 1000.0
|
|
||||||
self.startup_ramp_time = self.get_parameter("startup_ramp_time_s").value
|
|
||||||
frequency = self.get_parameter("frequency").value
|
|
||||||
|
|
||||||
# VESC connection
|
|
||||||
self.serial: Optional[serial.Serial] = None
|
|
||||||
self.vesc: Optional[pyvesc.VescUart] = None
|
|
||||||
|
|
||||||
# State tracking
|
|
||||||
self.state = BalanceState.STARTUP
|
|
||||||
self.imu_data: Optional[IMUData] = None
|
|
||||||
self.motor_telemetry: Optional[MotorTelemetry] = None
|
|
||||||
self.startup_time = time.time()
|
|
||||||
self.tilt_start_time: Optional[float] = None
|
|
||||||
|
|
||||||
# Subscriptions
|
|
||||||
self.create_subscription(Imu, "/imu/data", self._on_imu_data, 10)
|
|
||||||
self.create_subscription(String, "/vesc/state", self._on_vesc_state, 10)
|
|
||||||
|
|
||||||
# Publications
|
|
||||||
self.pub_balance_state = self.create_publisher(String, "/saltybot/balance_state", 10)
|
|
||||||
self.pub_balance_log = self.create_publisher(String, "/saltybot/balance_log", 10)
|
|
||||||
|
|
||||||
# Timer for control loop
|
|
||||||
period = 1.0 / frequency
|
|
||||||
self.create_timer(period, self._control_loop)
|
|
||||||
|
|
||||||
# Initialize VESC
|
|
||||||
self._init_vesc()
|
|
||||||
|
|
||||||
self.get_logger().info(
|
|
||||||
f"Balance controller initialized: port={self.port}, baud={self.baudrate}, "
|
|
||||||
f"PID=[{self.pid_p}, {self.pid_i}, {self.pid_d}], "
|
|
||||||
f"tilt_threshold={self.tilt_threshold}°, "
|
|
||||||
f"tilt_kill_duration={self.tilt_kill_duration}s, "
|
|
||||||
f"startup_ramp={self.startup_ramp_time}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _init_vesc(self) -> bool:
|
|
||||||
"""Initialize VESC connection."""
|
|
||||||
try:
|
|
||||||
if pyvesc is None:
|
|
||||||
self.get_logger().error("pyvesc not installed. Install with: pip install pyvesc")
|
|
||||||
self.state = BalanceState.ERROR
|
|
||||||
return False
|
|
||||||
|
|
||||||
self.serial = serial.Serial(
|
|
||||||
port=self.port,
|
|
||||||
baudrate=self.baudrate,
|
|
||||||
timeout=0.1,
|
|
||||||
)
|
|
||||||
self.vesc = pyvesc.VescUart(
|
|
||||||
serial_port=self.serial,
|
|
||||||
has_sensor=False,
|
|
||||||
start_heartbeat=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._set_pid_parameters()
|
|
||||||
self.get_logger().info(f"Connected to VESC on {self.port} @ {self.baudrate} baud")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except (serial.SerialException, Exception) as e:
|
|
||||||
self.get_logger().error(f"Failed to initialize VESC: {e}")
|
|
||||||
self.state = BalanceState.ERROR
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _set_pid_parameters(self) -> None:
|
|
||||||
"""Set VESC balance PID parameters."""
|
|
||||||
if self.vesc is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# pyvesc doesn't have direct balance mode PID setter, so we'd use
|
|
||||||
# custom VESC firmware commands or rely on pre-configured VESC.
|
|
||||||
# For now, log the intended parameters.
|
|
||||||
self.get_logger().info(
|
|
||||||
f"PID parameters set: P={self.pid_p}, I={self.pid_i}, D={self.pid_d}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
self.get_logger().error(f"Failed to set PID parameters: {e}")
|
|
||||||
|
|
||||||
def _on_imu_data(self, msg: Imu) -> None:
|
|
||||||
"""Update IMU orientation data."""
|
|
||||||
try:
|
|
||||||
# Extract roll/pitch from quaternion
|
|
||||||
roll, pitch, _ = self._quaternion_to_euler(
|
|
||||||
msg.orientation.x, msg.orientation.y,
|
|
||||||
msg.orientation.z, msg.orientation.w
|
|
||||||
)
|
|
||||||
|
|
||||||
self.imu_data = IMUData(
|
|
||||||
pitch_deg=math.degrees(pitch),
|
|
||||||
roll_deg=math.degrees(roll),
|
|
||||||
timestamp=time.time()
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
self.get_logger().warn(f"Error parsing IMU data: {e}")
|
|
||||||
|
|
||||||
def _on_vesc_state(self, msg: String) -> None:
|
|
||||||
"""Parse VESC telemetry from JSON."""
|
|
||||||
try:
|
|
||||||
data = json.loads(msg.data)
|
|
||||||
self.motor_telemetry = MotorTelemetry(
|
|
||||||
voltage_v=data.get("voltage_v", 0.0),
|
|
||||||
current_a=data.get("current_a", 0.0),
|
|
||||||
rpm=data.get("rpm", 0),
|
|
||||||
temperature_c=data.get("temperature_c", 0.0),
|
|
||||||
fault_code=data.get("fault_code", 0)
|
|
||||||
)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
self.get_logger().debug(f"Failed to parse VESC state: {e}")
|
|
||||||
|
|
||||||
def _quaternion_to_euler(self, x: float, y: float, z: float, w: float) -> tuple:
|
|
||||||
"""Convert quaternion to Euler angles (roll, pitch, yaw)."""
|
|
||||||
# Roll (X-axis rotation)
|
|
||||||
sinr_cosp = 2 * (w * x + y * z)
|
|
||||||
cosr_cosp = 1 - 2 * (x * x + y * y)
|
|
||||||
roll = math.atan2(sinr_cosp, cosr_cosp)
|
|
||||||
|
|
||||||
# Pitch (Y-axis rotation)
|
|
||||||
sinp = 2 * (w * y - z * x)
|
|
||||||
if abs(sinp) >= 1:
|
|
||||||
pitch = math.copysign(math.pi / 2, sinp)
|
|
||||||
else:
|
|
||||||
pitch = math.asin(sinp)
|
|
||||||
|
|
||||||
# Yaw (Z-axis rotation)
|
|
||||||
siny_cosp = 2 * (w * z + x * y)
|
|
||||||
cosy_cosp = 1 - 2 * (y * y + z * z)
|
|
||||||
yaw = math.atan2(siny_cosp, cosy_cosp)
|
|
||||||
|
|
||||||
return roll, pitch, yaw
|
|
||||||
|
|
||||||
def _check_tilt_safety(self) -> None:
|
|
||||||
"""Check tilt angle and apply safety kill if needed."""
|
|
||||||
if self.imu_data is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if tilted beyond threshold
|
|
||||||
is_tilted = abs(self.imu_data.pitch_deg) > self.tilt_threshold
|
|
||||||
|
|
||||||
if is_tilted:
|
|
||||||
# Tilt detected
|
|
||||||
if self.tilt_start_time is None:
|
|
||||||
self.tilt_start_time = time.time()
|
|
||||||
|
|
||||||
tilt_duration = time.time() - self.tilt_start_time
|
|
||||||
|
|
||||||
if tilt_duration > self.tilt_kill_duration:
|
|
||||||
# Tilt persisted too long, trigger kill
|
|
||||||
self.state = BalanceState.TILT_KILL
|
|
||||||
self._kill_motors()
|
|
||||||
self.get_logger().error(
|
|
||||||
f"TILT SAFETY KILL: pitch={self.imu_data.pitch_deg:.1f}° "
|
|
||||||
f"(threshold={self.tilt_threshold}°) for {tilt_duration:.2f}s"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Warning state
|
|
||||||
if self.state != BalanceState.TILT_WARNING:
|
|
||||||
self.state = BalanceState.TILT_WARNING
|
|
||||||
self.get_logger().warn(
|
|
||||||
f"Tilt warning: pitch={self.imu_data.pitch_deg:.1f}° "
|
|
||||||
f"for {tilt_duration:.2f}s / {self.tilt_kill_duration}s"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Not tilted, reset timer
|
|
||||||
if self.tilt_start_time is not None:
|
|
||||||
self.tilt_start_time = None
|
|
||||||
if self.state == BalanceState.TILT_WARNING:
|
|
||||||
self.state = BalanceState.RUNNING
|
|
||||||
self.get_logger().info("Tilt warning cleared, resuming normal operation")
|
|
||||||
|
|
||||||
def _check_startup_ramp(self) -> float:
|
|
||||||
"""Calculate startup ramp factor [0, 1]."""
|
|
||||||
elapsed = time.time() - self.startup_time
|
|
||||||
|
|
||||||
if elapsed >= self.startup_ramp_time:
|
|
||||||
# Startup complete
|
|
||||||
if self.state == BalanceState.STARTUP:
|
|
||||||
self.state = BalanceState.RUNNING
|
|
||||||
self.get_logger().info("Startup ramp complete, entering normal operation")
|
|
||||||
return 1.0
|
|
||||||
else:
|
|
||||||
# Linear ramp
|
|
||||||
return elapsed / self.startup_ramp_time
|
|
||||||
|
|
||||||
def _kill_motors(self) -> None:
|
|
||||||
"""Kill motor output."""
|
|
||||||
if self.vesc is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.vesc.set_duty(0.0)
|
|
||||||
self.get_logger().error("Motors killed via duty cycle = 0")
|
|
||||||
except Exception as e:
|
|
||||||
self.get_logger().error(f"Failed to kill motors: {e}")
|
|
||||||
|
|
||||||
def _control_loop(self) -> None:
|
|
||||||
"""Main control loop (50Hz)."""
|
|
||||||
# Check IMU data availability
|
|
||||||
if self.imu_data is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check tilt safety
|
|
||||||
self._check_tilt_safety()
|
|
||||||
|
|
||||||
# Check startup ramp
|
|
||||||
ramp_factor = self._check_startup_ramp()
|
|
||||||
|
|
||||||
# Publish state
|
|
||||||
self._publish_balance_state(ramp_factor)
|
|
||||||
|
|
||||||
# Log data
|
|
||||||
self._publish_balance_log()
|
|
||||||
|
|
||||||
def _publish_balance_state(self, ramp_factor: float) -> None:
|
|
||||||
"""Publish balance controller state as JSON."""
|
|
||||||
state_dict = {
|
|
||||||
"timestamp": time.time(),
|
|
||||||
"state": self.state.value,
|
|
||||||
"pitch_deg": round(self.imu_data.pitch_deg, 2) if self.imu_data else 0.0,
|
|
||||||
"roll_deg": round(self.imu_data.roll_deg, 2) if self.imu_data else 0.0,
|
|
||||||
"tilt_threshold_deg": self.tilt_threshold,
|
|
||||||
"tilt_duration_s": (time.time() - self.tilt_start_time) if self.tilt_start_time else 0.0,
|
|
||||||
"tilt_kill_duration_s": self.tilt_kill_duration,
|
|
||||||
"pid": {
|
|
||||||
"p": self.pid_p,
|
|
||||||
"i": self.pid_i,
|
|
||||||
"d": self.pid_d,
|
|
||||||
},
|
|
||||||
"startup_ramp_factor": round(ramp_factor, 3),
|
|
||||||
"motor": {
|
|
||||||
"voltage_v": round(self.motor_telemetry.voltage_v, 2) if self.motor_telemetry else 0.0,
|
|
||||||
"current_a": round(self.motor_telemetry.current_a, 2) if self.motor_telemetry else 0.0,
|
|
||||||
"rpm": self.motor_telemetry.rpm if self.motor_telemetry else 0,
|
|
||||||
"temperature_c": round(self.motor_telemetry.temperature_c, 1) if self.motor_telemetry else 0.0,
|
|
||||||
"fault_code": self.motor_telemetry.fault_code if self.motor_telemetry else 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
msg = String(data=json.dumps(state_dict))
|
|
||||||
self.pub_balance_state.publish(msg)
|
|
||||||
|
|
||||||
def _publish_balance_log(self) -> None:
|
|
||||||
"""Publish IMU + motor data log as CSV."""
|
|
||||||
if self.imu_data is None or self.motor_telemetry is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# CSV format: timestamp, pitch, roll, current, temp, rpm
|
|
||||||
log_entry = (
|
|
||||||
f"{time.time():.3f}, "
|
|
||||||
f"{self.imu_data.pitch_deg:.2f}, "
|
|
||||||
f"{self.imu_data.roll_deg:.2f}, "
|
|
||||||
f"{self.motor_telemetry.current_a:.2f}, "
|
|
||||||
f"{self.motor_telemetry.temperature_c:.1f}, "
|
|
||||||
f"{self.motor_telemetry.rpm}"
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = String(data=log_entry)
|
|
||||||
self.pub_balance_log.publish(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def main(args=None):
|
|
||||||
rclpy.init(args=args)
|
|
||||||
node = BalanceControllerNode()
|
|
||||||
try:
|
|
||||||
rclpy.spin(node)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
if node.vesc and node.serial:
|
|
||||||
node.vesc.set_duty(0.0) # Zero throttle on shutdown
|
|
||||||
node.serial.close()
|
|
||||||
node.destroy_node()
|
|
||||||
rclpy.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@ -1,4 +0,0 @@
|
|||||||
[develop]
|
|
||||||
script_dir=$base/lib/saltybot_balance_controller
|
|
||||||
[install]
|
|
||||||
install_scripts=$base/lib/saltybot_balance_controller
|
|
||||||
@ -1,27 +0,0 @@
|
|||||||
from setuptools import setup
|
|
||||||
|
|
||||||
package_name = "saltybot_balance_controller"
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name=package_name,
|
|
||||||
version="0.1.0",
|
|
||||||
packages=[package_name],
|
|
||||||
data_files=[
|
|
||||||
("share/ament_index/resource_index/packages", [f"resource/{package_name}"]),
|
|
||||||
(f"share/{package_name}", ["package.xml"]),
|
|
||||||
(f"share/{package_name}/launch", ["launch/balance_controller.launch.py"]),
|
|
||||||
(f"share/{package_name}/config", ["config/balance_params.yaml"]),
|
|
||||||
],
|
|
||||||
install_requires=["setuptools", "pyvesc"],
|
|
||||||
zip_safe=True,
|
|
||||||
maintainer="sl-controls",
|
|
||||||
maintainer_email="sl-controls@saltylab.local",
|
|
||||||
description="Balance mode PID controller with tilt safety for SaltyBot",
|
|
||||||
license="MIT",
|
|
||||||
tests_require=["pytest"],
|
|
||||||
entry_points={
|
|
||||||
"console_scripts": [
|
|
||||||
"balance_controller_node = saltybot_balance_controller.balance_controller_node:main",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@ -1,170 +0,0 @@
|
|||||||
"""Unit tests for balance controller node."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import math
|
|
||||||
from sensor_msgs.msg import Imu
|
|
||||||
from geometry_msgs.msg import Quaternion
|
|
||||||
from std_msgs.msg import String
|
|
||||||
|
|
||||||
import rclpy
|
|
||||||
from saltybot_balance_controller.balance_controller_node import BalanceControllerNode
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def rclpy_fixture():
|
|
||||||
"""Initialize and cleanup rclpy."""
|
|
||||||
rclpy.init()
|
|
||||||
yield
|
|
||||||
rclpy.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def node(rclpy_fixture):
|
|
||||||
"""Create a balance controller node instance."""
|
|
||||||
node = BalanceControllerNode()
|
|
||||||
yield node
|
|
||||||
node.destroy_node()
|
|
||||||
|
|
||||||
|
|
||||||
class TestNodeInitialization:
|
|
||||||
"""Test suite for node initialization."""
|
|
||||||
|
|
||||||
def test_node_initialization(self, node):
|
|
||||||
"""Test that node initializes with correct defaults."""
|
|
||||||
assert node.port == "/dev/ttyUSB0"
|
|
||||||
assert node.baudrate == 115200
|
|
||||||
assert node.pid_p == 0.5
|
|
||||||
assert node.pid_i == 0.1
|
|
||||||
assert node.pid_d == 0.05
|
|
||||||
|
|
||||||
def test_tilt_threshold_parameter(self, node):
|
|
||||||
"""Test tilt threshold parameter is set correctly."""
|
|
||||||
assert node.tilt_threshold == 45.0
|
|
||||||
|
|
||||||
def test_tilt_kill_duration_parameter(self, node):
|
|
||||||
"""Test tilt kill duration parameter is set correctly."""
|
|
||||||
assert node.tilt_kill_duration == 0.5
|
|
||||||
|
|
||||||
|
|
||||||
class TestQuaternionToEuler:
|
|
||||||
"""Test suite for quaternion to Euler conversion."""
|
|
||||||
|
|
||||||
def test_identity_quaternion(self, node):
|
|
||||||
"""Test identity quaternion (no rotation)."""
|
|
||||||
roll, pitch, yaw = node._quaternion_to_euler(0, 0, 0, 1)
|
|
||||||
assert roll == pytest.approx(0)
|
|
||||||
assert pitch == pytest.approx(0)
|
|
||||||
assert yaw == pytest.approx(0)
|
|
||||||
|
|
||||||
def test_90deg_pitch_rotation(self, node):
|
|
||||||
"""Test 90 degree pitch rotation."""
|
|
||||||
# Quaternion for 90 degree Y rotation
|
|
||||||
roll, pitch, yaw = node._quaternion_to_euler(0, 0.707, 0, 0.707)
|
|
||||||
assert roll == pytest.approx(0, abs=0.01)
|
|
||||||
assert pitch == pytest.approx(math.pi / 2, abs=0.01)
|
|
||||||
assert yaw == pytest.approx(0, abs=0.01)
|
|
||||||
|
|
||||||
def test_45deg_pitch_rotation(self, node):
|
|
||||||
"""Test 45 degree pitch rotation."""
|
|
||||||
roll, pitch, yaw = node._quaternion_to_euler(0, 0.383, 0, 0.924)
|
|
||||||
assert roll == pytest.approx(0, abs=0.01)
|
|
||||||
assert pitch == pytest.approx(math.pi / 4, abs=0.01)
|
|
||||||
assert yaw == pytest.approx(0, abs=0.01)
|
|
||||||
|
|
||||||
def test_roll_rotation(self, node):
|
|
||||||
"""Test roll rotation around X axis."""
|
|
||||||
roll, pitch, yaw = node._quaternion_to_euler(0.707, 0, 0, 0.707)
|
|
||||||
assert roll == pytest.approx(math.pi / 2, abs=0.01)
|
|
||||||
assert pitch == pytest.approx(0, abs=0.01)
|
|
||||||
assert yaw == pytest.approx(0, abs=0.01)
|
|
||||||
|
|
||||||
|
|
||||||
class TestIMUDataParsing:
|
|
||||||
"""Test suite for IMU data parsing."""
|
|
||||||
|
|
||||||
def test_imu_data_subscription(self, node):
|
|
||||||
"""Test IMU data subscription updates node state."""
|
|
||||||
imu = Imu()
|
|
||||||
imu.orientation = Quaternion(x=0, y=0, z=0, w=1)
|
|
||||||
|
|
||||||
node._on_imu_data(imu)
|
|
||||||
|
|
||||||
assert node.imu_data is not None
|
|
||||||
assert node.imu_data.pitch_deg == pytest.approx(0, abs=0.1)
|
|
||||||
assert node.imu_data.roll_deg == pytest.approx(0, abs=0.1)
|
|
||||||
|
|
||||||
def test_imu_pitch_tilted_forward(self, node):
|
|
||||||
"""Test IMU data with forward pitch tilt."""
|
|
||||||
# 45 degree forward pitch
|
|
||||||
imu = Imu()
|
|
||||||
imu.orientation = Quaternion(x=0, y=0.383, z=0, w=0.924)
|
|
||||||
|
|
||||||
node._on_imu_data(imu)
|
|
||||||
|
|
||||||
assert node.imu_data is not None
|
|
||||||
# Should be approximately 45 degrees (in radians converted to degrees)
|
|
||||||
assert node.imu_data.pitch_deg == pytest.approx(45, abs=1)
|
|
||||||
|
|
||||||
|
|
||||||
class TestTiltSafety:
|
|
||||||
"""Test suite for tilt safety checks."""
|
|
||||||
|
|
||||||
def test_tilt_warning_state_entry(self, node):
|
|
||||||
"""Test entry into tilt warning state."""
|
|
||||||
imu = Imu()
|
|
||||||
# 50 degree pitch (exceeds 45 degree threshold)
|
|
||||||
imu.orientation = Quaternion(x=0, y=0.438, z=0, w=0.899)
|
|
||||||
node._on_imu_data(imu)
|
|
||||||
|
|
||||||
# Call check with small duration
|
|
||||||
node._check_tilt_safety()
|
|
||||||
|
|
||||||
assert node.state.value in ["tilt_warning", "startup"]
|
|
||||||
|
|
||||||
def test_level_no_tilt_warning(self, node):
|
|
||||||
"""Test no tilt warning when level."""
|
|
||||||
imu = Imu()
|
|
||||||
imu.orientation = Quaternion(x=0, y=0, z=0, w=1)
|
|
||||||
node._on_imu_data(imu)
|
|
||||||
|
|
||||||
node._check_tilt_safety()
|
|
||||||
|
|
||||||
# Tilt start time should be None when level
|
|
||||||
assert node.tilt_start_time is None
|
|
||||||
|
|
||||||
|
|
||||||
class TestStartupRamp:
|
|
||||||
"""Test suite for startup ramp functionality."""
|
|
||||||
|
|
||||||
def test_startup_ramp_begins_at_zero(self, node):
|
|
||||||
"""Test startup ramp begins at 0."""
|
|
||||||
ramp = node._check_startup_ramp()
|
|
||||||
|
|
||||||
# At startup time, ramp should be close to 0
|
|
||||||
assert ramp <= 0.05 # Very small value at start
|
|
||||||
|
|
||||||
def test_startup_ramp_reaches_one(self, node):
|
|
||||||
"""Test startup ramp reaches 1.0 after duration."""
|
|
||||||
# Simulate startup ramp completion
|
|
||||||
node.startup_time = 0
|
|
||||||
node.startup_ramp_time = 0.001 # Very short ramp
|
|
||||||
|
|
||||||
import time
|
|
||||||
time.sleep(0.01) # Sleep longer than ramp time
|
|
||||||
|
|
||||||
ramp = node._check_startup_ramp()
|
|
||||||
|
|
||||||
# Should be complete after time has passed
|
|
||||||
assert ramp >= 0.99
|
|
||||||
|
|
||||||
def test_startup_ramp_linear(self, node):
|
|
||||||
"""Test startup ramp is linear."""
|
|
||||||
node.startup_ramp_time = 1.0
|
|
||||||
|
|
||||||
# At 25% of startup time
|
|
||||||
node.startup_time = 0
|
|
||||||
import time
|
|
||||||
time.sleep(0.25)
|
|
||||||
|
|
||||||
ramp = node._check_startup_ramp()
|
|
||||||
assert 0.2 < ramp < 0.3
|
|
||||||
Loading…
x
Reference in New Issue
Block a user