Compare commits

..

1 Commits

Author SHA1 Message Date
0612eedbcd feat(social): face detection + recognition (SCRFD + ArcFace TRT FP16, Issue #80)
Add two new ROS2 packages for the social sprint:

saltybot_social_msgs (ament_cmake):
- FaceDetection, FaceDetectionArray, FaceEmbedding, FaceEmbeddingArray
- PersonState, PersonStateArray
- EnrollPerson, ListPersons, DeletePerson, UpdatePerson services

saltybot_social_face (ament_python):
- SCRFDDetector: SCRFD face detection with TRT FP16 + ONNX fallback
  - 640x640 input, 3-stride anchor decoding, NMS
- ArcFaceRecognizer: 512-dim embedding extraction with gallery matching
  - 5-point landmark alignment to 112x112, cosine similarity
- FaceGallery: thread-safe persistent gallery (npz + JSON sidecar)
- FaceRecognitionNode: ROS2 node subscribing /camera/color/image_raw,
  publishing /social/faces/detections, /social/faces/embeddings
- Enrollment via /social/enroll service (N-sample face averaging)
- Launch file, config YAML, TRT engine builder script

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-01 23:11:20 -05:00
43 changed files with 184 additions and 2392 deletions

View File

@ -1,8 +0,0 @@
person_state_tracker:
ros__parameters:
engagement_distance: 2.0
absent_timeout: 5.0
prune_timeout: 30.0
update_rate: 10.0
n_cameras: 4
uwb_enabled: false

View File

@ -1,35 +0,0 @@
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
return LaunchDescription([
DeclareLaunchArgument(
'engagement_distance',
default_value='2.0',
description='Distance in metres below which a person is considered engaged'
),
DeclareLaunchArgument(
'absent_timeout',
default_value='5.0',
description='Seconds without detection before marking person as ABSENT'
),
DeclareLaunchArgument(
'uwb_enabled',
default_value='false',
description='Whether UWB anchor data is available'
),
Node(
package='saltybot_social',
executable='person_state_tracker',
name='person_state_tracker',
output='screen',
parameters=[{
'engagement_distance': LaunchConfiguration('engagement_distance'),
'absent_timeout': LaunchConfiguration('absent_timeout'),
'uwb_enabled': LaunchConfiguration('uwb_enabled'),
}],
),
])

View File

@ -1,25 +0,0 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social</name>
<version>0.1.0</version>
<description>Multi-modal person identity fusion and state tracking for saltybot</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>std_msgs</depend>
<depend>geometry_msgs</depend>
<depend>sensor_msgs</depend>
<depend>vision_msgs</depend>
<depend>saltybot_social_msgs</depend>
<depend>tf2_ros</depend>
<depend>tf2_geometry_msgs</depend>
<depend>cv_bridge</depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -1,201 +0,0 @@
import math
import time
from typing import List, Optional, Tuple
from saltybot_social.person_state import PersonTrack, State
class IdentityFuser:
"""Multi-modal identity fusion: face_id + speaker_id + UWB anchor -> unified person_id."""
def __init__(self, position_window: float = 3.0):
self._tracks: dict[int, PersonTrack] = {}
self._next_id: int = 1
self._position_window = position_window
self._max_history = 20
def _distance_3d(self, a: Optional[tuple], b: Optional[tuple]) -> float:
if a is None or b is None:
return float('inf')
return math.sqrt(sum((ai - bi) ** 2 for ai, bi in zip(a, b)))
def _compute_distance_and_bearing(self, position: tuple) -> Tuple[float, float]:
x, y, z = position
distance = math.sqrt(x * x + y * y + z * z)
bearing = math.degrees(math.atan2(y, x))
return distance, bearing
def _allocate_id(self) -> int:
pid = self._next_id
self._next_id += 1
return pid
def update_face(self, face_id: int, name: str,
position_3d: Optional[tuple], camera_id: int,
now: float) -> int:
"""Match by face_id first. If face_id >= 0, find existing track or create new."""
if face_id >= 0:
for track in self._tracks.values():
if track.face_id == face_id:
track.name = name if name and name != 'unknown' else track.name
if position_3d is not None:
track.position = position_3d
dist, bearing = self._compute_distance_and_bearing(position_3d)
track.distance = dist
track.bearing_deg = bearing
track.history_distances.append(dist)
if len(track.history_distances) > self._max_history:
track.history_distances = track.history_distances[-self._max_history:]
track.camera_id = camera_id
track.last_seen = now
track.last_face_seen = now
return track.person_id
# No existing track with this face_id; try proximity match for face_id < 0
if face_id < 0 and position_3d is not None:
best_track = None
best_dist = 0.5 # 0.5m proximity threshold
for track in self._tracks.values():
d = self._distance_3d(track.position, position_3d)
if d < best_dist and (now - track.last_seen) < self._position_window:
best_dist = d
best_track = track
if best_track is not None:
best_track.position = position_3d
dist, bearing = self._compute_distance_and_bearing(position_3d)
best_track.distance = dist
best_track.bearing_deg = bearing
best_track.history_distances.append(dist)
if len(best_track.history_distances) > self._max_history:
best_track.history_distances = best_track.history_distances[-self._max_history:]
best_track.camera_id = camera_id
best_track.last_seen = now
return best_track.person_id
# Create new track
pid = self._allocate_id()
track = PersonTrack(person_id=pid, face_id=face_id, name=name,
camera_id=camera_id, last_seen=now, last_face_seen=now)
if position_3d is not None:
track.position = position_3d
dist, bearing = self._compute_distance_and_bearing(position_3d)
track.distance = dist
track.bearing_deg = bearing
track.history_distances.append(dist)
self._tracks[pid] = track
return pid
def update_speaker(self, speaker_id: str, now: float) -> int:
"""Find track with nearest recently-seen position, assign speaker_id."""
# First check if any track already has this speaker_id
for track in self._tracks.values():
if track.speaker_id == speaker_id:
track.last_seen = now
return track.person_id
# Assign to nearest recently-seen track
best_track = None
best_dist = float('inf')
for track in self._tracks.values():
if (now - track.last_seen) < self._position_window and track.distance > 0:
if track.distance < best_dist:
best_dist = track.distance
best_track = track
if best_track is not None:
best_track.speaker_id = speaker_id
best_track.last_seen = now
return best_track.person_id
# No suitable track found; create a new one
pid = self._allocate_id()
track = PersonTrack(person_id=pid, speaker_id=speaker_id,
last_seen=now)
self._tracks[pid] = track
return pid
def update_uwb(self, uwb_anchor_id: str, position_3d: tuple,
now: float) -> int:
"""Match by proximity (nearest track within 0.5m)."""
# Check existing UWB assignment
for track in self._tracks.values():
if track.uwb_anchor_id == uwb_anchor_id:
track.position = position_3d
dist, bearing = self._compute_distance_and_bearing(position_3d)
track.distance = dist
track.bearing_deg = bearing
track.history_distances.append(dist)
if len(track.history_distances) > self._max_history:
track.history_distances = track.history_distances[-self._max_history:]
track.last_seen = now
return track.person_id
# Proximity match
best_track = None
best_dist = 0.5
for track in self._tracks.values():
d = self._distance_3d(track.position, position_3d)
if d < best_dist and (now - track.last_seen) < self._position_window:
best_dist = d
best_track = track
if best_track is not None:
best_track.uwb_anchor_id = uwb_anchor_id
best_track.position = position_3d
dist, bearing = self._compute_distance_and_bearing(position_3d)
best_track.distance = dist
best_track.bearing_deg = bearing
best_track.history_distances.append(dist)
if len(best_track.history_distances) > self._max_history:
best_track.history_distances = best_track.history_distances[-self._max_history:]
best_track.last_seen = now
return best_track.person_id
# Create new track
pid = self._allocate_id()
track = PersonTrack(person_id=pid, uwb_anchor_id=uwb_anchor_id,
position=position_3d, last_seen=now)
dist, bearing = self._compute_distance_and_bearing(position_3d)
track.distance = dist
track.bearing_deg = bearing
track.history_distances.append(dist)
self._tracks[pid] = track
return pid
def get_all_tracks(self) -> List[PersonTrack]:
return list(self._tracks.values())
@staticmethod
def compute_attention(tracks: List[PersonTrack]) -> int:
"""Focus on nearest engaged/talking person, or -1 if none."""
candidates = [t for t in tracks
if t.state in (State.ENGAGED, State.TALKING) and t.distance > 0]
if not candidates:
return -1
# Prefer talking, then nearest engaged
talking = [t for t in candidates if t.state == State.TALKING]
if talking:
return min(talking, key=lambda t: t.distance).person_id
return min(candidates, key=lambda t: t.distance).person_id
def prune_absent(self, absent_timeout: float = 30.0):
"""Remove tracks absent longer than timeout."""
now = time.monotonic()
to_remove = [pid for pid, t in self._tracks.items()
if t.state == State.ABSENT and (now - t.last_seen) > absent_timeout]
for pid in to_remove:
del self._tracks[pid]
@staticmethod
def detect_group(tracks: List[PersonTrack]) -> bool:
"""Returns True if >= 2 persons within 1.5m of each other."""
active = [t for t in tracks
if t.position is not None and t.state != State.ABSENT]
for i, a in enumerate(active):
for b in active[i + 1:]:
dx = a.position[0] - b.position[0]
dy = a.position[1] - b.position[1]
dz = a.position[2] - b.position[2]
if math.sqrt(dx * dx + dy * dy + dz * dz) < 1.5:
return True
return False

View File

@ -1,54 +0,0 @@
import time
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Optional
class State(IntEnum):
UNKNOWN = 0
APPROACHING = 1
ENGAGED = 2
TALKING = 3
LEAVING = 4
ABSENT = 5
@dataclass
class PersonTrack:
person_id: int
face_id: int = -1
speaker_id: str = ''
uwb_anchor_id: str = ''
name: str = 'unknown'
state: State = State.UNKNOWN
position: Optional[tuple] = None # (x, y, z) in base_link
distance: float = 0.0
bearing_deg: float = 0.0
engagement_score: float = 0.0
camera_id: int = -1
last_seen: float = field(default_factory=time.monotonic)
last_face_seen: float = 0.0
history_distances: list = field(default_factory=list) # last N distances for trend
def update_state(self, now: float, engagement_distance: float = 2.0,
absent_timeout: float = 5.0):
"""Transition state machine based on distance trend and time."""
age = now - self.last_seen
if age > absent_timeout:
self.state = State.ABSENT
return
if self.speaker_id:
self.state = State.TALKING
return
if len(self.history_distances) >= 3:
trend = self.history_distances[-1] - self.history_distances[-3]
if self.distance < engagement_distance:
self.state = State.ENGAGED
elif trend < -0.2: # moving closer
self.state = State.APPROACHING
elif trend > 0.3: # moving away
self.state = State.LEAVING
else:
self.state = State.ENGAGED if self.distance < engagement_distance else State.UNKNOWN
elif self.distance > 0:
self.state = State.ENGAGED if self.distance < engagement_distance else State.APPROACHING

View File

@ -1,273 +0,0 @@
import time
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from std_msgs.msg import Int32, String
from geometry_msgs.msg import PoseArray, PoseStamped, Point
from sensor_msgs.msg import Image
from builtin_interfaces.msg import Time as TimeMsg
from saltybot_social_msgs.msg import (
FaceDetection,
FaceDetectionArray,
PersonState,
PersonStateArray,
)
from saltybot_social.identity_fuser import IdentityFuser
from saltybot_social.person_state import State
class PersonStateTrackerNode(Node):
"""Main ROS2 node for multi-modal person tracking."""
def __init__(self):
super().__init__('person_state_tracker')
# Parameters
self.declare_parameter('engagement_distance', 2.0)
self.declare_parameter('absent_timeout', 5.0)
self.declare_parameter('prune_timeout', 30.0)
self.declare_parameter('update_rate', 10.0)
self.declare_parameter('n_cameras', 4)
self.declare_parameter('uwb_enabled', False)
self._engagement_distance = self.get_parameter('engagement_distance').value
self._absent_timeout = self.get_parameter('absent_timeout').value
self._prune_timeout = self.get_parameter('prune_timeout').value
update_rate = self.get_parameter('update_rate').value
n_cameras = self.get_parameter('n_cameras').value
uwb_enabled = self.get_parameter('uwb_enabled').value
self._fuser = IdentityFuser()
# QoS profiles
best_effort_qos = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=5
)
reliable_qos = QoSProfile(
reliability=ReliabilityPolicy.RELIABLE,
history=HistoryPolicy.KEEP_LAST,
depth=1
)
# Subscriptions
self.create_subscription(
FaceDetectionArray,
'/social/faces/detections',
self._on_face_detections,
best_effort_qos
)
self.create_subscription(
String,
'/social/speech/speaker_id',
self._on_speaker_id,
10
)
if uwb_enabled:
self.create_subscription(
PoseArray,
'/uwb/positions',
self._on_uwb,
10
)
# Camera subscriptions (monitor topic existence)
self.create_subscription(
Image, '/camera/color/image_raw',
self._on_camera_image, best_effort_qos)
for i in range(n_cameras):
self.create_subscription(
Image, f'/surround/cam{i}/image_raw',
self._on_camera_image, best_effort_qos)
# Existing person tracker position
self.create_subscription(
PoseStamped,
'/person/target',
self._on_person_target,
10
)
# Publishers
self._persons_pub = self.create_publisher(
PersonStateArray, '/social/persons', best_effort_qos)
self._attention_pub = self.create_publisher(
Int32, '/social/attention/target_id', reliable_qos)
# Timer
timer_period = 1.0 / update_rate
self.create_timer(timer_period, self._on_timer)
self.get_logger().info(
f'PersonStateTracker started: engagement={self._engagement_distance}m, '
f'absent_timeout={self._absent_timeout}s, rate={update_rate}Hz, '
f'cameras={n_cameras}, uwb={uwb_enabled}')
def _on_face_detections(self, msg: FaceDetectionArray):
now = time.monotonic()
for face in msg.faces:
position_3d = None
# Use bbox center as rough position estimate if depth is available
# Real 3D position would come from depth camera projection
if face.bbox_x > 0 or face.bbox_y > 0:
# Approximate: use bbox center as bearing proxy, distance from bbox size
# This is a placeholder; real impl would use depth image lookup
position_3d = (
max(0.5, 1.0 / max(face.bbox_w, 0.01)), # rough depth from bbox width
face.bbox_x + face.bbox_w / 2.0 - 0.5, # x offset from center
face.bbox_y + face.bbox_h / 2.0 - 0.5 # y offset from center
)
camera_id = -1
if hasattr(face.header, 'frame_id') and face.header.frame_id:
frame = face.header.frame_id
if 'cam' in frame:
try:
camera_id = int(frame.split('cam')[-1].split('_')[0])
except ValueError:
pass
self._fuser.update_face(
face_id=face.face_id,
name=face.person_name,
position_3d=position_3d,
camera_id=camera_id,
now=now
)
def _on_speaker_id(self, msg: String):
now = time.monotonic()
speaker_id = msg.data
# Parse "speaker_<id>" format or raw id
if speaker_id.startswith('speaker_'):
speaker_id = speaker_id[len('speaker_'):]
self._fuser.update_speaker(speaker_id, now)
def _on_uwb(self, msg: PoseArray):
now = time.monotonic()
for i, pose in enumerate(msg.poses):
position = (pose.position.x, pose.position.y, pose.position.z)
anchor_id = f'uwb_{i}'
if hasattr(msg, 'header') and msg.header.frame_id:
anchor_id = f'{msg.header.frame_id}_{i}'
self._fuser.update_uwb(anchor_id, position, now)
def _on_camera_image(self, msg: Image):
# Monitor only -- no processing here
pass
def _on_person_target(self, msg: PoseStamped):
now = time.monotonic()
position_3d = (
msg.pose.position.x,
msg.pose.position.y,
msg.pose.position.z
)
# Update position for unidentified person (from YOLOv8 tracker)
self._fuser.update_face(
face_id=-1,
name='unknown',
position_3d=position_3d,
camera_id=-1,
now=now
)
def _on_timer(self):
now = time.monotonic()
tracks = self._fuser.get_all_tracks()
# Update state machine for all tracks
for track in tracks:
track.update_state(
now,
engagement_distance=self._engagement_distance,
absent_timeout=self._absent_timeout
)
# Compute engagement scores
for track in tracks:
if track.state == State.ABSENT:
track.engagement_score = 0.0
elif track.state == State.TALKING:
track.engagement_score = 1.0
elif track.state == State.ENGAGED:
track.engagement_score = max(0.0, 1.0 - track.distance / self._engagement_distance)
elif track.state == State.APPROACHING:
track.engagement_score = 0.3
elif track.state == State.LEAVING:
track.engagement_score = 0.1
else:
track.engagement_score = 0.0
# Prune long-absent tracks
self._fuser.prune_absent(self._prune_timeout)
# Compute attention target
attention_id = IdentityFuser.compute_attention(tracks)
# Build and publish PersonStateArray
msg = PersonStateArray()
msg.header.stamp = self.get_clock().now().to_msg()
msg.header.frame_id = 'base_link'
msg.primary_attention_id = attention_id
for track in tracks:
ps = PersonState()
ps.header.stamp = msg.header.stamp
ps.header.frame_id = 'base_link'
ps.person_id = track.person_id
ps.person_name = track.name
ps.face_id = track.face_id
ps.speaker_id = track.speaker_id
ps.uwb_anchor_id = track.uwb_anchor_id
if track.position is not None:
ps.position = Point(
x=float(track.position[0]),
y=float(track.position[1]),
z=float(track.position[2]))
ps.distance = float(track.distance)
ps.bearing_deg = float(track.bearing_deg)
ps.state = int(track.state)
ps.engagement_score = float(track.engagement_score)
ps.last_seen = self._mono_to_ros_time(track.last_seen)
ps.camera_id = track.camera_id
msg.persons.append(ps)
self._persons_pub.publish(msg)
# Publish attention target
att_msg = Int32()
att_msg.data = attention_id
self._attention_pub.publish(att_msg)
def _mono_to_ros_time(self, mono: float) -> TimeMsg:
"""Convert monotonic timestamp to approximate ROS time."""
# Offset from monotonic to wall clock
offset = time.time() - time.monotonic()
wall = mono + offset
sec = int(wall)
nsec = int((wall - sec) * 1e9)
t = TimeMsg()
t.sec = sec
t.nanosec = nsec
return t
def main(args=None):
rclpy.init(args=args)
node = PersonStateTrackerNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -1,4 +0,0 @@
[develop]
script_dir=$base/lib/saltybot_social
[install]
install_scripts=$base/lib/saltybot_social

View File

@ -1,32 +0,0 @@
from setuptools import find_packages, setup
import os
from glob import glob
package_name = 'saltybot_social'
setup(
name=package_name,
version='0.1.0',
packages=find_packages(exclude=['test']),
data_files=[
('share/ament_index/resource_index/packages',
['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
(os.path.join('share', package_name, 'launch'),
glob(os.path.join('launch', '*launch.[pxy][yma]*'))),
(os.path.join('share', package_name, 'config'),
glob(os.path.join('config', '*.yaml'))),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='seb',
maintainer_email='seb@vayrette.com',
description='Multi-modal person identity fusion and state tracking for saltybot',
license='MIT',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'person_state_tracker = saltybot_social.person_state_tracker_node:main',
],
},
)

View File

@ -1,6 +1,10 @@
cmake_minimum_required(VERSION 3.8)
project(saltybot_social_msgs)
if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
add_compile_options(-Wall -Wextra -Wpedantic)
endif()
find_package(ament_cmake REQUIRED)
find_package(rosidl_default_generators REQUIRED)
find_package(std_msgs REQUIRED)

View File

@ -1,10 +1,12 @@
std_msgs/Header header
int32 face_id
string person_name
float32 confidence
float32 recognition_score
int32 face_id # -1 if unknown
string person_name # "" if unknown
float32 confidence # detection confidence 0-1
float32 recognition_score # cosine similarity 0-1 (0 if unknown)
# Bounding box in pixels
float32 bbox_x
float32 bbox_y
float32 bbox_w
float32 bbox_h
# 5-point landmarks [x0,y0, x1,y1, x2,y2, x3,y3, x4,y4] = left_eye, right_eye, nose, left_mouth, right_mouth
float32[10] landmarks

View File

@ -1,5 +1,5 @@
int32 person_id
string person_name
float32[] embedding
float32[] embedding # 512-dim ArcFace embedding
builtin_interfaces/Time enrolled_at
int32 sample_count

View File

@ -2,18 +2,18 @@ std_msgs/Header header
int32 person_id
string person_name
int32 face_id
string speaker_id
string uwb_anchor_id
geometry_msgs/Point position
float32 distance
float32 bearing_deg
uint8 state
string speaker_id # from audio, "" if unknown
string uwb_anchor_id # "" if no UWB
geometry_msgs/Point position # in base_link frame, zeros if unknown
float32 distance # metres
float32 bearing_deg # degrees, 0=forward
uint8 state # see STATE_* constants
uint8 STATE_UNKNOWN=0
uint8 STATE_APPROACHING=1
uint8 STATE_ENGAGED=2
uint8 STATE_TALKING=3
uint8 STATE_LEAVING=4
uint8 STATE_ABSENT=5
float32 engagement_score
float32 engagement_score # 0-1, attention model output
builtin_interfaces/Time last_seen
int32 camera_id
int32 camera_id # which CSI camera last saw them (-1=depth cam)

View File

@ -1,3 +1,3 @@
std_msgs/Header header
saltybot_social_msgs/PersonState[] persons
int32 primary_attention_id
int32 primary_attention_id # person_id of focus target (-1 if none)

View File

@ -3,16 +3,21 @@
<package format="3">
<name>saltybot_social_msgs</name>
<version>0.1.0</version>
<description>Custom ROS2 messages and services for saltybot social capabilities</description>
<description>Custom message and service definitions for the SaltyBot social sprint</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<buildtool_depend>ament_cmake</buildtool_depend>
<buildtool_depend>rosidl_default_generators</buildtool_depend>
<depend>std_msgs</depend>
<depend>geometry_msgs</depend>
<depend>builtin_interfaces</depend>
<build_depend>rosidl_default_generators</build_depend>
<exec_depend>rosidl_default_runtime</exec_depend>
<member_of_group>rosidl_interface_packages</member_of_group>
<export>
<build_type>ament_cmake</build_type>
</export>

View File

@ -1,6 +1,6 @@
string name
string mode
int32 n_samples
string mode # "face", "voice", "both"
int32 n_samples # number of face crops to average (default 10)
---
bool success
string message

View File

@ -1,22 +0,0 @@
social_nav:
ros__parameters:
follow_mode: 'shadow'
follow_distance: 1.2
lead_distance: 2.0
orbit_radius: 1.5
max_linear_speed: 1.0
max_linear_speed_fast: 5.5
max_angular_speed: 1.0
goal_tolerance: 0.3
routes_dir: '/mnt/nvme/saltybot/routes'
home_x: 0.0
home_y: 0.0
map_resolution: 0.05
obstacle_inflation_cells: 3
midas_depth:
ros__parameters:
onnx_path: '/mnt/nvme/saltybot/models/midas_small.onnx'
engine_path: '/mnt/nvme/saltybot/models/midas_small.engine'
process_rate: 5.0
output_scale: 1.0

View File

@ -1,57 +0,0 @@
"""Launch file for saltybot social navigation."""
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
return LaunchDescription([
DeclareLaunchArgument('follow_mode', default_value='shadow',
description='Follow mode: shadow/lead/side/orbit/loose/tight'),
DeclareLaunchArgument('follow_distance', default_value='1.2',
description='Follow distance in meters'),
DeclareLaunchArgument('max_linear_speed', default_value='1.0',
description='Max linear speed (m/s)'),
DeclareLaunchArgument('routes_dir',
default_value='/mnt/nvme/saltybot/routes',
description='Directory for saved routes'),
Node(
package='saltybot_social_nav',
executable='social_nav',
name='social_nav',
output='screen',
parameters=[{
'follow_mode': LaunchConfiguration('follow_mode'),
'follow_distance': LaunchConfiguration('follow_distance'),
'max_linear_speed': LaunchConfiguration('max_linear_speed'),
'routes_dir': LaunchConfiguration('routes_dir'),
}],
),
Node(
package='saltybot_social_nav',
executable='midas_depth',
name='midas_depth',
output='screen',
parameters=[{
'onnx_path': '/mnt/nvme/saltybot/models/midas_small.onnx',
'engine_path': '/mnt/nvme/saltybot/models/midas_small.engine',
'process_rate': 5.0,
'output_scale': 1.0,
}],
),
Node(
package='saltybot_social_nav',
executable='waypoint_teacher',
name='waypoint_teacher',
output='screen',
parameters=[{
'routes_dir': LaunchConfiguration('routes_dir'),
'recording_interval': 0.5,
}],
),
])

View File

@ -1,28 +0,0 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social_nav</name>
<version>0.1.0</version>
<description>Social navigation for saltybot: follow modes, waypoint teaching, A* avoidance, MiDaS depth</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>std_msgs</depend>
<depend>geometry_msgs</depend>
<depend>nav_msgs</depend>
<depend>sensor_msgs</depend>
<depend>cv_bridge</depend>
<depend>tf2_ros</depend>
<depend>tf2_geometry_msgs</depend>
<depend>saltybot_social_msgs</depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -1,82 +0,0 @@
"""astar.py -- A* path planner for saltybot social navigation."""
import heapq
import numpy as np
def astar(grid: np.ndarray, start: tuple, goal: tuple,
obstacle_val: int = 100) -> list | None:
"""
A* on a 2D occupancy grid (row, col indexing).
Args:
grid: 2D numpy array, values 0=free, >=obstacle_val=obstacle
start: (row, col) start cell
goal: (row, col) goal cell
obstacle_val: cells with value >= obstacle_val are blocked
Returns:
List of (row, col) tuples from start to goal, or None if no path.
"""
rows, cols = grid.shape
def h(a, b):
return abs(a[0] - b[0]) + abs(a[1] - b[1]) # Manhattan heuristic
open_set = []
heapq.heappush(open_set, (h(start, goal), 0, start))
came_from = {}
g_score = {start: 0}
# 8-directional movement
neighbors_delta = [
(-1, -1), (-1, 0), (-1, 1),
(0, -1), (0, 1),
(1, -1), (1, 0), (1, 1),
]
while open_set:
_, cost, current = heapq.heappop(open_set)
if current == goal:
path = []
while current in came_from:
path.append(current)
current = came_from[current]
path.append(start)
return list(reversed(path))
if cost > g_score.get(current, float('inf')):
continue
for dr, dc in neighbors_delta:
nr, nc = current[0] + dr, current[1] + dc
if not (0 <= nr < rows and 0 <= nc < cols):
continue
if grid[nr, nc] >= obstacle_val:
continue
move_cost = 1.414 if (dr != 0 and dc != 0) else 1.0
new_g = g_score[current] + move_cost
neighbor = (nr, nc)
if new_g < g_score.get(neighbor, float('inf')):
g_score[neighbor] = new_g
f = new_g + h(neighbor, goal)
came_from[neighbor] = current
heapq.heappush(open_set, (f, new_g, neighbor))
return None # No path found
def inflate_obstacles(grid: np.ndarray, inflation_radius_cells: int) -> np.ndarray:
"""Inflate obstacles for robot footprint safety."""
from scipy.ndimage import binary_dilation
obstacle_mask = grid >= 50
kernel = np.ones(
(2 * inflation_radius_cells + 1, 2 * inflation_radius_cells + 1),
dtype=bool,
)
inflated = binary_dilation(obstacle_mask, structure=kernel)
result = grid.copy()
result[inflated] = 100
return result

View File

@ -1,82 +0,0 @@
"""follow_modes.py -- Follow mode geometry for saltybot social navigation."""
import math
from enum import Enum
import numpy as np
class FollowMode(Enum):
SHADOW = 'shadow' # stay directly behind at follow_distance
LEAD = 'lead' # move ahead of person by lead_distance
SIDE = 'side' # stay to the right (or left) at side_offset
ORBIT = 'orbit' # circle around person at orbit_radius
LOOSE = 'loose' # general follow, larger tolerance
TIGHT = 'tight' # close follow, small tolerance
def compute_shadow_target(person_pos, person_bearing_deg, follow_dist=1.2):
"""Target position: behind person along their movement direction."""
bearing_rad = math.radians(person_bearing_deg + 180.0)
tx = person_pos[0] + follow_dist * math.sin(bearing_rad)
ty = person_pos[1] + follow_dist * math.cos(bearing_rad)
return (tx, ty, person_pos[2])
def compute_lead_target(person_pos, person_bearing_deg, lead_dist=2.0):
"""Target position: ahead of person."""
bearing_rad = math.radians(person_bearing_deg)
tx = person_pos[0] + lead_dist * math.sin(bearing_rad)
ty = person_pos[1] + lead_dist * math.cos(bearing_rad)
return (tx, ty, person_pos[2])
def compute_side_target(person_pos, person_bearing_deg, side_dist=1.0, right=True):
"""Target position: to the right (or left) of person."""
sign = 1.0 if right else -1.0
bearing_rad = math.radians(person_bearing_deg + sign * 90.0)
tx = person_pos[0] + side_dist * math.sin(bearing_rad)
ty = person_pos[1] + side_dist * math.cos(bearing_rad)
return (tx, ty, person_pos[2])
def compute_orbit_target(person_pos, orbit_angle_deg, orbit_radius=1.5):
"""Target on circle of radius orbit_radius around person."""
angle_rad = math.radians(orbit_angle_deg)
tx = person_pos[0] + orbit_radius * math.sin(angle_rad)
ty = person_pos[1] + orbit_radius * math.cos(angle_rad)
return (tx, ty, person_pos[2])
def compute_loose_target(person_pos, robot_pos, follow_dist=2.0, tolerance=0.8):
"""Only move if farther than follow_dist + tolerance."""
dx = person_pos[0] - robot_pos[0]
dy = person_pos[1] - robot_pos[1]
dist = math.hypot(dx, dy)
if dist <= follow_dist + tolerance:
return robot_pos
# Target at follow_dist behind person (toward robot)
scale = (dist - follow_dist) / dist
return (robot_pos[0] + dx * scale, robot_pos[1] + dy * scale, person_pos[2])
def compute_tight_target(person_pos, follow_dist=0.6):
"""Close follow: stay very near person."""
return (person_pos[0], person_pos[1] - follow_dist, person_pos[2])
MODE_VOICE_COMMANDS = {
'shadow': FollowMode.SHADOW,
'follow me': FollowMode.SHADOW,
'behind me': FollowMode.SHADOW,
'lead': FollowMode.LEAD,
'go ahead': FollowMode.LEAD,
'lead me': FollowMode.LEAD,
'side': FollowMode.SIDE,
'stay beside': FollowMode.SIDE,
'orbit': FollowMode.ORBIT,
'circle me': FollowMode.ORBIT,
'loose': FollowMode.LOOSE,
'give me space': FollowMode.LOOSE,
'tight': FollowMode.TIGHT,
'stay close': FollowMode.TIGHT,
}

View File

@ -1,231 +0,0 @@
"""
midas_depth_node.py -- MiDaS monocular depth estimation for saltybot.
Uses MiDaS_small via ONNX Runtime or TensorRT FP16.
Provides relative depth estimates for cameras without active depth (CSI cameras).
Publishes /social/depth/cam{i}/image (sensor_msgs/Image, float32, relative depth)
"""
import os
import numpy as np
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from sensor_msgs.msg import Image
from cv_bridge import CvBridge
# MiDaS_small input size
_MIDAS_H = 256
_MIDAS_W = 256
# ImageNet normalization
_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
class _TRTBackend:
"""TensorRT inference backend for MiDaS."""
def __init__(self, engine_path: str, logger):
self._logger = logger
try:
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit # noqa: F401
self._cuda = cuda
rt_logger = trt.Logger(trt.Logger.WARNING)
with open(engine_path, 'rb') as f:
engine = trt.Runtime(rt_logger).deserialize_cuda_engine(f.read())
self._context = engine.create_execution_context()
# Allocate buffers
self._d_input = cuda.mem_alloc(1 * 3 * _MIDAS_H * _MIDAS_W * 4)
self._d_output = cuda.mem_alloc(1 * _MIDAS_H * _MIDAS_W * 4)
self._h_output = np.empty((_MIDAS_H, _MIDAS_W), dtype=np.float32)
self._stream = cuda.Stream()
self._logger.info(f'TRT engine loaded: {engine_path}')
except Exception as e:
raise RuntimeError(f'TRT init failed: {e}')
def infer(self, input_tensor: np.ndarray) -> np.ndarray:
self._cuda.memcpy_htod_async(
self._d_input, input_tensor.ravel(), self._stream)
self._context.execute_async_v2(
bindings=[int(self._d_input), int(self._d_output)],
stream_handle=self._stream.handle)
self._cuda.memcpy_dtoh_async(
self._h_output, self._d_output, self._stream)
self._stream.synchronize()
return self._h_output.copy()
class _ONNXBackend:
"""ONNX Runtime inference backend for MiDaS."""
def __init__(self, onnx_path: str, logger):
self._logger = logger
try:
import onnxruntime as ort
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
self._session = ort.InferenceSession(onnx_path, providers=providers)
self._input_name = self._session.get_inputs()[0].name
self._logger.info(f'ONNX model loaded: {onnx_path}')
except Exception as e:
raise RuntimeError(f'ONNX init failed: {e}')
def infer(self, input_tensor: np.ndarray) -> np.ndarray:
result = self._session.run(None, {self._input_name: input_tensor})
return result[0].squeeze()
class MiDaSDepthNode(Node):
"""MiDaS monocular depth estimation node."""
def __init__(self):
super().__init__('midas_depth')
# Parameters
self.declare_parameter('onnx_path',
'/mnt/nvme/saltybot/models/midas_small.onnx')
self.declare_parameter('engine_path',
'/mnt/nvme/saltybot/models/midas_small.engine')
self.declare_parameter('camera_topics', [
'/surround/cam0/image_raw',
'/surround/cam1/image_raw',
'/surround/cam2/image_raw',
'/surround/cam3/image_raw',
])
self.declare_parameter('output_scale', 1.0)
self.declare_parameter('process_rate', 5.0)
onnx_path = self.get_parameter('onnx_path').value
engine_path = self.get_parameter('engine_path').value
self._camera_topics = self.get_parameter('camera_topics').value
self._output_scale = self.get_parameter('output_scale').value
process_rate = self.get_parameter('process_rate').value
# Initialize inference backend (TRT preferred, ONNX fallback)
self._backend = None
if os.path.exists(engine_path):
try:
self._backend = _TRTBackend(engine_path, self.get_logger())
except RuntimeError:
self.get_logger().warn('TRT failed, trying ONNX fallback')
if self._backend is None and os.path.exists(onnx_path):
try:
self._backend = _ONNXBackend(onnx_path, self.get_logger())
except RuntimeError:
self.get_logger().error('Both TRT and ONNX backends failed')
if self._backend is None:
self.get_logger().error(
'No MiDaS model found. Depth estimation disabled.')
self._bridge = CvBridge()
# Latest frames per camera (round-robin processing)
self._latest_frames = [None] * len(self._camera_topics)
self._current_cam_idx = 0
# QoS for camera subscriptions
cam_qos = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=1,
)
# Subscribe to each camera topic
self._cam_subs = []
for i, topic in enumerate(self._camera_topics):
sub = self.create_subscription(
Image, topic,
lambda msg, idx=i: self._on_image(msg, idx),
cam_qos)
self._cam_subs.append(sub)
# Publishers: one per camera
self._depth_pubs = []
for i in range(len(self._camera_topics)):
pub = self.create_publisher(
Image, f'/social/depth/cam{i}/image', 10)
self._depth_pubs.append(pub)
# Timer: round-robin across cameras
timer_period = 1.0 / process_rate
self._timer = self.create_timer(timer_period, self._timer_callback)
self.get_logger().info(
f'MiDaS depth node started: {len(self._camera_topics)} cameras '
f'@ {process_rate} Hz')
def _on_image(self, msg: Image, cam_idx: int):
"""Cache latest frame for each camera."""
self._latest_frames[cam_idx] = msg
def _preprocess(self, bgr: np.ndarray) -> np.ndarray:
"""Preprocess BGR image to MiDaS input tensor [1,3,256,256]."""
import cv2
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
resized = cv2.resize(rgb, (_MIDAS_W, _MIDAS_H),
interpolation=cv2.INTER_LINEAR)
normalized = (resized.astype(np.float32) / 255.0 - _MEAN) / _STD
# HWC -> CHW -> NCHW
tensor = normalized.transpose(2, 0, 1)[np.newaxis, ...]
return tensor.astype(np.float32)
def _infer(self, tensor: np.ndarray) -> np.ndarray:
"""Run inference, returns [256,256] float32 relative inverse depth."""
if self._backend is None:
return np.zeros((_MIDAS_H, _MIDAS_W), dtype=np.float32)
return self._backend.infer(tensor)
def _postprocess(self, raw: np.ndarray, orig_shape: tuple) -> np.ndarray:
"""Resize depth back to original image shape, apply output_scale."""
import cv2
h, w = orig_shape[:2]
depth = cv2.resize(raw, (w, h), interpolation=cv2.INTER_LINEAR)
depth = depth * self._output_scale
return depth
def _timer_callback(self):
"""Process one camera per tick (round-robin)."""
if not self._camera_topics:
return
idx = self._current_cam_idx
self._current_cam_idx = (idx + 1) % len(self._camera_topics)
msg = self._latest_frames[idx]
if msg is None:
return
try:
bgr = self._bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')
except Exception as e:
self.get_logger().warn(f'cv_bridge error cam{idx}: {e}')
return
tensor = self._preprocess(bgr)
raw_depth = self._infer(tensor)
depth_map = self._postprocess(raw_depth, bgr.shape)
# Publish as float32 Image
depth_msg = self._bridge.cv2_to_imgmsg(depth_map, encoding='32FC1')
depth_msg.header = msg.header
self._depth_pubs[idx].publish(depth_msg)
def main(args=None):
rclpy.init(args=args)
node = MiDaSDepthNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -1,584 +0,0 @@
"""
social_nav_node.py -- Social navigation node for saltybot.
Orchestrates person following with multiple modes, voice commands,
waypoint teaching/replay, and A* obstacle avoidance.
Follow modes:
shadow -- stay directly behind at follow_distance
lead -- move ahead of person
side -- stay beside (default right)
orbit -- circle around person
loose -- relaxed follow with deadband
tight -- close follow
Waypoint teaching:
Voice command "teach route <name>" -> record mode ON
Voice command "stop teaching" -> save route
Voice command "replay route <name>" -> playback
Voice commands:
"follow me" / "shadow" -> SHADOW mode
"lead me" / "go ahead" -> LEAD mode
"stay beside" -> SIDE mode
"orbit" -> ORBIT mode
"give me space" -> LOOSE mode
"stay close" -> TIGHT mode
"stop" / "halt" -> STOP
"go home" -> navigate to home waypoint
"teach route <name>" -> start recording
"stop teaching" -> finish recording
"replay route <name>" -> playback recorded route
"""
import math
import time
import re
from collections import deque
import numpy as np
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from geometry_msgs.msg import Twist, PoseStamped
from nav_msgs.msg import OccupancyGrid, Odometry
from std_msgs.msg import String, Int32
from .follow_modes import (
FollowMode, MODE_VOICE_COMMANDS,
compute_shadow_target, compute_lead_target, compute_side_target,
compute_orbit_target, compute_loose_target, compute_tight_target,
)
from .astar import astar, inflate_obstacles
from .waypoint_teacher import WaypointRoute, WaypointReplayer
# Try importing social msgs; fallback gracefully
try:
from saltybot_social_msgs.msg import PersonStateArray
_HAS_SOCIAL_MSGS = True
except ImportError:
_HAS_SOCIAL_MSGS = False
# Proportional controller gains
_K_ANG = 2.0 # angular gain
_K_LIN = 0.8 # linear gain
_HIGH_SPEED_THRESHOLD = 3.0 # m/s person velocity triggers fast mode
_PREDICT_AHEAD_S = 0.3 # seconds to extrapolate position
_TEACH_MIN_DIST = 0.5 # meters between recorded waypoints
class SocialNavNode(Node):
"""Main social navigation orchestrator."""
def __init__(self):
super().__init__('social_nav')
# -- Parameters --
self.declare_parameter('follow_mode', 'shadow')
self.declare_parameter('follow_distance', 1.2)
self.declare_parameter('lead_distance', 2.0)
self.declare_parameter('orbit_radius', 1.5)
self.declare_parameter('max_linear_speed', 1.0)
self.declare_parameter('max_linear_speed_fast', 5.5)
self.declare_parameter('max_angular_speed', 1.0)
self.declare_parameter('goal_tolerance', 0.3)
self.declare_parameter('routes_dir', '/mnt/nvme/saltybot/routes')
self.declare_parameter('home_x', 0.0)
self.declare_parameter('home_y', 0.0)
self.declare_parameter('map_resolution', 0.05)
self.declare_parameter('obstacle_inflation_cells', 3)
self._follow_mode = FollowMode(
self.get_parameter('follow_mode').value)
self._follow_distance = self.get_parameter('follow_distance').value
self._lead_distance = self.get_parameter('lead_distance').value
self._orbit_radius = self.get_parameter('orbit_radius').value
self._max_lin = self.get_parameter('max_linear_speed').value
self._max_lin_fast = self.get_parameter('max_linear_speed_fast').value
self._max_ang = self.get_parameter('max_angular_speed').value
self._goal_tol = self.get_parameter('goal_tolerance').value
self._routes_dir = self.get_parameter('routes_dir').value
self._home_x = self.get_parameter('home_x').value
self._home_y = self.get_parameter('home_y').value
self._map_resolution = self.get_parameter('map_resolution').value
self._inflation_cells = self.get_parameter(
'obstacle_inflation_cells').value
# -- State --
self._robot_x = 0.0
self._robot_y = 0.0
self._robot_yaw = 0.0
self._target_person_pos = None # (x, y, z)
self._target_person_bearing = 0.0
self._target_person_id = -1
self._person_history = deque(maxlen=5) # for velocity estimation
self._stopped = False
self._go_home = False
# Occupancy grid for A*
self._occ_grid = None
self._occ_origin = (0.0, 0.0)
self._occ_resolution = 0.05
self._astar_path = None
# Orbit state
self._orbit_angle = 0.0
# Waypoint teaching / replay
self._teaching = False
self._current_route = None
self._last_teach_x = None
self._last_teach_y = None
self._replayer = None
# -- QoS profiles --
best_effort_qos = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST, depth=1)
reliable_qos = QoSProfile(
reliability=ReliabilityPolicy.RELIABLE,
history=HistoryPolicy.KEEP_LAST, depth=1)
# -- Subscriptions --
if _HAS_SOCIAL_MSGS:
self.create_subscription(
PersonStateArray, '/social/persons',
self._on_persons, best_effort_qos)
else:
self.get_logger().warn(
'saltybot_social_msgs not found; '
'using /person/target fallback')
self.create_subscription(
PoseStamped, '/person/target',
self._on_person_target, best_effort_qos)
self.create_subscription(
String, '/social/speech/command',
self._on_voice_command, 10)
self.create_subscription(
String, '/social/speech/transcript',
self._on_transcript, 10)
self.create_subscription(
OccupancyGrid, '/map',
self._on_map, reliable_qos)
self.create_subscription(
Odometry, '/odom',
self._on_odom, best_effort_qos)
self.create_subscription(
Int32, '/social/attention/target_id',
self._on_target_id, 10)
# -- Publishers --
self._cmd_vel_pub = self.create_publisher(
Twist, '/cmd_vel', best_effort_qos)
self._mode_pub = self.create_publisher(
String, '/social/nav/mode', reliable_qos)
self._target_pub = self.create_publisher(
PoseStamped, '/social/nav/target_pos', 10)
self._status_pub = self.create_publisher(
String, '/social/nav/status', best_effort_qos)
# -- Main loop timer (20 Hz) --
self._timer = self.create_timer(0.05, self._control_loop)
self.get_logger().info(
f'Social nav started: mode={self._follow_mode.value}, '
f'dist={self._follow_distance}m')
# ----------------------------------------------------------------
# Subscriptions
# ----------------------------------------------------------------
def _on_persons(self, msg):
"""Handle PersonStateArray from social perception."""
target_id = msg.primary_attention_id
if self._target_person_id >= 0:
target_id = self._target_person_id
for p in msg.persons:
if p.person_id == target_id or (
target_id < 0 and len(msg.persons) > 0):
pos = (p.position.x, p.position.y, p.position.z)
self._update_person_position(pos, p.bearing_deg)
break
def _on_person_target(self, msg: PoseStamped):
"""Fallback: single person target pose."""
pos = (msg.pose.position.x, msg.pose.position.y,
msg.pose.position.z)
# Estimate bearing from quaternion yaw
q = msg.pose.orientation
yaw = math.atan2(2.0 * (q.w * q.z + q.x * q.y),
1.0 - 2.0 * (q.y * q.y + q.z * q.z))
self._update_person_position(pos, math.degrees(yaw))
def _update_person_position(self, pos, bearing_deg):
"""Update person tracking state and record history."""
now = time.time()
self._target_person_pos = pos
self._target_person_bearing = bearing_deg
self._person_history.append((now, pos[0], pos[1]))
def _on_odom(self, msg: Odometry):
"""Update robot pose from odometry."""
self._robot_x = msg.pose.pose.position.x
self._robot_y = msg.pose.pose.position.y
q = msg.pose.pose.orientation
self._robot_yaw = math.atan2(
2.0 * (q.w * q.z + q.x * q.y),
1.0 - 2.0 * (q.y * q.y + q.z * q.z))
def _on_map(self, msg: OccupancyGrid):
"""Cache occupancy grid for A* planning."""
w, h = msg.info.width, msg.info.height
data = np.array(msg.data, dtype=np.int8).reshape((h, w))
# Convert -1 (unknown) to free (0) for planning
data[data < 0] = 0
self._occ_grid = data.astype(np.int32)
self._occ_origin = (msg.info.origin.position.x,
msg.info.origin.position.y)
self._occ_resolution = msg.info.resolution
def _on_target_id(self, msg: Int32):
"""Switch target person."""
self._target_person_id = msg.data
self.get_logger().info(f'Target person ID set to {msg.data}')
def _on_voice_command(self, msg: String):
"""Handle discrete voice commands for mode switching."""
cmd = msg.data.strip().lower()
if cmd in ('stop', 'halt'):
self._stopped = True
self._replayer = None
self._publish_status('STOPPED')
return
if cmd in ('resume', 'go', 'start'):
self._stopped = False
self._publish_status('RESUMED')
return
matched = MODE_VOICE_COMMANDS.get(cmd)
if matched:
self._follow_mode = matched
self._stopped = False
mode_msg = String()
mode_msg.data = self._follow_mode.value
self._mode_pub.publish(mode_msg)
self._publish_status(f'MODE: {self._follow_mode.value}')
def _on_transcript(self, msg: String):
"""Handle free-form voice transcripts for route teaching."""
text = msg.data.strip().lower()
# "teach route <name>"
m = re.match(r'teach\s+route\s+(\w+)', text)
if m:
name = m.group(1)
self._teaching = True
self._current_route = WaypointRoute(name)
self._last_teach_x = self._robot_x
self._last_teach_y = self._robot_y
self._publish_status(f'TEACHING: {name}')
self.get_logger().info(f'Recording route: {name}')
return
# "stop teaching"
if 'stop teaching' in text:
if self._teaching and self._current_route:
self._current_route.save(self._routes_dir)
self._publish_status(
f'SAVED: {self._current_route.name} '
f'({len(self._current_route.waypoints)} pts)')
self.get_logger().info(
f'Route saved: {self._current_route.name}')
self._teaching = False
self._current_route = None
return
# "replay route <name>"
m = re.match(r'replay\s+route\s+(\w+)', text)
if m:
name = m.group(1)
try:
route = WaypointRoute.load(self._routes_dir, name)
self._replayer = WaypointReplayer(route)
self._stopped = False
self._publish_status(f'REPLAY: {name}')
self.get_logger().info(f'Replaying route: {name}')
except FileNotFoundError:
self._publish_status(f'ROUTE NOT FOUND: {name}')
return
# "go home"
if 'go home' in text:
self._go_home = True
self._stopped = False
self._publish_status('GO HOME')
return
# Also try mode commands from transcript
for phrase, mode in MODE_VOICE_COMMANDS.items():
if phrase in text:
self._follow_mode = mode
self._stopped = False
mode_msg = String()
mode_msg.data = self._follow_mode.value
self._mode_pub.publish(mode_msg)
self._publish_status(f'MODE: {self._follow_mode.value}')
return
# ----------------------------------------------------------------
# Control loop
# ----------------------------------------------------------------
def _control_loop(self):
"""Main 20Hz control loop."""
# Record waypoint if teaching
if self._teaching and self._current_route:
self._maybe_record_waypoint()
# Publish zero velocity if stopped
if self._stopped:
self._publish_cmd_vel(0.0, 0.0)
return
# Determine navigation target
target = self._get_nav_target()
if target is None:
self._publish_cmd_vel(0.0, 0.0)
return
tx, ty, tz = target
# Publish debug target
self._publish_target_pose(tx, ty, tz)
# Check if arrived
dist_to_target = math.hypot(tx - self._robot_x, ty - self._robot_y)
if dist_to_target < self._goal_tol:
self._publish_cmd_vel(0.0, 0.0)
if self._go_home:
self._go_home = False
self._publish_status('HOME REACHED')
return
# Try A* path if map available
if self._occ_grid is not None:
path_target = self._plan_astar(tx, ty)
if path_target:
tx, ty = path_target
# Determine speed limit
max_lin = self._max_lin
person_vel = self._estimate_person_velocity()
if person_vel > _HIGH_SPEED_THRESHOLD:
max_lin = self._max_lin_fast
# Compute and publish cmd_vel
lin, ang = self._compute_cmd_vel(
self._robot_x, self._robot_y, self._robot_yaw,
tx, ty, max_lin)
self._publish_cmd_vel(lin, ang)
def _get_nav_target(self):
"""Determine current navigation target based on mode/state."""
# Route replay takes priority
if self._replayer and not self._replayer.is_done:
self._replayer.check_arrived(self._robot_x, self._robot_y)
wp = self._replayer.current_waypoint()
if wp:
return (wp.x, wp.y, wp.z)
else:
self._replayer = None
self._publish_status('REPLAY DONE')
return None
# Go home
if self._go_home:
return (self._home_x, self._home_y, 0.0)
# Person following
if self._target_person_pos is None:
return None
# Predict person position ahead for high-speed tracking
px, py, pz = self._predict_person_position()
bearing = self._target_person_bearing
robot_pos = (self._robot_x, self._robot_y, 0.0)
if self._follow_mode == FollowMode.SHADOW:
return compute_shadow_target(
(px, py, pz), bearing, self._follow_distance)
elif self._follow_mode == FollowMode.LEAD:
return compute_lead_target(
(px, py, pz), bearing, self._lead_distance)
elif self._follow_mode == FollowMode.SIDE:
return compute_side_target(
(px, py, pz), bearing, self._follow_distance)
elif self._follow_mode == FollowMode.ORBIT:
self._orbit_angle = (self._orbit_angle + 1.0) % 360.0
return compute_orbit_target(
(px, py, pz), self._orbit_angle, self._orbit_radius)
elif self._follow_mode == FollowMode.LOOSE:
return compute_loose_target(
(px, py, pz), robot_pos, self._follow_distance)
elif self._follow_mode == FollowMode.TIGHT:
return compute_tight_target(
(px, py, pz), self._follow_distance)
return (px, py, pz)
def _predict_person_position(self):
"""Extrapolate person position using velocity from recent samples."""
if self._target_person_pos is None:
return (0.0, 0.0, 0.0)
px, py, pz = self._target_person_pos
if len(self._person_history) >= 3:
# Use last 3 samples for velocity estimation
t0, x0, y0 = self._person_history[-3]
t1, x1, y1 = self._person_history[-1]
dt = t1 - t0
if dt > 0.01:
vx = (x1 - x0) / dt
vy = (y1 - y0) / dt
speed = math.hypot(vx, vy)
if speed > _HIGH_SPEED_THRESHOLD:
px += vx * _PREDICT_AHEAD_S
py += vy * _PREDICT_AHEAD_S
return (px, py, pz)
def _estimate_person_velocity(self) -> float:
"""Estimate person speed from recent position history."""
if len(self._person_history) < 2:
return 0.0
t0, x0, y0 = self._person_history[-2]
t1, x1, y1 = self._person_history[-1]
dt = t1 - t0
if dt < 0.01:
return 0.0
return math.hypot(x1 - x0, y1 - y0) / dt
def _plan_astar(self, target_x, target_y):
"""Run A* on occupancy grid, return next waypoint in world coords."""
grid = self._occ_grid
res = self._occ_resolution
ox, oy = self._occ_origin
# World to grid
def w2g(wx, wy):
return (int((wy - oy) / res), int((wx - ox) / res))
start = w2g(self._robot_x, self._robot_y)
goal = w2g(target_x, target_y)
rows, cols = grid.shape
if not (0 <= start[0] < rows and 0 <= start[1] < cols):
return None
if not (0 <= goal[0] < rows and 0 <= goal[1] < cols):
return None
inflated = inflate_obstacles(grid, self._inflation_cells)
path = astar(inflated, start, goal)
if path and len(path) > 1:
# Follow a lookahead point (3 steps ahead or end)
lookahead_idx = min(3, len(path) - 1)
r, c = path[lookahead_idx]
wx = ox + c * res + res / 2.0
wy = oy + r * res + res / 2.0
self._astar_path = path
return (wx, wy)
return None
def _compute_cmd_vel(self, rx, ry, ryaw, tx, ty, max_lin):
"""Proportional controller: compute linear and angular velocity."""
dx = tx - rx
dy = ty - ry
dist = math.hypot(dx, dy)
angle_to_target = math.atan2(dy, dx)
angle_error = angle_to_target - ryaw
# Normalize angle error to [-pi, pi]
while angle_error > math.pi:
angle_error -= 2.0 * math.pi
while angle_error < -math.pi:
angle_error += 2.0 * math.pi
angular_vel = _K_ANG * angle_error
angular_vel = max(-self._max_ang,
min(self._max_ang, angular_vel))
# Reduce linear speed when turning hard
angle_factor = max(0.0, 1.0 - abs(angle_error) / (math.pi / 2.0))
linear_vel = _K_LIN * dist * angle_factor
linear_vel = max(0.0, min(max_lin, linear_vel))
return (linear_vel, angular_vel)
# ----------------------------------------------------------------
# Waypoint teaching
# ----------------------------------------------------------------
def _maybe_record_waypoint(self):
"""Record waypoint if robot moved > _TEACH_MIN_DIST."""
if self._last_teach_x is None:
self._last_teach_x = self._robot_x
self._last_teach_y = self._robot_y
dist = math.hypot(
self._robot_x - self._last_teach_x,
self._robot_y - self._last_teach_y)
if dist >= _TEACH_MIN_DIST:
yaw_deg = math.degrees(self._robot_yaw)
self._current_route.add(
self._robot_x, self._robot_y, 0.0, yaw_deg)
self._last_teach_x = self._robot_x
self._last_teach_y = self._robot_y
# ----------------------------------------------------------------
# Publishers
# ----------------------------------------------------------------
def _publish_cmd_vel(self, linear: float, angular: float):
twist = Twist()
twist.linear.x = linear
twist.angular.z = angular
self._cmd_vel_pub.publish(twist)
def _publish_target_pose(self, x, y, z):
msg = PoseStamped()
msg.header.stamp = self.get_clock().now().to_msg()
msg.header.frame_id = 'map'
msg.pose.position.x = x
msg.pose.position.y = y
msg.pose.position.z = z
self._target_pub.publish(msg)
def _publish_status(self, status: str):
msg = String()
msg.data = status
self._status_pub.publish(msg)
self.get_logger().info(f'Nav status: {status}')
def main(args=None):
rclpy.init(args=args)
node = SocialNavNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -1,91 +0,0 @@
"""waypoint_teacher.py -- Record and replay waypoint routes."""
import json
import time
import math
from pathlib import Path
from dataclasses import dataclass, asdict
@dataclass
class Waypoint:
x: float
y: float
z: float
yaw_deg: float
timestamp: float
label: str = ''
class WaypointRoute:
"""A named sequence of waypoints."""
def __init__(self, name: str):
self.name = name
self.waypoints: list[Waypoint] = []
self.created_at = time.time()
def add(self, x, y, z, yaw_deg, label=''):
self.waypoints.append(Waypoint(x, y, z, yaw_deg, time.time(), label))
def to_dict(self):
return {
'name': self.name,
'created_at': self.created_at,
'waypoints': [asdict(w) for w in self.waypoints],
}
@classmethod
def from_dict(cls, d):
route = cls(d['name'])
route.created_at = d.get('created_at', 0)
route.waypoints = [Waypoint(**w) for w in d['waypoints']]
return route
def save(self, routes_dir: str):
Path(routes_dir).mkdir(parents=True, exist_ok=True)
path = Path(routes_dir) / f'{self.name}.json'
path.write_text(json.dumps(self.to_dict(), indent=2))
@classmethod
def load(cls, routes_dir: str, name: str):
path = Path(routes_dir) / f'{name}.json'
return cls.from_dict(json.loads(path.read_text()))
@staticmethod
def list_routes(routes_dir: str) -> list[str]:
d = Path(routes_dir)
if not d.exists():
return []
return [p.stem for p in d.glob('*.json')]
class WaypointReplayer:
"""Iterates through waypoints, returning next target."""
def __init__(self, route: WaypointRoute, arrival_radius: float = 0.3):
self._route = route
self._idx = 0
self._arrival_radius = arrival_radius
def current_waypoint(self) -> Waypoint | None:
if self._idx < len(self._route.waypoints):
return self._route.waypoints[self._idx]
return None
def check_arrived(self, robot_x, robot_y) -> bool:
wp = self.current_waypoint()
if wp is None:
return False
dist = math.hypot(robot_x - wp.x, robot_y - wp.y)
if dist < self._arrival_radius:
self._idx += 1
return True
return False
@property
def is_done(self) -> bool:
return self._idx >= len(self._route.waypoints)
def reset(self):
self._idx = 0

View File

@ -1,135 +0,0 @@
"""
waypoint_teacher_node.py -- Standalone waypoint teacher ROS2 node.
Listens to /social/speech/transcript for "teach route <name>" and "stop teaching".
Records robot pose at configurable intervals. Saves/loads routes via WaypointRoute.
"""
import math
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from nav_msgs.msg import Odometry
from std_msgs.msg import String
from .waypoint_teacher import WaypointRoute
class WaypointTeacherNode(Node):
"""Standalone waypoint teaching node."""
def __init__(self):
super().__init__('waypoint_teacher')
self.declare_parameter('routes_dir', '/mnt/nvme/saltybot/routes')
self.declare_parameter('recording_interval', 0.5) # meters
self._routes_dir = self.get_parameter('routes_dir').value
self._interval = self.get_parameter('recording_interval').value
self._teaching = False
self._route = None
self._last_x = None
self._last_y = None
self._robot_x = 0.0
self._robot_y = 0.0
self._robot_yaw = 0.0
best_effort_qos = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST, depth=1)
self.create_subscription(
Odometry, '/odom', self._on_odom, best_effort_qos)
self.create_subscription(
String, '/social/speech/transcript',
self._on_transcript, 10)
self._status_pub = self.create_publisher(
String, '/social/waypoint/status', 10)
# Record timer at 10Hz (check distance)
self._timer = self.create_timer(0.1, self._record_tick)
self.get_logger().info(
f'Waypoint teacher ready (interval={self._interval}m, '
f'dir={self._routes_dir})')
def _on_odom(self, msg: Odometry):
self._robot_x = msg.pose.pose.position.x
self._robot_y = msg.pose.pose.position.y
q = msg.pose.pose.orientation
self._robot_yaw = math.atan2(
2.0 * (q.w * q.z + q.x * q.y),
1.0 - 2.0 * (q.y * q.y + q.z * q.z))
def _on_transcript(self, msg: String):
text = msg.data.strip().lower()
import re
m = re.match(r'teach\s+route\s+(\w+)', text)
if m:
name = m.group(1)
self._route = WaypointRoute(name)
self._teaching = True
self._last_x = self._robot_x
self._last_y = self._robot_y
self._pub_status(f'RECORDING: {name}')
self.get_logger().info(f'Recording route: {name}')
return
if 'stop teaching' in text:
if self._teaching and self._route:
self._route.save(self._routes_dir)
n = len(self._route.waypoints)
self._pub_status(
f'SAVED: {self._route.name} ({n} waypoints)')
self.get_logger().info(
f'Route saved: {self._route.name} ({n} pts)')
self._teaching = False
self._route = None
return
if 'list routes' in text:
routes = WaypointRoute.list_routes(self._routes_dir)
self._pub_status(f'ROUTES: {", ".join(routes) or "(none)"}')
def _record_tick(self):
if not self._teaching or self._route is None:
return
if self._last_x is None:
self._last_x = self._robot_x
self._last_y = self._robot_y
dist = math.hypot(
self._robot_x - self._last_x,
self._robot_y - self._last_y)
if dist >= self._interval:
yaw_deg = math.degrees(self._robot_yaw)
self._route.add(self._robot_x, self._robot_y, 0.0, yaw_deg)
self._last_x = self._robot_x
self._last_y = self._robot_y
def _pub_status(self, text: str):
msg = String()
msg.data = text
self._status_pub.publish(msg)
def main(args=None):
rclpy.init(args=args)
node = WaypointTeacherNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -1,80 +0,0 @@
#!/usr/bin/env python3
"""
build_midas_trt_engine.py -- Build TensorRT FP16 engine for MiDaS_small from ONNX.
Usage:
python3 build_midas_trt_engine.py \
--onnx /mnt/nvme/saltybot/models/midas_small.onnx \
--engine /mnt/nvme/saltybot/models/midas_small.engine \
--fp16
Requires: tensorrt, pycuda
"""
import argparse
import os
import sys
def build_engine(onnx_path: str, engine_path: str, fp16: bool = True):
try:
import tensorrt as trt
except ImportError:
print('ERROR: tensorrt not found. Install TensorRT first.')
sys.exit(1)
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
print(f'Parsing ONNX model: {onnx_path}')
with open(onnx_path, 'rb') as f:
if not parser.parse(f.read()):
for i in range(parser.num_errors):
print(f' ONNX parse error: {parser.get_error(i)}')
sys.exit(1)
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB
if fp16 and builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
print('FP16 mode enabled')
elif fp16:
print('WARNING: FP16 not supported on this platform, using FP32')
print('Building TensorRT engine (this may take several minutes)...')
engine_bytes = builder.build_serialized_network(network, config)
if engine_bytes is None:
print('ERROR: Failed to build engine')
sys.exit(1)
os.makedirs(os.path.dirname(engine_path) or '.', exist_ok=True)
with open(engine_path, 'wb') as f:
f.write(engine_bytes)
size_mb = len(engine_bytes) / (1024 * 1024)
print(f'Engine saved: {engine_path} ({size_mb:.1f} MB)')
def main():
parser = argparse.ArgumentParser(
description='Build TensorRT FP16 engine for MiDaS_small')
parser.add_argument('--onnx', required=True,
help='Path to MiDaS ONNX model')
parser.add_argument('--engine', required=True,
help='Output TRT engine path')
parser.add_argument('--fp16', action='store_true', default=True,
help='Enable FP16 (default: True)')
parser.add_argument('--fp32', action='store_true',
help='Force FP32 (disable FP16)')
args = parser.parse_args()
fp16 = not args.fp32
build_engine(args.onnx, args.engine, fp16=fp16)
if __name__ == '__main__':
main()

View File

@ -1,4 +0,0 @@
[develop]
script_dir=$base/lib/saltybot_social_nav
[install]
install_scripts=$base/lib/saltybot_social_nav

View File

@ -1,31 +0,0 @@
from setuptools import setup
import os
from glob import glob
package_name = 'saltybot_social_nav'
setup(
name=package_name,
version='0.1.0',
packages=[package_name],
data_files=[
('share/ament_index/resource_index/packages', ['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
(os.path.join('share', package_name, 'launch'), glob('launch/*.py')),
(os.path.join('share', package_name, 'config'), glob('config/*.yaml')),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='seb',
maintainer_email='seb@vayrette.com',
description='Social navigation for saltybot: follow modes, waypoint teaching, A* avoidance, MiDaS depth',
license='MIT',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'social_nav = saltybot_social_nav.social_nav_node:main',
'midas_depth = saltybot_social_nav.midas_depth_node:main',
'waypoint_teacher = saltybot_social_nav.waypoint_teacher_node:main',
],
},
)

View File

@ -1,12 +0,0 @@
# Copyright 2026 SaltyLab
# Licensed under MIT
from ament_copyright.main import main
import pytest
@pytest.mark.copyright
@pytest.mark.linter
def test_copyright():
rc = main(argv=['.', 'test'])
assert rc == 0, 'Found errors'

View File

@ -1,14 +0,0 @@
# Copyright 2026 SaltyLab
# Licensed under MIT
from ament_flake8.main import main_with_errors
import pytest
@pytest.mark.flake8
@pytest.mark.linter
def test_flake8():
rc, errors = main_with_errors(argv=[])
assert rc == 0, \
'Found %d code style errors / warnings:\n' % len(errors) + \
'\n'.join(errors)

View File

@ -1,12 +0,0 @@
# Copyright 2026 SaltyLab
# Licensed under MIT
from ament_pep257.main import main
import pytest
@pytest.mark.pep257
@pytest.mark.linter
def test_pep257():
rc = main(argv=['.', 'test'])
assert rc == 0, 'Found code style errors / warnings'

View File

@ -1,24 +1,57 @@
# uwb_config.yaml — MaUWB ESP32-S3 DW3000 UWB integration (Issue #90)
# uwb_config.yaml — MaUWB ESP32-S3 DW3000 UWB follow-me system
#
# Hardware layout:
# Anchor-0 (port side) → USB port_a, y = +anchor_separation/2
# Anchor-1 (starboard side) → USB port_b, y = -anchor_separation/2
# Tag on person → belt clip, ~0.9m above ground
#
# Run with:
# ros2 launch saltybot_uwb uwb.launch.py
# Override at launch:
# ros2 launch saltybot_uwb uwb.launch.py port_a:=/dev/ttyUSB2
port_a: /dev/uwb-anchor0
port_b: /dev/uwb-anchor1
baudrate: 115200
# ── Serial ports ──────────────────────────────────────────────────────────────
# Set udev rules to get stable symlinks:
# /dev/uwb-anchor0 → port_a
# /dev/uwb-anchor1 → port_b
# (See jetson/docs/pinout.md for udev setup)
port_a: /dev/uwb-anchor0 # Anchor-0 (port)
port_b: /dev/uwb-anchor1 # Anchor-1 (starboard)
baudrate: 115200 # MaUWB default — do not change
anchor_separation: 0.25
anchor_height: 0.80
tag_height: 0.90
# ── Anchor geometry ────────────────────────────────────────────────────────────
# anchor_separation: centre-to-centre distance between anchors (metres)
# Must match physical mounting. Larger = more accurate lateral resolution.
anchor_separation: 0.25 # metres (25cm)
range_timeout_s: 1.0
max_range_m: 8.0
min_range_m: 0.05
# anchor_height: height of anchors above ground (metres)
# Orin stem mount ≈ 0.80m on the saltybot platform
anchor_height: 0.80 # metres
# tag_height: height of person's belt-clip tag above ground (metres)
tag_height: 0.90 # metres (adjust per user)
# ── Range validity ─────────────────────────────────────────────────────────────
# range_timeout_s: stale anchor — excluded from triangulation after this gap
range_timeout_s: 1.0 # seconds
# max_range_m: discard ranges beyond this (DW3000 indoor practical limit ≈8m)
max_range_m: 8.0 # metres
# min_range_m: discard ranges below this (likely multipath artefacts)
min_range_m: 0.05 # metres
# ── Kalman filter ──────────────────────────────────────────────────────────────
# kf_process_noise: Q scalar — how dynamic the person's motion is
# Higher → faster response, more jitter
kf_process_noise: 0.1
kf_meas_noise: 0.3
range_rate: 100.0
bearing_rate: 10.0
# kf_meas_noise: R scalar — how noisy the UWB ranging is
# DW3000 indoor accuracy ≈ 10cm RMS → 0.1m → R ≈ 0.01
# Use 0.3 to be conservative on first deployment
kf_meas_noise: 0.3
enrolled_tag_ids: [""]
# ── Publish rate ──────────────────────────────────────────────────────────────
# Should match or exceed the AT+RANGE? poll rate from both anchors.
# 20Hz means 50ms per cycle; each anchor query takes ~10ms → headroom ok.
publish_rate: 20.0 # Hz

View File

@ -1,14 +1,10 @@
"""
uwb.launch.py Launch UWB driver node for MaUWB ESP32-S3 DW3000 (Issue #90).
Topics:
/uwb/ranges 100 Hz raw TWR ranges
/uwb/bearing 10 Hz Kalman-fused bearing
/uwb/target 10 Hz triangulated PoseStamped (backwards compat)
uwb.launch.py Launch UWB driver node for MaUWB ESP32-S3 follow-me.
Usage:
ros2 launch saltybot_uwb uwb.launch.py
ros2 launch saltybot_uwb uwb.launch.py enrolled_tag_ids:="['0xDEADBEEF']"
ros2 launch saltybot_uwb uwb.launch.py port_a:=/dev/ttyUSB2 port_b:=/dev/ttyUSB3
ros2 launch saltybot_uwb uwb.launch.py anchor_separation:=0.30 publish_rate:=10.0
"""
import os
@ -35,9 +31,7 @@ def generate_launch_description():
DeclareLaunchArgument("min_range_m", default_value="0.05"),
DeclareLaunchArgument("kf_process_noise", default_value="0.1"),
DeclareLaunchArgument("kf_meas_noise", default_value="0.3"),
DeclareLaunchArgument("range_rate", default_value="100.0"),
DeclareLaunchArgument("bearing_rate", default_value="10.0"),
DeclareLaunchArgument("enrolled_tag_ids", default_value="['']"),
DeclareLaunchArgument("publish_rate", default_value="20.0"),
Node(
package="saltybot_uwb",
@ -58,9 +52,7 @@ def generate_launch_description():
"min_range_m": LaunchConfiguration("min_range_m"),
"kf_process_noise": LaunchConfiguration("kf_process_noise"),
"kf_meas_noise": LaunchConfiguration("kf_meas_noise"),
"range_rate": LaunchConfiguration("range_rate"),
"bearing_rate": LaunchConfiguration("bearing_rate"),
"enrolled_tag_ids": LaunchConfiguration("enrolled_tag_ids"),
"publish_rate": LaunchConfiguration("publish_rate"),
},
],
),

View File

@ -29,26 +29,6 @@ Returns (x_t, y_t); caller should treat negative x_t as 0.
import math
# ── Bearing helper ────────────────────────────────────────────────────────────
def bearing_from_pos(x: float, y: float) -> float:
"""
Compute horizontal bearing to a point (x, y) in base_link.
Parameters
----------
x : forward distance (metres)
y : lateral offset (metres, positive = left of robot / CCW)
Returns
-------
bearing_rad : bearing in radians, range -π .. +π
positive = target to the left (CCW)
0 = directly ahead
"""
return math.atan2(y, x)
# ── Triangulation ─────────────────────────────────────────────────────────────
def triangulate_2anchor(

View File

@ -1,52 +1,32 @@
"""
uwb_driver_node.py ROS2 node for MaUWB ESP32-S3 DW3000 UWB integration.
uwb_driver_node.py ROS2 node for MaUWB ESP32-S3 DW3000 follow-me system.
Hardware
2× MaUWB ESP32-S3 DW3000 anchors on robot stem (USB Orin)
2× MaUWB ESP32-S3 DW3000 anchors on robot stem (USB Orin Nano)
- Anchor-0: port side (y = +sep/2)
- Anchor-1: starboard (y = -sep/2)
1× MaUWB tag per enrolled person (belt clip)
1× MaUWB tag on person (belt clip)
AT command interface (115200 8N1)
Query:
AT+RANGE?\r\n
Query: AT+RANGE?\r\n
Response (from anchors):
+RANGE:<anchor_id>,<range_mm>[,<rssi>]\r\n
Response (from anchors, TWR protocol):
+RANGE:<anchor_id>,<range_mm>[,<rssi>[,<tag_addr>]]\r\n
Tag pairing (optional targets a specific enrolled tag):
AT+RANGE_ADDR=<tag_addr>\r\n anchor only ranges with that tag
Config:
AT+anchor_tag=ANCHOR\r\n set module as anchor
AT+anchor_tag=TAG\r\n set module as tag
Publishes
/uwb/ranges (saltybot_uwb_msgs/UwbRangeArray) 100 Hz raw anchor ranges
/uwb/bearing (saltybot_uwb_msgs/UwbBearing) 10 Hz Kalman-fused bearing
/uwb/target (geometry_msgs/PoseStamped) 10 Hz triangulated position
(kept for backwards compat)
/uwb/target (geometry_msgs/PoseStamped) triangulated person position in base_link
/uwb/ranges (saltybot_uwb_msgs/UwbRangeArray) raw ranges from both anchors
Tag pairing
Set enrolled_tag_ids to a list of tag address strings (e.g. ["0x1234ABCD"]).
When non-empty, ranges from unrecognised tags are silently discarded.
The matched tag address is stamped in UwbRange.tag_id and UwbBearing.tag_id.
When enrolled_tag_ids is empty, all ranges are accepted (tag_id = "").
Parameters
port_a, port_b serial ports for anchor-0 / anchor-1
baudrate 115200 (default)
anchor_separation centre-to-centre anchor spacing (m)
anchor_height anchor mounting height (m)
tag_height person tag height (m)
range_timeout_s stale-anchor threshold (s)
max_range_m / min_range_m validity window (m)
kf_process_noise Kalman Q scalar
kf_meas_noise Kalman R scalar
range_rate Hz /uwb/ranges publish rate (default 100)
bearing_rate Hz /uwb/bearing publish rate (default 10)
enrolled_tag_ids list[str] accepted tag addresses; [] = accept all
Safety
If a range is stale (> range_timeout_s), that anchor is excluded from
triangulation. If both anchors are stale, /uwb/target is not published.
Usage
@ -64,8 +44,8 @@ from rclpy.node import Node
from geometry_msgs.msg import PoseStamped
from std_msgs.msg import Header
from saltybot_uwb_msgs.msg import UwbRange, UwbRangeArray, UwbBearing
from saltybot_uwb.ranging_math import triangulate_2anchor, KalmanFilter2D, bearing_from_pos
from saltybot_uwb_msgs.msg import UwbRange, UwbRangeArray
from saltybot_uwb.ranging_math import triangulate_2anchor, KalmanFilter2D
try:
import serial
@ -74,31 +54,26 @@ except ImportError:
_SERIAL_AVAILABLE = False
# +RANGE:<id>,<mm> or +RANGE:<id>,<mm>,<rssi> or +RANGE:<id>,<mm>,<rssi>,<tag>
# Regex: +RANGE:<id>,<mm> or +RANGE:<id>,<mm>,<rssi>
_RANGE_RE = re.compile(
r"\+RANGE\s*:\s*(\d+)\s*,\s*(\d+)"
r"(?:\s*,\s*(-?[\d.]+)"
r"(?:\s*,\s*([\w:x]+))?"
r")?",
r"\+RANGE\s*:\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(-?[\d.]+))?",
re.IGNORECASE,
)
class SerialReader(threading.Thread):
"""
Background thread: polls one anchor's UART at maximum TWR rate,
fires callback on every valid +RANGE response.
Supports optional tag pairing via AT+RANGE_ADDR command.
Background thread: polls an anchor's UART, fires callback on every
valid +RANGE response.
"""
def __init__(self, port, baudrate, anchor_id, callback, logger, tag_addr=None):
def __init__(self, port, baudrate, anchor_id, callback, logger):
super().__init__(daemon=True)
self._port = port
self._baudrate = baudrate
self._anchor_id = anchor_id
self._callback = callback
self._logger = logger
self._tag_addr = tag_addr
self._running = False
self._ser = None
@ -111,10 +86,7 @@ class SerialReader(threading.Thread):
)
self._logger.info(
f"Anchor-{self._anchor_id}: opened {self._port}"
+ (f" paired with tag {self._tag_addr}" if self._tag_addr else "")
)
if self._tag_addr:
self._send_pairing_cmd()
self._read_loop()
except Exception as exc:
self._logger.warn(
@ -124,24 +96,12 @@ class SerialReader(threading.Thread):
self._ser.close()
time.sleep(2.0)
def _send_pairing_cmd(self):
"""Configure the anchor to range only with the paired tag."""
try:
cmd = f"AT+RANGE_ADDR={self._tag_addr}\r\n".encode("ascii")
self._ser.write(cmd)
time.sleep(0.1)
self._logger.info(
f"Anchor-{self._anchor_id}: sent tag pairing {self._tag_addr}"
)
except Exception as exc:
self._logger.warn(
f"Anchor-{self._anchor_id}: pairing cmd failed: {exc}"
)
def _read_loop(self):
while self._running:
try:
# Query the anchor
self._ser.write(b"AT+RANGE?\r\n")
# Read up to 10 lines waiting for a +RANGE response
for _ in range(10):
raw = self._ser.readline()
if not raw:
@ -151,14 +111,13 @@ class SerialReader(threading.Thread):
if m:
range_mm = int(m.group(2))
rssi = float(m.group(3)) if m.group(3) else 0.0
tag_addr = m.group(4) if m.group(4) else ""
self._callback(self._anchor_id, range_mm, rssi, tag_addr)
self._callback(self._anchor_id, range_mm, rssi)
break
except Exception as exc:
self._logger.warn(
f"Anchor-{self._anchor_id} read error: {exc}"
)
break
break # trigger reconnect
def stop(self):
self._running = False
@ -171,51 +130,48 @@ class UwbDriverNode(Node):
def __init__(self):
super().__init__("uwb_driver")
self.declare_parameter("port_a", "/dev/uwb-anchor0")
self.declare_parameter("port_b", "/dev/uwb-anchor1")
self.declare_parameter("baudrate", 115200)
self.declare_parameter("anchor_separation", 0.25)
self.declare_parameter("anchor_height", 0.80)
self.declare_parameter("tag_height", 0.90)
self.declare_parameter("range_timeout_s", 1.0)
self.declare_parameter("max_range_m", 8.0)
self.declare_parameter("min_range_m", 0.05)
self.declare_parameter("kf_process_noise", 0.1)
self.declare_parameter("kf_meas_noise", 0.3)
self.declare_parameter("range_rate", 100.0)
self.declare_parameter("bearing_rate", 10.0)
self.declare_parameter("enrolled_tag_ids", [""])
# ── Parameters ────────────────────────────────────────────────────────
self.declare_parameter("port_a", "/dev/ttyUSB0")
self.declare_parameter("port_b", "/dev/ttyUSB1")
self.declare_parameter("baudrate", 115200)
self.declare_parameter("anchor_separation", 0.25)
self.declare_parameter("anchor_height", 0.80)
self.declare_parameter("tag_height", 0.90)
self.declare_parameter("range_timeout_s", 1.0)
self.declare_parameter("max_range_m", 8.0)
self.declare_parameter("min_range_m", 0.05)
self.declare_parameter("kf_process_noise", 0.1)
self.declare_parameter("kf_meas_noise", 0.3)
self.declare_parameter("publish_rate", 20.0)
self._p = self._load_params()
raw_ids = self.get_parameter("enrolled_tag_ids").value
self._enrolled_tags = [t.strip() for t in raw_ids if t.strip()]
paired_tag = self._enrolled_tags[0] if self._enrolled_tags else None
self._lock = threading.Lock()
self._ranges: dict = {}
self._kf = KalmanFilter2D(
# ── State (protected by lock) ──────────────────────────────────────
self._lock = threading.Lock()
self._ranges = {} # anchor_id → (range_m, rssi, timestamp)
self._kf = KalmanFilter2D(
process_noise=self._p["kf_process_noise"],
measurement_noise=self._p["kf_meas_noise"],
dt=1.0 / self._p["bearing_rate"],
dt=1.0 / self._p["publish_rate"],
)
self._ranges_pub = self.create_publisher(UwbRangeArray, "/uwb/ranges", 10)
self._bearing_pub = self.create_publisher(UwbBearing, "/uwb/bearing", 10)
self._target_pub = self.create_publisher(PoseStamped, "/uwb/target", 10)
# ── Publishers ────────────────────────────────────────────────────
self._target_pub = self.create_publisher(
PoseStamped, "/uwb/target", 10)
self._ranges_pub = self.create_publisher(
UwbRangeArray, "/uwb/ranges", 10)
# ── Serial readers ────────────────────────────────────────────────
if _SERIAL_AVAILABLE:
self._reader_a = SerialReader(
self._p["port_a"], self._p["baudrate"],
anchor_id=0, callback=self._range_cb,
logger=self.get_logger(),
tag_addr=paired_tag,
)
self._reader_b = SerialReader(
self._p["port_b"], self._p["baudrate"],
anchor_id=1, callback=self._range_cb,
logger=self.get_logger(),
tag_addr=paired_tag,
)
self._reader_a.start()
self._reader_b.start()
@ -224,21 +180,19 @@ class UwbDriverNode(Node):
"pyserial not installed — running in simulation mode (no serial I/O)"
)
self._range_timer = self.create_timer(
1.0 / self._p["range_rate"], self._range_publish_cb
)
self._bearing_timer = self.create_timer(
1.0 / self._p["bearing_rate"], self._bearing_publish_cb
# ── Publish timer ─────────────────────────────────────────────────
self._timer = self.create_timer(
1.0 / self._p["publish_rate"], self._publish_cb
)
self.get_logger().info(
f"UWB driver ready sep={self._p['anchor_separation']}m "
f"ports={self._p['port_a']},{self._p['port_b']} "
f"range={self._p['range_rate']}Hz "
f"bearing={self._p['bearing_rate']}Hz "
f"enrolled_tags={self._enrolled_tags or ['<any>']}"
f"rate={self._p['publish_rate']}Hz"
)
# ── Helpers ───────────────────────────────────────────────────────────────
def _load_params(self):
return {
"port_a": self.get_parameter("port_a").value,
@ -252,106 +206,90 @@ class UwbDriverNode(Node):
"min_range_m": self.get_parameter("min_range_m").value,
"kf_process_noise": self.get_parameter("kf_process_noise").value,
"kf_meas_noise": self.get_parameter("kf_meas_noise").value,
"range_rate": self.get_parameter("range_rate").value,
"bearing_rate": self.get_parameter("bearing_rate").value,
"publish_rate": self.get_parameter("publish_rate").value,
}
def _is_enrolled(self, tag_addr: str) -> bool:
if not self._enrolled_tags:
return True
return tag_addr in self._enrolled_tags
# ── Callbacks ─────────────────────────────────────────────────────────────
def _range_cb(self, anchor_id: int, range_mm: int, rssi: float, tag_addr: str):
if not self._is_enrolled(tag_addr):
return
def _range_cb(self, anchor_id: int, range_mm: int, rssi: float):
"""Called from serial reader threads — thread-safe update."""
range_m = range_mm / 1000.0
p = self._p
if range_m < p["min_range_m"] or range_m > p["max_range_m"]:
return
with self._lock:
self._ranges[anchor_id] = (range_m, rssi, tag_addr, time.monotonic())
self._ranges[anchor_id] = (range_m, rssi, time.monotonic())
def _range_publish_cb(self):
"""100 Hz: publish current raw ranges as UwbRangeArray."""
now = time.monotonic()
def _publish_cb(self):
now = time.monotonic()
timeout = self._p["range_timeout_s"]
sep = self._p["anchor_separation"]
with self._lock:
valid = {
aid: entry
for aid, entry in self._ranges.items()
if (now - entry[3]) <= timeout
}
# Collect valid (non-stale) ranges
valid = {}
for aid, (r, rssi, t) in self._ranges.items():
if now - t <= timeout:
valid[aid] = (r, rssi, t)
# Build and publish UwbRangeArray regardless (even if partial)
hdr = Header()
hdr.stamp = self.get_clock().now().to_msg()
hdr.frame_id = "base_link"
arr = UwbRangeArray()
arr.header = hdr
for aid, (r, rssi, tag_id, _) in valid.items():
for aid, (r, rssi, _) in valid.items():
entry = UwbRange()
entry.header = hdr
entry.anchor_id = aid
entry.range_m = float(r)
entry.raw_mm = int(round(r * 1000.0))
entry.rssi = float(rssi)
entry.tag_id = tag_id
arr.ranges.append(entry)
self._ranges_pub.publish(arr)
def _bearing_publish_cb(self):
"""10 Hz: Kalman predict+update, publish fused bearing."""
now = time.monotonic()
timeout = self._p["range_timeout_s"]
sep = self._p["anchor_separation"]
with self._lock:
valid = {
aid: entry
for aid, entry in self._ranges.items()
if (now - entry[3]) <= timeout
}
if not valid:
# Need both anchors to triangulate
if 0 not in valid or 1 not in valid:
return
both_fresh = 0 in valid and 1 in valid
confidence = 1.0 if both_fresh else 0.5
active_tag = valid[next(iter(valid))][2]
dt = 1.0 / self._p["bearing_rate"]
r0 = valid[0][0]
r1 = valid[1][0]
try:
x_t, y_t = triangulate_2anchor(
r0=r0,
r1=r1,
sep=sep,
anchor_z=self._p["anchor_height"],
tag_z=self._p["tag_height"],
)
except (ValueError, ZeroDivisionError) as exc:
self.get_logger().warn(f"Triangulation error: {exc}")
return
# Kalman filter update
dt = 1.0 / self._p["publish_rate"]
self._kf.predict(dt=dt)
if both_fresh:
r0 = valid[0][0]
r1 = valid[1][0]
try:
x_t, y_t = triangulate_2anchor(
r0=r0, r1=r1, sep=sep,
anchor_z=self._p["anchor_height"],
tag_z=self._p["tag_height"],
)
except (ValueError, ZeroDivisionError) as exc:
self.get_logger().warn(f"Triangulation error: {exc}")
return
self._kf.update(x_t, y_t)
self._kf.update(x_t, y_t)
kx, ky = self._kf.position()
bearing = bearing_from_pos(kx, ky)
range_m = math.sqrt(kx * kx + ky * ky)
hdr = Header()
hdr.stamp = self.get_clock().now().to_msg()
hdr.frame_id = "base_link"
brg_msg = UwbBearing()
brg_msg.header = hdr
brg_msg.bearing_rad = float(bearing)
brg_msg.range_m = float(range_m)
brg_msg.confidence = float(confidence)
brg_msg.tag_id = active_tag
self._bearing_pub.publish(brg_msg)
# Publish PoseStamped in base_link
pose = PoseStamped()
pose.header = hdr
pose.pose.position.x = kx
pose.pose.position.y = ky
pose.pose.position.z = 0.0
yaw = bearing
# Orientation: face the person (yaw = atan2(y, x))
yaw = math.atan2(ky, kx)
pose.pose.orientation.z = math.sin(yaw / 2.0)
pose.pose.orientation.w = math.cos(yaw / 2.0)
self._target_pub.publish(pose)
# ── Entry point ───────────────────────────────────────────────────────────────
def main(args=None):
rclpy.init(args=args)
node = UwbDriverNode()

View File

@ -7,7 +7,7 @@ No ROS2 / serial / GPU dependencies — runs with plain pytest.
import math
import pytest
from saltybot_uwb.ranging_math import triangulate_2anchor, KalmanFilter2D, bearing_from_pos
from saltybot_uwb.ranging_math import triangulate_2anchor, KalmanFilter2D
# ── triangulate_2anchor ───────────────────────────────────────────────────────
@ -172,47 +172,3 @@ class TestKalmanFilter2D:
x, y = kf.position()
assert not math.isnan(x)
assert not math.isnan(y)
# ── bearing_from_pos ──────────────────────────────────────────────────────────
class TestBearingFromPos:
def test_directly_ahead_zero_bearing(self):
"""Person directly ahead: x=2, y=0 → bearing ≈ 0."""
b = bearing_from_pos(2.0, 0.0)
assert abs(b) < 0.001
def test_left_gives_positive_bearing(self):
"""Person to the left (y>0): bearing should be positive."""
b = bearing_from_pos(1.0, 1.0)
assert b > 0.0
assert abs(b - math.pi / 4.0) < 0.001
def test_right_gives_negative_bearing(self):
"""Person to the right (y<0): bearing should be negative."""
b = bearing_from_pos(1.0, -1.0)
assert b < 0.0
assert abs(b + math.pi / 4.0) < 0.001
def test_directly_left_ninety_degrees(self):
"""Person directly to the left: x=0, y=1 → bearing = π/2."""
b = bearing_from_pos(0.0, 1.0)
assert abs(b - math.pi / 2.0) < 0.001
def test_directly_right_minus_ninety_degrees(self):
"""Person directly to the right: x=0, y=-1 → bearing = -π/2."""
b = bearing_from_pos(0.0, -1.0)
assert abs(b + math.pi / 2.0) < 0.001
def test_range_pi_to_minus_pi(self):
"""Bearing is always in -π..+π."""
for x in [-2.0, -0.1, 0.1, 2.0]:
for y in [-2.0, -0.1, 0.1, 2.0]:
b = bearing_from_pos(x, y)
assert -math.pi <= b <= math.pi
def test_no_nan_for_tiny_distance(self):
"""Very close target should not produce NaN."""
b = bearing_from_pos(0.001, 0.001)
assert not math.isnan(b)

View File

@ -8,7 +8,6 @@ find_package(rosidl_default_generators REQUIRED)
rosidl_generate_interfaces(${PROJECT_NAME}
"msg/UwbRange.msg"
"msg/UwbRangeArray.msg"
"msg/UwbBearing.msg"
DEPENDENCIES std_msgs
)

View File

@ -1,18 +0,0 @@
# UwbBearing.msg — 10 Hz Kalman-fused bearing estimate from dual-anchor UWB.
#
# bearing_rad : horizontal bearing to tag in base_link frame (radians)
# positive = tag to the left (CCW), negative = right (CW)
# range: -π .. +π
# range_m : Kalman-filtered horizontal range to tag (metres)
# confidence : quality indicator
# 1.0 = both anchors fresh
# 0.5 = single anchor only (bearing less reliable)
# 0.0 = stale / no data (message not published in this state)
# tag_id : enrolled tag identifier that produced this estimate
std_msgs/Header header
float32 bearing_rad
float32 range_m
float32 confidence
string tag_id

View File

@ -4,7 +4,6 @@
# range_m : measured horizontal range in metres (after height correction)
# raw_mm : raw TWR range from AT+RANGE? response, millimetres
# rssi : received signal strength (dBm), 0 if not reported by module
# tag_id : enrolled tag identifier; empty string if tag pairing is disabled
std_msgs/Header header
@ -12,4 +11,3 @@ uint8 anchor_id
float32 range_m
uint32 raw_mm
float32 rssi
string tag_id