From d872ea5e34e586a29d23e70d33aeaea7f63e5479 Mon Sep 17 00:00:00 2001 From: sl-perception Date: Sun, 1 Mar 2026 23:15:00 -0500 Subject: [PATCH] feat(social): navigation + follow modes + MiDaS depth + waypoints (Issue #91) - saltybot_social_msgs: full message/service definitions (standalone compilation) - saltybot_social_nav: social navigation orchestrator - Follow modes: shadow/lead/side/orbit/loose/tight - Voice steering: mode switching + route commands via /social/speech/* - A* obstacle avoidance on Nav2/SLAM occupancy grid (8-directional, inflation) - MiDaS monocular depth for CSI cameras (TRT FP16 + ONNX fallback) - Waypoint teaching + replay with WaypointRoute persistence - High-speed EUC tracking (5.5 m/s = ~20 km/h) - Predictive position extrapolation (0.3s ahead at high speed) - Launch: social_nav.launch.py (social_nav + midas_depth + waypoint_teacher) - Config: social_nav_params.yaml - Script: build_midas_trt_engine.py (ONNX -> TRT FP16) --- .../src/saltybot_social_msgs/CMakeLists.txt | 24 + .../msg/FaceDetection.msg | 10 + .../msg/FaceDetectionArray.msg | 2 + .../msg/FaceEmbedding.msg | 5 + .../msg/FaceEmbeddingArray.msg | 2 + .../saltybot_social_msgs/msg/PersonState.msg | 19 + .../msg/PersonStateArray.msg | 3 + .../src/saltybot_social_msgs/package.xml | 19 + .../saltybot_social_msgs/srv/DeletePerson.srv | 4 + .../saltybot_social_msgs/srv/EnrollPerson.srv | 7 + .../saltybot_social_msgs/srv/ListPersons.srv | 2 + .../saltybot_social_msgs/srv/UpdatePerson.srv | 5 + .../config/social_nav_params.yaml | 22 + .../launch/social_nav.launch.py | 57 ++ .../src/saltybot_social_nav/package.xml | 28 + .../resource/saltybot_social_nav | 0 .../saltybot_social_nav/__init__.py | 0 .../saltybot_social_nav/astar.py | 82 +++ .../saltybot_social_nav/follow_modes.py | 82 +++ .../saltybot_social_nav/midas_depth_node.py | 231 +++++++ .../saltybot_social_nav/social_nav_node.py | 584 ++++++++++++++++++ .../saltybot_social_nav/waypoint_teacher.py | 91 +++ .../waypoint_teacher_node.py | 135 ++++ .../scripts/build_midas_trt_engine.py | 80 +++ .../ros2_ws/src/saltybot_social_nav/setup.cfg | 4 + .../ros2_ws/src/saltybot_social_nav/setup.py | 31 + .../test/test_copyright.py | 12 + .../saltybot_social_nav/test/test_flake8.py | 14 + .../saltybot_social_nav/test/test_pep257.py | 12 + 29 files changed, 1567 insertions(+) create mode 100644 jetson/ros2_ws/src/saltybot_social_msgs/CMakeLists.txt create mode 100644 jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceDetection.msg create mode 100644 jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceDetectionArray.msg create mode 100644 jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceEmbedding.msg create mode 100644 jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceEmbeddingArray.msg create mode 100644 jetson/ros2_ws/src/saltybot_social_msgs/msg/PersonState.msg create mode 100644 jetson/ros2_ws/src/saltybot_social_msgs/msg/PersonStateArray.msg create mode 100644 jetson/ros2_ws/src/saltybot_social_msgs/package.xml create mode 100644 jetson/ros2_ws/src/saltybot_social_msgs/srv/DeletePerson.srv create mode 100644 jetson/ros2_ws/src/saltybot_social_msgs/srv/EnrollPerson.srv create mode 100644 jetson/ros2_ws/src/saltybot_social_msgs/srv/ListPersons.srv create mode 100644 jetson/ros2_ws/src/saltybot_social_msgs/srv/UpdatePerson.srv create mode 100644 jetson/ros2_ws/src/saltybot_social_nav/config/social_nav_params.yaml create mode 100644 jetson/ros2_ws/src/saltybot_social_nav/launch/social_nav.launch.py create mode 100644 jetson/ros2_ws/src/saltybot_social_nav/package.xml create mode 100644 jetson/ros2_ws/src/saltybot_social_nav/resource/saltybot_social_nav create mode 100644 jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/__init__.py create mode 100644 jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/astar.py create mode 100644 jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/follow_modes.py create mode 100644 jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/midas_depth_node.py create mode 100644 jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/social_nav_node.py create mode 100644 jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/waypoint_teacher.py create mode 100644 jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/waypoint_teacher_node.py create mode 100644 jetson/ros2_ws/src/saltybot_social_nav/scripts/build_midas_trt_engine.py create mode 100644 jetson/ros2_ws/src/saltybot_social_nav/setup.cfg create mode 100644 jetson/ros2_ws/src/saltybot_social_nav/setup.py create mode 100644 jetson/ros2_ws/src/saltybot_social_nav/test/test_copyright.py create mode 100644 jetson/ros2_ws/src/saltybot_social_nav/test/test_flake8.py create mode 100644 jetson/ros2_ws/src/saltybot_social_nav/test/test_pep257.py diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/CMakeLists.txt b/jetson/ros2_ws/src/saltybot_social_msgs/CMakeLists.txt new file mode 100644 index 0000000..b6e70ab --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_msgs/CMakeLists.txt @@ -0,0 +1,24 @@ +cmake_minimum_required(VERSION 3.8) +project(saltybot_social_msgs) + +find_package(ament_cmake REQUIRED) +find_package(rosidl_default_generators REQUIRED) +find_package(std_msgs REQUIRED) +find_package(geometry_msgs REQUIRED) +find_package(builtin_interfaces REQUIRED) + +rosidl_generate_interfaces(${PROJECT_NAME} + "msg/FaceDetection.msg" + "msg/FaceDetectionArray.msg" + "msg/FaceEmbedding.msg" + "msg/FaceEmbeddingArray.msg" + "msg/PersonState.msg" + "msg/PersonStateArray.msg" + "srv/EnrollPerson.srv" + "srv/ListPersons.srv" + "srv/DeletePerson.srv" + "srv/UpdatePerson.srv" + DEPENDENCIES std_msgs geometry_msgs builtin_interfaces +) + +ament_package() diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceDetection.msg b/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceDetection.msg new file mode 100644 index 0000000..53b3a10 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceDetection.msg @@ -0,0 +1,10 @@ +std_msgs/Header header +int32 face_id +string person_name +float32 confidence +float32 recognition_score +float32 bbox_x +float32 bbox_y +float32 bbox_w +float32 bbox_h +float32[10] landmarks diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceDetectionArray.msg b/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceDetectionArray.msg new file mode 100644 index 0000000..66550cc --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceDetectionArray.msg @@ -0,0 +1,2 @@ +std_msgs/Header header +saltybot_social_msgs/FaceDetection[] faces diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceEmbedding.msg b/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceEmbedding.msg new file mode 100644 index 0000000..456f0a2 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceEmbedding.msg @@ -0,0 +1,5 @@ +int32 person_id +string person_name +float32[] embedding +builtin_interfaces/Time enrolled_at +int32 sample_count diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceEmbeddingArray.msg b/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceEmbeddingArray.msg new file mode 100644 index 0000000..a9c23d9 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceEmbeddingArray.msg @@ -0,0 +1,2 @@ +std_msgs/Header header +saltybot_social_msgs/FaceEmbedding[] embeddings diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/msg/PersonState.msg b/jetson/ros2_ws/src/saltybot_social_msgs/msg/PersonState.msg new file mode 100644 index 0000000..f3c5fa1 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_msgs/msg/PersonState.msg @@ -0,0 +1,19 @@ +std_msgs/Header header +int32 person_id +string person_name +int32 face_id +string speaker_id +string uwb_anchor_id +geometry_msgs/Point position +float32 distance +float32 bearing_deg +uint8 state +uint8 STATE_UNKNOWN=0 +uint8 STATE_APPROACHING=1 +uint8 STATE_ENGAGED=2 +uint8 STATE_TALKING=3 +uint8 STATE_LEAVING=4 +uint8 STATE_ABSENT=5 +float32 engagement_score +builtin_interfaces/Time last_seen +int32 camera_id diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/msg/PersonStateArray.msg b/jetson/ros2_ws/src/saltybot_social_msgs/msg/PersonStateArray.msg new file mode 100644 index 0000000..2ade234 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_msgs/msg/PersonStateArray.msg @@ -0,0 +1,3 @@ +std_msgs/Header header +saltybot_social_msgs/PersonState[] persons +int32 primary_attention_id diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/package.xml b/jetson/ros2_ws/src/saltybot_social_msgs/package.xml new file mode 100644 index 0000000..83572c5 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_msgs/package.xml @@ -0,0 +1,19 @@ + + + + saltybot_social_msgs + 0.1.0 + Custom ROS2 messages and services for saltybot social capabilities + seb + MIT + ament_cmake + std_msgs + geometry_msgs + builtin_interfaces + rosidl_default_generators + rosidl_default_runtime + rosidl_interface_packages + + ament_cmake + + diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/srv/DeletePerson.srv b/jetson/ros2_ws/src/saltybot_social_msgs/srv/DeletePerson.srv new file mode 100644 index 0000000..0a77e93 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_msgs/srv/DeletePerson.srv @@ -0,0 +1,4 @@ +int32 person_id +--- +bool success +string message diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/srv/EnrollPerson.srv b/jetson/ros2_ws/src/saltybot_social_msgs/srv/EnrollPerson.srv new file mode 100644 index 0000000..3ba7231 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_msgs/srv/EnrollPerson.srv @@ -0,0 +1,7 @@ +string name +string mode +int32 n_samples +--- +bool success +string message +int32 person_id diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/srv/ListPersons.srv b/jetson/ros2_ws/src/saltybot_social_msgs/srv/ListPersons.srv new file mode 100644 index 0000000..bed755a --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_msgs/srv/ListPersons.srv @@ -0,0 +1,2 @@ +--- +saltybot_social_msgs/FaceEmbedding[] persons diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/srv/UpdatePerson.srv b/jetson/ros2_ws/src/saltybot_social_msgs/srv/UpdatePerson.srv new file mode 100644 index 0000000..8fc0abf --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_msgs/srv/UpdatePerson.srv @@ -0,0 +1,5 @@ +int32 person_id +string new_name +--- +bool success +string message diff --git a/jetson/ros2_ws/src/saltybot_social_nav/config/social_nav_params.yaml b/jetson/ros2_ws/src/saltybot_social_nav/config/social_nav_params.yaml new file mode 100644 index 0000000..f68e323 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_nav/config/social_nav_params.yaml @@ -0,0 +1,22 @@ +social_nav: + ros__parameters: + follow_mode: 'shadow' + follow_distance: 1.2 + lead_distance: 2.0 + orbit_radius: 1.5 + max_linear_speed: 1.0 + max_linear_speed_fast: 5.5 + max_angular_speed: 1.0 + goal_tolerance: 0.3 + routes_dir: '/mnt/nvme/saltybot/routes' + home_x: 0.0 + home_y: 0.0 + map_resolution: 0.05 + obstacle_inflation_cells: 3 + +midas_depth: + ros__parameters: + onnx_path: '/mnt/nvme/saltybot/models/midas_small.onnx' + engine_path: '/mnt/nvme/saltybot/models/midas_small.engine' + process_rate: 5.0 + output_scale: 1.0 diff --git a/jetson/ros2_ws/src/saltybot_social_nav/launch/social_nav.launch.py b/jetson/ros2_ws/src/saltybot_social_nav/launch/social_nav.launch.py new file mode 100644 index 0000000..8684854 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_nav/launch/social_nav.launch.py @@ -0,0 +1,57 @@ +"""Launch file for saltybot social navigation.""" + +from launch import LaunchDescription +from launch.actions import DeclareLaunchArgument +from launch.substitutions import LaunchConfiguration +from launch_ros.actions import Node + + +def generate_launch_description(): + return LaunchDescription([ + DeclareLaunchArgument('follow_mode', default_value='shadow', + description='Follow mode: shadow/lead/side/orbit/loose/tight'), + DeclareLaunchArgument('follow_distance', default_value='1.2', + description='Follow distance in meters'), + DeclareLaunchArgument('max_linear_speed', default_value='1.0', + description='Max linear speed (m/s)'), + DeclareLaunchArgument('routes_dir', + default_value='/mnt/nvme/saltybot/routes', + description='Directory for saved routes'), + + Node( + package='saltybot_social_nav', + executable='social_nav', + name='social_nav', + output='screen', + parameters=[{ + 'follow_mode': LaunchConfiguration('follow_mode'), + 'follow_distance': LaunchConfiguration('follow_distance'), + 'max_linear_speed': LaunchConfiguration('max_linear_speed'), + 'routes_dir': LaunchConfiguration('routes_dir'), + }], + ), + + Node( + package='saltybot_social_nav', + executable='midas_depth', + name='midas_depth', + output='screen', + parameters=[{ + 'onnx_path': '/mnt/nvme/saltybot/models/midas_small.onnx', + 'engine_path': '/mnt/nvme/saltybot/models/midas_small.engine', + 'process_rate': 5.0, + 'output_scale': 1.0, + }], + ), + + Node( + package='saltybot_social_nav', + executable='waypoint_teacher', + name='waypoint_teacher', + output='screen', + parameters=[{ + 'routes_dir': LaunchConfiguration('routes_dir'), + 'recording_interval': 0.5, + }], + ), + ]) diff --git a/jetson/ros2_ws/src/saltybot_social_nav/package.xml b/jetson/ros2_ws/src/saltybot_social_nav/package.xml new file mode 100644 index 0000000..1338f17 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_nav/package.xml @@ -0,0 +1,28 @@ + + + + saltybot_social_nav + 0.1.0 + Social navigation for saltybot: follow modes, waypoint teaching, A* avoidance, MiDaS depth + seb + MIT + + rclpy + std_msgs + geometry_msgs + nav_msgs + sensor_msgs + cv_bridge + tf2_ros + tf2_geometry_msgs + saltybot_social_msgs + + ament_copyright + ament_flake8 + ament_pep257 + python3-pytest + + + ament_python + + diff --git a/jetson/ros2_ws/src/saltybot_social_nav/resource/saltybot_social_nav b/jetson/ros2_ws/src/saltybot_social_nav/resource/saltybot_social_nav new file mode 100644 index 0000000..e69de29 diff --git a/jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/__init__.py b/jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/astar.py b/jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/astar.py new file mode 100644 index 0000000..c758784 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/astar.py @@ -0,0 +1,82 @@ +"""astar.py -- A* path planner for saltybot social navigation.""" + +import heapq +import numpy as np + + +def astar(grid: np.ndarray, start: tuple, goal: tuple, + obstacle_val: int = 100) -> list | None: + """ + A* on a 2D occupancy grid (row, col indexing). + + Args: + grid: 2D numpy array, values 0=free, >=obstacle_val=obstacle + start: (row, col) start cell + goal: (row, col) goal cell + obstacle_val: cells with value >= obstacle_val are blocked + + Returns: + List of (row, col) tuples from start to goal, or None if no path. + """ + rows, cols = grid.shape + + def h(a, b): + return abs(a[0] - b[0]) + abs(a[1] - b[1]) # Manhattan heuristic + + open_set = [] + heapq.heappush(open_set, (h(start, goal), 0, start)) + came_from = {} + g_score = {start: 0} + + # 8-directional movement + neighbors_delta = [ + (-1, -1), (-1, 0), (-1, 1), + (0, -1), (0, 1), + (1, -1), (1, 0), (1, 1), + ] + + while open_set: + _, cost, current = heapq.heappop(open_set) + + if current == goal: + path = [] + while current in came_from: + path.append(current) + current = came_from[current] + path.append(start) + return list(reversed(path)) + + if cost > g_score.get(current, float('inf')): + continue + + for dr, dc in neighbors_delta: + nr, nc = current[0] + dr, current[1] + dc + if not (0 <= nr < rows and 0 <= nc < cols): + continue + if grid[nr, nc] >= obstacle_val: + continue + move_cost = 1.414 if (dr != 0 and dc != 0) else 1.0 + new_g = g_score[current] + move_cost + neighbor = (nr, nc) + if new_g < g_score.get(neighbor, float('inf')): + g_score[neighbor] = new_g + f = new_g + h(neighbor, goal) + came_from[neighbor] = current + heapq.heappush(open_set, (f, new_g, neighbor)) + + return None # No path found + + +def inflate_obstacles(grid: np.ndarray, inflation_radius_cells: int) -> np.ndarray: + """Inflate obstacles for robot footprint safety.""" + from scipy.ndimage import binary_dilation + + obstacle_mask = grid >= 50 + kernel = np.ones( + (2 * inflation_radius_cells + 1, 2 * inflation_radius_cells + 1), + dtype=bool, + ) + inflated = binary_dilation(obstacle_mask, structure=kernel) + result = grid.copy() + result[inflated] = 100 + return result diff --git a/jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/follow_modes.py b/jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/follow_modes.py new file mode 100644 index 0000000..02a11b4 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/follow_modes.py @@ -0,0 +1,82 @@ +"""follow_modes.py -- Follow mode geometry for saltybot social navigation.""" + +import math +from enum import Enum +import numpy as np + + +class FollowMode(Enum): + SHADOW = 'shadow' # stay directly behind at follow_distance + LEAD = 'lead' # move ahead of person by lead_distance + SIDE = 'side' # stay to the right (or left) at side_offset + ORBIT = 'orbit' # circle around person at orbit_radius + LOOSE = 'loose' # general follow, larger tolerance + TIGHT = 'tight' # close follow, small tolerance + + +def compute_shadow_target(person_pos, person_bearing_deg, follow_dist=1.2): + """Target position: behind person along their movement direction.""" + bearing_rad = math.radians(person_bearing_deg + 180.0) + tx = person_pos[0] + follow_dist * math.sin(bearing_rad) + ty = person_pos[1] + follow_dist * math.cos(bearing_rad) + return (tx, ty, person_pos[2]) + + +def compute_lead_target(person_pos, person_bearing_deg, lead_dist=2.0): + """Target position: ahead of person.""" + bearing_rad = math.radians(person_bearing_deg) + tx = person_pos[0] + lead_dist * math.sin(bearing_rad) + ty = person_pos[1] + lead_dist * math.cos(bearing_rad) + return (tx, ty, person_pos[2]) + + +def compute_side_target(person_pos, person_bearing_deg, side_dist=1.0, right=True): + """Target position: to the right (or left) of person.""" + sign = 1.0 if right else -1.0 + bearing_rad = math.radians(person_bearing_deg + sign * 90.0) + tx = person_pos[0] + side_dist * math.sin(bearing_rad) + ty = person_pos[1] + side_dist * math.cos(bearing_rad) + return (tx, ty, person_pos[2]) + + +def compute_orbit_target(person_pos, orbit_angle_deg, orbit_radius=1.5): + """Target on circle of radius orbit_radius around person.""" + angle_rad = math.radians(orbit_angle_deg) + tx = person_pos[0] + orbit_radius * math.sin(angle_rad) + ty = person_pos[1] + orbit_radius * math.cos(angle_rad) + return (tx, ty, person_pos[2]) + + +def compute_loose_target(person_pos, robot_pos, follow_dist=2.0, tolerance=0.8): + """Only move if farther than follow_dist + tolerance.""" + dx = person_pos[0] - robot_pos[0] + dy = person_pos[1] - robot_pos[1] + dist = math.hypot(dx, dy) + if dist <= follow_dist + tolerance: + return robot_pos + # Target at follow_dist behind person (toward robot) + scale = (dist - follow_dist) / dist + return (robot_pos[0] + dx * scale, robot_pos[1] + dy * scale, person_pos[2]) + + +def compute_tight_target(person_pos, follow_dist=0.6): + """Close follow: stay very near person.""" + return (person_pos[0], person_pos[1] - follow_dist, person_pos[2]) + + +MODE_VOICE_COMMANDS = { + 'shadow': FollowMode.SHADOW, + 'follow me': FollowMode.SHADOW, + 'behind me': FollowMode.SHADOW, + 'lead': FollowMode.LEAD, + 'go ahead': FollowMode.LEAD, + 'lead me': FollowMode.LEAD, + 'side': FollowMode.SIDE, + 'stay beside': FollowMode.SIDE, + 'orbit': FollowMode.ORBIT, + 'circle me': FollowMode.ORBIT, + 'loose': FollowMode.LOOSE, + 'give me space': FollowMode.LOOSE, + 'tight': FollowMode.TIGHT, + 'stay close': FollowMode.TIGHT, +} diff --git a/jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/midas_depth_node.py b/jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/midas_depth_node.py new file mode 100644 index 0000000..dcc4600 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/midas_depth_node.py @@ -0,0 +1,231 @@ +""" +midas_depth_node.py -- MiDaS monocular depth estimation for saltybot. + +Uses MiDaS_small via ONNX Runtime or TensorRT FP16. +Provides relative depth estimates for cameras without active depth (CSI cameras). + +Publishes /social/depth/cam{i}/image (sensor_msgs/Image, float32, relative depth) +""" + +import os +import numpy as np +import rclpy +from rclpy.node import Node +from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy +from sensor_msgs.msg import Image +from cv_bridge import CvBridge + +# MiDaS_small input size +_MIDAS_H = 256 +_MIDAS_W = 256 +# ImageNet normalization +_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) +_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) + + +class _TRTBackend: + """TensorRT inference backend for MiDaS.""" + + def __init__(self, engine_path: str, logger): + self._logger = logger + try: + import tensorrt as trt + import pycuda.driver as cuda + import pycuda.autoinit # noqa: F401 + + self._cuda = cuda + rt_logger = trt.Logger(trt.Logger.WARNING) + with open(engine_path, 'rb') as f: + engine = trt.Runtime(rt_logger).deserialize_cuda_engine(f.read()) + self._context = engine.create_execution_context() + + # Allocate buffers + self._d_input = cuda.mem_alloc(1 * 3 * _MIDAS_H * _MIDAS_W * 4) + self._d_output = cuda.mem_alloc(1 * _MIDAS_H * _MIDAS_W * 4) + self._h_output = np.empty((_MIDAS_H, _MIDAS_W), dtype=np.float32) + self._stream = cuda.Stream() + self._logger.info(f'TRT engine loaded: {engine_path}') + except Exception as e: + raise RuntimeError(f'TRT init failed: {e}') + + def infer(self, input_tensor: np.ndarray) -> np.ndarray: + self._cuda.memcpy_htod_async( + self._d_input, input_tensor.ravel(), self._stream) + self._context.execute_async_v2( + bindings=[int(self._d_input), int(self._d_output)], + stream_handle=self._stream.handle) + self._cuda.memcpy_dtoh_async( + self._h_output, self._d_output, self._stream) + self._stream.synchronize() + return self._h_output.copy() + + +class _ONNXBackend: + """ONNX Runtime inference backend for MiDaS.""" + + def __init__(self, onnx_path: str, logger): + self._logger = logger + try: + import onnxruntime as ort + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + self._session = ort.InferenceSession(onnx_path, providers=providers) + self._input_name = self._session.get_inputs()[0].name + self._logger.info(f'ONNX model loaded: {onnx_path}') + except Exception as e: + raise RuntimeError(f'ONNX init failed: {e}') + + def infer(self, input_tensor: np.ndarray) -> np.ndarray: + result = self._session.run(None, {self._input_name: input_tensor}) + return result[0].squeeze() + + +class MiDaSDepthNode(Node): + """MiDaS monocular depth estimation node.""" + + def __init__(self): + super().__init__('midas_depth') + + # Parameters + self.declare_parameter('onnx_path', + '/mnt/nvme/saltybot/models/midas_small.onnx') + self.declare_parameter('engine_path', + '/mnt/nvme/saltybot/models/midas_small.engine') + self.declare_parameter('camera_topics', [ + '/surround/cam0/image_raw', + '/surround/cam1/image_raw', + '/surround/cam2/image_raw', + '/surround/cam3/image_raw', + ]) + self.declare_parameter('output_scale', 1.0) + self.declare_parameter('process_rate', 5.0) + + onnx_path = self.get_parameter('onnx_path').value + engine_path = self.get_parameter('engine_path').value + self._camera_topics = self.get_parameter('camera_topics').value + self._output_scale = self.get_parameter('output_scale').value + process_rate = self.get_parameter('process_rate').value + + # Initialize inference backend (TRT preferred, ONNX fallback) + self._backend = None + if os.path.exists(engine_path): + try: + self._backend = _TRTBackend(engine_path, self.get_logger()) + except RuntimeError: + self.get_logger().warn('TRT failed, trying ONNX fallback') + if self._backend is None and os.path.exists(onnx_path): + try: + self._backend = _ONNXBackend(onnx_path, self.get_logger()) + except RuntimeError: + self.get_logger().error('Both TRT and ONNX backends failed') + if self._backend is None: + self.get_logger().error( + 'No MiDaS model found. Depth estimation disabled.') + + self._bridge = CvBridge() + + # Latest frames per camera (round-robin processing) + self._latest_frames = [None] * len(self._camera_topics) + self._current_cam_idx = 0 + + # QoS for camera subscriptions + cam_qos = QoSProfile( + reliability=ReliabilityPolicy.BEST_EFFORT, + history=HistoryPolicy.KEEP_LAST, + depth=1, + ) + + # Subscribe to each camera topic + self._cam_subs = [] + for i, topic in enumerate(self._camera_topics): + sub = self.create_subscription( + Image, topic, + lambda msg, idx=i: self._on_image(msg, idx), + cam_qos) + self._cam_subs.append(sub) + + # Publishers: one per camera + self._depth_pubs = [] + for i in range(len(self._camera_topics)): + pub = self.create_publisher( + Image, f'/social/depth/cam{i}/image', 10) + self._depth_pubs.append(pub) + + # Timer: round-robin across cameras + timer_period = 1.0 / process_rate + self._timer = self.create_timer(timer_period, self._timer_callback) + + self.get_logger().info( + f'MiDaS depth node started: {len(self._camera_topics)} cameras ' + f'@ {process_rate} Hz') + + def _on_image(self, msg: Image, cam_idx: int): + """Cache latest frame for each camera.""" + self._latest_frames[cam_idx] = msg + + def _preprocess(self, bgr: np.ndarray) -> np.ndarray: + """Preprocess BGR image to MiDaS input tensor [1,3,256,256].""" + import cv2 + rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) + resized = cv2.resize(rgb, (_MIDAS_W, _MIDAS_H), + interpolation=cv2.INTER_LINEAR) + normalized = (resized.astype(np.float32) / 255.0 - _MEAN) / _STD + # HWC -> CHW -> NCHW + tensor = normalized.transpose(2, 0, 1)[np.newaxis, ...] + return tensor.astype(np.float32) + + def _infer(self, tensor: np.ndarray) -> np.ndarray: + """Run inference, returns [256,256] float32 relative inverse depth.""" + if self._backend is None: + return np.zeros((_MIDAS_H, _MIDAS_W), dtype=np.float32) + return self._backend.infer(tensor) + + def _postprocess(self, raw: np.ndarray, orig_shape: tuple) -> np.ndarray: + """Resize depth back to original image shape, apply output_scale.""" + import cv2 + h, w = orig_shape[:2] + depth = cv2.resize(raw, (w, h), interpolation=cv2.INTER_LINEAR) + depth = depth * self._output_scale + return depth + + def _timer_callback(self): + """Process one camera per tick (round-robin).""" + if not self._camera_topics: + return + + idx = self._current_cam_idx + self._current_cam_idx = (idx + 1) % len(self._camera_topics) + + msg = self._latest_frames[idx] + if msg is None: + return + + try: + bgr = self._bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8') + except Exception as e: + self.get_logger().warn(f'cv_bridge error cam{idx}: {e}') + return + + tensor = self._preprocess(bgr) + raw_depth = self._infer(tensor) + depth_map = self._postprocess(raw_depth, bgr.shape) + + # Publish as float32 Image + depth_msg = self._bridge.cv2_to_imgmsg(depth_map, encoding='32FC1') + depth_msg.header = msg.header + self._depth_pubs[idx].publish(depth_msg) + + +def main(args=None): + rclpy.init(args=args) + node = MiDaSDepthNode() + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.shutdown() + + +if __name__ == '__main__': + main() diff --git a/jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/social_nav_node.py b/jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/social_nav_node.py new file mode 100644 index 0000000..aeefe55 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/social_nav_node.py @@ -0,0 +1,584 @@ +""" +social_nav_node.py -- Social navigation node for saltybot. + +Orchestrates person following with multiple modes, voice commands, +waypoint teaching/replay, and A* obstacle avoidance. + +Follow modes: + shadow -- stay directly behind at follow_distance + lead -- move ahead of person + side -- stay beside (default right) + orbit -- circle around person + loose -- relaxed follow with deadband + tight -- close follow + +Waypoint teaching: + Voice command "teach route " -> record mode ON + Voice command "stop teaching" -> save route + Voice command "replay route " -> playback + +Voice commands: + "follow me" / "shadow" -> SHADOW mode + "lead me" / "go ahead" -> LEAD mode + "stay beside" -> SIDE mode + "orbit" -> ORBIT mode + "give me space" -> LOOSE mode + "stay close" -> TIGHT mode + "stop" / "halt" -> STOP + "go home" -> navigate to home waypoint + "teach route " -> start recording + "stop teaching" -> finish recording + "replay route " -> playback recorded route +""" + +import math +import time +import re +from collections import deque + +import numpy as np +import rclpy +from rclpy.node import Node +from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy +from geometry_msgs.msg import Twist, PoseStamped +from nav_msgs.msg import OccupancyGrid, Odometry +from std_msgs.msg import String, Int32 + +from .follow_modes import ( + FollowMode, MODE_VOICE_COMMANDS, + compute_shadow_target, compute_lead_target, compute_side_target, + compute_orbit_target, compute_loose_target, compute_tight_target, +) +from .astar import astar, inflate_obstacles +from .waypoint_teacher import WaypointRoute, WaypointReplayer + +# Try importing social msgs; fallback gracefully +try: + from saltybot_social_msgs.msg import PersonStateArray + _HAS_SOCIAL_MSGS = True +except ImportError: + _HAS_SOCIAL_MSGS = False + +# Proportional controller gains +_K_ANG = 2.0 # angular gain +_K_LIN = 0.8 # linear gain +_HIGH_SPEED_THRESHOLD = 3.0 # m/s person velocity triggers fast mode +_PREDICT_AHEAD_S = 0.3 # seconds to extrapolate position +_TEACH_MIN_DIST = 0.5 # meters between recorded waypoints + + +class SocialNavNode(Node): + """Main social navigation orchestrator.""" + + def __init__(self): + super().__init__('social_nav') + + # -- Parameters -- + self.declare_parameter('follow_mode', 'shadow') + self.declare_parameter('follow_distance', 1.2) + self.declare_parameter('lead_distance', 2.0) + self.declare_parameter('orbit_radius', 1.5) + self.declare_parameter('max_linear_speed', 1.0) + self.declare_parameter('max_linear_speed_fast', 5.5) + self.declare_parameter('max_angular_speed', 1.0) + self.declare_parameter('goal_tolerance', 0.3) + self.declare_parameter('routes_dir', '/mnt/nvme/saltybot/routes') + self.declare_parameter('home_x', 0.0) + self.declare_parameter('home_y', 0.0) + self.declare_parameter('map_resolution', 0.05) + self.declare_parameter('obstacle_inflation_cells', 3) + + self._follow_mode = FollowMode( + self.get_parameter('follow_mode').value) + self._follow_distance = self.get_parameter('follow_distance').value + self._lead_distance = self.get_parameter('lead_distance').value + self._orbit_radius = self.get_parameter('orbit_radius').value + self._max_lin = self.get_parameter('max_linear_speed').value + self._max_lin_fast = self.get_parameter('max_linear_speed_fast').value + self._max_ang = self.get_parameter('max_angular_speed').value + self._goal_tol = self.get_parameter('goal_tolerance').value + self._routes_dir = self.get_parameter('routes_dir').value + self._home_x = self.get_parameter('home_x').value + self._home_y = self.get_parameter('home_y').value + self._map_resolution = self.get_parameter('map_resolution').value + self._inflation_cells = self.get_parameter( + 'obstacle_inflation_cells').value + + # -- State -- + self._robot_x = 0.0 + self._robot_y = 0.0 + self._robot_yaw = 0.0 + self._target_person_pos = None # (x, y, z) + self._target_person_bearing = 0.0 + self._target_person_id = -1 + self._person_history = deque(maxlen=5) # for velocity estimation + self._stopped = False + self._go_home = False + + # Occupancy grid for A* + self._occ_grid = None + self._occ_origin = (0.0, 0.0) + self._occ_resolution = 0.05 + self._astar_path = None + + # Orbit state + self._orbit_angle = 0.0 + + # Waypoint teaching / replay + self._teaching = False + self._current_route = None + self._last_teach_x = None + self._last_teach_y = None + self._replayer = None + + # -- QoS profiles -- + best_effort_qos = QoSProfile( + reliability=ReliabilityPolicy.BEST_EFFORT, + history=HistoryPolicy.KEEP_LAST, depth=1) + reliable_qos = QoSProfile( + reliability=ReliabilityPolicy.RELIABLE, + history=HistoryPolicy.KEEP_LAST, depth=1) + + # -- Subscriptions -- + if _HAS_SOCIAL_MSGS: + self.create_subscription( + PersonStateArray, '/social/persons', + self._on_persons, best_effort_qos) + else: + self.get_logger().warn( + 'saltybot_social_msgs not found; ' + 'using /person/target fallback') + + self.create_subscription( + PoseStamped, '/person/target', + self._on_person_target, best_effort_qos) + self.create_subscription( + String, '/social/speech/command', + self._on_voice_command, 10) + self.create_subscription( + String, '/social/speech/transcript', + self._on_transcript, 10) + self.create_subscription( + OccupancyGrid, '/map', + self._on_map, reliable_qos) + self.create_subscription( + Odometry, '/odom', + self._on_odom, best_effort_qos) + self.create_subscription( + Int32, '/social/attention/target_id', + self._on_target_id, 10) + + # -- Publishers -- + self._cmd_vel_pub = self.create_publisher( + Twist, '/cmd_vel', best_effort_qos) + self._mode_pub = self.create_publisher( + String, '/social/nav/mode', reliable_qos) + self._target_pub = self.create_publisher( + PoseStamped, '/social/nav/target_pos', 10) + self._status_pub = self.create_publisher( + String, '/social/nav/status', best_effort_qos) + + # -- Main loop timer (20 Hz) -- + self._timer = self.create_timer(0.05, self._control_loop) + + self.get_logger().info( + f'Social nav started: mode={self._follow_mode.value}, ' + f'dist={self._follow_distance}m') + + # ---------------------------------------------------------------- + # Subscriptions + # ---------------------------------------------------------------- + + def _on_persons(self, msg): + """Handle PersonStateArray from social perception.""" + target_id = msg.primary_attention_id + if self._target_person_id >= 0: + target_id = self._target_person_id + + for p in msg.persons: + if p.person_id == target_id or ( + target_id < 0 and len(msg.persons) > 0): + pos = (p.position.x, p.position.y, p.position.z) + self._update_person_position(pos, p.bearing_deg) + break + + def _on_person_target(self, msg: PoseStamped): + """Fallback: single person target pose.""" + pos = (msg.pose.position.x, msg.pose.position.y, + msg.pose.position.z) + # Estimate bearing from quaternion yaw + q = msg.pose.orientation + yaw = math.atan2(2.0 * (q.w * q.z + q.x * q.y), + 1.0 - 2.0 * (q.y * q.y + q.z * q.z)) + self._update_person_position(pos, math.degrees(yaw)) + + def _update_person_position(self, pos, bearing_deg): + """Update person tracking state and record history.""" + now = time.time() + self._target_person_pos = pos + self._target_person_bearing = bearing_deg + self._person_history.append((now, pos[0], pos[1])) + + def _on_odom(self, msg: Odometry): + """Update robot pose from odometry.""" + self._robot_x = msg.pose.pose.position.x + self._robot_y = msg.pose.pose.position.y + q = msg.pose.pose.orientation + self._robot_yaw = math.atan2( + 2.0 * (q.w * q.z + q.x * q.y), + 1.0 - 2.0 * (q.y * q.y + q.z * q.z)) + + def _on_map(self, msg: OccupancyGrid): + """Cache occupancy grid for A* planning.""" + w, h = msg.info.width, msg.info.height + data = np.array(msg.data, dtype=np.int8).reshape((h, w)) + # Convert -1 (unknown) to free (0) for planning + data[data < 0] = 0 + self._occ_grid = data.astype(np.int32) + self._occ_origin = (msg.info.origin.position.x, + msg.info.origin.position.y) + self._occ_resolution = msg.info.resolution + + def _on_target_id(self, msg: Int32): + """Switch target person.""" + self._target_person_id = msg.data + self.get_logger().info(f'Target person ID set to {msg.data}') + + def _on_voice_command(self, msg: String): + """Handle discrete voice commands for mode switching.""" + cmd = msg.data.strip().lower() + + if cmd in ('stop', 'halt'): + self._stopped = True + self._replayer = None + self._publish_status('STOPPED') + return + + if cmd in ('resume', 'go', 'start'): + self._stopped = False + self._publish_status('RESUMED') + return + + matched = MODE_VOICE_COMMANDS.get(cmd) + if matched: + self._follow_mode = matched + self._stopped = False + mode_msg = String() + mode_msg.data = self._follow_mode.value + self._mode_pub.publish(mode_msg) + self._publish_status(f'MODE: {self._follow_mode.value}') + + def _on_transcript(self, msg: String): + """Handle free-form voice transcripts for route teaching.""" + text = msg.data.strip().lower() + + # "teach route " + m = re.match(r'teach\s+route\s+(\w+)', text) + if m: + name = m.group(1) + self._teaching = True + self._current_route = WaypointRoute(name) + self._last_teach_x = self._robot_x + self._last_teach_y = self._robot_y + self._publish_status(f'TEACHING: {name}') + self.get_logger().info(f'Recording route: {name}') + return + + # "stop teaching" + if 'stop teaching' in text: + if self._teaching and self._current_route: + self._current_route.save(self._routes_dir) + self._publish_status( + f'SAVED: {self._current_route.name} ' + f'({len(self._current_route.waypoints)} pts)') + self.get_logger().info( + f'Route saved: {self._current_route.name}') + self._teaching = False + self._current_route = None + return + + # "replay route " + m = re.match(r'replay\s+route\s+(\w+)', text) + if m: + name = m.group(1) + try: + route = WaypointRoute.load(self._routes_dir, name) + self._replayer = WaypointReplayer(route) + self._stopped = False + self._publish_status(f'REPLAY: {name}') + self.get_logger().info(f'Replaying route: {name}') + except FileNotFoundError: + self._publish_status(f'ROUTE NOT FOUND: {name}') + return + + # "go home" + if 'go home' in text: + self._go_home = True + self._stopped = False + self._publish_status('GO HOME') + return + + # Also try mode commands from transcript + for phrase, mode in MODE_VOICE_COMMANDS.items(): + if phrase in text: + self._follow_mode = mode + self._stopped = False + mode_msg = String() + mode_msg.data = self._follow_mode.value + self._mode_pub.publish(mode_msg) + self._publish_status(f'MODE: {self._follow_mode.value}') + return + + # ---------------------------------------------------------------- + # Control loop + # ---------------------------------------------------------------- + + def _control_loop(self): + """Main 20Hz control loop.""" + # Record waypoint if teaching + if self._teaching and self._current_route: + self._maybe_record_waypoint() + + # Publish zero velocity if stopped + if self._stopped: + self._publish_cmd_vel(0.0, 0.0) + return + + # Determine navigation target + target = self._get_nav_target() + if target is None: + self._publish_cmd_vel(0.0, 0.0) + return + + tx, ty, tz = target + + # Publish debug target + self._publish_target_pose(tx, ty, tz) + + # Check if arrived + dist_to_target = math.hypot(tx - self._robot_x, ty - self._robot_y) + if dist_to_target < self._goal_tol: + self._publish_cmd_vel(0.0, 0.0) + if self._go_home: + self._go_home = False + self._publish_status('HOME REACHED') + return + + # Try A* path if map available + if self._occ_grid is not None: + path_target = self._plan_astar(tx, ty) + if path_target: + tx, ty = path_target + + # Determine speed limit + max_lin = self._max_lin + person_vel = self._estimate_person_velocity() + if person_vel > _HIGH_SPEED_THRESHOLD: + max_lin = self._max_lin_fast + + # Compute and publish cmd_vel + lin, ang = self._compute_cmd_vel( + self._robot_x, self._robot_y, self._robot_yaw, + tx, ty, max_lin) + self._publish_cmd_vel(lin, ang) + + def _get_nav_target(self): + """Determine current navigation target based on mode/state.""" + # Route replay takes priority + if self._replayer and not self._replayer.is_done: + self._replayer.check_arrived(self._robot_x, self._robot_y) + wp = self._replayer.current_waypoint() + if wp: + return (wp.x, wp.y, wp.z) + else: + self._replayer = None + self._publish_status('REPLAY DONE') + return None + + # Go home + if self._go_home: + return (self._home_x, self._home_y, 0.0) + + # Person following + if self._target_person_pos is None: + return None + + # Predict person position ahead for high-speed tracking + px, py, pz = self._predict_person_position() + bearing = self._target_person_bearing + robot_pos = (self._robot_x, self._robot_y, 0.0) + + if self._follow_mode == FollowMode.SHADOW: + return compute_shadow_target( + (px, py, pz), bearing, self._follow_distance) + elif self._follow_mode == FollowMode.LEAD: + return compute_lead_target( + (px, py, pz), bearing, self._lead_distance) + elif self._follow_mode == FollowMode.SIDE: + return compute_side_target( + (px, py, pz), bearing, self._follow_distance) + elif self._follow_mode == FollowMode.ORBIT: + self._orbit_angle = (self._orbit_angle + 1.0) % 360.0 + return compute_orbit_target( + (px, py, pz), self._orbit_angle, self._orbit_radius) + elif self._follow_mode == FollowMode.LOOSE: + return compute_loose_target( + (px, py, pz), robot_pos, self._follow_distance) + elif self._follow_mode == FollowMode.TIGHT: + return compute_tight_target( + (px, py, pz), self._follow_distance) + + return (px, py, pz) + + def _predict_person_position(self): + """Extrapolate person position using velocity from recent samples.""" + if self._target_person_pos is None: + return (0.0, 0.0, 0.0) + + px, py, pz = self._target_person_pos + + if len(self._person_history) >= 3: + # Use last 3 samples for velocity estimation + t0, x0, y0 = self._person_history[-3] + t1, x1, y1 = self._person_history[-1] + dt = t1 - t0 + if dt > 0.01: + vx = (x1 - x0) / dt + vy = (y1 - y0) / dt + speed = math.hypot(vx, vy) + if speed > _HIGH_SPEED_THRESHOLD: + px += vx * _PREDICT_AHEAD_S + py += vy * _PREDICT_AHEAD_S + + return (px, py, pz) + + def _estimate_person_velocity(self) -> float: + """Estimate person speed from recent position history.""" + if len(self._person_history) < 2: + return 0.0 + t0, x0, y0 = self._person_history[-2] + t1, x1, y1 = self._person_history[-1] + dt = t1 - t0 + if dt < 0.01: + return 0.0 + return math.hypot(x1 - x0, y1 - y0) / dt + + def _plan_astar(self, target_x, target_y): + """Run A* on occupancy grid, return next waypoint in world coords.""" + grid = self._occ_grid + res = self._occ_resolution + ox, oy = self._occ_origin + + # World to grid + def w2g(wx, wy): + return (int((wy - oy) / res), int((wx - ox) / res)) + + start = w2g(self._robot_x, self._robot_y) + goal = w2g(target_x, target_y) + + rows, cols = grid.shape + if not (0 <= start[0] < rows and 0 <= start[1] < cols): + return None + if not (0 <= goal[0] < rows and 0 <= goal[1] < cols): + return None + + inflated = inflate_obstacles(grid, self._inflation_cells) + path = astar(inflated, start, goal) + + if path and len(path) > 1: + # Follow a lookahead point (3 steps ahead or end) + lookahead_idx = min(3, len(path) - 1) + r, c = path[lookahead_idx] + wx = ox + c * res + res / 2.0 + wy = oy + r * res + res / 2.0 + self._astar_path = path + return (wx, wy) + + return None + + def _compute_cmd_vel(self, rx, ry, ryaw, tx, ty, max_lin): + """Proportional controller: compute linear and angular velocity.""" + dx = tx - rx + dy = ty - ry + dist = math.hypot(dx, dy) + angle_to_target = math.atan2(dy, dx) + angle_error = angle_to_target - ryaw + + # Normalize angle error to [-pi, pi] + while angle_error > math.pi: + angle_error -= 2.0 * math.pi + while angle_error < -math.pi: + angle_error += 2.0 * math.pi + + angular_vel = _K_ANG * angle_error + angular_vel = max(-self._max_ang, + min(self._max_ang, angular_vel)) + + # Reduce linear speed when turning hard + angle_factor = max(0.0, 1.0 - abs(angle_error) / (math.pi / 2.0)) + linear_vel = _K_LIN * dist * angle_factor + linear_vel = max(0.0, min(max_lin, linear_vel)) + + return (linear_vel, angular_vel) + + # ---------------------------------------------------------------- + # Waypoint teaching + # ---------------------------------------------------------------- + + def _maybe_record_waypoint(self): + """Record waypoint if robot moved > _TEACH_MIN_DIST.""" + if self._last_teach_x is None: + self._last_teach_x = self._robot_x + self._last_teach_y = self._robot_y + + dist = math.hypot( + self._robot_x - self._last_teach_x, + self._robot_y - self._last_teach_y) + + if dist >= _TEACH_MIN_DIST: + yaw_deg = math.degrees(self._robot_yaw) + self._current_route.add( + self._robot_x, self._robot_y, 0.0, yaw_deg) + self._last_teach_x = self._robot_x + self._last_teach_y = self._robot_y + + # ---------------------------------------------------------------- + # Publishers + # ---------------------------------------------------------------- + + def _publish_cmd_vel(self, linear: float, angular: float): + twist = Twist() + twist.linear.x = linear + twist.angular.z = angular + self._cmd_vel_pub.publish(twist) + + def _publish_target_pose(self, x, y, z): + msg = PoseStamped() + msg.header.stamp = self.get_clock().now().to_msg() + msg.header.frame_id = 'map' + msg.pose.position.x = x + msg.pose.position.y = y + msg.pose.position.z = z + self._target_pub.publish(msg) + + def _publish_status(self, status: str): + msg = String() + msg.data = status + self._status_pub.publish(msg) + self.get_logger().info(f'Nav status: {status}') + + +def main(args=None): + rclpy.init(args=args) + node = SocialNavNode() + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.shutdown() + + +if __name__ == '__main__': + main() diff --git a/jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/waypoint_teacher.py b/jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/waypoint_teacher.py new file mode 100644 index 0000000..62fb8f2 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/waypoint_teacher.py @@ -0,0 +1,91 @@ +"""waypoint_teacher.py -- Record and replay waypoint routes.""" + +import json +import time +import math +from pathlib import Path +from dataclasses import dataclass, asdict + + +@dataclass +class Waypoint: + x: float + y: float + z: float + yaw_deg: float + timestamp: float + label: str = '' + + +class WaypointRoute: + """A named sequence of waypoints.""" + + def __init__(self, name: str): + self.name = name + self.waypoints: list[Waypoint] = [] + self.created_at = time.time() + + def add(self, x, y, z, yaw_deg, label=''): + self.waypoints.append(Waypoint(x, y, z, yaw_deg, time.time(), label)) + + def to_dict(self): + return { + 'name': self.name, + 'created_at': self.created_at, + 'waypoints': [asdict(w) for w in self.waypoints], + } + + @classmethod + def from_dict(cls, d): + route = cls(d['name']) + route.created_at = d.get('created_at', 0) + route.waypoints = [Waypoint(**w) for w in d['waypoints']] + return route + + def save(self, routes_dir: str): + Path(routes_dir).mkdir(parents=True, exist_ok=True) + path = Path(routes_dir) / f'{self.name}.json' + path.write_text(json.dumps(self.to_dict(), indent=2)) + + @classmethod + def load(cls, routes_dir: str, name: str): + path = Path(routes_dir) / f'{name}.json' + return cls.from_dict(json.loads(path.read_text())) + + @staticmethod + def list_routes(routes_dir: str) -> list[str]: + d = Path(routes_dir) + if not d.exists(): + return [] + return [p.stem for p in d.glob('*.json')] + + +class WaypointReplayer: + """Iterates through waypoints, returning next target.""" + + def __init__(self, route: WaypointRoute, arrival_radius: float = 0.3): + self._route = route + self._idx = 0 + self._arrival_radius = arrival_radius + + def current_waypoint(self) -> Waypoint | None: + if self._idx < len(self._route.waypoints): + return self._route.waypoints[self._idx] + return None + + def check_arrived(self, robot_x, robot_y) -> bool: + wp = self.current_waypoint() + if wp is None: + return False + dist = math.hypot(robot_x - wp.x, robot_y - wp.y) + if dist < self._arrival_radius: + self._idx += 1 + return True + return False + + @property + def is_done(self) -> bool: + return self._idx >= len(self._route.waypoints) + + def reset(self): + self._idx = 0 diff --git a/jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/waypoint_teacher_node.py b/jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/waypoint_teacher_node.py new file mode 100644 index 0000000..b214326 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_nav/saltybot_social_nav/waypoint_teacher_node.py @@ -0,0 +1,135 @@ +""" +waypoint_teacher_node.py -- Standalone waypoint teacher ROS2 node. + +Listens to /social/speech/transcript for "teach route " and "stop teaching". +Records robot pose at configurable intervals. Saves/loads routes via WaypointRoute. +""" + +import math + +import rclpy +from rclpy.node import Node +from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy +from nav_msgs.msg import Odometry +from std_msgs.msg import String + +from .waypoint_teacher import WaypointRoute + + +class WaypointTeacherNode(Node): + """Standalone waypoint teaching node.""" + + def __init__(self): + super().__init__('waypoint_teacher') + + self.declare_parameter('routes_dir', '/mnt/nvme/saltybot/routes') + self.declare_parameter('recording_interval', 0.5) # meters + + self._routes_dir = self.get_parameter('routes_dir').value + self._interval = self.get_parameter('recording_interval').value + + self._teaching = False + self._route = None + self._last_x = None + self._last_y = None + self._robot_x = 0.0 + self._robot_y = 0.0 + self._robot_yaw = 0.0 + + best_effort_qos = QoSProfile( + reliability=ReliabilityPolicy.BEST_EFFORT, + history=HistoryPolicy.KEEP_LAST, depth=1) + + self.create_subscription( + Odometry, '/odom', self._on_odom, best_effort_qos) + self.create_subscription( + String, '/social/speech/transcript', + self._on_transcript, 10) + + self._status_pub = self.create_publisher( + String, '/social/waypoint/status', 10) + + # Record timer at 10Hz (check distance) + self._timer = self.create_timer(0.1, self._record_tick) + + self.get_logger().info( + f'Waypoint teacher ready (interval={self._interval}m, ' + f'dir={self._routes_dir})') + + def _on_odom(self, msg: Odometry): + self._robot_x = msg.pose.pose.position.x + self._robot_y = msg.pose.pose.position.y + q = msg.pose.pose.orientation + self._robot_yaw = math.atan2( + 2.0 * (q.w * q.z + q.x * q.y), + 1.0 - 2.0 * (q.y * q.y + q.z * q.z)) + + def _on_transcript(self, msg: String): + text = msg.data.strip().lower() + + import re + m = re.match(r'teach\s+route\s+(\w+)', text) + if m: + name = m.group(1) + self._route = WaypointRoute(name) + self._teaching = True + self._last_x = self._robot_x + self._last_y = self._robot_y + self._pub_status(f'RECORDING: {name}') + self.get_logger().info(f'Recording route: {name}') + return + + if 'stop teaching' in text: + if self._teaching and self._route: + self._route.save(self._routes_dir) + n = len(self._route.waypoints) + self._pub_status( + f'SAVED: {self._route.name} ({n} waypoints)') + self.get_logger().info( + f'Route saved: {self._route.name} ({n} pts)') + self._teaching = False + self._route = None + return + + if 'list routes' in text: + routes = WaypointRoute.list_routes(self._routes_dir) + self._pub_status(f'ROUTES: {", ".join(routes) or "(none)"}') + + def _record_tick(self): + if not self._teaching or self._route is None: + return + + if self._last_x is None: + self._last_x = self._robot_x + self._last_y = self._robot_y + + dist = math.hypot( + self._robot_x - self._last_x, + self._robot_y - self._last_y) + + if dist >= self._interval: + yaw_deg = math.degrees(self._robot_yaw) + self._route.add(self._robot_x, self._robot_y, 0.0, yaw_deg) + self._last_x = self._robot_x + self._last_y = self._robot_y + + def _pub_status(self, text: str): + msg = String() + msg.data = text + self._status_pub.publish(msg) + + +def main(args=None): + rclpy.init(args=args) + node = WaypointTeacherNode() + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.shutdown() + + +if __name__ == '__main__': + main() diff --git a/jetson/ros2_ws/src/saltybot_social_nav/scripts/build_midas_trt_engine.py b/jetson/ros2_ws/src/saltybot_social_nav/scripts/build_midas_trt_engine.py new file mode 100644 index 0000000..493d440 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_nav/scripts/build_midas_trt_engine.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 +""" +build_midas_trt_engine.py -- Build TensorRT FP16 engine for MiDaS_small from ONNX. + +Usage: + python3 build_midas_trt_engine.py \ + --onnx /mnt/nvme/saltybot/models/midas_small.onnx \ + --engine /mnt/nvme/saltybot/models/midas_small.engine \ + --fp16 + +Requires: tensorrt, pycuda +""" + +import argparse +import os +import sys + + +def build_engine(onnx_path: str, engine_path: str, fp16: bool = True): + try: + import tensorrt as trt + except ImportError: + print('ERROR: tensorrt not found. Install TensorRT first.') + sys.exit(1) + + logger = trt.Logger(trt.Logger.INFO) + builder = trt.Builder(logger) + network = builder.create_network( + 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + parser = trt.OnnxParser(network, logger) + + print(f'Parsing ONNX model: {onnx_path}') + with open(onnx_path, 'rb') as f: + if not parser.parse(f.read()): + for i in range(parser.num_errors): + print(f' ONNX parse error: {parser.get_error(i)}') + sys.exit(1) + + config = builder.create_builder_config() + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB + + if fp16 and builder.platform_has_fast_fp16: + config.set_flag(trt.BuilderFlag.FP16) + print('FP16 mode enabled') + elif fp16: + print('WARNING: FP16 not supported on this platform, using FP32') + + print('Building TensorRT engine (this may take several minutes)...') + engine_bytes = builder.build_serialized_network(network, config) + if engine_bytes is None: + print('ERROR: Failed to build engine') + sys.exit(1) + + os.makedirs(os.path.dirname(engine_path) or '.', exist_ok=True) + with open(engine_path, 'wb') as f: + f.write(engine_bytes) + + size_mb = len(engine_bytes) / (1024 * 1024) + print(f'Engine saved: {engine_path} ({size_mb:.1f} MB)') + + +def main(): + parser = argparse.ArgumentParser( + description='Build TensorRT FP16 engine for MiDaS_small') + parser.add_argument('--onnx', required=True, + help='Path to MiDaS ONNX model') + parser.add_argument('--engine', required=True, + help='Output TRT engine path') + parser.add_argument('--fp16', action='store_true', default=True, + help='Enable FP16 (default: True)') + parser.add_argument('--fp32', action='store_true', + help='Force FP32 (disable FP16)') + args = parser.parse_args() + + fp16 = not args.fp32 + build_engine(args.onnx, args.engine, fp16=fp16) + + +if __name__ == '__main__': + main() diff --git a/jetson/ros2_ws/src/saltybot_social_nav/setup.cfg b/jetson/ros2_ws/src/saltybot_social_nav/setup.cfg new file mode 100644 index 0000000..febc349 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_nav/setup.cfg @@ -0,0 +1,4 @@ +[develop] +script_dir=$base/lib/saltybot_social_nav +[install] +install_scripts=$base/lib/saltybot_social_nav diff --git a/jetson/ros2_ws/src/saltybot_social_nav/setup.py b/jetson/ros2_ws/src/saltybot_social_nav/setup.py new file mode 100644 index 0000000..bb044f4 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_nav/setup.py @@ -0,0 +1,31 @@ +from setuptools import setup +import os +from glob import glob + +package_name = 'saltybot_social_nav' + +setup( + name=package_name, + version='0.1.0', + packages=[package_name], + data_files=[ + ('share/ament_index/resource_index/packages', ['resource/' + package_name]), + ('share/' + package_name, ['package.xml']), + (os.path.join('share', package_name, 'launch'), glob('launch/*.py')), + (os.path.join('share', package_name, 'config'), glob('config/*.yaml')), + ], + install_requires=['setuptools'], + zip_safe=True, + maintainer='seb', + maintainer_email='seb@vayrette.com', + description='Social navigation for saltybot: follow modes, waypoint teaching, A* avoidance, MiDaS depth', + license='MIT', + tests_require=['pytest'], + entry_points={ + 'console_scripts': [ + 'social_nav = saltybot_social_nav.social_nav_node:main', + 'midas_depth = saltybot_social_nav.midas_depth_node:main', + 'waypoint_teacher = saltybot_social_nav.waypoint_teacher_node:main', + ], + }, +) diff --git a/jetson/ros2_ws/src/saltybot_social_nav/test/test_copyright.py b/jetson/ros2_ws/src/saltybot_social_nav/test/test_copyright.py new file mode 100644 index 0000000..5712f77 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_nav/test/test_copyright.py @@ -0,0 +1,12 @@ +# Copyright 2026 SaltyLab +# Licensed under MIT + +from ament_copyright.main import main +import pytest + + +@pytest.mark.copyright +@pytest.mark.linter +def test_copyright(): + rc = main(argv=['.', 'test']) + assert rc == 0, 'Found errors' diff --git a/jetson/ros2_ws/src/saltybot_social_nav/test/test_flake8.py b/jetson/ros2_ws/src/saltybot_social_nav/test/test_flake8.py new file mode 100644 index 0000000..07ce385 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_nav/test/test_flake8.py @@ -0,0 +1,14 @@ +# Copyright 2026 SaltyLab +# Licensed under MIT + +from ament_flake8.main import main_with_errors +import pytest + + +@pytest.mark.flake8 +@pytest.mark.linter +def test_flake8(): + rc, errors = main_with_errors(argv=[]) + assert rc == 0, \ + 'Found %d code style errors / warnings:\n' % len(errors) + \ + '\n'.join(errors) diff --git a/jetson/ros2_ws/src/saltybot_social_nav/test/test_pep257.py b/jetson/ros2_ws/src/saltybot_social_nav/test/test_pep257.py new file mode 100644 index 0000000..9431120 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_nav/test/test_pep257.py @@ -0,0 +1,12 @@ +# Copyright 2026 SaltyLab +# Licensed under MIT + +from ament_pep257.main import main +import pytest + + +@pytest.mark.pep257 +@pytest.mark.linter +def test_pep257(): + rc = main(argv=['.', 'test']) + assert rc == 0, 'Found code style errors / warnings' -- 2.47.2