feat(social): navigation + follow modes + MiDaS depth + waypoints (Issue #91)

- saltybot_social_msgs: full message/service definitions (standalone compilation)
- saltybot_social_nav: social navigation orchestrator
  - Follow modes: shadow/lead/side/orbit/loose/tight
  - Voice steering: mode switching + route commands via /social/speech/*
  - A* obstacle avoidance on Nav2/SLAM occupancy grid (8-directional, inflation)
  - MiDaS monocular depth for CSI cameras (TRT FP16 + ONNX fallback)
  - Waypoint teaching + replay with WaypointRoute persistence
  - High-speed EUC tracking (5.5 m/s = ~20 km/h)
  - Predictive position extrapolation (0.3s ahead at high speed)
- Launch: social_nav.launch.py (social_nav + midas_depth + waypoint_teacher)
- Config: social_nav_params.yaml
- Script: build_midas_trt_engine.py (ONNX -> TRT FP16)
This commit is contained in:
sl-perception 2026-03-01 23:15:00 -05:00
parent ac6fcb9a42
commit d872ea5e34
29 changed files with 1567 additions and 0 deletions

View File

@ -0,0 +1,24 @@
cmake_minimum_required(VERSION 3.8)
project(saltybot_social_msgs)
find_package(ament_cmake REQUIRED)
find_package(rosidl_default_generators REQUIRED)
find_package(std_msgs REQUIRED)
find_package(geometry_msgs REQUIRED)
find_package(builtin_interfaces REQUIRED)
rosidl_generate_interfaces(${PROJECT_NAME}
"msg/FaceDetection.msg"
"msg/FaceDetectionArray.msg"
"msg/FaceEmbedding.msg"
"msg/FaceEmbeddingArray.msg"
"msg/PersonState.msg"
"msg/PersonStateArray.msg"
"srv/EnrollPerson.srv"
"srv/ListPersons.srv"
"srv/DeletePerson.srv"
"srv/UpdatePerson.srv"
DEPENDENCIES std_msgs geometry_msgs builtin_interfaces
)
ament_package()

View File

@ -0,0 +1,10 @@
std_msgs/Header header
int32 face_id
string person_name
float32 confidence
float32 recognition_score
float32 bbox_x
float32 bbox_y
float32 bbox_w
float32 bbox_h
float32[10] landmarks

View File

@ -0,0 +1,2 @@
std_msgs/Header header
saltybot_social_msgs/FaceDetection[] faces

View File

@ -0,0 +1,5 @@
int32 person_id
string person_name
float32[] embedding
builtin_interfaces/Time enrolled_at
int32 sample_count

View File

@ -0,0 +1,2 @@
std_msgs/Header header
saltybot_social_msgs/FaceEmbedding[] embeddings

View File

@ -0,0 +1,19 @@
std_msgs/Header header
int32 person_id
string person_name
int32 face_id
string speaker_id
string uwb_anchor_id
geometry_msgs/Point position
float32 distance
float32 bearing_deg
uint8 state
uint8 STATE_UNKNOWN=0
uint8 STATE_APPROACHING=1
uint8 STATE_ENGAGED=2
uint8 STATE_TALKING=3
uint8 STATE_LEAVING=4
uint8 STATE_ABSENT=5
float32 engagement_score
builtin_interfaces/Time last_seen
int32 camera_id

View File

@ -0,0 +1,3 @@
std_msgs/Header header
saltybot_social_msgs/PersonState[] persons
int32 primary_attention_id

View File

@ -0,0 +1,19 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social_msgs</name>
<version>0.1.0</version>
<description>Custom ROS2 messages and services for saltybot social capabilities</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<buildtool_depend>ament_cmake</buildtool_depend>
<depend>std_msgs</depend>
<depend>geometry_msgs</depend>
<depend>builtin_interfaces</depend>
<build_depend>rosidl_default_generators</build_depend>
<exec_depend>rosidl_default_runtime</exec_depend>
<member_of_group>rosidl_interface_packages</member_of_group>
<export>
<build_type>ament_cmake</build_type>
</export>
</package>

View File

@ -0,0 +1,4 @@
int32 person_id
---
bool success
string message

View File

@ -0,0 +1,7 @@
string name
string mode
int32 n_samples
---
bool success
string message
int32 person_id

View File

@ -0,0 +1,2 @@
---
saltybot_social_msgs/FaceEmbedding[] persons

View File

@ -0,0 +1,5 @@
int32 person_id
string new_name
---
bool success
string message

View File

@ -0,0 +1,22 @@
social_nav:
ros__parameters:
follow_mode: 'shadow'
follow_distance: 1.2
lead_distance: 2.0
orbit_radius: 1.5
max_linear_speed: 1.0
max_linear_speed_fast: 5.5
max_angular_speed: 1.0
goal_tolerance: 0.3
routes_dir: '/mnt/nvme/saltybot/routes'
home_x: 0.0
home_y: 0.0
map_resolution: 0.05
obstacle_inflation_cells: 3
midas_depth:
ros__parameters:
onnx_path: '/mnt/nvme/saltybot/models/midas_small.onnx'
engine_path: '/mnt/nvme/saltybot/models/midas_small.engine'
process_rate: 5.0
output_scale: 1.0

View File

@ -0,0 +1,57 @@
"""Launch file for saltybot social navigation."""
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
return LaunchDescription([
DeclareLaunchArgument('follow_mode', default_value='shadow',
description='Follow mode: shadow/lead/side/orbit/loose/tight'),
DeclareLaunchArgument('follow_distance', default_value='1.2',
description='Follow distance in meters'),
DeclareLaunchArgument('max_linear_speed', default_value='1.0',
description='Max linear speed (m/s)'),
DeclareLaunchArgument('routes_dir',
default_value='/mnt/nvme/saltybot/routes',
description='Directory for saved routes'),
Node(
package='saltybot_social_nav',
executable='social_nav',
name='social_nav',
output='screen',
parameters=[{
'follow_mode': LaunchConfiguration('follow_mode'),
'follow_distance': LaunchConfiguration('follow_distance'),
'max_linear_speed': LaunchConfiguration('max_linear_speed'),
'routes_dir': LaunchConfiguration('routes_dir'),
}],
),
Node(
package='saltybot_social_nav',
executable='midas_depth',
name='midas_depth',
output='screen',
parameters=[{
'onnx_path': '/mnt/nvme/saltybot/models/midas_small.onnx',
'engine_path': '/mnt/nvme/saltybot/models/midas_small.engine',
'process_rate': 5.0,
'output_scale': 1.0,
}],
),
Node(
package='saltybot_social_nav',
executable='waypoint_teacher',
name='waypoint_teacher',
output='screen',
parameters=[{
'routes_dir': LaunchConfiguration('routes_dir'),
'recording_interval': 0.5,
}],
),
])

View File

@ -0,0 +1,28 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social_nav</name>
<version>0.1.0</version>
<description>Social navigation for saltybot: follow modes, waypoint teaching, A* avoidance, MiDaS depth</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>std_msgs</depend>
<depend>geometry_msgs</depend>
<depend>nav_msgs</depend>
<depend>sensor_msgs</depend>
<depend>cv_bridge</depend>
<depend>tf2_ros</depend>
<depend>tf2_geometry_msgs</depend>
<depend>saltybot_social_msgs</depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,82 @@
"""astar.py -- A* path planner for saltybot social navigation."""
import heapq
import numpy as np
def astar(grid: np.ndarray, start: tuple, goal: tuple,
obstacle_val: int = 100) -> list | None:
"""
A* on a 2D occupancy grid (row, col indexing).
Args:
grid: 2D numpy array, values 0=free, >=obstacle_val=obstacle
start: (row, col) start cell
goal: (row, col) goal cell
obstacle_val: cells with value >= obstacle_val are blocked
Returns:
List of (row, col) tuples from start to goal, or None if no path.
"""
rows, cols = grid.shape
def h(a, b):
return abs(a[0] - b[0]) + abs(a[1] - b[1]) # Manhattan heuristic
open_set = []
heapq.heappush(open_set, (h(start, goal), 0, start))
came_from = {}
g_score = {start: 0}
# 8-directional movement
neighbors_delta = [
(-1, -1), (-1, 0), (-1, 1),
(0, -1), (0, 1),
(1, -1), (1, 0), (1, 1),
]
while open_set:
_, cost, current = heapq.heappop(open_set)
if current == goal:
path = []
while current in came_from:
path.append(current)
current = came_from[current]
path.append(start)
return list(reversed(path))
if cost > g_score.get(current, float('inf')):
continue
for dr, dc in neighbors_delta:
nr, nc = current[0] + dr, current[1] + dc
if not (0 <= nr < rows and 0 <= nc < cols):
continue
if grid[nr, nc] >= obstacle_val:
continue
move_cost = 1.414 if (dr != 0 and dc != 0) else 1.0
new_g = g_score[current] + move_cost
neighbor = (nr, nc)
if new_g < g_score.get(neighbor, float('inf')):
g_score[neighbor] = new_g
f = new_g + h(neighbor, goal)
came_from[neighbor] = current
heapq.heappush(open_set, (f, new_g, neighbor))
return None # No path found
def inflate_obstacles(grid: np.ndarray, inflation_radius_cells: int) -> np.ndarray:
"""Inflate obstacles for robot footprint safety."""
from scipy.ndimage import binary_dilation
obstacle_mask = grid >= 50
kernel = np.ones(
(2 * inflation_radius_cells + 1, 2 * inflation_radius_cells + 1),
dtype=bool,
)
inflated = binary_dilation(obstacle_mask, structure=kernel)
result = grid.copy()
result[inflated] = 100
return result

View File

@ -0,0 +1,82 @@
"""follow_modes.py -- Follow mode geometry for saltybot social navigation."""
import math
from enum import Enum
import numpy as np
class FollowMode(Enum):
SHADOW = 'shadow' # stay directly behind at follow_distance
LEAD = 'lead' # move ahead of person by lead_distance
SIDE = 'side' # stay to the right (or left) at side_offset
ORBIT = 'orbit' # circle around person at orbit_radius
LOOSE = 'loose' # general follow, larger tolerance
TIGHT = 'tight' # close follow, small tolerance
def compute_shadow_target(person_pos, person_bearing_deg, follow_dist=1.2):
"""Target position: behind person along their movement direction."""
bearing_rad = math.radians(person_bearing_deg + 180.0)
tx = person_pos[0] + follow_dist * math.sin(bearing_rad)
ty = person_pos[1] + follow_dist * math.cos(bearing_rad)
return (tx, ty, person_pos[2])
def compute_lead_target(person_pos, person_bearing_deg, lead_dist=2.0):
"""Target position: ahead of person."""
bearing_rad = math.radians(person_bearing_deg)
tx = person_pos[0] + lead_dist * math.sin(bearing_rad)
ty = person_pos[1] + lead_dist * math.cos(bearing_rad)
return (tx, ty, person_pos[2])
def compute_side_target(person_pos, person_bearing_deg, side_dist=1.0, right=True):
"""Target position: to the right (or left) of person."""
sign = 1.0 if right else -1.0
bearing_rad = math.radians(person_bearing_deg + sign * 90.0)
tx = person_pos[0] + side_dist * math.sin(bearing_rad)
ty = person_pos[1] + side_dist * math.cos(bearing_rad)
return (tx, ty, person_pos[2])
def compute_orbit_target(person_pos, orbit_angle_deg, orbit_radius=1.5):
"""Target on circle of radius orbit_radius around person."""
angle_rad = math.radians(orbit_angle_deg)
tx = person_pos[0] + orbit_radius * math.sin(angle_rad)
ty = person_pos[1] + orbit_radius * math.cos(angle_rad)
return (tx, ty, person_pos[2])
def compute_loose_target(person_pos, robot_pos, follow_dist=2.0, tolerance=0.8):
"""Only move if farther than follow_dist + tolerance."""
dx = person_pos[0] - robot_pos[0]
dy = person_pos[1] - robot_pos[1]
dist = math.hypot(dx, dy)
if dist <= follow_dist + tolerance:
return robot_pos
# Target at follow_dist behind person (toward robot)
scale = (dist - follow_dist) / dist
return (robot_pos[0] + dx * scale, robot_pos[1] + dy * scale, person_pos[2])
def compute_tight_target(person_pos, follow_dist=0.6):
"""Close follow: stay very near person."""
return (person_pos[0], person_pos[1] - follow_dist, person_pos[2])
MODE_VOICE_COMMANDS = {
'shadow': FollowMode.SHADOW,
'follow me': FollowMode.SHADOW,
'behind me': FollowMode.SHADOW,
'lead': FollowMode.LEAD,
'go ahead': FollowMode.LEAD,
'lead me': FollowMode.LEAD,
'side': FollowMode.SIDE,
'stay beside': FollowMode.SIDE,
'orbit': FollowMode.ORBIT,
'circle me': FollowMode.ORBIT,
'loose': FollowMode.LOOSE,
'give me space': FollowMode.LOOSE,
'tight': FollowMode.TIGHT,
'stay close': FollowMode.TIGHT,
}

View File

@ -0,0 +1,231 @@
"""
midas_depth_node.py -- MiDaS monocular depth estimation for saltybot.
Uses MiDaS_small via ONNX Runtime or TensorRT FP16.
Provides relative depth estimates for cameras without active depth (CSI cameras).
Publishes /social/depth/cam{i}/image (sensor_msgs/Image, float32, relative depth)
"""
import os
import numpy as np
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from sensor_msgs.msg import Image
from cv_bridge import CvBridge
# MiDaS_small input size
_MIDAS_H = 256
_MIDAS_W = 256
# ImageNet normalization
_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
class _TRTBackend:
"""TensorRT inference backend for MiDaS."""
def __init__(self, engine_path: str, logger):
self._logger = logger
try:
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit # noqa: F401
self._cuda = cuda
rt_logger = trt.Logger(trt.Logger.WARNING)
with open(engine_path, 'rb') as f:
engine = trt.Runtime(rt_logger).deserialize_cuda_engine(f.read())
self._context = engine.create_execution_context()
# Allocate buffers
self._d_input = cuda.mem_alloc(1 * 3 * _MIDAS_H * _MIDAS_W * 4)
self._d_output = cuda.mem_alloc(1 * _MIDAS_H * _MIDAS_W * 4)
self._h_output = np.empty((_MIDAS_H, _MIDAS_W), dtype=np.float32)
self._stream = cuda.Stream()
self._logger.info(f'TRT engine loaded: {engine_path}')
except Exception as e:
raise RuntimeError(f'TRT init failed: {e}')
def infer(self, input_tensor: np.ndarray) -> np.ndarray:
self._cuda.memcpy_htod_async(
self._d_input, input_tensor.ravel(), self._stream)
self._context.execute_async_v2(
bindings=[int(self._d_input), int(self._d_output)],
stream_handle=self._stream.handle)
self._cuda.memcpy_dtoh_async(
self._h_output, self._d_output, self._stream)
self._stream.synchronize()
return self._h_output.copy()
class _ONNXBackend:
"""ONNX Runtime inference backend for MiDaS."""
def __init__(self, onnx_path: str, logger):
self._logger = logger
try:
import onnxruntime as ort
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
self._session = ort.InferenceSession(onnx_path, providers=providers)
self._input_name = self._session.get_inputs()[0].name
self._logger.info(f'ONNX model loaded: {onnx_path}')
except Exception as e:
raise RuntimeError(f'ONNX init failed: {e}')
def infer(self, input_tensor: np.ndarray) -> np.ndarray:
result = self._session.run(None, {self._input_name: input_tensor})
return result[0].squeeze()
class MiDaSDepthNode(Node):
"""MiDaS monocular depth estimation node."""
def __init__(self):
super().__init__('midas_depth')
# Parameters
self.declare_parameter('onnx_path',
'/mnt/nvme/saltybot/models/midas_small.onnx')
self.declare_parameter('engine_path',
'/mnt/nvme/saltybot/models/midas_small.engine')
self.declare_parameter('camera_topics', [
'/surround/cam0/image_raw',
'/surround/cam1/image_raw',
'/surround/cam2/image_raw',
'/surround/cam3/image_raw',
])
self.declare_parameter('output_scale', 1.0)
self.declare_parameter('process_rate', 5.0)
onnx_path = self.get_parameter('onnx_path').value
engine_path = self.get_parameter('engine_path').value
self._camera_topics = self.get_parameter('camera_topics').value
self._output_scale = self.get_parameter('output_scale').value
process_rate = self.get_parameter('process_rate').value
# Initialize inference backend (TRT preferred, ONNX fallback)
self._backend = None
if os.path.exists(engine_path):
try:
self._backend = _TRTBackend(engine_path, self.get_logger())
except RuntimeError:
self.get_logger().warn('TRT failed, trying ONNX fallback')
if self._backend is None and os.path.exists(onnx_path):
try:
self._backend = _ONNXBackend(onnx_path, self.get_logger())
except RuntimeError:
self.get_logger().error('Both TRT and ONNX backends failed')
if self._backend is None:
self.get_logger().error(
'No MiDaS model found. Depth estimation disabled.')
self._bridge = CvBridge()
# Latest frames per camera (round-robin processing)
self._latest_frames = [None] * len(self._camera_topics)
self._current_cam_idx = 0
# QoS for camera subscriptions
cam_qos = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=1,
)
# Subscribe to each camera topic
self._cam_subs = []
for i, topic in enumerate(self._camera_topics):
sub = self.create_subscription(
Image, topic,
lambda msg, idx=i: self._on_image(msg, idx),
cam_qos)
self._cam_subs.append(sub)
# Publishers: one per camera
self._depth_pubs = []
for i in range(len(self._camera_topics)):
pub = self.create_publisher(
Image, f'/social/depth/cam{i}/image', 10)
self._depth_pubs.append(pub)
# Timer: round-robin across cameras
timer_period = 1.0 / process_rate
self._timer = self.create_timer(timer_period, self._timer_callback)
self.get_logger().info(
f'MiDaS depth node started: {len(self._camera_topics)} cameras '
f'@ {process_rate} Hz')
def _on_image(self, msg: Image, cam_idx: int):
"""Cache latest frame for each camera."""
self._latest_frames[cam_idx] = msg
def _preprocess(self, bgr: np.ndarray) -> np.ndarray:
"""Preprocess BGR image to MiDaS input tensor [1,3,256,256]."""
import cv2
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
resized = cv2.resize(rgb, (_MIDAS_W, _MIDAS_H),
interpolation=cv2.INTER_LINEAR)
normalized = (resized.astype(np.float32) / 255.0 - _MEAN) / _STD
# HWC -> CHW -> NCHW
tensor = normalized.transpose(2, 0, 1)[np.newaxis, ...]
return tensor.astype(np.float32)
def _infer(self, tensor: np.ndarray) -> np.ndarray:
"""Run inference, returns [256,256] float32 relative inverse depth."""
if self._backend is None:
return np.zeros((_MIDAS_H, _MIDAS_W), dtype=np.float32)
return self._backend.infer(tensor)
def _postprocess(self, raw: np.ndarray, orig_shape: tuple) -> np.ndarray:
"""Resize depth back to original image shape, apply output_scale."""
import cv2
h, w = orig_shape[:2]
depth = cv2.resize(raw, (w, h), interpolation=cv2.INTER_LINEAR)
depth = depth * self._output_scale
return depth
def _timer_callback(self):
"""Process one camera per tick (round-robin)."""
if not self._camera_topics:
return
idx = self._current_cam_idx
self._current_cam_idx = (idx + 1) % len(self._camera_topics)
msg = self._latest_frames[idx]
if msg is None:
return
try:
bgr = self._bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')
except Exception as e:
self.get_logger().warn(f'cv_bridge error cam{idx}: {e}')
return
tensor = self._preprocess(bgr)
raw_depth = self._infer(tensor)
depth_map = self._postprocess(raw_depth, bgr.shape)
# Publish as float32 Image
depth_msg = self._bridge.cv2_to_imgmsg(depth_map, encoding='32FC1')
depth_msg.header = msg.header
self._depth_pubs[idx].publish(depth_msg)
def main(args=None):
rclpy.init(args=args)
node = MiDaSDepthNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,584 @@
"""
social_nav_node.py -- Social navigation node for saltybot.
Orchestrates person following with multiple modes, voice commands,
waypoint teaching/replay, and A* obstacle avoidance.
Follow modes:
shadow -- stay directly behind at follow_distance
lead -- move ahead of person
side -- stay beside (default right)
orbit -- circle around person
loose -- relaxed follow with deadband
tight -- close follow
Waypoint teaching:
Voice command "teach route <name>" -> record mode ON
Voice command "stop teaching" -> save route
Voice command "replay route <name>" -> playback
Voice commands:
"follow me" / "shadow" -> SHADOW mode
"lead me" / "go ahead" -> LEAD mode
"stay beside" -> SIDE mode
"orbit" -> ORBIT mode
"give me space" -> LOOSE mode
"stay close" -> TIGHT mode
"stop" / "halt" -> STOP
"go home" -> navigate to home waypoint
"teach route <name>" -> start recording
"stop teaching" -> finish recording
"replay route <name>" -> playback recorded route
"""
import math
import time
import re
from collections import deque
import numpy as np
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from geometry_msgs.msg import Twist, PoseStamped
from nav_msgs.msg import OccupancyGrid, Odometry
from std_msgs.msg import String, Int32
from .follow_modes import (
FollowMode, MODE_VOICE_COMMANDS,
compute_shadow_target, compute_lead_target, compute_side_target,
compute_orbit_target, compute_loose_target, compute_tight_target,
)
from .astar import astar, inflate_obstacles
from .waypoint_teacher import WaypointRoute, WaypointReplayer
# Try importing social msgs; fallback gracefully
try:
from saltybot_social_msgs.msg import PersonStateArray
_HAS_SOCIAL_MSGS = True
except ImportError:
_HAS_SOCIAL_MSGS = False
# Proportional controller gains
_K_ANG = 2.0 # angular gain
_K_LIN = 0.8 # linear gain
_HIGH_SPEED_THRESHOLD = 3.0 # m/s person velocity triggers fast mode
_PREDICT_AHEAD_S = 0.3 # seconds to extrapolate position
_TEACH_MIN_DIST = 0.5 # meters between recorded waypoints
class SocialNavNode(Node):
"""Main social navigation orchestrator."""
def __init__(self):
super().__init__('social_nav')
# -- Parameters --
self.declare_parameter('follow_mode', 'shadow')
self.declare_parameter('follow_distance', 1.2)
self.declare_parameter('lead_distance', 2.0)
self.declare_parameter('orbit_radius', 1.5)
self.declare_parameter('max_linear_speed', 1.0)
self.declare_parameter('max_linear_speed_fast', 5.5)
self.declare_parameter('max_angular_speed', 1.0)
self.declare_parameter('goal_tolerance', 0.3)
self.declare_parameter('routes_dir', '/mnt/nvme/saltybot/routes')
self.declare_parameter('home_x', 0.0)
self.declare_parameter('home_y', 0.0)
self.declare_parameter('map_resolution', 0.05)
self.declare_parameter('obstacle_inflation_cells', 3)
self._follow_mode = FollowMode(
self.get_parameter('follow_mode').value)
self._follow_distance = self.get_parameter('follow_distance').value
self._lead_distance = self.get_parameter('lead_distance').value
self._orbit_radius = self.get_parameter('orbit_radius').value
self._max_lin = self.get_parameter('max_linear_speed').value
self._max_lin_fast = self.get_parameter('max_linear_speed_fast').value
self._max_ang = self.get_parameter('max_angular_speed').value
self._goal_tol = self.get_parameter('goal_tolerance').value
self._routes_dir = self.get_parameter('routes_dir').value
self._home_x = self.get_parameter('home_x').value
self._home_y = self.get_parameter('home_y').value
self._map_resolution = self.get_parameter('map_resolution').value
self._inflation_cells = self.get_parameter(
'obstacle_inflation_cells').value
# -- State --
self._robot_x = 0.0
self._robot_y = 0.0
self._robot_yaw = 0.0
self._target_person_pos = None # (x, y, z)
self._target_person_bearing = 0.0
self._target_person_id = -1
self._person_history = deque(maxlen=5) # for velocity estimation
self._stopped = False
self._go_home = False
# Occupancy grid for A*
self._occ_grid = None
self._occ_origin = (0.0, 0.0)
self._occ_resolution = 0.05
self._astar_path = None
# Orbit state
self._orbit_angle = 0.0
# Waypoint teaching / replay
self._teaching = False
self._current_route = None
self._last_teach_x = None
self._last_teach_y = None
self._replayer = None
# -- QoS profiles --
best_effort_qos = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST, depth=1)
reliable_qos = QoSProfile(
reliability=ReliabilityPolicy.RELIABLE,
history=HistoryPolicy.KEEP_LAST, depth=1)
# -- Subscriptions --
if _HAS_SOCIAL_MSGS:
self.create_subscription(
PersonStateArray, '/social/persons',
self._on_persons, best_effort_qos)
else:
self.get_logger().warn(
'saltybot_social_msgs not found; '
'using /person/target fallback')
self.create_subscription(
PoseStamped, '/person/target',
self._on_person_target, best_effort_qos)
self.create_subscription(
String, '/social/speech/command',
self._on_voice_command, 10)
self.create_subscription(
String, '/social/speech/transcript',
self._on_transcript, 10)
self.create_subscription(
OccupancyGrid, '/map',
self._on_map, reliable_qos)
self.create_subscription(
Odometry, '/odom',
self._on_odom, best_effort_qos)
self.create_subscription(
Int32, '/social/attention/target_id',
self._on_target_id, 10)
# -- Publishers --
self._cmd_vel_pub = self.create_publisher(
Twist, '/cmd_vel', best_effort_qos)
self._mode_pub = self.create_publisher(
String, '/social/nav/mode', reliable_qos)
self._target_pub = self.create_publisher(
PoseStamped, '/social/nav/target_pos', 10)
self._status_pub = self.create_publisher(
String, '/social/nav/status', best_effort_qos)
# -- Main loop timer (20 Hz) --
self._timer = self.create_timer(0.05, self._control_loop)
self.get_logger().info(
f'Social nav started: mode={self._follow_mode.value}, '
f'dist={self._follow_distance}m')
# ----------------------------------------------------------------
# Subscriptions
# ----------------------------------------------------------------
def _on_persons(self, msg):
"""Handle PersonStateArray from social perception."""
target_id = msg.primary_attention_id
if self._target_person_id >= 0:
target_id = self._target_person_id
for p in msg.persons:
if p.person_id == target_id or (
target_id < 0 and len(msg.persons) > 0):
pos = (p.position.x, p.position.y, p.position.z)
self._update_person_position(pos, p.bearing_deg)
break
def _on_person_target(self, msg: PoseStamped):
"""Fallback: single person target pose."""
pos = (msg.pose.position.x, msg.pose.position.y,
msg.pose.position.z)
# Estimate bearing from quaternion yaw
q = msg.pose.orientation
yaw = math.atan2(2.0 * (q.w * q.z + q.x * q.y),
1.0 - 2.0 * (q.y * q.y + q.z * q.z))
self._update_person_position(pos, math.degrees(yaw))
def _update_person_position(self, pos, bearing_deg):
"""Update person tracking state and record history."""
now = time.time()
self._target_person_pos = pos
self._target_person_bearing = bearing_deg
self._person_history.append((now, pos[0], pos[1]))
def _on_odom(self, msg: Odometry):
"""Update robot pose from odometry."""
self._robot_x = msg.pose.pose.position.x
self._robot_y = msg.pose.pose.position.y
q = msg.pose.pose.orientation
self._robot_yaw = math.atan2(
2.0 * (q.w * q.z + q.x * q.y),
1.0 - 2.0 * (q.y * q.y + q.z * q.z))
def _on_map(self, msg: OccupancyGrid):
"""Cache occupancy grid for A* planning."""
w, h = msg.info.width, msg.info.height
data = np.array(msg.data, dtype=np.int8).reshape((h, w))
# Convert -1 (unknown) to free (0) for planning
data[data < 0] = 0
self._occ_grid = data.astype(np.int32)
self._occ_origin = (msg.info.origin.position.x,
msg.info.origin.position.y)
self._occ_resolution = msg.info.resolution
def _on_target_id(self, msg: Int32):
"""Switch target person."""
self._target_person_id = msg.data
self.get_logger().info(f'Target person ID set to {msg.data}')
def _on_voice_command(self, msg: String):
"""Handle discrete voice commands for mode switching."""
cmd = msg.data.strip().lower()
if cmd in ('stop', 'halt'):
self._stopped = True
self._replayer = None
self._publish_status('STOPPED')
return
if cmd in ('resume', 'go', 'start'):
self._stopped = False
self._publish_status('RESUMED')
return
matched = MODE_VOICE_COMMANDS.get(cmd)
if matched:
self._follow_mode = matched
self._stopped = False
mode_msg = String()
mode_msg.data = self._follow_mode.value
self._mode_pub.publish(mode_msg)
self._publish_status(f'MODE: {self._follow_mode.value}')
def _on_transcript(self, msg: String):
"""Handle free-form voice transcripts for route teaching."""
text = msg.data.strip().lower()
# "teach route <name>"
m = re.match(r'teach\s+route\s+(\w+)', text)
if m:
name = m.group(1)
self._teaching = True
self._current_route = WaypointRoute(name)
self._last_teach_x = self._robot_x
self._last_teach_y = self._robot_y
self._publish_status(f'TEACHING: {name}')
self.get_logger().info(f'Recording route: {name}')
return
# "stop teaching"
if 'stop teaching' in text:
if self._teaching and self._current_route:
self._current_route.save(self._routes_dir)
self._publish_status(
f'SAVED: {self._current_route.name} '
f'({len(self._current_route.waypoints)} pts)')
self.get_logger().info(
f'Route saved: {self._current_route.name}')
self._teaching = False
self._current_route = None
return
# "replay route <name>"
m = re.match(r'replay\s+route\s+(\w+)', text)
if m:
name = m.group(1)
try:
route = WaypointRoute.load(self._routes_dir, name)
self._replayer = WaypointReplayer(route)
self._stopped = False
self._publish_status(f'REPLAY: {name}')
self.get_logger().info(f'Replaying route: {name}')
except FileNotFoundError:
self._publish_status(f'ROUTE NOT FOUND: {name}')
return
# "go home"
if 'go home' in text:
self._go_home = True
self._stopped = False
self._publish_status('GO HOME')
return
# Also try mode commands from transcript
for phrase, mode in MODE_VOICE_COMMANDS.items():
if phrase in text:
self._follow_mode = mode
self._stopped = False
mode_msg = String()
mode_msg.data = self._follow_mode.value
self._mode_pub.publish(mode_msg)
self._publish_status(f'MODE: {self._follow_mode.value}')
return
# ----------------------------------------------------------------
# Control loop
# ----------------------------------------------------------------
def _control_loop(self):
"""Main 20Hz control loop."""
# Record waypoint if teaching
if self._teaching and self._current_route:
self._maybe_record_waypoint()
# Publish zero velocity if stopped
if self._stopped:
self._publish_cmd_vel(0.0, 0.0)
return
# Determine navigation target
target = self._get_nav_target()
if target is None:
self._publish_cmd_vel(0.0, 0.0)
return
tx, ty, tz = target
# Publish debug target
self._publish_target_pose(tx, ty, tz)
# Check if arrived
dist_to_target = math.hypot(tx - self._robot_x, ty - self._robot_y)
if dist_to_target < self._goal_tol:
self._publish_cmd_vel(0.0, 0.0)
if self._go_home:
self._go_home = False
self._publish_status('HOME REACHED')
return
# Try A* path if map available
if self._occ_grid is not None:
path_target = self._plan_astar(tx, ty)
if path_target:
tx, ty = path_target
# Determine speed limit
max_lin = self._max_lin
person_vel = self._estimate_person_velocity()
if person_vel > _HIGH_SPEED_THRESHOLD:
max_lin = self._max_lin_fast
# Compute and publish cmd_vel
lin, ang = self._compute_cmd_vel(
self._robot_x, self._robot_y, self._robot_yaw,
tx, ty, max_lin)
self._publish_cmd_vel(lin, ang)
def _get_nav_target(self):
"""Determine current navigation target based on mode/state."""
# Route replay takes priority
if self._replayer and not self._replayer.is_done:
self._replayer.check_arrived(self._robot_x, self._robot_y)
wp = self._replayer.current_waypoint()
if wp:
return (wp.x, wp.y, wp.z)
else:
self._replayer = None
self._publish_status('REPLAY DONE')
return None
# Go home
if self._go_home:
return (self._home_x, self._home_y, 0.0)
# Person following
if self._target_person_pos is None:
return None
# Predict person position ahead for high-speed tracking
px, py, pz = self._predict_person_position()
bearing = self._target_person_bearing
robot_pos = (self._robot_x, self._robot_y, 0.0)
if self._follow_mode == FollowMode.SHADOW:
return compute_shadow_target(
(px, py, pz), bearing, self._follow_distance)
elif self._follow_mode == FollowMode.LEAD:
return compute_lead_target(
(px, py, pz), bearing, self._lead_distance)
elif self._follow_mode == FollowMode.SIDE:
return compute_side_target(
(px, py, pz), bearing, self._follow_distance)
elif self._follow_mode == FollowMode.ORBIT:
self._orbit_angle = (self._orbit_angle + 1.0) % 360.0
return compute_orbit_target(
(px, py, pz), self._orbit_angle, self._orbit_radius)
elif self._follow_mode == FollowMode.LOOSE:
return compute_loose_target(
(px, py, pz), robot_pos, self._follow_distance)
elif self._follow_mode == FollowMode.TIGHT:
return compute_tight_target(
(px, py, pz), self._follow_distance)
return (px, py, pz)
def _predict_person_position(self):
"""Extrapolate person position using velocity from recent samples."""
if self._target_person_pos is None:
return (0.0, 0.0, 0.0)
px, py, pz = self._target_person_pos
if len(self._person_history) >= 3:
# Use last 3 samples for velocity estimation
t0, x0, y0 = self._person_history[-3]
t1, x1, y1 = self._person_history[-1]
dt = t1 - t0
if dt > 0.01:
vx = (x1 - x0) / dt
vy = (y1 - y0) / dt
speed = math.hypot(vx, vy)
if speed > _HIGH_SPEED_THRESHOLD:
px += vx * _PREDICT_AHEAD_S
py += vy * _PREDICT_AHEAD_S
return (px, py, pz)
def _estimate_person_velocity(self) -> float:
"""Estimate person speed from recent position history."""
if len(self._person_history) < 2:
return 0.0
t0, x0, y0 = self._person_history[-2]
t1, x1, y1 = self._person_history[-1]
dt = t1 - t0
if dt < 0.01:
return 0.0
return math.hypot(x1 - x0, y1 - y0) / dt
def _plan_astar(self, target_x, target_y):
"""Run A* on occupancy grid, return next waypoint in world coords."""
grid = self._occ_grid
res = self._occ_resolution
ox, oy = self._occ_origin
# World to grid
def w2g(wx, wy):
return (int((wy - oy) / res), int((wx - ox) / res))
start = w2g(self._robot_x, self._robot_y)
goal = w2g(target_x, target_y)
rows, cols = grid.shape
if not (0 <= start[0] < rows and 0 <= start[1] < cols):
return None
if not (0 <= goal[0] < rows and 0 <= goal[1] < cols):
return None
inflated = inflate_obstacles(grid, self._inflation_cells)
path = astar(inflated, start, goal)
if path and len(path) > 1:
# Follow a lookahead point (3 steps ahead or end)
lookahead_idx = min(3, len(path) - 1)
r, c = path[lookahead_idx]
wx = ox + c * res + res / 2.0
wy = oy + r * res + res / 2.0
self._astar_path = path
return (wx, wy)
return None
def _compute_cmd_vel(self, rx, ry, ryaw, tx, ty, max_lin):
"""Proportional controller: compute linear and angular velocity."""
dx = tx - rx
dy = ty - ry
dist = math.hypot(dx, dy)
angle_to_target = math.atan2(dy, dx)
angle_error = angle_to_target - ryaw
# Normalize angle error to [-pi, pi]
while angle_error > math.pi:
angle_error -= 2.0 * math.pi
while angle_error < -math.pi:
angle_error += 2.0 * math.pi
angular_vel = _K_ANG * angle_error
angular_vel = max(-self._max_ang,
min(self._max_ang, angular_vel))
# Reduce linear speed when turning hard
angle_factor = max(0.0, 1.0 - abs(angle_error) / (math.pi / 2.0))
linear_vel = _K_LIN * dist * angle_factor
linear_vel = max(0.0, min(max_lin, linear_vel))
return (linear_vel, angular_vel)
# ----------------------------------------------------------------
# Waypoint teaching
# ----------------------------------------------------------------
def _maybe_record_waypoint(self):
"""Record waypoint if robot moved > _TEACH_MIN_DIST."""
if self._last_teach_x is None:
self._last_teach_x = self._robot_x
self._last_teach_y = self._robot_y
dist = math.hypot(
self._robot_x - self._last_teach_x,
self._robot_y - self._last_teach_y)
if dist >= _TEACH_MIN_DIST:
yaw_deg = math.degrees(self._robot_yaw)
self._current_route.add(
self._robot_x, self._robot_y, 0.0, yaw_deg)
self._last_teach_x = self._robot_x
self._last_teach_y = self._robot_y
# ----------------------------------------------------------------
# Publishers
# ----------------------------------------------------------------
def _publish_cmd_vel(self, linear: float, angular: float):
twist = Twist()
twist.linear.x = linear
twist.angular.z = angular
self._cmd_vel_pub.publish(twist)
def _publish_target_pose(self, x, y, z):
msg = PoseStamped()
msg.header.stamp = self.get_clock().now().to_msg()
msg.header.frame_id = 'map'
msg.pose.position.x = x
msg.pose.position.y = y
msg.pose.position.z = z
self._target_pub.publish(msg)
def _publish_status(self, status: str):
msg = String()
msg.data = status
self._status_pub.publish(msg)
self.get_logger().info(f'Nav status: {status}')
def main(args=None):
rclpy.init(args=args)
node = SocialNavNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,91 @@
"""waypoint_teacher.py -- Record and replay waypoint routes."""
import json
import time
import math
from pathlib import Path
from dataclasses import dataclass, asdict
@dataclass
class Waypoint:
x: float
y: float
z: float
yaw_deg: float
timestamp: float
label: str = ''
class WaypointRoute:
"""A named sequence of waypoints."""
def __init__(self, name: str):
self.name = name
self.waypoints: list[Waypoint] = []
self.created_at = time.time()
def add(self, x, y, z, yaw_deg, label=''):
self.waypoints.append(Waypoint(x, y, z, yaw_deg, time.time(), label))
def to_dict(self):
return {
'name': self.name,
'created_at': self.created_at,
'waypoints': [asdict(w) for w in self.waypoints],
}
@classmethod
def from_dict(cls, d):
route = cls(d['name'])
route.created_at = d.get('created_at', 0)
route.waypoints = [Waypoint(**w) for w in d['waypoints']]
return route
def save(self, routes_dir: str):
Path(routes_dir).mkdir(parents=True, exist_ok=True)
path = Path(routes_dir) / f'{self.name}.json'
path.write_text(json.dumps(self.to_dict(), indent=2))
@classmethod
def load(cls, routes_dir: str, name: str):
path = Path(routes_dir) / f'{name}.json'
return cls.from_dict(json.loads(path.read_text()))
@staticmethod
def list_routes(routes_dir: str) -> list[str]:
d = Path(routes_dir)
if not d.exists():
return []
return [p.stem for p in d.glob('*.json')]
class WaypointReplayer:
"""Iterates through waypoints, returning next target."""
def __init__(self, route: WaypointRoute, arrival_radius: float = 0.3):
self._route = route
self._idx = 0
self._arrival_radius = arrival_radius
def current_waypoint(self) -> Waypoint | None:
if self._idx < len(self._route.waypoints):
return self._route.waypoints[self._idx]
return None
def check_arrived(self, robot_x, robot_y) -> bool:
wp = self.current_waypoint()
if wp is None:
return False
dist = math.hypot(robot_x - wp.x, robot_y - wp.y)
if dist < self._arrival_radius:
self._idx += 1
return True
return False
@property
def is_done(self) -> bool:
return self._idx >= len(self._route.waypoints)
def reset(self):
self._idx = 0

View File

@ -0,0 +1,135 @@
"""
waypoint_teacher_node.py -- Standalone waypoint teacher ROS2 node.
Listens to /social/speech/transcript for "teach route <name>" and "stop teaching".
Records robot pose at configurable intervals. Saves/loads routes via WaypointRoute.
"""
import math
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from nav_msgs.msg import Odometry
from std_msgs.msg import String
from .waypoint_teacher import WaypointRoute
class WaypointTeacherNode(Node):
"""Standalone waypoint teaching node."""
def __init__(self):
super().__init__('waypoint_teacher')
self.declare_parameter('routes_dir', '/mnt/nvme/saltybot/routes')
self.declare_parameter('recording_interval', 0.5) # meters
self._routes_dir = self.get_parameter('routes_dir').value
self._interval = self.get_parameter('recording_interval').value
self._teaching = False
self._route = None
self._last_x = None
self._last_y = None
self._robot_x = 0.0
self._robot_y = 0.0
self._robot_yaw = 0.0
best_effort_qos = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST, depth=1)
self.create_subscription(
Odometry, '/odom', self._on_odom, best_effort_qos)
self.create_subscription(
String, '/social/speech/transcript',
self._on_transcript, 10)
self._status_pub = self.create_publisher(
String, '/social/waypoint/status', 10)
# Record timer at 10Hz (check distance)
self._timer = self.create_timer(0.1, self._record_tick)
self.get_logger().info(
f'Waypoint teacher ready (interval={self._interval}m, '
f'dir={self._routes_dir})')
def _on_odom(self, msg: Odometry):
self._robot_x = msg.pose.pose.position.x
self._robot_y = msg.pose.pose.position.y
q = msg.pose.pose.orientation
self._robot_yaw = math.atan2(
2.0 * (q.w * q.z + q.x * q.y),
1.0 - 2.0 * (q.y * q.y + q.z * q.z))
def _on_transcript(self, msg: String):
text = msg.data.strip().lower()
import re
m = re.match(r'teach\s+route\s+(\w+)', text)
if m:
name = m.group(1)
self._route = WaypointRoute(name)
self._teaching = True
self._last_x = self._robot_x
self._last_y = self._robot_y
self._pub_status(f'RECORDING: {name}')
self.get_logger().info(f'Recording route: {name}')
return
if 'stop teaching' in text:
if self._teaching and self._route:
self._route.save(self._routes_dir)
n = len(self._route.waypoints)
self._pub_status(
f'SAVED: {self._route.name} ({n} waypoints)')
self.get_logger().info(
f'Route saved: {self._route.name} ({n} pts)')
self._teaching = False
self._route = None
return
if 'list routes' in text:
routes = WaypointRoute.list_routes(self._routes_dir)
self._pub_status(f'ROUTES: {", ".join(routes) or "(none)"}')
def _record_tick(self):
if not self._teaching or self._route is None:
return
if self._last_x is None:
self._last_x = self._robot_x
self._last_y = self._robot_y
dist = math.hypot(
self._robot_x - self._last_x,
self._robot_y - self._last_y)
if dist >= self._interval:
yaw_deg = math.degrees(self._robot_yaw)
self._route.add(self._robot_x, self._robot_y, 0.0, yaw_deg)
self._last_x = self._robot_x
self._last_y = self._robot_y
def _pub_status(self, text: str):
msg = String()
msg.data = text
self._status_pub.publish(msg)
def main(args=None):
rclpy.init(args=args)
node = WaypointTeacherNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,80 @@
#!/usr/bin/env python3
"""
build_midas_trt_engine.py -- Build TensorRT FP16 engine for MiDaS_small from ONNX.
Usage:
python3 build_midas_trt_engine.py \
--onnx /mnt/nvme/saltybot/models/midas_small.onnx \
--engine /mnt/nvme/saltybot/models/midas_small.engine \
--fp16
Requires: tensorrt, pycuda
"""
import argparse
import os
import sys
def build_engine(onnx_path: str, engine_path: str, fp16: bool = True):
try:
import tensorrt as trt
except ImportError:
print('ERROR: tensorrt not found. Install TensorRT first.')
sys.exit(1)
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
print(f'Parsing ONNX model: {onnx_path}')
with open(onnx_path, 'rb') as f:
if not parser.parse(f.read()):
for i in range(parser.num_errors):
print(f' ONNX parse error: {parser.get_error(i)}')
sys.exit(1)
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB
if fp16 and builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
print('FP16 mode enabled')
elif fp16:
print('WARNING: FP16 not supported on this platform, using FP32')
print('Building TensorRT engine (this may take several minutes)...')
engine_bytes = builder.build_serialized_network(network, config)
if engine_bytes is None:
print('ERROR: Failed to build engine')
sys.exit(1)
os.makedirs(os.path.dirname(engine_path) or '.', exist_ok=True)
with open(engine_path, 'wb') as f:
f.write(engine_bytes)
size_mb = len(engine_bytes) / (1024 * 1024)
print(f'Engine saved: {engine_path} ({size_mb:.1f} MB)')
def main():
parser = argparse.ArgumentParser(
description='Build TensorRT FP16 engine for MiDaS_small')
parser.add_argument('--onnx', required=True,
help='Path to MiDaS ONNX model')
parser.add_argument('--engine', required=True,
help='Output TRT engine path')
parser.add_argument('--fp16', action='store_true', default=True,
help='Enable FP16 (default: True)')
parser.add_argument('--fp32', action='store_true',
help='Force FP32 (disable FP16)')
args = parser.parse_args()
fp16 = not args.fp32
build_engine(args.onnx, args.engine, fp16=fp16)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/saltybot_social_nav
[install]
install_scripts=$base/lib/saltybot_social_nav

View File

@ -0,0 +1,31 @@
from setuptools import setup
import os
from glob import glob
package_name = 'saltybot_social_nav'
setup(
name=package_name,
version='0.1.0',
packages=[package_name],
data_files=[
('share/ament_index/resource_index/packages', ['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
(os.path.join('share', package_name, 'launch'), glob('launch/*.py')),
(os.path.join('share', package_name, 'config'), glob('config/*.yaml')),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='seb',
maintainer_email='seb@vayrette.com',
description='Social navigation for saltybot: follow modes, waypoint teaching, A* avoidance, MiDaS depth',
license='MIT',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'social_nav = saltybot_social_nav.social_nav_node:main',
'midas_depth = saltybot_social_nav.midas_depth_node:main',
'waypoint_teacher = saltybot_social_nav.waypoint_teacher_node:main',
],
},
)

View File

@ -0,0 +1,12 @@
# Copyright 2026 SaltyLab
# Licensed under MIT
from ament_copyright.main import main
import pytest
@pytest.mark.copyright
@pytest.mark.linter
def test_copyright():
rc = main(argv=['.', 'test'])
assert rc == 0, 'Found errors'

View File

@ -0,0 +1,14 @@
# Copyright 2026 SaltyLab
# Licensed under MIT
from ament_flake8.main import main_with_errors
import pytest
@pytest.mark.flake8
@pytest.mark.linter
def test_flake8():
rc, errors = main_with_errors(argv=[])
assert rc == 0, \
'Found %d code style errors / warnings:\n' % len(errors) + \
'\n'.join(errors)

View File

@ -0,0 +1,12 @@
# Copyright 2026 SaltyLab
# Licensed under MIT
from ament_pep257.main import main
import pytest
@pytest.mark.pep257
@pytest.mark.linter
def test_pep257():
rc = main(argv=['.', 'test'])
assert rc == 0, 'Found code style errors / warnings'