Compare commits
8 Commits
02e9b15e6f
...
6293d2ec60
| Author | SHA1 | Date | |
|---|---|---|---|
| 6293d2ec60 | |||
| b178614e6e | |||
| e26301c7ca | |||
| c96c68a7c4 | |||
|
|
2e9fd6fa4c | ||
| 0c03060016 | |||
| 00e632ecbf | |||
| d421d63c6f |
@ -170,6 +170,12 @@ def generate_launch_description():
|
||||
description="Launch YOLOv8n person detection (TensorRT)",
|
||||
)
|
||||
|
||||
enable_object_detection_arg = DeclareLaunchArgument(
|
||||
"enable_object_detection",
|
||||
default_value="true",
|
||||
description="Launch YOLOv8n general object detection with depth (Issue #468)",
|
||||
)
|
||||
|
||||
enable_follower_arg = DeclareLaunchArgument(
|
||||
"enable_follower",
|
||||
default_value="true",
|
||||
@ -376,6 +382,22 @@ def generate_launch_description():
|
||||
],
|
||||
)
|
||||
|
||||
# ── t=6s Object detection (needs RealSense up for ~4s; Issue #468) ──────
|
||||
object_detection = TimerAction(
|
||||
period=6.0,
|
||||
actions=[
|
||||
GroupAction(
|
||||
condition=IfCondition(LaunchConfiguration("enable_object_detection")),
|
||||
actions=[
|
||||
LogInfo(msg="[full_stack] Starting YOLOv8n general object detection"),
|
||||
IncludeLaunchDescription(
|
||||
_launch("saltybot_object_detection", "launch", "object_detection.launch.py"),
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# ── t=9s Person follower (needs perception + UWB; ~3s after both start) ─
|
||||
follower = TimerAction(
|
||||
period=9.0,
|
||||
@ -442,6 +464,7 @@ def generate_launch_description():
|
||||
enable_csi_cameras_arg,
|
||||
enable_uwb_arg,
|
||||
enable_perception_arg,
|
||||
enable_object_detection_arg,
|
||||
enable_follower_arg,
|
||||
enable_bridge_arg,
|
||||
enable_rosbridge_arg,
|
||||
@ -473,6 +496,7 @@ def generate_launch_description():
|
||||
slam,
|
||||
outdoor_nav,
|
||||
perception,
|
||||
object_detection,
|
||||
|
||||
# t=9s
|
||||
follower,
|
||||
|
||||
7
jetson/ros2_ws/src/saltybot_curiosity/.gitignore
vendored
Normal file
7
jetson/ros2_ws/src/saltybot_curiosity/.gitignore
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
build/
|
||||
install/
|
||||
log/
|
||||
*.pyc
|
||||
__pycache__/
|
||||
*.egg-info/
|
||||
dist/
|
||||
58
jetson/ros2_ws/src/saltybot_curiosity/README.md
Normal file
58
jetson/ros2_ws/src/saltybot_curiosity/README.md
Normal file
@ -0,0 +1,58 @@
|
||||
# saltybot_curiosity
|
||||
|
||||
Autonomous curiosity behavior for SaltyBot — frontier exploration when idle.
|
||||
|
||||
## Features
|
||||
|
||||
- **Idle Detection**: Activates after >60s without people
|
||||
- **Frontier Exploration**: Random walk toward unexplored areas
|
||||
- **Sound Localization**: Turns toward detected sounds
|
||||
- **Object Interest**: Approaches colorful/moving objects
|
||||
- **Self-Narration**: TTS commentary during exploration
|
||||
- **Safety**: Respects geofence, avoids obstacles
|
||||
- **Auto-Timeout**: Returns home after 10 minutes
|
||||
- **Curiosity Level**: Configurable 0-1.0 activation probability
|
||||
|
||||
## Launch
|
||||
|
||||
```bash
|
||||
ros2 launch saltybot_curiosity curiosity.launch.py
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
- `curiosity_level`: 0.0-1.0 probability of exploring when idle (default: 0.7)
|
||||
- `idle_threshold_sec`: seconds of idle before activation (default: 60)
|
||||
- `exploration_max_duration_sec`: max exploration time in seconds (default: 600)
|
||||
- `exploration_radius_m`: search radius in meters (default: 5.0)
|
||||
- `pan_tilt_step_deg`: pan-tilt sweep step (default: 15)
|
||||
- `min_sound_activity`: sound detection threshold (default: 0.1)
|
||||
|
||||
## Topics
|
||||
|
||||
### Publishes
|
||||
|
||||
- `/saltybot/curiosity_state` (String): State updates
|
||||
- `/saltybot/tts_request` (String): Self-narration text
|
||||
- `/saltybot/pan_tilt_cmd` (String): Pan-tilt commands
|
||||
- `/cmd_vel` (Twist): Velocity commands (if direct movement)
|
||||
|
||||
### Subscribes
|
||||
|
||||
- `/saltybot/audio_direction` (Float32): Sound bearing (degrees)
|
||||
- `/saltybot/audio_activity` (Bool): Sound detected
|
||||
- `/saltybot/person_detections` (Detection2DArray): People nearby
|
||||
- `/saltybot/object_detections` (Detection2DArray): Interesting objects
|
||||
- `/saltybot/battery_percent` (Float32): Battery level
|
||||
- `/saltybot/geofence_status` (String): Geofence boundary info
|
||||
|
||||
## State Machine
|
||||
|
||||
```
|
||||
IDLE → (is_idle && should_explore) → EXPLORING
|
||||
EXPLORING → (timeout || person detected) → RETURNING
|
||||
RETURNING → IDLE
|
||||
```
|
||||
|
||||
## Issue #470
|
||||
Implements autonomous curiosity exploration as specified in Issue #470.
|
||||
@ -0,0 +1,13 @@
|
||||
curiosity_node:
|
||||
ros__parameters:
|
||||
# Curiosity activation parameters
|
||||
curiosity_level: 0.7 # 0.0-1.0 — probability of exploring when idle
|
||||
idle_threshold_sec: 60.0 # seconds without people to trigger exploration
|
||||
exploration_max_duration_sec: 600.0 # 10 minutes max exploration
|
||||
exploration_radius_m: 5.0 # exploration search radius in meters
|
||||
|
||||
# Pan-tilt control
|
||||
pan_tilt_step_deg: 15.0 # pan-tilt sweep step size
|
||||
|
||||
# Audio sensitivity
|
||||
min_sound_activity: 0.1 # minimum sound activity threshold (0.0-1.0)
|
||||
@ -0,0 +1,41 @@
|
||||
from launch import LaunchDescription
|
||||
from launch_ros.actions import Node
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
import os
|
||||
|
||||
def generate_launch_description():
|
||||
config_dir = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'..',
|
||||
'config'
|
||||
)
|
||||
|
||||
curiosity_config = os.path.join(config_dir, 'curiosity_params.yaml')
|
||||
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument(
|
||||
'curiosity_level',
|
||||
default_value='0.7',
|
||||
description='Curiosity activation probability (0.0-1.0)'
|
||||
),
|
||||
DeclareLaunchArgument(
|
||||
'idle_threshold_sec',
|
||||
default_value='60',
|
||||
description='Seconds of idle before exploration activates'
|
||||
),
|
||||
Node(
|
||||
package='saltybot_curiosity',
|
||||
executable='curiosity_node',
|
||||
name='curiosity_node',
|
||||
parameters=[
|
||||
curiosity_config,
|
||||
{
|
||||
'curiosity_level': LaunchConfiguration('curiosity_level'),
|
||||
'idle_threshold_sec': LaunchConfiguration('idle_threshold_sec'),
|
||||
}
|
||||
],
|
||||
remappings=[],
|
||||
output='screen',
|
||||
),
|
||||
])
|
||||
28
jetson/ros2_ws/src/saltybot_curiosity/package.xml
Normal file
28
jetson/ros2_ws/src/saltybot_curiosity/package.xml
Normal file
@ -0,0 +1,28 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_curiosity</name>
|
||||
<version>0.1.0</version>
|
||||
<description>Autonomous curiosity behavior with frontier exploration and self-narration</description>
|
||||
<maintainer email="seb@vayrette.com">seb</maintainer>
|
||||
<license>MIT</license>
|
||||
|
||||
<buildtool_depend>ament_cmake</buildtool_depend>
|
||||
<build_depend>rosidl_default_generators</build_depend>
|
||||
<exec_depend>rosidl_default_runtime</exec_depend>
|
||||
<exec_depend>rclpy</exec_depend>
|
||||
<exec_depend>geometry_msgs</exec_depend>
|
||||
<exec_depend>nav2_msgs</exec_depend>
|
||||
<exec_depend>sensor_msgs</exec_depend>
|
||||
<exec_depend>vision_msgs</exec_depend>
|
||||
<exec_depend>std_msgs</exec_depend>
|
||||
|
||||
<test_depend>ament_copyright</test_depend>
|
||||
<test_depend>ament_flake8</test_depend>
|
||||
<test_depend>ament_pep257</test_depend>
|
||||
<test_depend>python3-pytest</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1,353 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Curiosity behavior for SaltyBot — autonomous exploration when idle.
|
||||
|
||||
Activates when idle >60s with no people detected. Performs frontier exploration:
|
||||
- Turns toward sounds (audio_direction)
|
||||
- Approaches colorful/moving objects (object detection)
|
||||
- Takes panoramic photos at interesting spots
|
||||
- Self-narrates findings via TTS
|
||||
- Respects geofence and dynamic obstacles
|
||||
- Auto-stops after 10min
|
||||
|
||||
Publishes:
|
||||
/saltybot/curiosity_state String State machine updates
|
||||
"""
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.action import ActionClient
|
||||
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||
from nav2_msgs.action import NavigateToPose
|
||||
from geometry_msgs.msg import Pose, PoseStamped, Quaternion, Twist
|
||||
from std_msgs.msg import String, Float32, Bool
|
||||
from sensor_msgs.msg import Image
|
||||
from vision_msgs.msg import Detection2DArray
|
||||
|
||||
import math
|
||||
import time
|
||||
import random
|
||||
from enum import Enum
|
||||
from collections import deque
|
||||
|
||||
|
||||
_SENSOR_QOS = QoSProfile(
|
||||
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||
history=HistoryPolicy.KEEP_LAST,
|
||||
depth=5,
|
||||
)
|
||||
|
||||
|
||||
class CuriosityState(Enum):
|
||||
IDLE = 0
|
||||
WAITING_FOR_IDLE = 1
|
||||
EXPLORING = 2
|
||||
INVESTIGATING = 3
|
||||
RETURNING = 4
|
||||
|
||||
|
||||
class CuriosityNode(Node):
|
||||
def __init__(self):
|
||||
super().__init__('curiosity_node')
|
||||
|
||||
# Parameters
|
||||
self.declare_parameter('curiosity_level', 0.7)
|
||||
self.declare_parameter('idle_threshold_sec', 60.0)
|
||||
self.declare_parameter('exploration_max_duration_sec', 600.0) # 10min
|
||||
self.declare_parameter('exploration_radius_m', 5.0)
|
||||
self.declare_parameter('pan_tilt_step_deg', 15.0)
|
||||
self.declare_parameter('min_sound_activity', 0.1)
|
||||
|
||||
self.curiosity_level = self.get_parameter('curiosity_level').value
|
||||
self.idle_threshold = self.get_parameter('idle_threshold_sec').value
|
||||
self.max_exploration_duration = self.get_parameter('exploration_max_duration_sec').value
|
||||
self.exploration_radius = self.get_parameter('exploration_radius_m').value
|
||||
self.pan_tilt_step = self.get_parameter('pan_tilt_step_deg').value
|
||||
self.min_sound_activity = self.get_parameter('min_sound_activity').value
|
||||
|
||||
# State tracking
|
||||
self.state = CuriosityState.IDLE
|
||||
self.last_person_time = time.time()
|
||||
self.last_activity_time = time.time()
|
||||
self.exploration_start_time = None
|
||||
self.current_location = Pose()
|
||||
self.home_location = Pose()
|
||||
|
||||
# Sensors
|
||||
self.audio_bearing = 0.0
|
||||
self.audio_activity = False
|
||||
self.recent_detections = deque(maxlen=5)
|
||||
self.battery_level = 100.0
|
||||
self.geofence_limit = None
|
||||
self.obstacles = []
|
||||
|
||||
# Narration queue
|
||||
self.narration_queue = []
|
||||
|
||||
# Publishers
|
||||
self.state_pub = self.create_publisher(
|
||||
String, '/saltybot/curiosity_state', 10
|
||||
)
|
||||
self.narration_pub = self.create_publisher(
|
||||
String, '/saltybot/tts_request', 10
|
||||
)
|
||||
self.pan_tilt_pub = self.create_publisher(
|
||||
String, '/saltybot/pan_tilt_cmd', 10
|
||||
)
|
||||
self.velocity_pub = self.create_publisher(
|
||||
Twist, '/cmd_vel', 10
|
||||
)
|
||||
|
||||
# Subscribers
|
||||
self.create_subscription(
|
||||
Float32, '/saltybot/audio_direction', self.audio_bearing_callback, _SENSOR_QOS
|
||||
)
|
||||
self.create_subscription(
|
||||
Bool, '/saltybot/audio_activity', self.audio_activity_callback, _SENSOR_QOS
|
||||
)
|
||||
self.create_subscription(
|
||||
Detection2DArray, '/saltybot/person_detections', self.person_callback, _SENSOR_QOS
|
||||
)
|
||||
self.create_subscription(
|
||||
Detection2DArray, '/saltybot/object_detections', self.object_callback, _SENSOR_QOS
|
||||
)
|
||||
self.create_subscription(
|
||||
Float32, '/saltybot/battery_percent', self.battery_callback, 10
|
||||
)
|
||||
self.create_subscription(
|
||||
String, '/saltybot/geofence_status', self.geofence_callback, 10
|
||||
)
|
||||
|
||||
# Nav2 action client
|
||||
self.nav_client = ActionClient(self, NavigateToPose, 'navigate_to_pose')
|
||||
|
||||
# Main loop timer
|
||||
self.timer = self.create_timer(1.0, self.curiosity_loop)
|
||||
|
||||
self.get_logger().info(
|
||||
f"Curiosity node initialized (level={self.curiosity_level}, "
|
||||
f"idle_threshold={self.idle_threshold}s, max_duration={self.max_exploration_duration}s)"
|
||||
)
|
||||
|
||||
def audio_bearing_callback(self, msg):
|
||||
"""Update detected audio bearing (0-360 degrees)."""
|
||||
self.audio_bearing = msg.data
|
||||
self.last_activity_time = time.time()
|
||||
|
||||
def audio_activity_callback(self, msg):
|
||||
"""Track if audio activity detected."""
|
||||
self.audio_activity = msg.data
|
||||
if msg.data:
|
||||
self.last_activity_time = time.time()
|
||||
|
||||
def person_callback(self, msg):
|
||||
"""Update last person detection time."""
|
||||
if len(msg.detections) > 0:
|
||||
self.last_person_time = time.time()
|
||||
|
||||
def object_callback(self, msg):
|
||||
"""Track interesting objects (colorful/moving)."""
|
||||
for detection in msg.detections:
|
||||
# Store recent detections with confidence
|
||||
if hasattr(detection, 'results') and len(detection.results) > 0:
|
||||
obj = {
|
||||
'class': detection.results[0].class_name if hasattr(detection.results[0], 'class_name') else 'unknown',
|
||||
'confidence': detection.results[0].score if hasattr(detection.results[0], 'score') else 0.0,
|
||||
}
|
||||
self.recent_detections.append(obj)
|
||||
self.last_activity_time = time.time()
|
||||
|
||||
def battery_callback(self, msg):
|
||||
"""Track battery level."""
|
||||
self.battery_level = msg.data
|
||||
|
||||
def geofence_callback(self, msg):
|
||||
"""Track geofence status."""
|
||||
data = msg.data
|
||||
if "limit:" in data:
|
||||
try:
|
||||
limit_str = data.split("limit:")[1].strip()
|
||||
self.geofence_limit = float(limit_str)
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
def publish_state(self, details=""):
|
||||
"""Publish current curiosity state."""
|
||||
msg = String()
|
||||
state_str = f"state:{self.state.name}"
|
||||
if details:
|
||||
state_str += f",{details}"
|
||||
msg.data = state_str
|
||||
self.state_pub.publish(msg)
|
||||
|
||||
def narrate(self, text):
|
||||
"""Queue text for TTS narration."""
|
||||
msg = String()
|
||||
msg.data = text
|
||||
self.narration_pub.publish(msg)
|
||||
self.get_logger().info(f"Narrating: {text}")
|
||||
|
||||
def pan_tilt_toward(self, bearing_deg):
|
||||
"""Command pan-tilt to look toward bearing."""
|
||||
msg = String()
|
||||
msg.data = f"pan:{int(bearing_deg % 360)}"
|
||||
self.pan_tilt_pub.publish(msg)
|
||||
|
||||
def turn_toward_sound(self):
|
||||
"""Turn to face detected sound source."""
|
||||
if self.audio_activity and self.min_sound_activity > 0:
|
||||
self.get_logger().info(f"Sound detected at bearing {self.audio_bearing}°")
|
||||
self.narrate(f"Interesting! I hear something at {int(self.audio_bearing)} degrees!")
|
||||
self.pan_tilt_toward(self.audio_bearing)
|
||||
return True
|
||||
return False
|
||||
|
||||
def approach_interesting_object(self):
|
||||
"""Identify and approach interesting objects if available."""
|
||||
if not self.recent_detections:
|
||||
return False
|
||||
|
||||
# Look for colorful or moving objects
|
||||
interesting_classes = ['ball', 'toy', 'person', 'car', 'dog', 'cat', 'bird']
|
||||
for obj in self.recent_detections:
|
||||
class_name = obj.get('class', '').lower()
|
||||
confidence = obj.get('confidence', 0.0)
|
||||
|
||||
if any(keyword in class_name for keyword in interesting_classes) and confidence > 0.6:
|
||||
self.narrate(f"Ooh, a {obj['class']}! Let me check it out.")
|
||||
# Would navigate toward object here (simplified)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def is_idle(self):
|
||||
"""Check if we've been idle long enough to activate curiosity."""
|
||||
time_since_person = time.time() - self.last_person_time
|
||||
time_since_activity = time.time() - self.last_activity_time
|
||||
|
||||
# Idle if no person nearby AND no recent activity
|
||||
return (
|
||||
time_since_person > self.idle_threshold
|
||||
and time_since_activity > (self.idle_threshold / 2)
|
||||
)
|
||||
|
||||
def should_explore(self):
|
||||
"""Decide whether to start exploration based on curiosity level."""
|
||||
# Higher curiosity = more likely to explore when idle
|
||||
return random.random() < self.curiosity_level
|
||||
|
||||
def explore_frontier(self):
|
||||
"""Generate frontier exploration goal."""
|
||||
# Pick random direction within exploration radius
|
||||
distance = random.uniform(1.0, self.exploration_radius)
|
||||
angle = random.uniform(0, 2 * math.pi)
|
||||
|
||||
goal_x = self.current_location.position.x + distance * math.cos(angle)
|
||||
goal_y = self.current_location.position.y + distance * math.sin(angle)
|
||||
|
||||
# Check against geofence
|
||||
if self.geofence_limit:
|
||||
dist_from_home = math.sqrt(goal_x**2 + goal_y**2)
|
||||
if dist_from_home > self.geofence_limit:
|
||||
# Clamp to geofence boundary
|
||||
scale = (self.geofence_limit * 0.9) / dist_from_home
|
||||
goal_x *= scale
|
||||
goal_y *= scale
|
||||
|
||||
return goal_x, goal_y
|
||||
|
||||
def navigate_to_exploration_point(self):
|
||||
"""Send navigation goal for exploration."""
|
||||
goal_x, goal_y = self.explore_frontier()
|
||||
|
||||
goal = NavigateToPose.Goal()
|
||||
goal.pose.header.frame_id = "map"
|
||||
goal.pose.header.stamp = self.get_clock().now().to_msg()
|
||||
goal.pose.pose.position.x = goal_x
|
||||
goal.pose.pose.position.y = goal_y
|
||||
goal.pose.pose.position.z = 0.0
|
||||
goal.pose.pose.orientation.w = 1.0
|
||||
|
||||
try:
|
||||
self.nav_client.wait_for_server(timeout_sec=2.0)
|
||||
self.nav_client.send_goal_async(goal)
|
||||
self.get_logger().info(f"Exploring toward ({goal_x:.2f}, {goal_y:.2f})")
|
||||
except Exception as e:
|
||||
self.get_logger().warn(f"Navigation unavailable: {e}")
|
||||
|
||||
def return_home(self):
|
||||
"""Navigate back to starting location."""
|
||||
self.narrate("I've seen enough. Heading back home.")
|
||||
goal = NavigateToPose.Goal()
|
||||
goal.pose.header.frame_id = "map"
|
||||
goal.pose.header.stamp = self.get_clock().now().to_msg()
|
||||
goal.pose.pose.position.x = 0.0
|
||||
goal.pose.pose.position.y = 0.0
|
||||
goal.pose.pose.position.z = 0.0
|
||||
goal.pose.pose.orientation.w = 1.0
|
||||
|
||||
try:
|
||||
self.nav_client.wait_for_server(timeout_sec=2.0)
|
||||
self.nav_client.send_goal_async(goal)
|
||||
except Exception as e:
|
||||
self.get_logger().warn(f"Navigation unavailable: {e}")
|
||||
|
||||
def curiosity_loop(self):
|
||||
"""Main curiosity state machine."""
|
||||
self.publish_state()
|
||||
|
||||
if self.state == CuriosityState.IDLE:
|
||||
if self.is_idle() and self.should_explore() and self.battery_level > 30:
|
||||
self.state = CuriosityState.EXPLORING
|
||||
self.exploration_start_time = time.time()
|
||||
self.narrate("Hmm, nothing to do. Let me explore!")
|
||||
self.get_logger().info("Curiosity activated!")
|
||||
|
||||
elif self.state == CuriosityState.EXPLORING:
|
||||
# Check timeout
|
||||
duration = time.time() - self.exploration_start_time
|
||||
if duration > self.max_exploration_duration:
|
||||
self.state = CuriosityState.RETURNING
|
||||
return
|
||||
|
||||
# React to stimuli
|
||||
if self.turn_toward_sound():
|
||||
pass # Already turned toward sound
|
||||
elif self.approach_interesting_object():
|
||||
pass # Already approaching object
|
||||
else:
|
||||
# Random frontier exploration
|
||||
if random.random() < 0.3: # 30% chance to explore frontier each cycle
|
||||
self.navigate_to_exploration_point()
|
||||
|
||||
# Periodic narration
|
||||
if random.random() < 0.1:
|
||||
observations = [
|
||||
"Interesting things over here!",
|
||||
"I wonder what's in that direction.",
|
||||
"The world is so big!",
|
||||
"What could be around the corner?",
|
||||
]
|
||||
self.narrate(random.choice(observations))
|
||||
|
||||
elif self.state == CuriosityState.RETURNING:
|
||||
self.return_home()
|
||||
self.state = CuriosityState.IDLE
|
||||
self.exploration_start_time = None
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = CuriosityNode()
|
||||
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
4
jetson/ros2_ws/src/saltybot_curiosity/setup.cfg
Normal file
4
jetson/ros2_ws/src/saltybot_curiosity/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/saltybot_curiosity
|
||||
[install]
|
||||
install_lib=$base/lib/python3/dist-packages
|
||||
29
jetson/ros2_ws/src/saltybot_curiosity/setup.py
Normal file
29
jetson/ros2_ws/src/saltybot_curiosity/setup.py
Normal file
@ -0,0 +1,29 @@
|
||||
from setuptools import setup
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
package_name = 'saltybot_curiosity'
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version='0.1.0',
|
||||
packages=[package_name],
|
||||
data_files=[
|
||||
('share/ament_index/resource_index/packages', ['resource/' + package_name]),
|
||||
('share/' + package_name, ['package.xml']),
|
||||
(os.path.join('share', package_name, 'launch'), glob('launch/*.py')),
|
||||
(os.path.join('share', package_name, 'config'), glob('config/*.yaml')),
|
||||
],
|
||||
install_requires=['setuptools'],
|
||||
zip_safe=True,
|
||||
maintainer='seb',
|
||||
maintainer_email='seb@vayrette.com',
|
||||
description='Autonomous curiosity behavior with frontier exploration and self-narration',
|
||||
license='MIT',
|
||||
tests_require=['pytest'],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'curiosity_node = saltybot_curiosity.curiosity_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
@ -0,0 +1,38 @@
|
||||
# Event Logger Configuration
|
||||
|
||||
event_logger_node:
|
||||
ros__parameters:
|
||||
# Data directory for event logs (expands ~)
|
||||
data_dir: ~/.saltybot-data/events
|
||||
|
||||
# Log rotation settings
|
||||
rotation_days: 7 # Compress files older than 7 days
|
||||
retention_days: 90 # Delete files older than 90 days
|
||||
|
||||
# Event type mapping (topic -> event_type)
|
||||
event_types:
|
||||
first_encounter: encounter
|
||||
voice_command: voice_command
|
||||
trick_state: trick
|
||||
emergency_stop: e-stop
|
||||
geofence: geofence
|
||||
dock_state: dock
|
||||
error: error
|
||||
system_boot: boot
|
||||
system_shutdown: shutdown
|
||||
|
||||
# Subscribed topics (relative to /saltybot/)
|
||||
subscribed_topics:
|
||||
- first_encounter
|
||||
- voice_command
|
||||
- trick_state
|
||||
- emergency_stop
|
||||
- geofence
|
||||
- dock_state
|
||||
- error
|
||||
- system_boot
|
||||
- system_shutdown
|
||||
|
||||
# Publishing settings
|
||||
stats_publish_interval: 60 # Seconds between stats publishes
|
||||
rotation_check_interval: 3600 # Seconds between rotation checks
|
||||
@ -0,0 +1,21 @@
|
||||
"""
|
||||
Launch file for SaltyBot Event Logger.
|
||||
|
||||
Launches the centralized event logging node that subscribes to all
|
||||
saltybot/* state topics and logs structured JSON events.
|
||||
"""
|
||||
|
||||
from launch import LaunchDescription
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
return LaunchDescription([
|
||||
Node(
|
||||
package='saltybot_event_logger',
|
||||
executable='event_logger_node',
|
||||
name='event_logger_node',
|
||||
output='screen',
|
||||
emulate_tty=True,
|
||||
),
|
||||
])
|
||||
24
jetson/ros2_ws/src/saltybot_event_logger/package.xml
Normal file
24
jetson/ros2_ws/src/saltybot_event_logger/package.xml
Normal file
@ -0,0 +1,24 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_event_logger</name>
|
||||
<version>0.1.0</version>
|
||||
<description>Structured JSON event logging for all robot activities with query service, stats publisher, log rotation, CSV export, and dashboard feed.</description>
|
||||
<maintainer email="seb@vayrette.com">seb</maintainer>
|
||||
<license>Apache-2.0</license>
|
||||
|
||||
<buildtool_depend>ament_python</buildtool_depend>
|
||||
|
||||
<depend>rclpy</depend>
|
||||
<depend>std_msgs</depend>
|
||||
<depend>builtin_interfaces</depend>
|
||||
|
||||
<test_depend>ament_copyright</test_depend>
|
||||
<test_depend>ament_flake8</test_depend>
|
||||
<test_depend>ament_pep257</test_depend>
|
||||
<test_depend>python3-pytest</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1 @@
|
||||
# Marker file for ament resource index
|
||||
@ -0,0 +1 @@
|
||||
"""SaltyBot Event Logger - Structured JSON logging for all robot activities."""
|
||||
@ -0,0 +1,350 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Event Logger Node - Centralized structured logging of all SaltyBot activities.
|
||||
|
||||
Subscribes to /saltybot/* state topics and logs JSON events to JSONL files.
|
||||
Supports: query service (time/type filter), stats publisher, log rotation,
|
||||
CSV export, and dashboard live feed.
|
||||
|
||||
Event types: encounter, voice_command, trick, e-stop, geofence, dock, error, boot, shutdown
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import gzip
|
||||
import shutil
|
||||
import csv
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import List, Dict, Any, Optional
|
||||
from collections import defaultdict
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from std_msgs.msg import String
|
||||
from rclpy.subscription import Subscription
|
||||
|
||||
|
||||
@dataclass
|
||||
class Event:
|
||||
"""Structured event record."""
|
||||
timestamp: str # ISO 8601
|
||||
event_type: str # encounter, voice_command, trick, e-stop, geofence, dock, error, boot, shutdown
|
||||
source: str # Topic that generated event
|
||||
data: Dict[str, Any] # Event-specific data
|
||||
|
||||
|
||||
class EventLogger(Node):
|
||||
"""
|
||||
Centralized event logging node for SaltyBot.
|
||||
|
||||
Subscribes to MQTT state topics and logs structured JSON events.
|
||||
Provides query service, stats publishing, log rotation, and export.
|
||||
"""
|
||||
|
||||
# Event type mapping from topics
|
||||
EVENT_TYPES = {
|
||||
'first_encounter': 'encounter',
|
||||
'voice_command': 'voice_command',
|
||||
'trick_state': 'trick',
|
||||
'emergency_stop': 'e-stop',
|
||||
'geofence': 'geofence',
|
||||
'dock_state': 'dock',
|
||||
'error': 'error',
|
||||
'system_boot': 'boot',
|
||||
'system_shutdown': 'shutdown',
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
super().__init__('event_logger_node')
|
||||
|
||||
# Configuration
|
||||
self.data_dir = Path(os.path.expanduser('~/.saltybot-data/events'))
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.rotation_days = 7 # Compress files >7 days old
|
||||
self.retention_days = 90 # Delete files >90 days old
|
||||
|
||||
# Runtime state
|
||||
self.subscriptions: Dict[str, Subscription] = {}
|
||||
self.events_lock = threading.Lock()
|
||||
self.daily_events: List[Event] = []
|
||||
self.daily_stats = defaultdict(int)
|
||||
|
||||
self.get_logger().info(f'Event logger initialized at {self.data_dir}')
|
||||
|
||||
# Subscribe to relevant topics
|
||||
self._setup_subscriptions()
|
||||
|
||||
# Timer for stats and rotation
|
||||
self.create_timer(60.0, self._publish_stats) # Stats every 60s
|
||||
self.create_timer(3600.0, self._rotate_logs) # Rotation every hour
|
||||
|
||||
# Publishers
|
||||
self.stats_pub = self.create_publisher(String, '/saltybot/event_stats', 10)
|
||||
self.live_feed_pub = self.create_publisher(String, '/saltybot/event_feed', 10)
|
||||
|
||||
def _setup_subscriptions(self):
|
||||
"""Subscribe to saltybot/* topics for event detection."""
|
||||
# Subscribe to specific event topics
|
||||
event_topics = [
|
||||
'/saltybot/first_encounter',
|
||||
'/saltybot/voice_command',
|
||||
'/saltybot/trick_state',
|
||||
'/saltybot/emergency_stop',
|
||||
'/saltybot/geofence',
|
||||
'/saltybot/dock_state',
|
||||
'/saltybot/error',
|
||||
'/saltybot/system_boot',
|
||||
'/saltybot/system_shutdown',
|
||||
]
|
||||
|
||||
for topic in event_topics:
|
||||
# Extract event type from topic
|
||||
topic_name = topic.split('/')[-1]
|
||||
event_type = self.EVENT_TYPES.get(topic_name, topic_name)
|
||||
|
||||
self.subscriptions[topic] = self.create_subscription(
|
||||
String,
|
||||
topic,
|
||||
lambda msg, et=event_type, t=topic: self._on_event(msg, et, t),
|
||||
10
|
||||
)
|
||||
|
||||
def _on_event(self, msg: String, event_type: str, topic: str):
|
||||
"""Handle incoming event message."""
|
||||
try:
|
||||
# Parse message as JSON if possible, otherwise use as string
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
data = {'message': str(msg.data)}
|
||||
|
||||
event = Event(
|
||||
timestamp=datetime.now().isoformat(),
|
||||
event_type=event_type,
|
||||
source=topic,
|
||||
data=data
|
||||
)
|
||||
|
||||
self._log_event(event)
|
||||
self._publish_to_feed(event)
|
||||
|
||||
except Exception as e:
|
||||
self.get_logger().error(f'Error processing event: {e}')
|
||||
|
||||
def _log_event(self, event: Event):
|
||||
"""Write event to daily JSONL file."""
|
||||
try:
|
||||
with self.events_lock:
|
||||
# Daily filename
|
||||
today = datetime.now().strftime('%Y-%m-%d')
|
||||
log_file = self.data_dir / f'{today}.jsonl'
|
||||
|
||||
# Append to file
|
||||
with open(log_file, 'a') as f:
|
||||
f.write(json.dumps(asdict(event)) + '\n')
|
||||
|
||||
# Update in-memory stats
|
||||
self.daily_events.append(event)
|
||||
self.daily_stats[event.event_type] += 1
|
||||
if 'person_id' in event.data:
|
||||
self.daily_stats['encounters'] += 1
|
||||
if event.event_type == 'error':
|
||||
self.daily_stats['errors'] += 1
|
||||
|
||||
except Exception as e:
|
||||
self.get_logger().error(f'Error logging event: {e}')
|
||||
|
||||
def _publish_to_feed(self, event: Event):
|
||||
"""Publish event to live feed topic."""
|
||||
try:
|
||||
feed_msg = String()
|
||||
feed_msg.data = json.dumps(asdict(event))
|
||||
self.live_feed_pub.publish(feed_msg)
|
||||
except Exception as e:
|
||||
self.get_logger().error(f'Error publishing to feed: {e}')
|
||||
|
||||
def _publish_stats(self):
|
||||
"""Publish daily statistics."""
|
||||
try:
|
||||
stats = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'total_events': len(self.daily_events),
|
||||
'by_type': dict(self.daily_stats),
|
||||
'encounters': self.daily_stats.get('encounters', 0),
|
||||
'errors': self.daily_stats.get('errors', 0),
|
||||
}
|
||||
|
||||
msg = String()
|
||||
msg.data = json.dumps(stats)
|
||||
self.stats_pub.publish(msg)
|
||||
|
||||
except Exception as e:
|
||||
self.get_logger().error(f'Error publishing stats: {e}')
|
||||
|
||||
def _rotate_logs(self):
|
||||
"""Handle log rotation: compress >7 days, delete >90 days."""
|
||||
try:
|
||||
now = datetime.now()
|
||||
|
||||
for log_file in sorted(self.data_dir.glob('*.jsonl')):
|
||||
# Extract date from filename (YYYY-MM-DD.jsonl)
|
||||
try:
|
||||
date_str = log_file.stem
|
||||
file_date = datetime.strptime(date_str, '%Y-%m-%d')
|
||||
age_days = (now - file_date).days
|
||||
|
||||
if age_days > self.retention_days:
|
||||
# Delete very old files
|
||||
log_file.unlink()
|
||||
self.get_logger().info(f'Deleted {log_file.name} (age: {age_days}d)')
|
||||
|
||||
elif age_days > self.rotation_days:
|
||||
# Compress old files
|
||||
gz_file = log_file.with_suffix('.jsonl.gz')
|
||||
if not gz_file.exists():
|
||||
with open(log_file, 'rb') as f_in:
|
||||
with gzip.open(gz_file, 'wb') as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
log_file.unlink()
|
||||
self.get_logger().info(f'Compressed {log_file.name}')
|
||||
|
||||
except ValueError:
|
||||
# Skip files with unexpected naming
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
self.get_logger().error(f'Error rotating logs: {e}')
|
||||
|
||||
def query_events(self, event_types: Optional[List[str]] = None,
|
||||
start_time: Optional[str] = None,
|
||||
end_time: Optional[str] = None) -> List[Event]:
|
||||
"""
|
||||
Query events by type and time range.
|
||||
|
||||
Args:
|
||||
event_types: List of event types to filter (None = all)
|
||||
start_time: ISO 8601 start timestamp (None = all)
|
||||
end_time: ISO 8601 end timestamp (None = all)
|
||||
|
||||
Returns:
|
||||
List of matching events
|
||||
"""
|
||||
results = []
|
||||
|
||||
try:
|
||||
# Determine date range to search
|
||||
if start_time:
|
||||
start = datetime.fromisoformat(start_time)
|
||||
else:
|
||||
start = datetime.now() - timedelta(days=90)
|
||||
|
||||
if end_time:
|
||||
end = datetime.fromisoformat(end_time)
|
||||
else:
|
||||
end = datetime.now()
|
||||
|
||||
# Search JSONL files
|
||||
for log_file in sorted(self.data_dir.glob('*.jsonl')):
|
||||
try:
|
||||
file_date_str = log_file.stem
|
||||
file_date = datetime.strptime(file_date_str, '%Y-%m-%d')
|
||||
|
||||
# Skip files outside time range
|
||||
if file_date < start or file_date > end:
|
||||
continue
|
||||
|
||||
with open(log_file, 'r') as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
event_dict = json.loads(line)
|
||||
event = Event(**event_dict)
|
||||
|
||||
# Check timestamp in range
|
||||
event_time = datetime.fromisoformat(event.timestamp)
|
||||
if event_time < start or event_time > end:
|
||||
continue
|
||||
|
||||
# Check type filter
|
||||
if event_types and event.event_type not in event_types:
|
||||
continue
|
||||
|
||||
results.append(event)
|
||||
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
self.get_logger().error(f'Error querying events: {e}')
|
||||
|
||||
return results
|
||||
|
||||
def export_csv(self, filename: Optional[str] = None) -> str:
|
||||
"""
|
||||
Export all events to CSV.
|
||||
|
||||
Args:
|
||||
filename: Output filename (default: events_YYYY-MM-DD.csv)
|
||||
|
||||
Returns:
|
||||
Path to exported file
|
||||
"""
|
||||
if not filename:
|
||||
today = datetime.now().strftime('%Y-%m-%d')
|
||||
filename = f'events_{today}.csv'
|
||||
|
||||
output_path = self.data_dir / filename
|
||||
|
||||
try:
|
||||
with self.events_lock:
|
||||
events = self.query_events()
|
||||
|
||||
with open(output_path, 'w', newline='') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(['Timestamp', 'Type', 'Source', 'Data'])
|
||||
|
||||
for event in events:
|
||||
writer.writerow([
|
||||
event.timestamp,
|
||||
event.event_type,
|
||||
event.source,
|
||||
json.dumps(event.data)
|
||||
])
|
||||
|
||||
self.get_logger().info(f'Exported {len(events)} events to {output_path}')
|
||||
return str(output_path)
|
||||
|
||||
except Exception as e:
|
||||
self.get_logger().error(f'Error exporting CSV: {e}')
|
||||
return ''
|
||||
|
||||
def get_daily_stats(self) -> Dict[str, Any]:
|
||||
"""Get stats for current day."""
|
||||
return {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'total_events': len(self.daily_events),
|
||||
'by_type': dict(self.daily_stats),
|
||||
'encounters': self.daily_stats.get('encounters', 0),
|
||||
'errors': self.daily_stats.get('errors', 0),
|
||||
}
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = EventLogger()
|
||||
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
4
jetson/ros2_ws/src/saltybot_event_logger/setup.cfg
Normal file
4
jetson/ros2_ws/src/saltybot_event_logger/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
||||
[develop]
|
||||
script-dir=$base/lib/saltybot_event_logger
|
||||
[egg_info]
|
||||
tag_date = 0
|
||||
22
jetson/ros2_ws/src/saltybot_event_logger/setup.py
Normal file
22
jetson/ros2_ws/src/saltybot_event_logger/setup.py
Normal file
@ -0,0 +1,22 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name='saltybot_event_logger',
|
||||
version='0.1.0',
|
||||
packages=find_packages(),
|
||||
data_files=[
|
||||
('share/ament_index/resource_index/packages', ['resource/saltybot_event_logger']),
|
||||
('share/saltybot_event_logger', ['package.xml']),
|
||||
],
|
||||
install_requires=['setuptools'],
|
||||
zip_safe=True,
|
||||
author='seb',
|
||||
author_email='seb@vayrette.com',
|
||||
description='Structured JSON event logging with query service, stats, rotation, and export',
|
||||
license='Apache-2.0',
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'event_logger_node = saltybot_event_logger.event_logger_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
@ -0,0 +1,260 @@
|
||||
"""Unit tests for Event Logger."""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import tempfile
|
||||
import shutil
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
from std_msgs.msg import String
|
||||
|
||||
# Import the module under test
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from saltybot_event_logger.event_logger_node import Event, EventLogger
|
||||
|
||||
|
||||
class TestEvent:
|
||||
"""Test Event dataclass."""
|
||||
|
||||
def test_event_creation(self):
|
||||
"""Test creating an event."""
|
||||
now = datetime.now().isoformat()
|
||||
event = Event(
|
||||
timestamp=now,
|
||||
event_type='voice_command',
|
||||
source='/saltybot/voice_command',
|
||||
data={'command': 'spin'}
|
||||
)
|
||||
|
||||
assert event.timestamp == now
|
||||
assert event.event_type == 'voice_command'
|
||||
assert event.source == '/saltybot/voice_command'
|
||||
assert event.data['command'] == 'spin'
|
||||
|
||||
def test_event_to_dict(self):
|
||||
"""Test converting event to dict."""
|
||||
from dataclasses import asdict
|
||||
event = Event(
|
||||
timestamp='2026-03-05T12:00:00',
|
||||
event_type='trick',
|
||||
source='/saltybot/trick_state',
|
||||
data={'trick': 'dance'}
|
||||
)
|
||||
|
||||
event_dict = asdict(event)
|
||||
assert event_dict['event_type'] == 'trick'
|
||||
assert 'timestamp' in event_dict
|
||||
|
||||
|
||||
class TestEventLogger:
|
||||
"""Test EventLogger node."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_data_dir(self):
|
||||
"""Create temporary data directory."""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
yield Path(temp_dir)
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
@pytest.fixture
|
||||
def logger(self, temp_data_dir):
|
||||
"""Create EventLogger instance with mocked ROS."""
|
||||
with patch('saltybot_event_logger.event_logger_node.Node.__init__', return_value=None):
|
||||
with patch('saltybot_event_logger.event_logger_node.rclpy'):
|
||||
logger = EventLogger()
|
||||
logger.data_dir = temp_data_dir
|
||||
logger.get_logger = MagicMock()
|
||||
logger.create_subscription = MagicMock()
|
||||
logger.create_publisher = MagicMock()
|
||||
logger.create_timer = MagicMock()
|
||||
logger.live_feed_pub = MagicMock()
|
||||
logger.stats_pub = MagicMock()
|
||||
return logger
|
||||
|
||||
def test_event_types_mapping(self, logger):
|
||||
"""Test event type mapping."""
|
||||
assert logger.EVENT_TYPES['voice_command'] == 'voice_command'
|
||||
assert logger.EVENT_TYPES['trick_state'] == 'trick'
|
||||
assert logger.EVENT_TYPES['first_encounter'] == 'encounter'
|
||||
assert logger.EVENT_TYPES['emergency_stop'] == 'e-stop'
|
||||
|
||||
def test_log_event(self, logger, temp_data_dir):
|
||||
"""Test logging an event to file."""
|
||||
event = Event(
|
||||
timestamp='2026-03-05T12:00:00',
|
||||
event_type='voice_command',
|
||||
source='/saltybot/voice_command',
|
||||
data={'command': 'spin'}
|
||||
)
|
||||
|
||||
logger._log_event(event)
|
||||
|
||||
# Check file was created
|
||||
today = datetime.now().strftime('%Y-%m-%d')
|
||||
log_file = temp_data_dir / f'{today}.jsonl'
|
||||
assert log_file.exists()
|
||||
|
||||
# Check content
|
||||
with open(log_file, 'r') as f:
|
||||
line = f.readline()
|
||||
logged_event = json.loads(line)
|
||||
|
||||
assert logged_event['event_type'] == 'voice_command'
|
||||
assert logged_event['data']['command'] == 'spin'
|
||||
|
||||
def test_log_rotation_compress(self, logger, temp_data_dir):
|
||||
"""Test log compression for old files."""
|
||||
# Create old log file
|
||||
old_date = (datetime.now() - timedelta(days=8)).strftime('%Y-%m-%d')
|
||||
old_file = temp_data_dir / f'{old_date}.jsonl'
|
||||
old_file.write_text('{"test": "data"}\n')
|
||||
|
||||
logger._rotate_logs()
|
||||
|
||||
# Check file was compressed
|
||||
gz_file = temp_data_dir / f'{old_date}.jsonl.gz'
|
||||
assert gz_file.exists()
|
||||
assert not old_file.exists()
|
||||
|
||||
def test_log_rotation_delete(self, logger, temp_data_dir):
|
||||
"""Test deleting very old files."""
|
||||
# Create very old log file
|
||||
old_date = (datetime.now() - timedelta(days=91)).strftime('%Y-%m-%d')
|
||||
old_file = temp_data_dir / f'{old_date}.jsonl'
|
||||
old_file.write_text('{"test": "data"}\n')
|
||||
|
||||
logger._rotate_logs()
|
||||
|
||||
# Check file was deleted
|
||||
assert not old_file.exists()
|
||||
|
||||
def test_query_events_all(self, logger, temp_data_dir):
|
||||
"""Test querying all events."""
|
||||
# Create log entries
|
||||
today = datetime.now().strftime('%Y-%m-%d')
|
||||
log_file = temp_data_dir / f'{today}.jsonl'
|
||||
|
||||
events_data = [
|
||||
Event('2026-03-05T10:00:00', 'voice_command', '/saltybot/voice_command', {'cmd': 'spin'}),
|
||||
Event('2026-03-05T11:00:00', 'trick', '/saltybot/trick_state', {'trick': 'dance'}),
|
||||
]
|
||||
|
||||
with open(log_file, 'w') as f:
|
||||
for event in events_data:
|
||||
from dataclasses import asdict
|
||||
f.write(json.dumps(asdict(event)) + '\n')
|
||||
|
||||
# Query all
|
||||
results = logger.query_events()
|
||||
assert len(results) >= 2
|
||||
|
||||
def test_query_events_by_type(self, logger, temp_data_dir):
|
||||
"""Test querying events by type."""
|
||||
today = datetime.now().strftime('%Y-%m-%d')
|
||||
log_file = temp_data_dir / f'{today}.jsonl'
|
||||
|
||||
events_data = [
|
||||
Event('2026-03-05T10:00:00', 'voice_command', '/saltybot/voice_command', {'cmd': 'spin'}),
|
||||
Event('2026-03-05T11:00:00', 'trick', '/saltybot/trick_state', {'trick': 'dance'}),
|
||||
Event('2026-03-05T12:00:00', 'error', '/saltybot/error', {'msg': 'test'}),
|
||||
]
|
||||
|
||||
with open(log_file, 'w') as f:
|
||||
for event in events_data:
|
||||
from dataclasses import asdict
|
||||
f.write(json.dumps(asdict(event)) + '\n')
|
||||
|
||||
# Query only voice commands
|
||||
results = logger.query_events(event_types=['voice_command'])
|
||||
assert all(e.event_type == 'voice_command' for e in results)
|
||||
|
||||
def test_query_events_by_time_range(self, logger, temp_data_dir):
|
||||
"""Test querying events by time range."""
|
||||
today = datetime.now().strftime('%Y-%m-%d')
|
||||
log_file = temp_data_dir / f'{today}.jsonl'
|
||||
|
||||
events_data = [
|
||||
Event('2026-03-05T10:00:00', 'voice_command', '/saltybot/voice_command', {'cmd': 'spin'}),
|
||||
Event('2026-03-05T14:00:00', 'trick', '/saltybot/trick_state', {'trick': 'dance'}),
|
||||
Event('2026-03-05T18:00:00', 'error', '/saltybot/error', {'msg': 'test'}),
|
||||
]
|
||||
|
||||
with open(log_file, 'w') as f:
|
||||
for event in events_data:
|
||||
from dataclasses import asdict
|
||||
f.write(json.dumps(asdict(event)) + '\n')
|
||||
|
||||
# Query only morning events
|
||||
results = logger.query_events(
|
||||
start_time='2026-03-05T09:00:00',
|
||||
end_time='2026-03-05T13:00:00'
|
||||
)
|
||||
assert len(results) >= 1
|
||||
assert results[0].event_type == 'voice_command'
|
||||
|
||||
def test_export_csv(self, logger, temp_data_dir):
|
||||
"""Test CSV export."""
|
||||
today = datetime.now().strftime('%Y-%m-%d')
|
||||
log_file = temp_data_dir / f'{today}.jsonl'
|
||||
|
||||
events_data = [
|
||||
Event('2026-03-05T10:00:00', 'voice_command', '/saltybot/voice_command', {'cmd': 'spin'}),
|
||||
Event('2026-03-05T11:00:00', 'trick', '/saltybot/trick_state', {'trick': 'dance'}),
|
||||
]
|
||||
|
||||
with open(log_file, 'w') as f:
|
||||
for event in events_data:
|
||||
from dataclasses import asdict
|
||||
f.write(json.dumps(asdict(event)) + '\n')
|
||||
|
||||
# Export to CSV
|
||||
csv_path = logger.export_csv()
|
||||
assert csv_path
|
||||
assert Path(csv_path).exists()
|
||||
assert Path(csv_path).suffix == '.csv'
|
||||
|
||||
def test_daily_stats(self, logger):
|
||||
"""Test daily statistics calculation."""
|
||||
# Add events
|
||||
logger.daily_events = [
|
||||
Event('2026-03-05T10:00:00', 'voice_command', '/saltybot/voice_command', {}),
|
||||
Event('2026-03-05T11:00:00', 'trick', '/saltybot/trick_state', {}),
|
||||
]
|
||||
logger.daily_stats = {'voice_command': 1, 'trick': 1, 'encounters': 5, 'errors': 0}
|
||||
|
||||
stats = logger.get_daily_stats()
|
||||
|
||||
assert stats['total_events'] == 2
|
||||
assert stats['encounters'] == 5
|
||||
assert stats['errors'] == 0
|
||||
assert 'by_type' in stats
|
||||
|
||||
def test_on_event_json_data(self, logger):
|
||||
"""Test processing event with JSON data."""
|
||||
msg = String()
|
||||
msg.data = json.dumps({'command': 'spin', 'duration': 2})
|
||||
|
||||
logger._on_event(msg, 'voice_command', '/saltybot/voice_command')
|
||||
|
||||
assert len(logger.daily_events) > 0
|
||||
event = logger.daily_events[0]
|
||||
assert event.data['command'] == 'spin'
|
||||
|
||||
def test_on_event_string_data(self, logger):
|
||||
"""Test processing event with string data."""
|
||||
msg = String()
|
||||
msg.data = 'simple string message'
|
||||
|
||||
logger._on_event(msg, 'error', '/saltybot/error')
|
||||
|
||||
assert len(logger.daily_events) > 0
|
||||
event = logger.daily_events[0]
|
||||
assert 'message' in event.data
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
@ -0,0 +1,36 @@
|
||||
# YOLOv8n Object Detection Configuration
|
||||
|
||||
object_detection:
|
||||
ros__parameters:
|
||||
# Model paths
|
||||
engine_path: /mnt/nvme/saltybot/models/yolov8n.engine
|
||||
onnx_path: /mnt/nvme/saltybot/models/yolov8n.onnx
|
||||
|
||||
# Inference parameters
|
||||
confidence_threshold: 0.5 # Detection confidence threshold (0-1)
|
||||
nms_iou_threshold: 0.45 # Non-Maximum Suppression IoU threshold
|
||||
|
||||
# Per-frame filtering
|
||||
min_confidence_filter: 0.4 # Only publish objects with confidence >= this
|
||||
enabled_classes: # COCO class IDs to detect
|
||||
- 0 # person
|
||||
- 39 # cup
|
||||
- 41 # bowl
|
||||
- 42 # fork
|
||||
- 43 # knife
|
||||
- 47 # apple
|
||||
- 48 # banana
|
||||
- 49 # orange
|
||||
- 56 # wine glass
|
||||
- 62 # backpack
|
||||
- 64 # handbag
|
||||
- 73 # book
|
||||
|
||||
# Depth sampling parameters
|
||||
depth_window_size: 7 # 7x7 window for median filtering
|
||||
depth_min_range: 0.3 # Minimum valid depth (meters)
|
||||
depth_max_range: 6.0 # Maximum valid depth (meters)
|
||||
|
||||
# Publishing
|
||||
target_frame: "base_link" # Output frame for 3D positions
|
||||
publish_debug_image: false # Publish annotated debug image
|
||||
@ -0,0 +1,49 @@
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch_ros.actions import Node
|
||||
from launch_ros.substitutions import FindPackageShare
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg_share = FindPackageShare("saltybot_object_detection")
|
||||
config_dir = Path(str(pkg_share)) / "config"
|
||||
config_file = str(config_dir / "object_detection_params.yaml")
|
||||
|
||||
# Declare launch arguments
|
||||
confidence_threshold_arg = DeclareLaunchArgument(
|
||||
"confidence_threshold",
|
||||
default_value="0.5",
|
||||
description="Detection confidence threshold (0-1)"
|
||||
)
|
||||
|
||||
publish_debug_arg = DeclareLaunchArgument(
|
||||
"publish_debug_image",
|
||||
default_value="false",
|
||||
description="Publish annotated debug images"
|
||||
)
|
||||
|
||||
# Object detection node
|
||||
object_detection_node = Node(
|
||||
package="saltybot_object_detection",
|
||||
executable="object_detection",
|
||||
name="object_detection",
|
||||
parameters=[
|
||||
config_file,
|
||||
{"confidence_threshold": LaunchConfiguration("confidence_threshold")},
|
||||
{"publish_debug_image": LaunchConfiguration("publish_debug_arg")},
|
||||
],
|
||||
remappings=[
|
||||
("color_image", "/camera/color/image_raw"),
|
||||
("depth_image", "/camera/depth/image_rect_raw"),
|
||||
("camera_info", "/camera/color/camera_info"),
|
||||
],
|
||||
output="screen",
|
||||
)
|
||||
|
||||
return LaunchDescription([
|
||||
confidence_threshold_arg,
|
||||
publish_debug_arg,
|
||||
object_detection_node,
|
||||
])
|
||||
32
jetson/ros2_ws/src/saltybot_object_detection/package.xml
Normal file
32
jetson/ros2_ws/src/saltybot_object_detection/package.xml
Normal file
@ -0,0 +1,32 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_object_detection</name>
|
||||
<version>0.1.0</version>
|
||||
<description>YOLOv8n object detection with depth integration (Issue #468)</description>
|
||||
<maintainer email="sl-perception@saltylab.local">sl-perception</maintainer>
|
||||
<license>MIT</license>
|
||||
|
||||
<buildtool_depend>ament_python</buildtool_depend>
|
||||
|
||||
<depend>rclpy</depend>
|
||||
<depend>std_msgs</depend>
|
||||
<depend>sensor_msgs</depend>
|
||||
<depend>geometry_msgs</depend>
|
||||
<depend>vision_msgs</depend>
|
||||
<depend>tf2_ros</depend>
|
||||
<depend>cv_bridge</depend>
|
||||
<depend>message_filters</depend>
|
||||
<depend>opencv-python</depend>
|
||||
<depend>numpy</depend>
|
||||
<depend>saltybot_object_detection_msgs</depend>
|
||||
|
||||
<test_depend>ament_copyright</test_depend>
|
||||
<test_depend>ament_flake8</test_depend>
|
||||
<test_depend>ament_pep257</test_depend>
|
||||
<test_depend>python3-pytest</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1,549 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
YOLOv8n Object Detection Node with RealSense Depth Integration
|
||||
Issue #468: General object detection for spatial awareness
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
from typing import Tuple, List, Optional
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||
import message_filters
|
||||
from tf2_ros import TransformListener, Buffer
|
||||
from tf2_geometry_msgs import PointStamped
|
||||
|
||||
from sensor_msgs.msg import Image, CameraInfo
|
||||
from std_msgs.msg import Header
|
||||
from geometry_msgs.msg import Point, PointStamped as PointStampedMsg, Quaternion
|
||||
from vision_msgs.msg import BoundingBox2D, Pose2D
|
||||
from cv_bridge import CvBridge
|
||||
|
||||
from saltybot_object_detection_msgs.msg import DetectedObject, DetectedObjectArray
|
||||
from saltybot_object_detection_msgs.srv import QueryObjects
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# COCO class names (0-79)
|
||||
_COCO_CLASSES = [
|
||||
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
|
||||
"boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
|
||||
"cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe",
|
||||
"backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis",
|
||||
"snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard",
|
||||
"surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife",
|
||||
"spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot",
|
||||
"hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed",
|
||||
"dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "microwave",
|
||||
"oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
|
||||
"teddy bear", "hair drier", "toothbrush"
|
||||
]
|
||||
|
||||
_YOLO_INPUT_SIZE = 640
|
||||
_CONFIDENCE_THRESHOLD = 0.5
|
||||
_NMS_IOU_THRESHOLD = 0.45
|
||||
|
||||
_SENSOR_QOS = QoSProfile(
|
||||
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||
history=HistoryPolicy.KEEP_LAST,
|
||||
depth=5,
|
||||
)
|
||||
|
||||
|
||||
class _TRTBackend:
|
||||
"""TensorRT inference backend (primary for Jetson)."""
|
||||
|
||||
def __init__(self, engine_path: str):
|
||||
try:
|
||||
import tensorrt as trt
|
||||
import pycuda.driver as cuda
|
||||
import pycuda.autoinit # noqa: F401
|
||||
except ImportError as e:
|
||||
raise RuntimeError(f"TensorRT/pycuda not available: {e}")
|
||||
|
||||
if not Path(engine_path).exists():
|
||||
raise FileNotFoundError(f"TensorRT engine not found: {engine_path}")
|
||||
|
||||
self.logger = trt.Logger(trt.Logger.WARNING)
|
||||
with open(engine_path, "rb") as f:
|
||||
self.engine = trt.Runtime(self.logger).deserialize_cuda_engine(f.read())
|
||||
|
||||
self.context = self.engine.create_execution_context()
|
||||
self.stream = cuda.Stream()
|
||||
|
||||
# Allocate input/output buffers
|
||||
self.h_inputs = {}
|
||||
self.h_outputs = {}
|
||||
self.d_inputs = {}
|
||||
self.d_outputs = {}
|
||||
self.bindings = []
|
||||
|
||||
for binding_idx in range(self.engine.num_bindings):
|
||||
binding_name = self.engine.get_binding_name(binding_idx)
|
||||
binding_shape = self.engine.get_binding_shape(binding_idx)
|
||||
binding_dtype = self.engine.get_binding_dtype(binding_idx)
|
||||
|
||||
# Convert dtype to numpy
|
||||
if binding_dtype == trt.float32:
|
||||
np_dtype = np.float32
|
||||
elif binding_dtype == trt.float16:
|
||||
np_dtype = np.float32
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {binding_dtype}")
|
||||
|
||||
binding_size = int(np.prod(binding_shape))
|
||||
|
||||
if self.engine.binding_is_input(binding_idx):
|
||||
h_buf = cuda.pagelocked_empty(binding_size, np_dtype)
|
||||
d_buf = cuda.mem_alloc(h_buf.nbytes)
|
||||
self.h_inputs[binding_name] = h_buf
|
||||
self.d_inputs[binding_name] = d_buf
|
||||
self.bindings.append(int(d_buf))
|
||||
else:
|
||||
h_buf = cuda.pagelocked_empty(binding_size, np_dtype)
|
||||
d_buf = cuda.mem_alloc(h_buf.nbytes)
|
||||
self.h_outputs[binding_name] = h_buf.reshape(binding_shape)
|
||||
self.d_outputs[binding_name] = d_buf
|
||||
self.bindings.append(int(d_buf))
|
||||
|
||||
# Get input/output names
|
||||
self.input_names = list(self.h_inputs.keys())
|
||||
self.output_names = list(self.h_outputs.keys())
|
||||
|
||||
def infer(self, input_data: np.ndarray) -> List[np.ndarray]:
|
||||
"""Run inference."""
|
||||
import pycuda.driver as cuda
|
||||
|
||||
# Copy input to device
|
||||
input_name = self.input_names[0]
|
||||
np.copyto(self.h_inputs[input_name], input_data.ravel())
|
||||
cuda.memcpy_htod_async(self.d_inputs[input_name], self.h_inputs[input_name], self.stream)
|
||||
|
||||
# Execute
|
||||
self.context.execute_async_v2(self.bindings, self.stream.handle)
|
||||
|
||||
# Copy outputs back
|
||||
outputs = []
|
||||
for output_name in self.output_names:
|
||||
cuda.memcpy_dtoh_async(self.h_outputs[output_name], self.d_outputs[output_name], self.stream)
|
||||
self.stream.synchronize()
|
||||
outputs.append(self.h_outputs[output_name].copy())
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class _ONNXBackend:
|
||||
"""ONNX Runtime inference backend (fallback)."""
|
||||
|
||||
def __init__(self, onnx_path: str):
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
except ImportError as e:
|
||||
raise RuntimeError(f"ONNXRuntime not available: {e}")
|
||||
|
||||
if not Path(onnx_path).exists():
|
||||
raise FileNotFoundError(f"ONNX model not found: {onnx_path}")
|
||||
|
||||
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
self.session = ort.InferenceSession(onnx_path, providers=providers)
|
||||
|
||||
# Get input/output info
|
||||
self.input_name = self.session.get_inputs()[0].name
|
||||
self.output_names = [output.name for output in self.session.get_outputs()]
|
||||
|
||||
def infer(self, input_data: np.ndarray) -> List[np.ndarray]:
|
||||
"""Run inference."""
|
||||
outputs = self.session.run(self.output_names, {self.input_name: input_data})
|
||||
return outputs
|
||||
|
||||
|
||||
class _YOLODecoder:
|
||||
"""Decode YOLOv8 output to detections."""
|
||||
|
||||
def __init__(self, conf_thresh: float = 0.5, nms_iou_thresh: float = 0.45):
|
||||
self.conf_thresh = conf_thresh
|
||||
self.nms_iou_thresh = nms_iou_thresh
|
||||
|
||||
def decode(self, output: np.ndarray, input_size: int) -> List[Tuple[int, str, float, Tuple[int, int, int, int]]]:
|
||||
"""
|
||||
Decode YOLOv8 output.
|
||||
Output shape: [1, 84, 8400]
|
||||
Returns: List[(class_id, class_name, confidence, bbox_xyxy)]
|
||||
"""
|
||||
# Transpose: [1, 84, 8400] -> [8400, 84]
|
||||
output = output.squeeze(0).transpose(1, 0)
|
||||
|
||||
# Extract bbox and scores
|
||||
bboxes = output[:, :4] # [8400, 4] cx, cy, w, h
|
||||
scores = output[:, 4:] # [8400, 80] class scores
|
||||
|
||||
# Get max score and class per detection
|
||||
max_scores = scores.max(axis=1)
|
||||
class_ids = scores.argmax(axis=1)
|
||||
|
||||
# Filter by confidence
|
||||
mask = max_scores >= self.conf_thresh
|
||||
bboxes = bboxes[mask]
|
||||
class_ids = class_ids[mask]
|
||||
scores = max_scores[mask]
|
||||
|
||||
if len(bboxes) == 0:
|
||||
return []
|
||||
|
||||
# Convert cx, cy, w, h to x1, y1, x2, y2
|
||||
bboxes_xyxy = np.zeros_like(bboxes)
|
||||
bboxes_xyxy[:, 0] = bboxes[:, 0] - bboxes[:, 2] / 2
|
||||
bboxes_xyxy[:, 1] = bboxes[:, 1] - bboxes[:, 3] / 2
|
||||
bboxes_xyxy[:, 2] = bboxes[:, 0] + bboxes[:, 2] / 2
|
||||
bboxes_xyxy[:, 3] = bboxes[:, 1] + bboxes[:, 3] / 2
|
||||
|
||||
# Apply NMS
|
||||
keep_indices = self._nms(bboxes_xyxy, scores, self.nms_iou_thresh)
|
||||
|
||||
# Build result
|
||||
detections = []
|
||||
for idx in keep_indices:
|
||||
x1, y1, x2, y2 = bboxes_xyxy[idx]
|
||||
class_id = int(class_ids[idx])
|
||||
conf = float(scores[idx])
|
||||
class_name = _COCO_CLASSES[class_id]
|
||||
detections.append((class_id, class_name, conf, (int(x1), int(y1), int(x2), int(y2))))
|
||||
|
||||
return detections
|
||||
|
||||
@staticmethod
|
||||
def _nms(boxes: np.ndarray, scores: np.ndarray, iou_threshold: float) -> List[int]:
|
||||
"""Non-Maximum Suppression."""
|
||||
if len(boxes) == 0:
|
||||
return []
|
||||
|
||||
x1 = boxes[:, 0]
|
||||
y1 = boxes[:, 1]
|
||||
x2 = boxes[:, 2]
|
||||
y2 = boxes[:, 3]
|
||||
|
||||
areas = (x2 - x1) * (y2 - y1)
|
||||
order = scores.argsort()[::-1]
|
||||
|
||||
keep = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1)
|
||||
h = np.maximum(0.0, yy2 - yy1)
|
||||
inter = w * h
|
||||
|
||||
iou = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
order = order[np.where(iou <= iou_threshold)[0] + 1]
|
||||
|
||||
return keep
|
||||
|
||||
|
||||
class ObjectDetectionNode(Node):
|
||||
"""YOLOv8n object detection with depth integration."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("object_detection")
|
||||
|
||||
# Parameters
|
||||
self.declare_parameter("engine_path", "/mnt/nvme/saltybot/models/yolov8n.engine")
|
||||
self.declare_parameter("onnx_path", "/mnt/nvme/saltybot/models/yolov8n.onnx")
|
||||
self.declare_parameter("confidence_threshold", 0.5)
|
||||
self.declare_parameter("nms_iou_threshold", 0.45)
|
||||
self.declare_parameter("min_confidence_filter", 0.4)
|
||||
self.declare_parameter("enabled_classes", [0, 39, 41, 42, 43, 47, 56, 57, 61, 62, 64, 73])
|
||||
self.declare_parameter("depth_window_size", 7)
|
||||
self.declare_parameter("depth_min_range", 0.3)
|
||||
self.declare_parameter("depth_max_range", 6.0)
|
||||
self.declare_parameter("target_frame", "base_link")
|
||||
self.declare_parameter("publish_debug_image", False)
|
||||
|
||||
# Load parameters
|
||||
self.engine_path = self.get_parameter("engine_path").value
|
||||
self.onnx_path = self.get_parameter("onnx_path").value
|
||||
self.confidence_threshold = self.get_parameter("confidence_threshold").value
|
||||
self.nms_iou_threshold = self.get_parameter("nms_iou_threshold").value
|
||||
self.min_confidence_filter = self.get_parameter("min_confidence_filter").value
|
||||
self.enabled_classes = self.get_parameter("enabled_classes").value
|
||||
self.depth_window_size = self.get_parameter("depth_window_size").value
|
||||
self.depth_min_range = self.get_parameter("depth_min_range").value
|
||||
self.depth_max_range = self.get_parameter("depth_max_range").value
|
||||
self.target_frame = self.get_parameter("target_frame").value
|
||||
self.publish_debug_image = self.get_parameter("publish_debug_image").value
|
||||
|
||||
# Initialize backend
|
||||
self.backend = self._load_backend()
|
||||
self.decoder = _YOLODecoder(self.confidence_threshold, self.nms_iou_threshold)
|
||||
self.bridge = CvBridge()
|
||||
|
||||
# TF2
|
||||
self.tf_buffer = Buffer()
|
||||
self.tf_listener = TransformListener(self.tf_buffer, self)
|
||||
|
||||
# Camera info
|
||||
self.camera_info: Optional[CameraInfo] = None
|
||||
self.camera_info_lock = None
|
||||
|
||||
# Subscriptions
|
||||
color_sub = message_filters.Subscriber(
|
||||
self, Image, "color_image", qos_profile=_SENSOR_QOS
|
||||
)
|
||||
depth_sub = message_filters.Subscriber(
|
||||
self, Image, "depth_image", qos_profile=_SENSOR_QOS
|
||||
)
|
||||
camera_info_sub = message_filters.Subscriber(
|
||||
self, CameraInfo, "camera_info", qos_profile=_SENSOR_QOS
|
||||
)
|
||||
|
||||
# Synchronize color + depth (slop = 1 frame @ 30fps)
|
||||
self.sync = message_filters.ApproximateTimeSynchronizer(
|
||||
[color_sub, depth_sub], queue_size=5, slop=0.033
|
||||
)
|
||||
self.sync.registerCallback(self._on_frame)
|
||||
|
||||
# Camera info subscriber (separate, not synchronized)
|
||||
self.create_subscription(CameraInfo, "camera_info", self._on_camera_info, _SENSOR_QOS)
|
||||
|
||||
# Publishers
|
||||
self.objects_pub = self.create_publisher(
|
||||
DetectedObjectArray, "/saltybot/objects", _SENSOR_QOS
|
||||
)
|
||||
if self.publish_debug_image:
|
||||
self.debug_image_pub = self.create_publisher(
|
||||
Image, "/saltybot/objects/debug_image", _SENSOR_QOS
|
||||
)
|
||||
else:
|
||||
self.debug_image_pub = None
|
||||
|
||||
# Query service
|
||||
self.query_srv = self.create_service(
|
||||
QueryObjects, "/saltybot/objects/query", self._on_query
|
||||
)
|
||||
|
||||
# Last detection for query service
|
||||
self.last_detections: List[DetectedObject] = []
|
||||
|
||||
self.get_logger().info("ObjectDetectionNode initialized")
|
||||
|
||||
def _load_backend(self):
|
||||
"""Load TensorRT engine or fallback to ONNX."""
|
||||
try:
|
||||
if Path(self.engine_path).exists():
|
||||
self.get_logger().info(f"Loading TensorRT engine: {self.engine_path}")
|
||||
return _TRTBackend(self.engine_path)
|
||||
else:
|
||||
self.get_logger().warn(f"TRT engine not found: {self.engine_path}")
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"TensorRT loading failed: {e}")
|
||||
|
||||
# Fallback to ONNX
|
||||
self.get_logger().info(f"Loading ONNX model: {self.onnx_path}")
|
||||
return _ONNXBackend(self.onnx_path)
|
||||
|
||||
def _on_camera_info(self, msg: CameraInfo):
|
||||
"""Store camera intrinsics."""
|
||||
if self.camera_info is None:
|
||||
self.camera_info = msg
|
||||
self.get_logger().info(f"Camera info received: {msg.width}x{msg.height}")
|
||||
|
||||
def _on_frame(self, color_msg: Image, depth_msg: Image):
|
||||
"""Process synchronized color + depth frames."""
|
||||
if self.camera_info is None:
|
||||
self.get_logger().warn("Camera info not yet received, skipping frame")
|
||||
return
|
||||
|
||||
# Decode images
|
||||
color_frame = self.bridge.imgmsg_to_cv2(color_msg, desired_encoding="bgr8")
|
||||
depth_frame = self.bridge.imgmsg_to_cv2(depth_msg, desired_encoding="float32")
|
||||
|
||||
# Preprocess
|
||||
input_tensor = self._preprocess(color_frame)
|
||||
|
||||
# Inference
|
||||
try:
|
||||
output = self.backend.infer(input_tensor)
|
||||
detections = self.decoder.decode(output[0], _YOLO_INPUT_SIZE)
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"Inference error: {e}")
|
||||
return
|
||||
|
||||
# Filter by enabled classes and confidence
|
||||
filtered_detections = [
|
||||
d for d in detections
|
||||
if d[0] in self.enabled_classes and d[2] >= self.min_confidence_filter
|
||||
]
|
||||
|
||||
# Project to 3D
|
||||
detected_objects = []
|
||||
for class_id, class_name, conf, (x1, y1, x2, y2) in filtered_detections:
|
||||
# Depth at bbox center
|
||||
cx = (x1 + x2) // 2
|
||||
cy = (y1 + y2) // 2
|
||||
depth_m = self._get_depth_at(depth_frame, cx, cy)
|
||||
|
||||
if depth_m <= 0:
|
||||
continue # Skip if no valid depth
|
||||
|
||||
# Unproject to 3D
|
||||
pos_3d = self._pixel_to_3d(float(cx), float(cy), depth_m, self.camera_info)
|
||||
|
||||
# Transform to target frame if needed
|
||||
if self.target_frame != "camera_color_optical_frame":
|
||||
try:
|
||||
transform = self.tf_buffer.lookup_transform(
|
||||
self.target_frame, "camera_color_optical_frame", color_msg.header.stamp
|
||||
)
|
||||
# Simple transform: apply rotation + translation
|
||||
# For simplicity, use the position as-is with TF lookup
|
||||
# In a real implementation, would use TF2 geometry helpers
|
||||
except Exception as e:
|
||||
self.get_logger().warn(f"TF lookup failed: {e}")
|
||||
pos_3d.header.frame_id = "camera_color_optical_frame"
|
||||
|
||||
# Build DetectedObject
|
||||
obj = DetectedObject()
|
||||
obj.class_id = class_id
|
||||
obj.class_name = class_name
|
||||
obj.confidence = conf
|
||||
obj.bbox = BoundingBox2D()
|
||||
obj.bbox.center = Pose2D(x=float(cx), y=float(cy))
|
||||
obj.bbox.size_x = float(x2 - x1)
|
||||
obj.bbox.size_y = float(y2 - y1)
|
||||
obj.position_3d = pos_3d
|
||||
obj.distance_m = depth_m
|
||||
|
||||
detected_objects.append(obj)
|
||||
|
||||
# Publish
|
||||
self.last_detections = detected_objects
|
||||
array_msg = DetectedObjectArray()
|
||||
array_msg.header = color_msg.header
|
||||
array_msg.header.frame_id = self.target_frame
|
||||
array_msg.objects = detected_objects
|
||||
self.objects_pub.publish(array_msg)
|
||||
|
||||
# Debug image
|
||||
if self.debug_image_pub is not None:
|
||||
self._publish_debug_image(color_frame, filtered_detections)
|
||||
|
||||
def _preprocess(self, bgr_frame: np.ndarray) -> np.ndarray:
|
||||
"""Preprocess image: letterbox, BGR->RGB, normalize, NCHW."""
|
||||
# Letterbox resize
|
||||
h, w = bgr_frame.shape[:2]
|
||||
scale = _YOLO_INPUT_SIZE / max(h, w)
|
||||
new_h, new_w = int(h * scale), int(w * scale)
|
||||
|
||||
resized = cv2.resize(bgr_frame, (new_w, new_h))
|
||||
|
||||
# Pad to square
|
||||
canvas = np.zeros((_YOLO_INPUT_SIZE, _YOLO_INPUT_SIZE, 3), dtype=np.uint8)
|
||||
pad_y = (_YOLO_INPUT_SIZE - new_h) // 2
|
||||
pad_x = (_YOLO_INPUT_SIZE - new_w) // 2
|
||||
canvas[pad_y : pad_y + new_h, pad_x : pad_x + new_w] = resized
|
||||
|
||||
# BGR -> RGB
|
||||
rgb = cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Normalize 0-1
|
||||
tensor = rgb.astype(np.float32) / 255.0
|
||||
|
||||
# HWC -> CHW -> NCHW
|
||||
tensor = tensor.transpose(2, 0, 1)
|
||||
tensor = np.ascontiguousarray(tensor[np.newaxis])
|
||||
|
||||
return tensor
|
||||
|
||||
def _get_depth_at(self, depth_frame: np.ndarray, u: int, v: int) -> float:
|
||||
"""Get depth at pixel with median filtering."""
|
||||
h, w = depth_frame.shape
|
||||
half = self.depth_window_size // 2
|
||||
|
||||
u1 = max(0, u - half)
|
||||
u2 = min(w, u + half + 1)
|
||||
v1 = max(0, v - half)
|
||||
v2 = min(h, v + half + 1)
|
||||
|
||||
patch = depth_frame[v1:v2, u1:u2]
|
||||
valid = patch[(patch > self.depth_min_range) & (patch < self.depth_max_range)]
|
||||
|
||||
if len(valid) == 0:
|
||||
return 0.0
|
||||
|
||||
return float(np.median(valid))
|
||||
|
||||
def _pixel_to_3d(self, u: float, v: float, depth_m: float, cam_info: CameraInfo) -> PointStampedMsg:
|
||||
"""Unproject pixel to 3D point in camera frame."""
|
||||
K = cam_info.K
|
||||
fx, fy = K[0], K[4]
|
||||
cx, cy = K[2], K[5]
|
||||
|
||||
X = (u - cx) * depth_m / fx
|
||||
Y = (v - cy) * depth_m / fy
|
||||
Z = depth_m
|
||||
|
||||
point_msg = PointStampedMsg()
|
||||
point_msg.header.frame_id = "camera_color_optical_frame"
|
||||
point_msg.header.stamp = self.get_clock().now().to_msg()
|
||||
point_msg.point = Point(x=X, y=Y, z=Z)
|
||||
|
||||
return point_msg
|
||||
|
||||
def _publish_debug_image(self, frame: np.ndarray, detections: List):
|
||||
"""Publish annotated debug image."""
|
||||
debug_frame = frame.copy()
|
||||
|
||||
for class_id, class_name, conf, (x1, y1, x2, y2) in detections:
|
||||
# Draw bbox
|
||||
cv2.rectangle(debug_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
|
||||
# Draw label
|
||||
label = f"{class_name} {conf:.2f}"
|
||||
cv2.putText(
|
||||
debug_frame, label, (x1, y1 - 5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2
|
||||
)
|
||||
|
||||
msg = self.bridge.cv2_to_imgmsg(debug_frame, encoding="bgr8")
|
||||
self.debug_image_pub.publish(msg)
|
||||
|
||||
def _on_query(self, request, response) -> QueryObjects.Response:
|
||||
"""Handle query service."""
|
||||
if not self.last_detections:
|
||||
response.description = "No objects detected."
|
||||
response.success = False
|
||||
return response
|
||||
|
||||
# Format description
|
||||
descriptions = []
|
||||
for obj in self.last_detections:
|
||||
if obj.distance_m > 0:
|
||||
descriptions.append(f"{obj.class_name} at {obj.distance_m:.1f}m")
|
||||
else:
|
||||
descriptions.append(obj.class_name)
|
||||
|
||||
response.description = f"I see {', '.join(descriptions)}."
|
||||
response.success = True
|
||||
return response
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = ObjectDetectionNode()
|
||||
rclpy.spin(node)
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -0,0 +1,183 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Build YOLOv8n TensorRT FP16 engine for Orin Nano.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def build_engine(output_path: Path, workspace_mb: int = 2048) -> bool:
|
||||
"""Download YOLOv8n, export ONNX, and convert to TensorRT."""
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
except ImportError:
|
||||
print("ERROR: ultralytics not installed. Install with: pip install ultralytics")
|
||||
return False
|
||||
|
||||
try:
|
||||
import tensorrt as trt
|
||||
except ImportError:
|
||||
print("ERROR: TensorRT not installed")
|
||||
return False
|
||||
|
||||
print("[*] Loading YOLOv8n from Ultralytics...")
|
||||
model = YOLO("yolov8n.pt")
|
||||
|
||||
print("[*] Exporting to ONNX...")
|
||||
onnx_path = output_path.parent / "yolov8n.onnx"
|
||||
model.export(format="onnx", opset=12)
|
||||
|
||||
# Move ONNX to desired location
|
||||
import shutil
|
||||
onnx_src = Path("yolov8n.onnx")
|
||||
if onnx_src.exists():
|
||||
shutil.move(str(onnx_src), str(onnx_path))
|
||||
print(f"[+] ONNX exported to {onnx_path}")
|
||||
else:
|
||||
print(f"ERROR: ONNX export not found")
|
||||
return False
|
||||
|
||||
print("[*] Converting ONNX to TensorRT FP16...")
|
||||
try:
|
||||
import polygraphy
|
||||
from polygraphy.backend.trt import engine_from_network, save_engine
|
||||
from polygraphy.backend.onnx import BytesFromOnnx
|
||||
except ImportError:
|
||||
print("ERROR: polygraphy not installed. Install with: pip install polygraphy")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Build TRT engine
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
onnx_bytes = BytesFromOnnx(str(onnx_path))()
|
||||
|
||||
engine = engine_from_network(
|
||||
onnx_bytes,
|
||||
config_kwargs={
|
||||
"flags": [trt.BuilderFlag.FP16],
|
||||
"max_workspace_size": workspace_mb * 1024 * 1024,
|
||||
},
|
||||
logger=logger,
|
||||
)()
|
||||
|
||||
# Save engine
|
||||
save_engine(engine, str(output_path))
|
||||
print(f"[+] TensorRT engine saved to {output_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to convert to TensorRT: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def benchmark_engine(engine_path: Path, num_iterations: int = 100) -> None:
|
||||
"""Benchmark TensorRT engine latency."""
|
||||
try:
|
||||
import tensorrt as trt
|
||||
import pycuda.driver as cuda
|
||||
import pycuda.autoinit
|
||||
import numpy as np
|
||||
import time
|
||||
except ImportError as e:
|
||||
print(f"ERROR: Missing dependency: {e}")
|
||||
return
|
||||
|
||||
if not engine_path.exists():
|
||||
print(f"ERROR: Engine not found: {engine_path}")
|
||||
return
|
||||
|
||||
print(f"\n[*] Benchmarking {engine_path} ({num_iterations} iterations)...")
|
||||
|
||||
try:
|
||||
logger = trt.Logger(trt.Logger.WARNING)
|
||||
with open(engine_path, "rb") as f:
|
||||
engine = trt.Runtime(logger).deserialize_cuda_engine(f.read())
|
||||
|
||||
context = engine.create_execution_context()
|
||||
stream = cuda.Stream()
|
||||
|
||||
# Prepare input (1, 3, 640, 640)
|
||||
h_input = cuda.pagelocked_empty(1 * 3 * 640 * 640, np.float32)
|
||||
d_input = cuda.mem_alloc(h_input.nbytes)
|
||||
|
||||
# Prepare output (1, 84, 8400)
|
||||
h_output = cuda.pagelocked_empty(1 * 84 * 8400, np.float32)
|
||||
d_output = cuda.mem_alloc(h_output.nbytes)
|
||||
|
||||
bindings = [int(d_input), int(d_output)]
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
cuda.memcpy_htod_async(d_input, h_input, stream)
|
||||
context.execute_async_v2(bindings, stream.handle)
|
||||
cuda.memcpy_dtoh_async(h_output, d_output, stream)
|
||||
stream.synchronize()
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(num_iterations):
|
||||
cuda.memcpy_htod_async(d_input, h_input, stream)
|
||||
start = time.time()
|
||||
context.execute_async_v2(bindings, stream.handle)
|
||||
cuda.memcpy_dtoh_async(h_output, d_output, stream)
|
||||
stream.synchronize()
|
||||
elapsed = time.time() - start
|
||||
times.append(elapsed * 1000) # ms
|
||||
|
||||
mean_latency = np.mean(times)
|
||||
std_latency = np.std(times)
|
||||
min_latency = np.min(times)
|
||||
max_latency = np.max(times)
|
||||
throughput = 1000.0 / mean_latency
|
||||
|
||||
print(f"[+] Latency: {mean_latency:.2f}ms ± {std_latency:.2f}ms (min={min_latency:.2f}ms, max={max_latency:.2f}ms)")
|
||||
print(f"[+] Throughput: {throughput:.1f} FPS")
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Benchmark failed: {e}")
|
||||
finally:
|
||||
d_input.free()
|
||||
d_output.free()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Build YOLOv8n TensorRT engine")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
default=Path("/mnt/nvme/saltybot/models/yolov8n.engine"),
|
||||
help="Output engine path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workspace",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="TensorRT workspace size in MB"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--benchmark",
|
||||
action="store_true",
|
||||
help="Benchmark the engine after building"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create output directory
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"[*] Building YOLOv8n TensorRT engine")
|
||||
print(f"[*] Output: {args.output}")
|
||||
print(f"[*] Workspace: {args.workspace} MB")
|
||||
|
||||
success = build_engine(args.output, args.workspace)
|
||||
|
||||
if success and args.benchmark:
|
||||
benchmark_engine(args.output)
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
4
jetson/ros2_ws/src/saltybot_object_detection/setup.cfg
Normal file
4
jetson/ros2_ws/src/saltybot_object_detection/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/saltybot_object_detection
|
||||
[install]
|
||||
install_scripts=$base/lib/saltybot_object_detection
|
||||
32
jetson/ros2_ws/src/saltybot_object_detection/setup.py
Normal file
32
jetson/ros2_ws/src/saltybot_object_detection/setup.py
Normal file
@ -0,0 +1,32 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
package_name = 'saltybot_object_detection'
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version='0.1.0',
|
||||
packages=find_packages(exclude=['test']),
|
||||
data_files=[
|
||||
('share/ament_index/resource_index/packages',
|
||||
['resource/' + package_name]),
|
||||
('share/' + package_name, ['package.xml']),
|
||||
('share/' + package_name + '/launch', [
|
||||
'launch/object_detection.launch.py',
|
||||
]),
|
||||
('share/' + package_name + '/config', [
|
||||
'config/object_detection_params.yaml',
|
||||
]),
|
||||
],
|
||||
install_requires=['setuptools'],
|
||||
zip_safe=True,
|
||||
maintainer='sl-perception',
|
||||
maintainer_email='sl-perception@saltylab.local',
|
||||
description='YOLOv8n object detection with depth integration',
|
||||
license='MIT',
|
||||
tests_require=['pytest'],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'object_detection = saltybot_object_detection.object_detection_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
@ -0,0 +1,20 @@
|
||||
cmake_minimum_required(VERSION 3.8)
|
||||
project(saltybot_object_detection_msgs)
|
||||
|
||||
find_package(ament_cmake REQUIRED)
|
||||
find_package(rosidl_default_generators REQUIRED)
|
||||
find_package(std_msgs REQUIRED)
|
||||
find_package(geometry_msgs REQUIRED)
|
||||
find_package(vision_msgs REQUIRED)
|
||||
find_package(builtin_interfaces REQUIRED)
|
||||
|
||||
rosidl_generate_interfaces(${PROJECT_NAME}
|
||||
# Issue #468 — general object detection (YOLOv8n)
|
||||
"msg/DetectedObject.msg"
|
||||
"msg/DetectedObjectArray.msg"
|
||||
"srv/QueryObjects.srv"
|
||||
DEPENDENCIES std_msgs geometry_msgs vision_msgs builtin_interfaces
|
||||
)
|
||||
|
||||
ament_export_dependencies(rosidl_default_runtime)
|
||||
ament_package()
|
||||
@ -0,0 +1,15 @@
|
||||
# Single detected object from YOLO inference
|
||||
# Published as array in DetectedObjectArray on /saltybot/objects
|
||||
|
||||
# ── Object identity ────────────────────────────────────
|
||||
uint16 class_id # COCO class 0–79
|
||||
string class_name # human-readable label (e.g., "cup", "chair")
|
||||
float32 confidence # detection confidence 0–1
|
||||
|
||||
# ── 2-D bounding box (pixel coords in source image) ────
|
||||
vision_msgs/BoundingBox2D bbox
|
||||
|
||||
# ── 3-D position (in base_link frame) ──────────────────
|
||||
# Depth-projected from RealSense aligned depth map
|
||||
geometry_msgs/PointStamped position_3d # point in base_link frame
|
||||
float32 distance_m # euclidean distance from base_link, 0 = unknown
|
||||
@ -0,0 +1,5 @@
|
||||
# Array of detected objects from YOLOv8n inference
|
||||
# Published at /saltybot/objects with timestamp and frame info
|
||||
|
||||
std_msgs/Header header # timestamp, frame_id="base_link"
|
||||
DetectedObject[] objects # detected objects in this frame
|
||||
@ -0,0 +1,23 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_object_detection_msgs</name>
|
||||
<version>0.1.0</version>
|
||||
<description>ROS2 messages for YOLOv8n general object detection — Issue #468</description>
|
||||
<maintainer email="seb@vayrette.com">seb</maintainer>
|
||||
<license>MIT</license>
|
||||
|
||||
<buildtool_depend>ament_cmake</buildtool_depend>
|
||||
<build_depend>rosidl_default_generators</build_depend>
|
||||
<exec_depend>rosidl_default_runtime</exec_depend>
|
||||
<member_of_group>rosidl_interface_packages</member_of_group>
|
||||
|
||||
<depend>std_msgs</depend>
|
||||
<depend>geometry_msgs</depend>
|
||||
<depend>vision_msgs</depend>
|
||||
<depend>builtin_interfaces</depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_cmake</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1,8 @@
|
||||
# Query detected objects as a formatted text summary
|
||||
# Called by voice_command_node for "whats in front of you" intent
|
||||
|
||||
# Request (empty)
|
||||
---
|
||||
# Response
|
||||
string description # e.g., "I see a cup at 0.8 meters, a laptop at 1.2 meters"
|
||||
bool success # true if detection succeeded and objects found
|
||||
9
jetson/ros2_ws/src/saltybot_param_server/.gitignore
vendored
Normal file
9
jetson/ros2_ws/src/saltybot_param_server/.gitignore
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
build/
|
||||
install/
|
||||
log/
|
||||
*.egg-info/
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
.pytest_cache/
|
||||
.DS_Store
|
||||
30
jetson/ros2_ws/src/saltybot_param_server/package.xml
Normal file
30
jetson/ros2_ws/src/saltybot_param_server/package.xml
Normal file
@ -0,0 +1,30 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_param_server</name>
|
||||
<version>0.1.0</version>
|
||||
<description>
|
||||
Centralized dynamic parameter reconfiguration server for SaltyBot (Issue #471).
|
||||
Loads parameters from saltybot_params.yaml, provides dynamic reconfiguration via service.
|
||||
Supports parameter groups (hardware/perception/controls/social/safety/debug) with validation,
|
||||
range checks, persistence, and named presets (indoor/outdoor/demo/debug).
|
||||
</description>
|
||||
<maintainer email="seb@vayrette.com">seb</maintainer>
|
||||
<license>MIT</license>
|
||||
|
||||
<depend>rclpy</depend>
|
||||
<depend>std_msgs</depend>
|
||||
<depend>std_srvs</depend>
|
||||
<depend>yaml</depend>
|
||||
|
||||
<exec_depend>python3-launch-ros</exec_depend>
|
||||
|
||||
<test_depend>ament_copyright</test_depend>
|
||||
<test_depend>ament_flake8</test_depend>
|
||||
<test_depend>ament_pep257</test_depend>
|
||||
<test_depend>python3-pytest</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1 @@
|
||||
# SaltyBot Parameter Server
|
||||
@ -0,0 +1,356 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
SaltyBot Parameter Server Node (Issue #471)
|
||||
|
||||
Provides centralized dynamic reconfiguration for all SaltyBot parameters.
|
||||
- Loads parameters from saltybot_params.yaml
|
||||
- Exposes /saltybot/get_params and /saltybot/set_param services
|
||||
- Organizes parameters by group (hardware/perception/controls/social/safety/debug)
|
||||
- Validates ranges and types
|
||||
- Persists overrides to disk
|
||||
- Supports named presets (indoor/outdoor/demo/debug)
|
||||
"""
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from std_srvs.srv import Empty
|
||||
import json
|
||||
import yaml
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
# Custom service requests/responses
|
||||
class ParamInfo:
|
||||
"""Parameter metadata"""
|
||||
def __init__(self, name, value, param_type, group, min_val=None, max_val=None,
|
||||
description="", is_safety=False, requires_restart=False):
|
||||
self.name = name
|
||||
self.value = value
|
||||
self.type = param_type
|
||||
self.group = group
|
||||
self.min = min_val
|
||||
self.max = max_val
|
||||
self.description = description
|
||||
self.is_safety = is_safety
|
||||
self.requires_restart = requires_restart
|
||||
self.original_value = value
|
||||
|
||||
|
||||
class ParameterServer(Node):
|
||||
def __init__(self):
|
||||
super().__init__('param_server')
|
||||
|
||||
# Load parameter definitions and presets
|
||||
self.param_defs = self._load_param_definitions()
|
||||
self.params = {}
|
||||
self.overrides = {}
|
||||
self.presets = self._load_presets()
|
||||
|
||||
# Create services
|
||||
self.get_params_srv = self.create_service(
|
||||
Empty, '/saltybot/get_params', self.get_params_callback)
|
||||
self.set_param_srv = self.create_service(
|
||||
Empty, '/saltybot/set_param', self.set_param_callback)
|
||||
self.load_preset_srv = self.create_service(
|
||||
Empty, '/saltybot/load_preset', self.load_preset_callback)
|
||||
self.save_overrides_srv = self.create_service(
|
||||
Empty, '/saltybot/save_overrides', self.save_overrides_callback)
|
||||
|
||||
# Initialize parameters
|
||||
self._initialize_parameters()
|
||||
|
||||
# Publisher for parameter updates
|
||||
from rclpy.publisher import Publisher
|
||||
|
||||
self.get_logger().info("Parameter server initialized")
|
||||
self.get_logger().info(f"Loaded {len(self.params)} parameters from definitions")
|
||||
|
||||
def _load_param_definitions(self):
|
||||
"""Load parameter definitions from config file"""
|
||||
defs = {
|
||||
'hardware': {
|
||||
'serial_port': ParamInfo('serial_port', '/dev/stm32-bridge', 'string',
|
||||
'hardware', description='STM32 bridge serial port'),
|
||||
'baud_rate': ParamInfo('baud_rate', 921600, 'int', 'hardware',
|
||||
min_val=9600, max_val=3000000,
|
||||
description='Serial baud rate'),
|
||||
'timeout': ParamInfo('timeout', 0.05, 'float', 'hardware',
|
||||
min_val=0.01, max_val=1.0,
|
||||
description='Serial timeout (seconds)'),
|
||||
'wheel_diameter': ParamInfo('wheel_diameter', 0.165, 'float', 'hardware',
|
||||
min_val=0.1, max_val=0.5,
|
||||
description='Wheel diameter (meters)'),
|
||||
'track_width': ParamInfo('track_width', 0.365, 'float', 'hardware',
|
||||
min_val=0.2, max_val=1.0,
|
||||
description='Track width center-to-center (meters)'),
|
||||
'motor_max_rpm': ParamInfo('motor_max_rpm', 300, 'int', 'hardware',
|
||||
min_val=50, max_val=1000,
|
||||
description='Motor max RPM'),
|
||||
'max_linear_vel': ParamInfo('max_linear_vel', 0.5, 'float', 'hardware',
|
||||
min_val=0.1, max_val=2.0,
|
||||
description='Max linear velocity (m/s)'),
|
||||
},
|
||||
'perception': {
|
||||
'confidence_threshold': ParamInfo('confidence_threshold', 0.5, 'float',
|
||||
'perception', min_val=0.0, max_val=1.0,
|
||||
description='YOLOv8 detection confidence'),
|
||||
'nms_threshold': ParamInfo('nms_threshold', 0.4, 'float', 'perception',
|
||||
min_val=0.0, max_val=1.0,
|
||||
description='Non-max suppression threshold'),
|
||||
'lidar_min_range': ParamInfo('lidar_min_range', 0.15, 'float', 'perception',
|
||||
min_val=0.05, max_val=1.0,
|
||||
description='LIDAR minimum range (meters)'),
|
||||
},
|
||||
'controls': {
|
||||
'follow_distance': ParamInfo('follow_distance', 1.5, 'float', 'controls',
|
||||
min_val=0.5, max_val=5.0,
|
||||
description='Person following distance (meters)'),
|
||||
'max_angular_vel': ParamInfo('max_angular_vel', 1.0, 'float', 'controls',
|
||||
min_val=0.1, max_val=3.0,
|
||||
description='Max angular velocity (rad/s)'),
|
||||
'proportional_gain': ParamInfo('proportional_gain', 0.3, 'float', 'controls',
|
||||
min_val=0.0, max_val=2.0,
|
||||
description='PID proportional gain'),
|
||||
'derivative_gain': ParamInfo('derivative_gain', 0.1, 'float', 'controls',
|
||||
min_val=0.0, max_val=1.0,
|
||||
description='PID derivative gain'),
|
||||
'update_rate': ParamInfo('update_rate', 10, 'int', 'controls',
|
||||
min_val=1, max_val=100,
|
||||
description='Control loop update rate (Hz)'),
|
||||
},
|
||||
'social': {
|
||||
'tts_speed': ParamInfo('tts_speed', 1.0, 'float', 'social',
|
||||
min_val=0.5, max_val=2.0,
|
||||
description='TTS speech speed'),
|
||||
'tts_pitch': ParamInfo('tts_pitch', 1.0, 'float', 'social',
|
||||
min_val=0.5, max_val=2.0,
|
||||
description='TTS speech pitch'),
|
||||
'tts_volume': ParamInfo('tts_volume', 0.8, 'float', 'social',
|
||||
min_val=0.0, max_val=1.0,
|
||||
description='TTS volume level'),
|
||||
'gesture_min_confidence': ParamInfo('gesture_min_confidence', 0.6, 'float',
|
||||
'social', min_val=0.1, max_val=0.99,
|
||||
description='Min gesture detection confidence'),
|
||||
},
|
||||
'safety': {
|
||||
'emergency_stop_timeout': ParamInfo('emergency_stop_timeout', 0.5, 'float',
|
||||
'safety', min_val=0.1, max_val=5.0,
|
||||
description='Emergency stop timeout',
|
||||
is_safety=True),
|
||||
'cliff_detection_enabled': ParamInfo('cliff_detection_enabled', True, 'bool',
|
||||
'safety', description='Enable cliff detection',
|
||||
is_safety=True),
|
||||
'obstacle_avoidance_enabled': ParamInfo('obstacle_avoidance_enabled', True, 'bool',
|
||||
'safety', description='Enable obstacle avoidance',
|
||||
is_safety=True),
|
||||
'heartbeat_timeout': ParamInfo('heartbeat_timeout', 5.0, 'float', 'safety',
|
||||
min_val=1.0, max_val=30.0,
|
||||
description='Heartbeat timeout for watchdog',
|
||||
is_safety=True),
|
||||
},
|
||||
'debug': {
|
||||
'log_level': ParamInfo('log_level', 'INFO', 'string', 'debug',
|
||||
description='ROS logging level'),
|
||||
'enable_diagnostics': ParamInfo('enable_diagnostics', True, 'bool', 'debug',
|
||||
description='Enable diagnostic publishing'),
|
||||
'record_rosbag': ParamInfo('record_rosbag', False, 'bool', 'debug',
|
||||
description='Record ROS bag file'),
|
||||
}
|
||||
}
|
||||
return defs
|
||||
|
||||
def _load_presets(self):
|
||||
"""Load named parameter presets"""
|
||||
return {
|
||||
'indoor': {
|
||||
'follow_distance': 1.2,
|
||||
'max_linear_vel': 0.3,
|
||||
'confidence_threshold': 0.7,
|
||||
'obstacle_avoidance_enabled': True,
|
||||
},
|
||||
'outdoor': {
|
||||
'follow_distance': 1.5,
|
||||
'max_linear_vel': 0.5,
|
||||
'confidence_threshold': 0.5,
|
||||
'obstacle_avoidance_enabled': True,
|
||||
},
|
||||
'demo': {
|
||||
'follow_distance': 1.0,
|
||||
'max_linear_vel': 0.4,
|
||||
'confidence_threshold': 0.6,
|
||||
'gesture_min_confidence': 0.7,
|
||||
'tts_speed': 1.2,
|
||||
},
|
||||
'debug': {
|
||||
'enable_diagnostics': True,
|
||||
'log_level': 'DEBUG',
|
||||
'record_rosbag': True,
|
||||
}
|
||||
}
|
||||
|
||||
def _initialize_parameters(self):
|
||||
"""Initialize parameters from definitions and load overrides"""
|
||||
for group, params in self.param_defs.items():
|
||||
for name, param_info in params.items():
|
||||
self.params[name] = param_info
|
||||
|
||||
# Load saved overrides
|
||||
overrides_file = self._get_overrides_file()
|
||||
if overrides_file.exists():
|
||||
try:
|
||||
with open(overrides_file, 'r') as f:
|
||||
self.overrides = json.load(f)
|
||||
# Apply overrides to parameters
|
||||
for name, value in self.overrides.items():
|
||||
if name in self.params:
|
||||
self.params[name].value = value
|
||||
self.get_logger().info(f"Loaded {len(self.overrides)} parameter overrides")
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"Failed to load overrides: {e}")
|
||||
|
||||
def _get_overrides_file(self):
|
||||
"""Get path to persistent overrides file"""
|
||||
home = Path.home()
|
||||
config_dir = home / '.saltybot' / 'params'
|
||||
config_dir.mkdir(parents=True, exist_ok=True)
|
||||
return config_dir / 'overrides.json'
|
||||
|
||||
def _validate_parameter(self, name, value):
|
||||
"""Validate parameter value against type and range"""
|
||||
if name not in self.params:
|
||||
return False, f"Unknown parameter: {name}"
|
||||
|
||||
param = self.params[name]
|
||||
|
||||
# Type checking
|
||||
if param.type == 'int':
|
||||
if not isinstance(value, int):
|
||||
try:
|
||||
value = int(value)
|
||||
except:
|
||||
return False, f"Invalid int value for {name}: {value}"
|
||||
elif param.type == 'float':
|
||||
if not isinstance(value, (int, float)):
|
||||
try:
|
||||
value = float(value)
|
||||
except:
|
||||
return False, f"Invalid float value for {name}: {value}"
|
||||
elif param.type == 'bool':
|
||||
if isinstance(value, str):
|
||||
value = value.lower() in ('true', '1', 'yes', 'on')
|
||||
elif not isinstance(value, bool):
|
||||
value = bool(value)
|
||||
elif param.type != 'string':
|
||||
return False, f"Unknown parameter type: {param.type}"
|
||||
|
||||
# Range checking
|
||||
if param.min is not None and value < param.min:
|
||||
return False, f"{name} value {value} below minimum {param.min}"
|
||||
if param.max is not None and value > param.max:
|
||||
return False, f"{name} value {value} above maximum {param.max}"
|
||||
|
||||
return True, None
|
||||
|
||||
def get_params_callback(self, request, response):
|
||||
"""Handle get parameters request"""
|
||||
params_dict = {}
|
||||
for group, params in self.param_defs.items():
|
||||
params_dict[group] = {}
|
||||
for name, param_info in params.items():
|
||||
params_dict[group][name] = {
|
||||
'value': param_info.value,
|
||||
'type': param_info.type,
|
||||
'min': param_info.min,
|
||||
'max': param_info.max,
|
||||
'description': param_info.description,
|
||||
'is_safety': param_info.is_safety,
|
||||
'is_override': name in self.overrides,
|
||||
}
|
||||
|
||||
# Store in node parameters for rosbridge access
|
||||
self.set_parameters([
|
||||
rclpy.Parameter('_params_json', rclpy.Parameter.Type.STRING,
|
||||
json.dumps(params_dict))
|
||||
])
|
||||
|
||||
return response
|
||||
|
||||
def set_param_callback(self, request, response):
|
||||
"""Handle set parameter request"""
|
||||
# This would be called with parameter name and value
|
||||
# In actual usage, rosbridge would pass these via message body
|
||||
return response
|
||||
|
||||
def set_parameter(self, name, value, requires_confirmation=False):
|
||||
"""Set a parameter value with validation"""
|
||||
valid, error = self._validate_parameter(name, value)
|
||||
if not valid:
|
||||
self.get_logger().error(f"Parameter validation failed: {error}")
|
||||
return False
|
||||
|
||||
if name in self.params:
|
||||
self.params[name].value = value
|
||||
self.overrides[name] = value
|
||||
self.get_logger().info(f"Parameter {name} set to {value}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def load_preset(self, preset_name):
|
||||
"""Load a named parameter preset"""
|
||||
if preset_name not in self.presets:
|
||||
self.get_logger().error(f"Unknown preset: {preset_name}")
|
||||
return False
|
||||
|
||||
preset = self.presets[preset_name]
|
||||
count = 0
|
||||
for param_name, value in preset.items():
|
||||
if self.set_parameter(param_name, value):
|
||||
count += 1
|
||||
|
||||
self.get_logger().info(f"Loaded preset '{preset_name}' ({count} parameters)")
|
||||
return True
|
||||
|
||||
def load_preset_callback(self, request, response):
|
||||
"""Handle load preset request"""
|
||||
return response
|
||||
|
||||
def save_overrides_callback(self, request, response):
|
||||
"""Handle save overrides request"""
|
||||
try:
|
||||
overrides_file = self._get_overrides_file()
|
||||
with open(overrides_file, 'w') as f:
|
||||
json.dump(self.overrides, f, indent=2)
|
||||
self.get_logger().info(f"Saved {len(self.overrides)} overrides to {overrides_file}")
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"Failed to save overrides: {e}")
|
||||
return response
|
||||
|
||||
def get_parameters_as_json(self):
|
||||
"""Get all parameters as JSON for WebUI"""
|
||||
params_dict = {}
|
||||
for group, params in self.param_defs.items():
|
||||
params_dict[group] = {}
|
||||
for name, param_info in params.items():
|
||||
params_dict[group][name] = {
|
||||
'value': param_info.value,
|
||||
'type': param_info.type,
|
||||
'min': param_info.min,
|
||||
'max': param_info.max,
|
||||
'description': param_info.description,
|
||||
'is_safety': param_info.is_safety,
|
||||
'is_override': name in self.overrides,
|
||||
}
|
||||
return json.dumps(params_dict)
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = ParameterServer()
|
||||
rclpy.spin(node)
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
5
jetson/ros2_ws/src/saltybot_param_server/setup.cfg
Normal file
5
jetson/ros2_ws/src/saltybot_param_server/setup.cfg
Normal file
@ -0,0 +1,5 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/saltybot_param_server
|
||||
|
||||
[install]
|
||||
install_scripts=$base/lib/saltybot_param_server
|
||||
30
jetson/ros2_ws/src/saltybot_param_server/setup.py
Normal file
30
jetson/ros2_ws/src/saltybot_param_server/setup.py
Normal file
@ -0,0 +1,30 @@
|
||||
from setuptools import setup
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
package_name = 'saltybot_param_server'
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version='0.1.0',
|
||||
packages=[package_name],
|
||||
data_files=[
|
||||
('share/ament_index/resource_index/packages',
|
||||
['resource/' + package_name]),
|
||||
('share/' + package_name, ['package.xml']),
|
||||
(os.path.join('share', package_name, 'config'),
|
||||
glob('config/*.yaml')),
|
||||
],
|
||||
install_requires=['setuptools', 'pyyaml'],
|
||||
zip_safe=True,
|
||||
maintainer='seb',
|
||||
maintainer_email='seb@vayrette.com',
|
||||
description='Centralized dynamic parameter reconfiguration server (Issue #471)',
|
||||
license='MIT',
|
||||
tests_require=['pytest'],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'param_server = saltybot_param_server.param_server_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
@ -88,6 +88,9 @@ import { HandTracker } from './components/HandTracker.jsx';
|
||||
// Salty Face animated expression UI (issue #370)
|
||||
import { SaltyFace } from './components/SaltyFace.jsx';
|
||||
|
||||
// Parameter server (issue #471)
|
||||
import { ParameterServer } from './components/ParameterServer.jsx';
|
||||
|
||||
const TAB_GROUPS = [
|
||||
{
|
||||
label: 'DISPLAY',
|
||||
@ -160,8 +163,9 @@ const TAB_GROUPS = [
|
||||
label: 'CONFIG',
|
||||
color: 'text-purple-600',
|
||||
tabs: [
|
||||
{ id: 'network', label: 'Network' },
|
||||
{ id: 'settings', label: 'Settings' },
|
||||
{ id: 'parameters', label: 'Parameters' },
|
||||
{ id: 'network', label: 'Network' },
|
||||
{ id: 'settings', label: 'Settings' },
|
||||
],
|
||||
},
|
||||
];
|
||||
@ -322,6 +326,8 @@ export default function App() {
|
||||
|
||||
{activeTab === 'logs' && <LogViewer subscribe={subscribe} />}
|
||||
|
||||
{activeTab === 'parameters' && <ParameterServer subscribe={subscribe} callService={callService} />}
|
||||
|
||||
{activeTab === 'network' && <NetworkPanel subscribe={subscribe} connected={connected} wsUrl={wsUrl} />}
|
||||
|
||||
{activeTab === 'settings' && <SettingsPanel subscribe={subscribe} callService={callService} connected={connected} wsUrl={wsUrl} />}
|
||||
|
||||
415
ui/social-bot/src/components/ParameterServer.jsx
Normal file
415
ui/social-bot/src/components/ParameterServer.jsx
Normal file
@ -0,0 +1,415 @@
|
||||
/**
|
||||
* ParameterServer.jsx — SaltyBot Centralized Dynamic Parameter Configuration (Issue #471)
|
||||
*
|
||||
* Features:
|
||||
* - Load and display parameters grouped by category (hardware/perception/controls/social/safety/debug)
|
||||
* - Edit parameters with real-time validation (type checking, min/max ranges)
|
||||
* - Display metadata: type, range, description, is_safety flag
|
||||
* - Load named presets (indoor/outdoor/demo/debug)
|
||||
* - Safety confirmation for critical parameters
|
||||
* - Persist parameter overrides
|
||||
* - Visual feedback for modified parameters
|
||||
* - Reset to defaults option
|
||||
*/
|
||||
|
||||
import { useState, useEffect, useCallback } from 'react';
|
||||
|
||||
const PARAM_GROUPS = ['hardware', 'perception', 'controls', 'social', 'safety', 'debug'];
|
||||
const GROUP_COLORS = {
|
||||
hardware: 'border-blue-500',
|
||||
perception: 'border-purple-500',
|
||||
controls: 'border-green-500',
|
||||
social: 'border-rose-500',
|
||||
safety: 'border-red-500',
|
||||
debug: 'border-yellow-500',
|
||||
};
|
||||
const GROUP_BG = {
|
||||
hardware: 'bg-blue-950',
|
||||
perception: 'bg-purple-950',
|
||||
controls: 'bg-green-950',
|
||||
social: 'bg-rose-950',
|
||||
safety: 'bg-red-950',
|
||||
debug: 'bg-yellow-950',
|
||||
};
|
||||
|
||||
export function ParameterServer({ subscribe, callService }) {
|
||||
const [params, setParams] = useState({});
|
||||
const [editValues, setEditValues] = useState({});
|
||||
const [presets, setPresets] = useState(['indoor', 'outdoor', 'demo', 'debug']);
|
||||
const [activeGroup, setActiveGroup] = useState('hardware');
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState(null);
|
||||
const [expandedParams, setExpandedParams] = useState(new Set());
|
||||
const [pendingChanges, setPendingChanges] = useState(new Set());
|
||||
const [confirmDialog, setConfirmDialog] = useState(null);
|
||||
|
||||
// Fetch parameters from server
|
||||
useEffect(() => {
|
||||
const fetchParams = async () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
// Try to call the service or subscribe to parameter topic
|
||||
if (subscribe) {
|
||||
// Subscribe to parameter updates (if available)
|
||||
subscribe('/saltybot/parameters', 'std_msgs/String', (msg) => {
|
||||
try {
|
||||
const paramsData = JSON.parse(msg.data);
|
||||
setParams(paramsData);
|
||||
setError(null);
|
||||
} catch (e) {
|
||||
console.error('Failed to parse parameters:', e);
|
||||
}
|
||||
});
|
||||
}
|
||||
setLoading(false);
|
||||
} catch (err) {
|
||||
setError(`Failed to fetch parameters: ${err.message}`);
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
fetchParams();
|
||||
}, [subscribe]);
|
||||
|
||||
const toggleParamExpanded = useCallback((paramName) => {
|
||||
setExpandedParams(prev => {
|
||||
const next = new Set(prev);
|
||||
next.has(paramName) ? next.delete(paramName) : next.add(paramName);
|
||||
return next;
|
||||
});
|
||||
}, []);
|
||||
|
||||
const handleParamChange = useCallback((paramName, value, paramInfo) => {
|
||||
// Validate input based on type
|
||||
let validatedValue = value;
|
||||
let isValid = true;
|
||||
|
||||
if (paramInfo.type === 'int') {
|
||||
validatedValue = parseInt(value, 10);
|
||||
isValid = !isNaN(validatedValue);
|
||||
} else if (paramInfo.type === 'float') {
|
||||
validatedValue = parseFloat(value);
|
||||
isValid = !isNaN(validatedValue);
|
||||
} else if (paramInfo.type === 'bool') {
|
||||
validatedValue = value === 'true' || value === true || value === 1;
|
||||
}
|
||||
|
||||
// Check range
|
||||
if (isValid && paramInfo.min !== undefined && validatedValue < paramInfo.min) {
|
||||
isValid = false;
|
||||
}
|
||||
if (isValid && paramInfo.max !== undefined && validatedValue > paramInfo.max) {
|
||||
isValid = false;
|
||||
}
|
||||
|
||||
if (isValid) {
|
||||
setEditValues(prev => ({
|
||||
...prev,
|
||||
[paramName]: validatedValue
|
||||
}));
|
||||
|
||||
if (paramInfo.value !== validatedValue) {
|
||||
setPendingChanges(prev => new Set([...prev, paramName]));
|
||||
} else {
|
||||
setPendingChanges(prev => {
|
||||
const next = new Set(prev);
|
||||
next.delete(paramName);
|
||||
return next;
|
||||
});
|
||||
}
|
||||
|
||||
// For safety parameters, show confirmation
|
||||
if (paramInfo.is_safety && paramInfo.value !== validatedValue) {
|
||||
setConfirmDialog({
|
||||
paramName,
|
||||
paramInfo,
|
||||
newValue: validatedValue,
|
||||
message: `Safety parameter "${paramName}" will be changed. This may affect robot behavior.`
|
||||
});
|
||||
}
|
||||
}
|
||||
}, []);
|
||||
|
||||
const confirmParameterChange = useCallback(() => {
|
||||
if (!confirmDialog) return;
|
||||
|
||||
const { paramName, newValue } = confirmDialog;
|
||||
// Persist to backend
|
||||
if (callService) {
|
||||
callService('/saltybot/set_param', {
|
||||
name: paramName,
|
||||
value: newValue
|
||||
});
|
||||
}
|
||||
|
||||
setConfirmDialog(null);
|
||||
}, [confirmDialog, callService]);
|
||||
|
||||
const rejectParameterChange = useCallback(() => {
|
||||
if (!confirmDialog) {
|
||||
setConfirmDialog(null);
|
||||
return;
|
||||
}
|
||||
|
||||
const { paramName } = confirmDialog;
|
||||
setEditValues(prev => {
|
||||
const next = { ...prev };
|
||||
delete next[paramName];
|
||||
return next;
|
||||
});
|
||||
|
||||
setPendingChanges(prev => {
|
||||
const next = new Set(prev);
|
||||
next.delete(paramName);
|
||||
return next;
|
||||
});
|
||||
|
||||
setConfirmDialog(null);
|
||||
}, [confirmDialog]);
|
||||
|
||||
const loadPreset = useCallback((presetName) => {
|
||||
if (callService) {
|
||||
callService('/saltybot/load_preset', {
|
||||
preset: presetName
|
||||
});
|
||||
}
|
||||
}, [callService]);
|
||||
|
||||
const saveOverrides = useCallback(() => {
|
||||
if (callService) {
|
||||
callService('/saltybot/save_overrides', {});
|
||||
}
|
||||
setPendingChanges(new Set());
|
||||
}, [callService]);
|
||||
|
||||
const resetParameter = useCallback((paramName) => {
|
||||
setEditValues(prev => {
|
||||
const next = { ...prev };
|
||||
delete next[paramName];
|
||||
return next;
|
||||
});
|
||||
setPendingChanges(prev => {
|
||||
const next = new Set(prev);
|
||||
next.delete(paramName);
|
||||
return next;
|
||||
});
|
||||
}, []);
|
||||
|
||||
const resetAllParameters = useCallback(() => {
|
||||
setEditValues({});
|
||||
setPendingChanges(new Set());
|
||||
}, []);
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="flex items-center justify-center h-screen">
|
||||
<div className="text-cyan-400">Loading parameters...</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const groupParams = params[activeGroup] || {};
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full gap-4 p-4 bg-[#050510]">
|
||||
{/* Header */}
|
||||
<div className="flex items-center justify-between gap-4">
|
||||
<div>
|
||||
<h1 className="text-xl font-bold text-cyan-400">⚙️ Parameter Server</h1>
|
||||
<p className="text-xs text-gray-500">Dynamic reconfiguration • {Object.keys(groupParams).length} parameters</p>
|
||||
</div>
|
||||
<div className="flex gap-2">
|
||||
<button
|
||||
onClick={saveOverrides}
|
||||
disabled={pendingChanges.size === 0}
|
||||
className="px-3 py-1 text-xs rounded border border-green-700 bg-green-950 text-green-400 hover:bg-green-900 disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
💾 Save ({pendingChanges.size})
|
||||
</button>
|
||||
<button
|
||||
onClick={resetAllParameters}
|
||||
disabled={pendingChanges.size === 0}
|
||||
className="px-3 py-1 text-xs rounded border border-gray-700 bg-gray-950 text-gray-400 hover:bg-gray-900 disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
Reset All
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Presets */}
|
||||
<div className="flex gap-2 pb-2 border-b border-gray-800">
|
||||
<span className="text-xs text-gray-500 py-1">Presets:</span>
|
||||
{presets.map(preset => (
|
||||
<button
|
||||
key={preset}
|
||||
onClick={() => loadPreset(preset)}
|
||||
className="px-2 py-1 text-xs rounded border border-cyan-700 bg-cyan-950 text-cyan-400 hover:bg-cyan-900 transition-colors"
|
||||
>
|
||||
{preset.charAt(0).toUpperCase() + preset.slice(1)}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Group Tabs */}
|
||||
<div className="flex gap-1 flex-wrap">
|
||||
{PARAM_GROUPS.map(group => (
|
||||
<button
|
||||
key={group}
|
||||
onClick={() => setActiveGroup(group)}
|
||||
className={`px-3 py-1 text-xs font-bold rounded transition-colors ${
|
||||
activeGroup === group
|
||||
? `border-2 ${GROUP_COLORS[group]} ${GROUP_BG[group]} text-white`
|
||||
: 'border border-gray-700 bg-gray-950 text-gray-400 hover:bg-gray-900'
|
||||
}`}
|
||||
>
|
||||
{group.charAt(0).toUpperCase() + group.slice(1)}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Error Display */}
|
||||
{error && (
|
||||
<div className="p-3 rounded bg-red-950 border border-red-700 text-red-400 text-sm">
|
||||
⚠️ {error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Parameters */}
|
||||
<div className="flex-1 overflow-y-auto space-y-2">
|
||||
{Object.entries(groupParams).map(([paramName, paramInfo]) => {
|
||||
const isModified = pendingChanges.has(paramName);
|
||||
const currentValue = editValues[paramName] !== undefined ? editValues[paramName] : paramInfo.value;
|
||||
const isExpanded = expandedParams.has(paramName);
|
||||
|
||||
return (
|
||||
<div
|
||||
key={paramName}
|
||||
className={`p-3 rounded border transition-all ${
|
||||
paramInfo.is_safety
|
||||
? 'border-red-700 bg-red-950 bg-opacity-30'
|
||||
: 'border-gray-700 bg-gray-950 bg-opacity-30 hover:bg-opacity-50'
|
||||
} ${isModified ? 'ring-2 ring-yellow-500' : ''}`}
|
||||
>
|
||||
<div className="flex items-start justify-between gap-2">
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="flex items-center gap-2">
|
||||
<button
|
||||
onClick={() => toggleParamExpanded(paramName)}
|
||||
className="text-gray-500 hover:text-gray-300 px-1"
|
||||
>
|
||||
{isExpanded ? '▼' : '▶'}
|
||||
</button>
|
||||
<div className="flex-1">
|
||||
<div className="font-mono text-sm text-gray-300 break-all">
|
||||
{paramName}
|
||||
{paramInfo.is_safety && <span className="ml-2 text-xs text-red-400">🔒 SAFETY</span>}
|
||||
{isModified && <span className="ml-2 text-xs text-yellow-400">⚡ Modified</span>}
|
||||
</div>
|
||||
<div className="text-xs text-gray-500 mt-0.5">{paramInfo.description}</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{isModified && (
|
||||
<button
|
||||
onClick={() => resetParameter(paramName)}
|
||||
className="px-2 py-0.5 text-xs rounded border border-gray-600 bg-gray-900 text-gray-400 hover:bg-gray-800"
|
||||
>
|
||||
Reset
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{isExpanded && (
|
||||
<div className="mt-3 ml-6 space-y-2">
|
||||
{/* Type and Range Info */}
|
||||
<div className="text-xs text-gray-500 grid grid-cols-3 gap-2">
|
||||
<div>Type: <span className="text-cyan-400">{paramInfo.type}</span></div>
|
||||
{paramInfo.min !== undefined && (
|
||||
<div>Min: <span className="text-cyan-400">{paramInfo.min}</span></div>
|
||||
)}
|
||||
{paramInfo.max !== undefined && (
|
||||
<div>Max: <span className="text-cyan-400">{paramInfo.max}</span></div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Input Field */}
|
||||
<div className="flex items-center gap-2">
|
||||
{paramInfo.type === 'bool' ? (
|
||||
<select
|
||||
value={currentValue ? 'true' : 'false'}
|
||||
onChange={(e) => handleParamChange(paramName, e.target.value === 'true', paramInfo)}
|
||||
className="flex-1 px-2 py-1 rounded bg-gray-900 border border-gray-700 text-gray-300 text-sm focus:outline-none focus:border-cyan-500"
|
||||
>
|
||||
<option value="true">True</option>
|
||||
<option value="false">False</option>
|
||||
</select>
|
||||
) : (
|
||||
<input
|
||||
type={paramInfo.type === 'int' ? 'number' : 'text'}
|
||||
value={currentValue}
|
||||
onChange={(e) => handleParamChange(paramName, e.target.value, paramInfo)}
|
||||
step={paramInfo.type === 'float' ? '0.01' : '1'}
|
||||
className="flex-1 px-2 py-1 rounded bg-gray-900 border border-gray-700 text-gray-300 text-sm focus:outline-none focus:border-cyan-500"
|
||||
/>
|
||||
)}
|
||||
<span className="text-xs text-gray-500">
|
||||
{isModified ? (
|
||||
<>
|
||||
<span className="text-gray-600">{paramInfo.value}</span>
|
||||
<span className="mx-1">→</span>
|
||||
<span className="text-yellow-400">{currentValue}</span>
|
||||
</>
|
||||
) : (
|
||||
<span className="text-gray-400">{currentValue}</span>
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{/* Range Visualization */}
|
||||
{paramInfo.type !== 'bool' && paramInfo.type !== 'string' && paramInfo.min !== undefined && paramInfo.max !== undefined && (
|
||||
<div className="w-full bg-gray-900 rounded h-1 overflow-hidden">
|
||||
<div
|
||||
className="h-full bg-cyan-600"
|
||||
style={{
|
||||
width: `${((currentValue - paramInfo.min) / (paramInfo.max - paramInfo.min)) * 100}%`
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
|
||||
{/* Safety Confirmation Dialog */}
|
||||
{confirmDialog && (
|
||||
<div className="fixed inset-0 bg-black bg-opacity-70 flex items-center justify-center z-50">
|
||||
<div className="bg-gray-900 border-2 border-red-600 rounded p-4 max-w-md">
|
||||
<h3 className="text-lg font-bold text-red-400 mb-2">⚠️ Safety Parameter Confirmation</h3>
|
||||
<p className="text-gray-300 mb-4">{confirmDialog.message}</p>
|
||||
<div className="bg-gray-950 p-2 rounded mb-4 text-sm font-mono">
|
||||
<div className="text-gray-500">{confirmDialog.paramName}</div>
|
||||
<div className="text-gray-400">{confirmDialog.paramInfo.value} → <span className="text-yellow-400">{confirmDialog.newValue}</span></div>
|
||||
</div>
|
||||
<div className="flex gap-2">
|
||||
<button
|
||||
onClick={confirmParameterChange}
|
||||
className="flex-1 px-3 py-2 rounded bg-red-950 border border-red-700 text-red-400 hover:bg-red-900 font-bold"
|
||||
>
|
||||
✓ Confirm
|
||||
</button>
|
||||
<button
|
||||
onClick={rejectParameterChange}
|
||||
className="flex-1 px-3 py-2 rounded bg-gray-950 border border-gray-700 text-gray-400 hover:bg-gray-900"
|
||||
>
|
||||
✕ Cancel
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user