Compare commits

...

7 Commits

Author SHA1 Message Date
f61a03b3c5 feat(social): face detection + recognition (SCRFD + ArcFace TRT FP16, Issue #80)
Add two new ROS2 packages for the social sprint:

saltybot_social_msgs (ament_cmake):
- FaceDetection, FaceDetectionArray, FaceEmbedding, FaceEmbeddingArray
- PersonState, PersonStateArray
- EnrollPerson, ListPersons, DeletePerson, UpdatePerson services

saltybot_social_face (ament_python):
- SCRFDDetector: SCRFD face detection with TRT FP16 + ONNX fallback
  - 640x640 input, 3-stride anchor decoding, NMS
- ArcFaceRecognizer: 512-dim embedding extraction with gallery matching
  - 5-point landmark alignment to 112x112, cosine similarity
- FaceGallery: thread-safe persistent gallery (npz + JSON sidecar)
- FaceRecognitionNode: ROS2 node subscribing /camera/color/image_raw,
  publishing /social/faces/detections, /social/faces/embeddings
- Enrollment via /social/enroll service (N-sample face averaging)
- Launch file, config YAML, TRT engine builder script

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-01 23:31:48 -05:00
d9c983f666 Merge pull request 'feat(social): navigation & path planning #91' (#97) from sl-perception/social-nav into main 2026-03-01 23:30:40 -05:00
54e9274405 Merge pull request 'feat(uwb): MaUWB ESP32-S3 DW3000 dual-anchor bearing driver (Issue #90)' (#99) from sl-firmware/uwb-integration into main 2026-03-01 23:30:12 -05:00
b432492785 Merge pull request 'feat(social): multi-modal person state tracker #82' (#93) from sl-perception/social-person-state into main 2026-03-01 23:30:04 -05:00
9a68dfdb2e feat(uwb): MaUWB ESP32-S3 DW3000 dual-anchor bearing driver (Issue #90)
## Summary
- saltybot_uwb_msgs: add UwbBearing.msg, add tag_id to UwbRange.msg,
  register UwbBearing in CMakeLists.txt
- ranging_math.py: add bearing_from_pos(x, y) helper (atan2-based)
- uwb_driver_node.py: dual-rate architecture
    • 100 Hz /uwb/ranges  — raw TWR ranges with tag_id attribution
    • 10 Hz  /uwb/bearing — Kalman-fused bearing + range estimate
    • enrolled_tag_ids parameter for tag pairing filter
    • AT+RANGE_ADDR=<tag> pairing command on connect
- uwb_config.yaml: range_rate / bearing_rate / enrolled_tag_ids params
- uwb.launch.py: expose new params as launch arguments
- test_ranging_math.py: 7 new bearing_from_pos unit tests

Closes #90

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-01 23:25:08 -05:00
d872ea5e34 feat(social): navigation + follow modes + MiDaS depth + waypoints (Issue #91)
- saltybot_social_msgs: full message/service definitions (standalone compilation)
- saltybot_social_nav: social navigation orchestrator
  - Follow modes: shadow/lead/side/orbit/loose/tight
  - Voice steering: mode switching + route commands via /social/speech/*
  - A* obstacle avoidance on Nav2/SLAM occupancy grid (8-directional, inflation)
  - MiDaS monocular depth for CSI cameras (TRT FP16 + ONNX fallback)
  - Waypoint teaching + replay with WaypointRoute persistence
  - High-speed EUC tracking (5.5 m/s = ~20 km/h)
  - Predictive position extrapolation (0.3s ahead at high speed)
- Launch: social_nav.launch.py (social_nav + midas_depth + waypoint_teacher)
- Config: social_nav_params.yaml
- Script: build_midas_trt_engine.py (ONNX -> TRT FP16)
2026-03-01 23:15:00 -05:00
84790412d6 feat(social): multi-modal person state tracker (Issue #82) 2026-03-01 23:08:22 -05:00
61 changed files with 4122 additions and 155 deletions

View File

@ -0,0 +1,8 @@
person_state_tracker:
ros__parameters:
engagement_distance: 2.0
absent_timeout: 5.0
prune_timeout: 30.0
update_rate: 10.0
n_cameras: 4
uwb_enabled: false

View File

@ -0,0 +1,35 @@
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
return LaunchDescription([
DeclareLaunchArgument(
'engagement_distance',
default_value='2.0',
description='Distance in metres below which a person is considered engaged'
),
DeclareLaunchArgument(
'absent_timeout',
default_value='5.0',
description='Seconds without detection before marking person as ABSENT'
),
DeclareLaunchArgument(
'uwb_enabled',
default_value='false',
description='Whether UWB anchor data is available'
),
Node(
package='saltybot_social',
executable='person_state_tracker',
name='person_state_tracker',
output='screen',
parameters=[{
'engagement_distance': LaunchConfiguration('engagement_distance'),
'absent_timeout': LaunchConfiguration('absent_timeout'),
'uwb_enabled': LaunchConfiguration('uwb_enabled'),
}],
),
])

View File

@ -0,0 +1,25 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social</name>
<version>0.1.0</version>
<description>Multi-modal person identity fusion and state tracking for saltybot</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>std_msgs</depend>
<depend>geometry_msgs</depend>
<depend>sensor_msgs</depend>
<depend>vision_msgs</depend>
<depend>saltybot_social_msgs</depend>
<depend>tf2_ros</depend>
<depend>tf2_geometry_msgs</depend>
<depend>cv_bridge</depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,201 @@
import math
import time
from typing import List, Optional, Tuple
from saltybot_social.person_state import PersonTrack, State
class IdentityFuser:
"""Multi-modal identity fusion: face_id + speaker_id + UWB anchor -> unified person_id."""
def __init__(self, position_window: float = 3.0):
self._tracks: dict[int, PersonTrack] = {}
self._next_id: int = 1
self._position_window = position_window
self._max_history = 20
def _distance_3d(self, a: Optional[tuple], b: Optional[tuple]) -> float:
if a is None or b is None:
return float('inf')
return math.sqrt(sum((ai - bi) ** 2 for ai, bi in zip(a, b)))
def _compute_distance_and_bearing(self, position: tuple) -> Tuple[float, float]:
x, y, z = position
distance = math.sqrt(x * x + y * y + z * z)
bearing = math.degrees(math.atan2(y, x))
return distance, bearing
def _allocate_id(self) -> int:
pid = self._next_id
self._next_id += 1
return pid
def update_face(self, face_id: int, name: str,
position_3d: Optional[tuple], camera_id: int,
now: float) -> int:
"""Match by face_id first. If face_id >= 0, find existing track or create new."""
if face_id >= 0:
for track in self._tracks.values():
if track.face_id == face_id:
track.name = name if name and name != 'unknown' else track.name
if position_3d is not None:
track.position = position_3d
dist, bearing = self._compute_distance_and_bearing(position_3d)
track.distance = dist
track.bearing_deg = bearing
track.history_distances.append(dist)
if len(track.history_distances) > self._max_history:
track.history_distances = track.history_distances[-self._max_history:]
track.camera_id = camera_id
track.last_seen = now
track.last_face_seen = now
return track.person_id
# No existing track with this face_id; try proximity match for face_id < 0
if face_id < 0 and position_3d is not None:
best_track = None
best_dist = 0.5 # 0.5m proximity threshold
for track in self._tracks.values():
d = self._distance_3d(track.position, position_3d)
if d < best_dist and (now - track.last_seen) < self._position_window:
best_dist = d
best_track = track
if best_track is not None:
best_track.position = position_3d
dist, bearing = self._compute_distance_and_bearing(position_3d)
best_track.distance = dist
best_track.bearing_deg = bearing
best_track.history_distances.append(dist)
if len(best_track.history_distances) > self._max_history:
best_track.history_distances = best_track.history_distances[-self._max_history:]
best_track.camera_id = camera_id
best_track.last_seen = now
return best_track.person_id
# Create new track
pid = self._allocate_id()
track = PersonTrack(person_id=pid, face_id=face_id, name=name,
camera_id=camera_id, last_seen=now, last_face_seen=now)
if position_3d is not None:
track.position = position_3d
dist, bearing = self._compute_distance_and_bearing(position_3d)
track.distance = dist
track.bearing_deg = bearing
track.history_distances.append(dist)
self._tracks[pid] = track
return pid
def update_speaker(self, speaker_id: str, now: float) -> int:
"""Find track with nearest recently-seen position, assign speaker_id."""
# First check if any track already has this speaker_id
for track in self._tracks.values():
if track.speaker_id == speaker_id:
track.last_seen = now
return track.person_id
# Assign to nearest recently-seen track
best_track = None
best_dist = float('inf')
for track in self._tracks.values():
if (now - track.last_seen) < self._position_window and track.distance > 0:
if track.distance < best_dist:
best_dist = track.distance
best_track = track
if best_track is not None:
best_track.speaker_id = speaker_id
best_track.last_seen = now
return best_track.person_id
# No suitable track found; create a new one
pid = self._allocate_id()
track = PersonTrack(person_id=pid, speaker_id=speaker_id,
last_seen=now)
self._tracks[pid] = track
return pid
def update_uwb(self, uwb_anchor_id: str, position_3d: tuple,
now: float) -> int:
"""Match by proximity (nearest track within 0.5m)."""
# Check existing UWB assignment
for track in self._tracks.values():
if track.uwb_anchor_id == uwb_anchor_id:
track.position = position_3d
dist, bearing = self._compute_distance_and_bearing(position_3d)
track.distance = dist
track.bearing_deg = bearing
track.history_distances.append(dist)
if len(track.history_distances) > self._max_history:
track.history_distances = track.history_distances[-self._max_history:]
track.last_seen = now
return track.person_id
# Proximity match
best_track = None
best_dist = 0.5
for track in self._tracks.values():
d = self._distance_3d(track.position, position_3d)
if d < best_dist and (now - track.last_seen) < self._position_window:
best_dist = d
best_track = track
if best_track is not None:
best_track.uwb_anchor_id = uwb_anchor_id
best_track.position = position_3d
dist, bearing = self._compute_distance_and_bearing(position_3d)
best_track.distance = dist
best_track.bearing_deg = bearing
best_track.history_distances.append(dist)
if len(best_track.history_distances) > self._max_history:
best_track.history_distances = best_track.history_distances[-self._max_history:]
best_track.last_seen = now
return best_track.person_id
# Create new track
pid = self._allocate_id()
track = PersonTrack(person_id=pid, uwb_anchor_id=uwb_anchor_id,
position=position_3d, last_seen=now)
dist, bearing = self._compute_distance_and_bearing(position_3d)
track.distance = dist
track.bearing_deg = bearing
track.history_distances.append(dist)
self._tracks[pid] = track
return pid
def get_all_tracks(self) -> List[PersonTrack]:
return list(self._tracks.values())
@staticmethod
def compute_attention(tracks: List[PersonTrack]) -> int:
"""Focus on nearest engaged/talking person, or -1 if none."""
candidates = [t for t in tracks
if t.state in (State.ENGAGED, State.TALKING) and t.distance > 0]
if not candidates:
return -1
# Prefer talking, then nearest engaged
talking = [t for t in candidates if t.state == State.TALKING]
if talking:
return min(talking, key=lambda t: t.distance).person_id
return min(candidates, key=lambda t: t.distance).person_id
def prune_absent(self, absent_timeout: float = 30.0):
"""Remove tracks absent longer than timeout."""
now = time.monotonic()
to_remove = [pid for pid, t in self._tracks.items()
if t.state == State.ABSENT and (now - t.last_seen) > absent_timeout]
for pid in to_remove:
del self._tracks[pid]
@staticmethod
def detect_group(tracks: List[PersonTrack]) -> bool:
"""Returns True if >= 2 persons within 1.5m of each other."""
active = [t for t in tracks
if t.position is not None and t.state != State.ABSENT]
for i, a in enumerate(active):
for b in active[i + 1:]:
dx = a.position[0] - b.position[0]
dy = a.position[1] - b.position[1]
dz = a.position[2] - b.position[2]
if math.sqrt(dx * dx + dy * dy + dz * dz) < 1.5:
return True
return False

View File

@ -0,0 +1,54 @@
import time
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Optional
class State(IntEnum):
UNKNOWN = 0
APPROACHING = 1
ENGAGED = 2
TALKING = 3
LEAVING = 4
ABSENT = 5
@dataclass
class PersonTrack:
person_id: int
face_id: int = -1
speaker_id: str = ''
uwb_anchor_id: str = ''
name: str = 'unknown'
state: State = State.UNKNOWN
position: Optional[tuple] = None # (x, y, z) in base_link
distance: float = 0.0
bearing_deg: float = 0.0
engagement_score: float = 0.0
camera_id: int = -1
last_seen: float = field(default_factory=time.monotonic)
last_face_seen: float = 0.0
history_distances: list = field(default_factory=list) # last N distances for trend
def update_state(self, now: float, engagement_distance: float = 2.0,
absent_timeout: float = 5.0):
"""Transition state machine based on distance trend and time."""
age = now - self.last_seen
if age > absent_timeout:
self.state = State.ABSENT
return
if self.speaker_id:
self.state = State.TALKING
return
if len(self.history_distances) >= 3:
trend = self.history_distances[-1] - self.history_distances[-3]
if self.distance < engagement_distance:
self.state = State.ENGAGED
elif trend < -0.2: # moving closer
self.state = State.APPROACHING
elif trend > 0.3: # moving away
self.state = State.LEAVING
else:
self.state = State.ENGAGED if self.distance < engagement_distance else State.UNKNOWN
elif self.distance > 0:
self.state = State.ENGAGED if self.distance < engagement_distance else State.APPROACHING

View File

@ -0,0 +1,273 @@
import time
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from std_msgs.msg import Int32, String
from geometry_msgs.msg import PoseArray, PoseStamped, Point
from sensor_msgs.msg import Image
from builtin_interfaces.msg import Time as TimeMsg
from saltybot_social_msgs.msg import (
FaceDetection,
FaceDetectionArray,
PersonState,
PersonStateArray,
)
from saltybot_social.identity_fuser import IdentityFuser
from saltybot_social.person_state import State
class PersonStateTrackerNode(Node):
"""Main ROS2 node for multi-modal person tracking."""
def __init__(self):
super().__init__('person_state_tracker')
# Parameters
self.declare_parameter('engagement_distance', 2.0)
self.declare_parameter('absent_timeout', 5.0)
self.declare_parameter('prune_timeout', 30.0)
self.declare_parameter('update_rate', 10.0)
self.declare_parameter('n_cameras', 4)
self.declare_parameter('uwb_enabled', False)
self._engagement_distance = self.get_parameter('engagement_distance').value
self._absent_timeout = self.get_parameter('absent_timeout').value
self._prune_timeout = self.get_parameter('prune_timeout').value
update_rate = self.get_parameter('update_rate').value
n_cameras = self.get_parameter('n_cameras').value
uwb_enabled = self.get_parameter('uwb_enabled').value
self._fuser = IdentityFuser()
# QoS profiles
best_effort_qos = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=5
)
reliable_qos = QoSProfile(
reliability=ReliabilityPolicy.RELIABLE,
history=HistoryPolicy.KEEP_LAST,
depth=1
)
# Subscriptions
self.create_subscription(
FaceDetectionArray,
'/social/faces/detections',
self._on_face_detections,
best_effort_qos
)
self.create_subscription(
String,
'/social/speech/speaker_id',
self._on_speaker_id,
10
)
if uwb_enabled:
self.create_subscription(
PoseArray,
'/uwb/positions',
self._on_uwb,
10
)
# Camera subscriptions (monitor topic existence)
self.create_subscription(
Image, '/camera/color/image_raw',
self._on_camera_image, best_effort_qos)
for i in range(n_cameras):
self.create_subscription(
Image, f'/surround/cam{i}/image_raw',
self._on_camera_image, best_effort_qos)
# Existing person tracker position
self.create_subscription(
PoseStamped,
'/person/target',
self._on_person_target,
10
)
# Publishers
self._persons_pub = self.create_publisher(
PersonStateArray, '/social/persons', best_effort_qos)
self._attention_pub = self.create_publisher(
Int32, '/social/attention/target_id', reliable_qos)
# Timer
timer_period = 1.0 / update_rate
self.create_timer(timer_period, self._on_timer)
self.get_logger().info(
f'PersonStateTracker started: engagement={self._engagement_distance}m, '
f'absent_timeout={self._absent_timeout}s, rate={update_rate}Hz, '
f'cameras={n_cameras}, uwb={uwb_enabled}')
def _on_face_detections(self, msg: FaceDetectionArray):
now = time.monotonic()
for face in msg.faces:
position_3d = None
# Use bbox center as rough position estimate if depth is available
# Real 3D position would come from depth camera projection
if face.bbox_x > 0 or face.bbox_y > 0:
# Approximate: use bbox center as bearing proxy, distance from bbox size
# This is a placeholder; real impl would use depth image lookup
position_3d = (
max(0.5, 1.0 / max(face.bbox_w, 0.01)), # rough depth from bbox width
face.bbox_x + face.bbox_w / 2.0 - 0.5, # x offset from center
face.bbox_y + face.bbox_h / 2.0 - 0.5 # y offset from center
)
camera_id = -1
if hasattr(face.header, 'frame_id') and face.header.frame_id:
frame = face.header.frame_id
if 'cam' in frame:
try:
camera_id = int(frame.split('cam')[-1].split('_')[0])
except ValueError:
pass
self._fuser.update_face(
face_id=face.face_id,
name=face.person_name,
position_3d=position_3d,
camera_id=camera_id,
now=now
)
def _on_speaker_id(self, msg: String):
now = time.monotonic()
speaker_id = msg.data
# Parse "speaker_<id>" format or raw id
if speaker_id.startswith('speaker_'):
speaker_id = speaker_id[len('speaker_'):]
self._fuser.update_speaker(speaker_id, now)
def _on_uwb(self, msg: PoseArray):
now = time.monotonic()
for i, pose in enumerate(msg.poses):
position = (pose.position.x, pose.position.y, pose.position.z)
anchor_id = f'uwb_{i}'
if hasattr(msg, 'header') and msg.header.frame_id:
anchor_id = f'{msg.header.frame_id}_{i}'
self._fuser.update_uwb(anchor_id, position, now)
def _on_camera_image(self, msg: Image):
# Monitor only -- no processing here
pass
def _on_person_target(self, msg: PoseStamped):
now = time.monotonic()
position_3d = (
msg.pose.position.x,
msg.pose.position.y,
msg.pose.position.z
)
# Update position for unidentified person (from YOLOv8 tracker)
self._fuser.update_face(
face_id=-1,
name='unknown',
position_3d=position_3d,
camera_id=-1,
now=now
)
def _on_timer(self):
now = time.monotonic()
tracks = self._fuser.get_all_tracks()
# Update state machine for all tracks
for track in tracks:
track.update_state(
now,
engagement_distance=self._engagement_distance,
absent_timeout=self._absent_timeout
)
# Compute engagement scores
for track in tracks:
if track.state == State.ABSENT:
track.engagement_score = 0.0
elif track.state == State.TALKING:
track.engagement_score = 1.0
elif track.state == State.ENGAGED:
track.engagement_score = max(0.0, 1.0 - track.distance / self._engagement_distance)
elif track.state == State.APPROACHING:
track.engagement_score = 0.3
elif track.state == State.LEAVING:
track.engagement_score = 0.1
else:
track.engagement_score = 0.0
# Prune long-absent tracks
self._fuser.prune_absent(self._prune_timeout)
# Compute attention target
attention_id = IdentityFuser.compute_attention(tracks)
# Build and publish PersonStateArray
msg = PersonStateArray()
msg.header.stamp = self.get_clock().now().to_msg()
msg.header.frame_id = 'base_link'
msg.primary_attention_id = attention_id
for track in tracks:
ps = PersonState()
ps.header.stamp = msg.header.stamp
ps.header.frame_id = 'base_link'
ps.person_id = track.person_id
ps.person_name = track.name
ps.face_id = track.face_id
ps.speaker_id = track.speaker_id
ps.uwb_anchor_id = track.uwb_anchor_id
if track.position is not None:
ps.position = Point(
x=float(track.position[0]),
y=float(track.position[1]),
z=float(track.position[2]))
ps.distance = float(track.distance)
ps.bearing_deg = float(track.bearing_deg)
ps.state = int(track.state)
ps.engagement_score = float(track.engagement_score)
ps.last_seen = self._mono_to_ros_time(track.last_seen)
ps.camera_id = track.camera_id
msg.persons.append(ps)
self._persons_pub.publish(msg)
# Publish attention target
att_msg = Int32()
att_msg.data = attention_id
self._attention_pub.publish(att_msg)
def _mono_to_ros_time(self, mono: float) -> TimeMsg:
"""Convert monotonic timestamp to approximate ROS time."""
# Offset from monotonic to wall clock
offset = time.time() - time.monotonic()
wall = mono + offset
sec = int(wall)
nsec = int((wall - sec) * 1e9)
t = TimeMsg()
t.sec = sec
t.nanosec = nsec
return t
def main(args=None):
rclpy.init(args=args)
node = PersonStateTrackerNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/saltybot_social
[install]
install_scripts=$base/lib/saltybot_social

View File

@ -0,0 +1,32 @@
from setuptools import find_packages, setup
import os
from glob import glob
package_name = 'saltybot_social'
setup(
name=package_name,
version='0.1.0',
packages=find_packages(exclude=['test']),
data_files=[
('share/ament_index/resource_index/packages',
['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
(os.path.join('share', package_name, 'launch'),
glob(os.path.join('launch', '*launch.[pxy][yma]*'))),
(os.path.join('share', package_name, 'config'),
glob(os.path.join('config', '*.yaml'))),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='seb',
maintainer_email='seb@vayrette.com',
description='Multi-modal person identity fusion and state tracking for saltybot',
license='MIT',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'person_state_tracker = saltybot_social.person_state_tracker_node:main',
],
},
)

View File

@ -0,0 +1,11 @@
face_recognizer:
ros__parameters:
scrfd_engine_path: '/mnt/nvme/saltybot/models/scrfd_2.5g.engine'
scrfd_onnx_path: '/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx'
arcface_engine_path: '/mnt/nvme/saltybot/models/arcface_r50.engine'
arcface_onnx_path: '/mnt/nvme/saltybot/models/arcface_r50.onnx'
gallery_dir: '/mnt/nvme/saltybot/gallery'
recognition_threshold: 0.35
publish_debug_image: false
max_faces: 10
scrfd_conf_thresh: 0.5

View File

@ -0,0 +1,80 @@
"""
face_recognition.launch.py -- Launch file for the SCRFD + ArcFace face recognition node.
Launches the face_recognizer node with configurable model paths and parameters.
The RealSense camera must be running separately (e.g., via realsense.launch.py).
"""
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
"""Generate launch description for face recognition pipeline."""
return LaunchDescription([
DeclareLaunchArgument(
'scrfd_engine_path',
default_value='/mnt/nvme/saltybot/models/scrfd_2.5g.engine',
description='Path to SCRFD TensorRT engine file',
),
DeclareLaunchArgument(
'scrfd_onnx_path',
default_value='/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx',
description='Path to SCRFD ONNX model file (fallback)',
),
DeclareLaunchArgument(
'arcface_engine_path',
default_value='/mnt/nvme/saltybot/models/arcface_r50.engine',
description='Path to ArcFace TensorRT engine file',
),
DeclareLaunchArgument(
'arcface_onnx_path',
default_value='/mnt/nvme/saltybot/models/arcface_r50.onnx',
description='Path to ArcFace ONNX model file (fallback)',
),
DeclareLaunchArgument(
'gallery_dir',
default_value='/mnt/nvme/saltybot/gallery',
description='Directory for persistent face gallery storage',
),
DeclareLaunchArgument(
'recognition_threshold',
default_value='0.35',
description='Cosine similarity threshold for face recognition',
),
DeclareLaunchArgument(
'publish_debug_image',
default_value='false',
description='Publish annotated debug image to /social/faces/debug_image',
),
DeclareLaunchArgument(
'max_faces',
default_value='10',
description='Maximum faces to process per frame',
),
DeclareLaunchArgument(
'scrfd_conf_thresh',
default_value='0.5',
description='SCRFD detection confidence threshold',
),
Node(
package='saltybot_social_face',
executable='face_recognition',
name='face_recognizer',
output='screen',
parameters=[{
'scrfd_engine_path': LaunchConfiguration('scrfd_engine_path'),
'scrfd_onnx_path': LaunchConfiguration('scrfd_onnx_path'),
'arcface_engine_path': LaunchConfiguration('arcface_engine_path'),
'arcface_onnx_path': LaunchConfiguration('arcface_onnx_path'),
'gallery_dir': LaunchConfiguration('gallery_dir'),
'recognition_threshold': LaunchConfiguration('recognition_threshold'),
'publish_debug_image': LaunchConfiguration('publish_debug_image'),
'max_faces': LaunchConfiguration('max_faces'),
'scrfd_conf_thresh': LaunchConfiguration('scrfd_conf_thresh'),
}],
),
])

View File

@ -0,0 +1,27 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social_face</name>
<version>0.1.0</version>
<description>SCRFD face detection and ArcFace recognition for SaltyBot social interactions</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>sensor_msgs</depend>
<depend>cv_bridge</depend>
<depend>image_transport</depend>
<depend>saltybot_social_msgs</depend>
<exec_depend>python3-numpy</exec_depend>
<exec_depend>python3-opencv</exec_depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1 @@
"""SaltyBot social face detection and recognition package."""

View File

@ -0,0 +1,316 @@
"""
arcface_recognizer.py -- ArcFace face embedding extraction and gallery matching.
Performs face alignment using 5-point landmarks (insightface standard reference),
extracts 512-dimensional embeddings via ArcFace (TRT FP16 or ONNX fallback),
and matches against a persistent gallery using cosine similarity.
"""
import os
import logging
from typing import Optional
import numpy as np
import cv2
logger = logging.getLogger(__name__)
# InsightFace standard reference landmarks for 112x112 alignment
ARCFACE_SRC = np.array([
[38.2946, 51.6963], # left eye
[73.5318, 51.5014], # right eye
[56.0252, 71.7366], # nose
[41.5493, 92.3655], # left mouth
[70.7299, 92.2041], # right mouth
], dtype=np.float32)
# -- Inference backends --------------------------------------------------------
class _TRTBackend:
"""TensorRT inference engine for ArcFace."""
def __init__(self, engine_path: str):
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit # noqa: F401
self._cuda = cuda
trt_logger = trt.Logger(trt.Logger.WARNING)
with open(engine_path, 'rb') as f, trt.Runtime(trt_logger) as runtime:
self._engine = runtime.deserialize_cuda_engine(f.read())
self._context = self._engine.create_execution_context()
self._inputs = []
self._outputs = []
self._bindings = []
for i in range(self._engine.num_io_tensors):
name = self._engine.get_tensor_name(i)
dtype = trt.nptype(self._engine.get_tensor_dtype(name))
shape = tuple(self._engine.get_tensor_shape(name))
nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize
host_mem = cuda.pagelocked_empty(shape, dtype)
device_mem = cuda.mem_alloc(nbytes)
self._bindings.append(int(device_mem))
if self._engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
self._inputs.append({'host': host_mem, 'device': device_mem})
else:
self._outputs.append({'host': host_mem, 'device': device_mem,
'shape': shape})
self._stream = cuda.Stream()
def infer(self, input_data: np.ndarray) -> np.ndarray:
"""Run inference and return the embedding vector."""
np.copyto(self._inputs[0]['host'], input_data.ravel())
self._cuda.memcpy_htod_async(
self._inputs[0]['device'], self._inputs[0]['host'], self._stream)
self._context.execute_async_v2(self._bindings, self._stream.handle)
for out in self._outputs:
self._cuda.memcpy_dtoh_async(out['host'], out['device'], self._stream)
self._stream.synchronize()
return self._outputs[0]['host'].reshape(self._outputs[0]['shape']).copy()
class _ONNXBackend:
"""ONNX Runtime inference (CUDA EP with CPU fallback)."""
def __init__(self, onnx_path: str):
import onnxruntime as ort
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
self._session = ort.InferenceSession(onnx_path, providers=providers)
self._input_name = self._session.get_inputs()[0].name
def infer(self, input_data: np.ndarray) -> np.ndarray:
"""Run inference and return the embedding vector."""
results = self._session.run(None, {self._input_name: input_data})
return results[0]
# -- Face alignment ------------------------------------------------------------
def align_face(bgr: np.ndarray, landmarks_10: list[float]) -> np.ndarray:
"""Align a face to 112x112 using 5-point landmarks.
Args:
bgr: Source BGR image.
landmarks_10: Flat list of 10 floats [x0,y0, x1,y1, ..., x4,y4].
Returns:
Aligned BGR face crop of shape (112, 112, 3).
"""
src_pts = np.array(landmarks_10, dtype=np.float32).reshape(5, 2)
M, _ = cv2.estimateAffinePartial2D(src_pts, ARCFACE_SRC)
if M is None:
# Fallback: simple crop and resize from bbox-like region
cx = np.mean(src_pts[:, 0])
cy = np.mean(src_pts[:, 1])
spread = max(np.ptp(src_pts[:, 0]), np.ptp(src_pts[:, 1])) * 1.5
half = spread / 2
x1 = max(0, int(cx - half))
y1 = max(0, int(cy - half))
x2 = min(bgr.shape[1], int(cx + half))
y2 = min(bgr.shape[0], int(cy + half))
crop = bgr[y1:y2, x1:x2]
return cv2.resize(crop, (112, 112), interpolation=cv2.INTER_LINEAR)
aligned = cv2.warpAffine(bgr, M, (112, 112), borderMode=cv2.BORDER_REPLICATE)
return aligned
# -- Main recognizer class -----------------------------------------------------
class ArcFaceRecognizer:
"""ArcFace face embedding extractor and gallery matcher.
Args:
engine_path: Path to TensorRT engine file.
onnx_path: Path to ONNX model file (used if engine not available).
"""
def __init__(self, engine_path: str = '', onnx_path: str = ''):
self._backend: Optional[_TRTBackend | _ONNXBackend] = None
self.gallery: dict[int, dict] = {}
# Try TRT first, then ONNX
if engine_path and os.path.isfile(engine_path):
try:
self._backend = _TRTBackend(engine_path)
logger.info('ArcFace TensorRT backend loaded: %s', engine_path)
return
except Exception as e:
logger.warning('ArcFace TRT load failed (%s), trying ONNX', e)
if onnx_path and os.path.isfile(onnx_path):
try:
self._backend = _ONNXBackend(onnx_path)
logger.info('ArcFace ONNX backend loaded: %s', onnx_path)
return
except Exception as e:
logger.error('ArcFace ONNX load failed: %s', e)
logger.error('No ArcFace model loaded. Recognition will be unavailable.')
@property
def is_loaded(self) -> bool:
"""Return True if a backend is loaded and ready."""
return self._backend is not None
def embed(self, bgr_face_112x112: np.ndarray) -> np.ndarray:
"""Extract 512-dim L2-normalized embedding from a 112x112 aligned face.
Args:
bgr_face_112x112: Aligned face crop, BGR, shape (112, 112, 3).
Returns:
L2-normalized embedding of shape (512,).
"""
if self._backend is None:
return np.zeros(512, dtype=np.float32)
# Preprocess: BGR->RGB, /255, subtract 0.5, divide 0.5 -> [1,3,112,112]
rgb = cv2.cvtColor(bgr_face_112x112, cv2.COLOR_BGR2RGB).astype(np.float32)
rgb = rgb / 255.0
rgb = (rgb - 0.5) / 0.5
blob = rgb.transpose(2, 0, 1)[np.newaxis] # [1, 3, 112, 112]
blob = np.ascontiguousarray(blob)
output = self._backend.infer(blob)
embedding = output.flatten()[:512].astype(np.float32)
# L2 normalize
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
return embedding
def align_and_embed(self, bgr_image: np.ndarray, landmarks_10: list[float]) -> np.ndarray:
"""Align face using landmarks and extract embedding.
Args:
bgr_image: Full BGR image.
landmarks_10: Flat list of 10 floats from SCRFD detection.
Returns:
L2-normalized embedding of shape (512,).
"""
aligned = align_face(bgr_image, landmarks_10)
return self.embed(aligned)
def load_gallery(self, gallery_path: str) -> None:
"""Load gallery from .npz file with JSON metadata sidecar.
Args:
gallery_path: Path to the .npz gallery file.
"""
import json
if not os.path.isfile(gallery_path):
logger.info('No gallery file at %s, starting empty.', gallery_path)
self.gallery = {}
return
data = np.load(gallery_path, allow_pickle=False)
meta_path = gallery_path.replace('.npz', '_meta.json')
if os.path.isfile(meta_path):
with open(meta_path, 'r') as f:
meta = json.load(f)
else:
meta = {}
self.gallery = {}
for key in data.files:
pid = int(key)
embedding = data[key].astype(np.float32)
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
info = meta.get(str(pid), {})
self.gallery[pid] = {
'name': info.get('name', f'person_{pid}'),
'embedding': embedding,
'samples': info.get('samples', 1),
'enrolled_at': info.get('enrolled_at', 0.0),
}
logger.info('Gallery loaded: %d persons from %s', len(self.gallery), gallery_path)
def save_gallery(self, gallery_path: str) -> None:
"""Save gallery to .npz file with JSON metadata sidecar.
Args:
gallery_path: Path to the .npz gallery file.
"""
import json
arrays = {}
meta = {}
for pid, info in self.gallery.items():
arrays[str(pid)] = info['embedding']
meta[str(pid)] = {
'name': info['name'],
'samples': info['samples'],
'enrolled_at': info['enrolled_at'],
}
os.makedirs(os.path.dirname(gallery_path) or '.', exist_ok=True)
np.savez(gallery_path, **arrays)
meta_path = gallery_path.replace('.npz', '_meta.json')
with open(meta_path, 'w') as f:
json.dump(meta, f, indent=2)
logger.info('Gallery saved: %d persons to %s', len(self.gallery), gallery_path)
def match(self, embedding: np.ndarray, threshold: float = 0.35) -> tuple[int, str, float]:
"""Match an embedding against the gallery.
Args:
embedding: L2-normalized query embedding of shape (512,).
threshold: Minimum cosine similarity for a match.
Returns:
(person_id, person_name, score) or (-1, '', 0.0) if no match.
"""
if not self.gallery:
return (-1, '', 0.0)
best_pid = -1
best_name = ''
best_score = 0.0
for pid, info in self.gallery.items():
score = float(np.dot(embedding, info['embedding']))
if score > best_score:
best_score = score
best_pid = pid
best_name = info['name']
if best_score >= threshold:
return (best_pid, best_name, best_score)
return (-1, '', 0.0)
def enroll(self, person_id: int, person_name: str, embeddings_list: list[np.ndarray]) -> None:
"""Enroll a person by averaging multiple embeddings.
Args:
person_id: Unique integer ID for this person.
person_name: Human-readable name.
embeddings_list: List of L2-normalized embeddings to average.
"""
import time as _time
if not embeddings_list:
return
mean_emb = np.mean(embeddings_list, axis=0).astype(np.float32)
norm = np.linalg.norm(mean_emb)
if norm > 0:
mean_emb = mean_emb / norm
self.gallery[person_id] = {
'name': person_name,
'embedding': mean_emb,
'samples': len(embeddings_list),
'enrolled_at': _time.time(),
}
logger.info('Enrolled person %d (%s) with %d samples.',
person_id, person_name, len(embeddings_list))

View File

@ -0,0 +1,78 @@
#!/usr/bin/env python3
"""
enrollment_cli.py -- CLI tool for enrolling persons via the /social/enroll service.
Usage:
ros2 run saltybot_social_face enrollment_cli -- --name Alice --mode face --samples 10
"""
import argparse
import sys
import rclpy
from rclpy.node import Node
from saltybot_social_msgs.srv import EnrollPerson
class EnrollmentCLI(Node):
"""Simple CLI node that calls the EnrollPerson service."""
def __init__(self, name: str, mode: str, n_samples: int):
super().__init__('enrollment_cli')
self._client = self.create_client(EnrollPerson, '/social/enroll')
self.get_logger().info('Waiting for /social/enroll service...')
if not self._client.wait_for_service(timeout_sec=10.0):
self.get_logger().error('Service /social/enroll not available.')
return
request = EnrollPerson.Request()
request.name = name
request.mode = mode
request.n_samples = n_samples
self.get_logger().info(
'Enrolling "%s" (mode=%s, samples=%d)...', name, mode, n_samples)
future = self._client.call_async(request)
rclpy.spin_until_future_complete(self, future, timeout_sec=120.0)
if future.result() is not None:
result = future.result()
if result.success:
self.get_logger().info(
'Enrollment successful: person_id=%d, %s',
result.person_id, result.message)
else:
self.get_logger().error(
'Enrollment failed: %s', result.message)
else:
self.get_logger().error('Enrollment service call timed out or failed.')
def main(args=None):
"""Entry point for enrollment CLI."""
parser = argparse.ArgumentParser(description='Enroll a person for face recognition.')
parser.add_argument('--name', type=str, required=True,
help='Name of the person to enroll.')
parser.add_argument('--mode', type=str, default='face',
choices=['face', 'voice', 'both'],
help='Enrollment mode (default: face).')
parser.add_argument('--samples', type=int, default=10,
help='Number of face samples to collect (default: 10).')
# Parse only known args so ROS2 remapping args pass through
parsed, remaining = parser.parse_known_args(args=sys.argv[1:])
rclpy.init(args=remaining)
node = EnrollmentCLI(parsed.name, parsed.mode, parsed.samples)
try:
pass # Node does all work in __init__
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,206 @@
"""
face_gallery.py -- Persistent face embedding gallery backed by numpy .npz + JSON.
Thread-safe gallery storage for face recognition. Embeddings are stored in a
.npz file, with a sidecar metadata.json containing names, sample counts, and
enrollment timestamps. Auto-increment IDs start at 1.
"""
import json
import logging
import os
import threading
import time
from typing import Optional
import numpy as np
logger = logging.getLogger(__name__)
class FaceGallery:
"""Persistent, thread-safe face embedding gallery.
Args:
gallery_dir: Directory for gallery.npz and metadata.json files.
"""
def __init__(self, gallery_dir: str):
self._gallery_dir = gallery_dir
self._npz_path = os.path.join(gallery_dir, 'gallery.npz')
self._meta_path = os.path.join(gallery_dir, 'metadata.json')
self._gallery: dict[int, dict] = {}
self._next_id = 1
self._lock = threading.Lock()
def load(self) -> None:
"""Load gallery from disk. Populates internal gallery dict."""
with self._lock:
self._gallery = {}
self._next_id = 1
if not os.path.isfile(self._npz_path):
logger.info('No gallery file at %s, starting empty.', self._npz_path)
return
data = np.load(self._npz_path, allow_pickle=False)
meta: dict = {}
if os.path.isfile(self._meta_path):
with open(self._meta_path, 'r') as f:
meta = json.load(f)
for key in data.files:
pid = int(key)
embedding = data[key].astype(np.float32)
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
info = meta.get(str(pid), {})
self._gallery[pid] = {
'name': info.get('name', f'person_{pid}'),
'embedding': embedding,
'samples': info.get('samples', 1),
'enrolled_at': info.get('enrolled_at', 0.0),
}
if pid >= self._next_id:
self._next_id = pid + 1
logger.info('Gallery loaded: %d persons from %s',
len(self._gallery), self._npz_path)
def save(self) -> None:
"""Save gallery to disk (npz + JSON sidecar)."""
with self._lock:
os.makedirs(self._gallery_dir, exist_ok=True)
arrays = {}
meta = {}
for pid, info in self._gallery.items():
arrays[str(pid)] = info['embedding']
meta[str(pid)] = {
'name': info['name'],
'samples': info['samples'],
'enrolled_at': info['enrolled_at'],
}
np.savez(self._npz_path, **arrays)
with open(self._meta_path, 'w') as f:
json.dump(meta, f, indent=2)
logger.info('Gallery saved: %d persons to %s',
len(self._gallery), self._npz_path)
def add_person(self, name: str, embedding: np.ndarray, samples: int = 1) -> int:
"""Add a new person to the gallery.
Args:
name: Person's name.
embedding: L2-normalized 512-dim embedding.
samples: Number of samples used to compute the embedding.
Returns:
Assigned person_id (auto-increment integer).
"""
with self._lock:
pid = self._next_id
self._next_id += 1
emb = embedding.astype(np.float32)
norm = np.linalg.norm(emb)
if norm > 0:
emb = emb / norm
self._gallery[pid] = {
'name': name,
'embedding': emb,
'samples': samples,
'enrolled_at': time.time(),
}
logger.info('Added person %d (%s), %d samples.', pid, name, samples)
return pid
def update_name(self, person_id: int, new_name: str) -> bool:
"""Update a person's name.
Args:
person_id: The ID of the person to update.
new_name: New name string.
Returns:
True if the person was found and updated.
"""
with self._lock:
if person_id not in self._gallery:
return False
self._gallery[person_id]['name'] = new_name
return True
def delete_person(self, person_id: int) -> bool:
"""Remove a person from the gallery.
Args:
person_id: The ID of the person to delete.
Returns:
True if the person was found and removed.
"""
with self._lock:
if person_id not in self._gallery:
return False
del self._gallery[person_id]
logger.info('Deleted person %d.', person_id)
return True
def get_all(self) -> list[dict]:
"""Get all gallery entries.
Returns:
List of dicts with keys: person_id, name, embedding, samples, enrolled_at.
"""
with self._lock:
result = []
for pid, info in self._gallery.items():
result.append({
'person_id': pid,
'name': info['name'],
'embedding': info['embedding'].copy(),
'samples': info['samples'],
'enrolled_at': info['enrolled_at'],
})
return result
def match(self, query_embedding: np.ndarray, threshold: float = 0.35) -> tuple[int, str, float]:
"""Match a query embedding against the gallery using cosine similarity.
Args:
query_embedding: L2-normalized 512-dim embedding.
threshold: Minimum cosine similarity for a match.
Returns:
(person_id, name, score) or (-1, '', 0.0) if no match.
"""
with self._lock:
if not self._gallery:
return (-1, '', 0.0)
best_pid = -1
best_name = ''
best_score = 0.0
query = query_embedding.astype(np.float32)
norm = np.linalg.norm(query)
if norm > 0:
query = query / norm
for pid, info in self._gallery.items():
score = float(np.dot(query, info['embedding']))
if score > best_score:
best_score = score
best_pid = pid
best_name = info['name']
if best_score >= threshold:
return (best_pid, best_name, best_score)
return (-1, '', 0.0)
def __len__(self) -> int:
with self._lock:
return len(self._gallery)

View File

@ -0,0 +1,431 @@
"""
face_recognition_node.py -- ROS2 node for SCRFD face detection + ArcFace recognition.
Pipeline:
1. Subscribe to /camera/color/image_raw (RealSense D435i color stream).
2. Run SCRFD face detection (TensorRT FP16 or ONNX fallback).
3. For each detected face, align and extract ArcFace embedding.
4. Match embedding against persistent gallery.
5. Publish FaceDetectionArray with identified faces.
Services:
/social/enroll -- Enroll a new person (collects N face samples).
/social/persons/list -- List all enrolled persons.
/social/persons/delete -- Delete a person from the gallery.
/social/persons/update -- Update a person's name.
"""
import time
import numpy as np
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy, DurabilityPolicy
import cv2
from cv_bridge import CvBridge
from sensor_msgs.msg import Image
from builtin_interfaces.msg import Time
from saltybot_social_msgs.msg import (
FaceDetection,
FaceDetectionArray,
FaceEmbedding,
FaceEmbeddingArray,
)
from saltybot_social_msgs.srv import (
EnrollPerson,
ListPersons,
DeletePerson,
UpdatePerson,
)
from .scrfd_detector import SCRFDDetector
from .arcface_recognizer import ArcFaceRecognizer
from .face_gallery import FaceGallery
class FaceRecognitionNode(Node):
"""ROS2 node: SCRFD face detection + ArcFace gallery matching."""
def __init__(self):
super().__init__('face_recognizer')
self._bridge = CvBridge()
self._frame_count = 0
self._fps_t0 = time.monotonic()
# -- Parameters --------------------------------------------------------
self.declare_parameter('scrfd_engine_path',
'/mnt/nvme/saltybot/models/scrfd_2.5g.engine')
self.declare_parameter('scrfd_onnx_path',
'/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx')
self.declare_parameter('arcface_engine_path',
'/mnt/nvme/saltybot/models/arcface_r50.engine')
self.declare_parameter('arcface_onnx_path',
'/mnt/nvme/saltybot/models/arcface_r50.onnx')
self.declare_parameter('gallery_dir', '/mnt/nvme/saltybot/gallery')
self.declare_parameter('recognition_threshold', 0.35)
self.declare_parameter('publish_debug_image', False)
self.declare_parameter('max_faces', 10)
self.declare_parameter('scrfd_conf_thresh', 0.5)
self._recognition_threshold = self.get_parameter('recognition_threshold').value
self._pub_debug_flag = self.get_parameter('publish_debug_image').value
self._max_faces = self.get_parameter('max_faces').value
# -- Models ------------------------------------------------------------
self._detector = SCRFDDetector(
engine_path=self.get_parameter('scrfd_engine_path').value,
onnx_path=self.get_parameter('scrfd_onnx_path').value,
conf_thresh=self.get_parameter('scrfd_conf_thresh').value,
)
self._recognizer = ArcFaceRecognizer(
engine_path=self.get_parameter('arcface_engine_path').value,
onnx_path=self.get_parameter('arcface_onnx_path').value,
)
# -- Gallery -----------------------------------------------------------
gallery_dir = self.get_parameter('gallery_dir').value
self._gallery = FaceGallery(gallery_dir)
self._gallery.load()
self.get_logger().info('Gallery loaded: %d persons.', len(self._gallery))
# -- Enrollment state --------------------------------------------------
self._enrolling = None # {name, samples_needed, collected: [embeddings]}
# -- QoS profiles ------------------------------------------------------
best_effort_qos = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=1,
)
reliable_qos = QoSProfile(
reliability=ReliabilityPolicy.RELIABLE,
durability=DurabilityPolicy.TRANSIENT_LOCAL,
history=HistoryPolicy.KEEP_LAST,
depth=1,
)
# -- Subscribers -------------------------------------------------------
self.create_subscription(
Image,
'/camera/color/image_raw',
self._on_image,
best_effort_qos,
)
# -- Publishers --------------------------------------------------------
self._pub_detections = self.create_publisher(
FaceDetectionArray, '/social/faces/detections', best_effort_qos)
self._pub_embeddings = self.create_publisher(
FaceEmbeddingArray, '/social/faces/embeddings', reliable_qos)
if self._pub_debug_flag:
self._pub_debug_img = self.create_publisher(
Image, '/social/faces/debug_image', best_effort_qos)
# -- Services ----------------------------------------------------------
self.create_service(EnrollPerson, '/social/enroll', self._handle_enroll)
self.create_service(ListPersons, '/social/persons/list', self._handle_list)
self.create_service(DeletePerson, '/social/persons/delete', self._handle_delete)
self.create_service(UpdatePerson, '/social/persons/update', self._handle_update)
# Publish initial gallery state
self._publish_gallery_embeddings()
self.get_logger().info('FaceRecognitionNode ready.')
# -- Image callback --------------------------------------------------------
def _on_image(self, msg: Image):
"""Process incoming camera frame: detect, recognize, publish."""
try:
bgr = self._bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')
except Exception as e:
self.get_logger().error('Image decode error: %s', str(e),
throttle_duration_sec=5.0)
return
# Detect faces
detections = self._detector.detect(bgr)
# Limit face count
if len(detections) > self._max_faces:
detections = sorted(detections, key=lambda d: d['score'], reverse=True)
detections = detections[:self._max_faces]
# Build output message
det_array = FaceDetectionArray()
det_array.header = msg.header
for det in detections:
# Extract embedding and match gallery
embedding = self._recognizer.align_and_embed(bgr, det['kps'])
pid, pname, score = self._gallery.match(
embedding, self._recognition_threshold)
# Handle enrollment: collect embedding from largest face
if self._enrolling is not None:
self._enrollment_collect(det, embedding)
# Build FaceDetection message
face_msg = FaceDetection()
face_msg.header = msg.header
face_msg.face_id = pid
face_msg.person_name = pname
face_msg.confidence = det['score']
face_msg.recognition_score = score
bbox = det['bbox']
face_msg.bbox_x = bbox[0]
face_msg.bbox_y = bbox[1]
face_msg.bbox_w = bbox[2] - bbox[0]
face_msg.bbox_h = bbox[3] - bbox[1]
kps = det['kps']
for i in range(10):
face_msg.landmarks[i] = kps[i]
det_array.faces.append(face_msg)
self._pub_detections.publish(det_array)
# Debug image
if self._pub_debug_flag and hasattr(self, '_pub_debug_img'):
debug_img = self._draw_debug(bgr, detections, det_array.faces)
self._pub_debug_img.publish(
self._bridge.cv2_to_imgmsg(debug_img, encoding='bgr8'))
# FPS logging
self._frame_count += 1
if self._frame_count % 30 == 0:
elapsed = time.monotonic() - self._fps_t0
fps = 30.0 / elapsed if elapsed > 0 else 0.0
self._fps_t0 = time.monotonic()
self.get_logger().info(
'FPS: %.1f | faces: %d', fps, len(detections))
# -- Enrollment logic ------------------------------------------------------
def _enrollment_collect(self, det: dict, embedding: np.ndarray):
"""Collect an embedding sample during enrollment (largest face only)."""
if self._enrolling is None:
return
# Only collect from the largest face (by bbox area)
bbox = det['bbox']
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
if not hasattr(self, '_enroll_best_area'):
self._enroll_best_area = 0.0
self._enroll_best_embedding = None
if area > self._enroll_best_area:
self._enroll_best_area = area
self._enroll_best_embedding = embedding
def _enrollment_frame_end(self):
"""Called at end of each frame to finalize enrollment sample collection."""
if self._enrolling is None or self._enroll_best_embedding is None:
return
self._enrolling['collected'].append(self._enroll_best_embedding)
self._enroll_best_area = 0.0
self._enroll_best_embedding = None
collected = len(self._enrolling['collected'])
needed = self._enrolling['samples_needed']
self.get_logger().info('Enrollment: %d/%d samples for "%s".',
collected, needed, self._enrolling['name'])
if collected >= needed:
# Finalize enrollment
name = self._enrolling['name']
embeddings = self._enrolling['collected']
mean_emb = np.mean(embeddings, axis=0).astype(np.float32)
norm = np.linalg.norm(mean_emb)
if norm > 0:
mean_emb = mean_emb / norm
pid = self._gallery.add_person(name, mean_emb, samples=len(embeddings))
self._gallery.save()
self._publish_gallery_embeddings()
self.get_logger().info('Enrollment complete: person %d (%s).', pid, name)
# Store result for the service callback
self._enrolling['result_pid'] = pid
self._enrolling['done'] = True
self._enrolling = None
# -- Service handlers ------------------------------------------------------
def _handle_enroll(self, request, response):
"""Handle EnrollPerson service: start collecting face samples."""
name = request.name.strip()
if not name:
response.success = False
response.message = 'Name cannot be empty.'
response.person_id = -1
return response
n_samples = request.n_samples if request.n_samples > 0 else 10
self.get_logger().info('Starting enrollment for "%s" (%d samples).',
name, n_samples)
# Set enrollment state — frames will collect embeddings
self._enrolling = {
'name': name,
'samples_needed': n_samples,
'collected': [],
'done': False,
'result_pid': -1,
}
self._enroll_best_area = 0.0
self._enroll_best_embedding = None
# Spin until enrollment is done (blocking service)
rate = self.create_rate(10) # 10 Hz check
timeout_sec = n_samples * 2.0 + 10.0 # generous timeout
t0 = time.monotonic()
while not self._enrolling.get('done', False):
# Finalize any pending frame collection
self._enrollment_frame_end()
if time.monotonic() - t0 > timeout_sec:
self._enrolling = None
response.success = False
response.message = f'Enrollment timed out after {timeout_sec:.0f}s.'
response.person_id = -1
return response
rclpy.spin_once(self, timeout_sec=0.1)
response.success = True
response.message = f'Enrolled "{name}" with {n_samples} samples.'
response.person_id = self._enrolling.get('result_pid', -1) if self._enrolling else -1
# Clean up (already set to None in _enrollment_frame_end on success)
return response
def _handle_list(self, request, response):
"""Handle ListPersons service: return all gallery entries."""
entries = self._gallery.get_all()
for entry in entries:
emb_msg = FaceEmbedding()
emb_msg.person_id = entry['person_id']
emb_msg.person_name = entry['name']
emb_msg.embedding = entry['embedding'].tolist()
emb_msg.sample_count = entry['samples']
secs = int(entry['enrolled_at'])
nsecs = int((entry['enrolled_at'] - secs) * 1e9)
emb_msg.enrolled_at = Time(sec=secs, nanosec=nsecs)
response.persons.append(emb_msg)
return response
def _handle_delete(self, request, response):
"""Handle DeletePerson service: remove a person from the gallery."""
if self._gallery.delete_person(request.person_id):
self._gallery.save()
self._publish_gallery_embeddings()
response.success = True
response.message = f'Deleted person {request.person_id}.'
else:
response.success = False
response.message = f'Person {request.person_id} not found.'
return response
def _handle_update(self, request, response):
"""Handle UpdatePerson service: rename a person."""
new_name = request.new_name.strip()
if not new_name:
response.success = False
response.message = 'New name cannot be empty.'
return response
if self._gallery.update_name(request.person_id, new_name):
self._gallery.save()
self._publish_gallery_embeddings()
response.success = True
response.message = f'Updated person {request.person_id} to "{new_name}".'
else:
response.success = False
response.message = f'Person {request.person_id} not found.'
return response
# -- Gallery publishing ----------------------------------------------------
def _publish_gallery_embeddings(self):
"""Publish current gallery as FaceEmbeddingArray (latched-like)."""
entries = self._gallery.get_all()
msg = FaceEmbeddingArray()
msg.header.stamp = self.get_clock().now().to_msg()
for entry in entries:
emb_msg = FaceEmbedding()
emb_msg.person_id = entry['person_id']
emb_msg.person_name = entry['name']
emb_msg.embedding = entry['embedding'].tolist()
emb_msg.sample_count = entry['samples']
secs = int(entry['enrolled_at'])
nsecs = int((entry['enrolled_at'] - secs) * 1e9)
emb_msg.enrolled_at = Time(sec=secs, nanosec=nsecs)
msg.embeddings.append(emb_msg)
self._pub_embeddings.publish(msg)
# -- Debug image -----------------------------------------------------------
def _draw_debug(self, bgr: np.ndarray, detections: list[dict],
face_msgs: list) -> np.ndarray:
"""Draw bounding boxes, landmarks, and names on the image."""
vis = bgr.copy()
for det, face_msg in zip(detections, face_msgs):
bbox = det['bbox']
x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
# Color: green if recognized, yellow if unknown
if face_msg.face_id >= 0:
color = (0, 255, 0)
label = f'{face_msg.person_name} ({face_msg.recognition_score:.2f})'
else:
color = (0, 255, 255)
label = f'unknown ({face_msg.confidence:.2f})'
cv2.rectangle(vis, (x1, y1), (x2, y2), color, 2)
cv2.putText(vis, label, (x1, y1 - 8),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
# Draw landmarks
kps = det['kps']
for k in range(5):
px, py = int(kps[k * 2]), int(kps[k * 2 + 1])
cv2.circle(vis, (px, py), 2, (0, 0, 255), -1)
return vis
# -- Entry point ---------------------------------------------------------------
def main(args=None):
"""ROS2 entry point for face_recognition node."""
rclpy.init(args=args)
node = FaceRecognitionNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,350 @@
"""
scrfd_detector.py -- SCRFD face detection with TensorRT FP16 + ONNX fallback.
SCRFD (Sample and Computation Redistribution for Face Detection) produces
9 output tensors across 3 strides (8, 16, 32), each with score, bbox, and
keypoint branches. This module handles anchor generation, bbox/keypoint
decoding, and NMS to produce final face detections.
"""
import os
import logging
from typing import Optional
import numpy as np
import cv2
logger = logging.getLogger(__name__)
_STRIDES = [8, 16, 32]
_NUM_ANCHORS = 2 # anchors per cell per stride
# -- Inference backends --------------------------------------------------------
class _TRTBackend:
"""TensorRT inference engine for SCRFD."""
def __init__(self, engine_path: str):
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit # noqa: F401
self._cuda = cuda
trt_logger = trt.Logger(trt.Logger.WARNING)
with open(engine_path, 'rb') as f, trt.Runtime(trt_logger) as runtime:
self._engine = runtime.deserialize_cuda_engine(f.read())
self._context = self._engine.create_execution_context()
self._inputs = []
self._outputs = []
self._output_names = []
self._bindings = []
for i in range(self._engine.num_io_tensors):
name = self._engine.get_tensor_name(i)
dtype = trt.nptype(self._engine.get_tensor_dtype(name))
shape = tuple(self._engine.get_tensor_shape(name))
nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize
host_mem = cuda.pagelocked_empty(shape, dtype)
device_mem = cuda.mem_alloc(nbytes)
self._bindings.append(int(device_mem))
if self._engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
self._inputs.append({'host': host_mem, 'device': device_mem})
else:
self._outputs.append({'host': host_mem, 'device': device_mem,
'shape': shape})
self._output_names.append(name)
self._stream = cuda.Stream()
def infer(self, input_data: np.ndarray) -> list[np.ndarray]:
"""Run inference and return output tensors."""
np.copyto(self._inputs[0]['host'], input_data.ravel())
self._cuda.memcpy_htod_async(
self._inputs[0]['device'], self._inputs[0]['host'], self._stream)
self._context.execute_async_v2(self._bindings, self._stream.handle)
results = []
for out in self._outputs:
self._cuda.memcpy_dtoh_async(out['host'], out['device'], self._stream)
self._stream.synchronize()
for out in self._outputs:
results.append(out['host'].reshape(out['shape']).copy())
return results
class _ONNXBackend:
"""ONNX Runtime inference (CUDA EP with CPU fallback)."""
def __init__(self, onnx_path: str):
import onnxruntime as ort
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
self._session = ort.InferenceSession(onnx_path, providers=providers)
self._input_name = self._session.get_inputs()[0].name
self._output_names = [o.name for o in self._session.get_outputs()]
def infer(self, input_data: np.ndarray) -> list[np.ndarray]:
"""Run inference and return output tensors."""
return self._session.run(None, {self._input_name: input_data})
# -- NMS ----------------------------------------------------------------------
def _nms(boxes: np.ndarray, scores: np.ndarray, iou_thresh: float) -> list[int]:
"""Non-maximum suppression. boxes: [N, 4] as x1,y1,x2,y2."""
if len(boxes) == 0:
return []
x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
areas = (x2 - x1) * (y2 - y1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(int(i))
if order.size == 1:
break
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
inter = np.maximum(0.0, xx2 - xx1) * np.maximum(0.0, yy2 - yy1)
iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-6)
remaining = np.where(iou <= iou_thresh)[0]
order = order[remaining + 1]
return keep
# -- Anchor generation ---------------------------------------------------------
def _generate_anchors(input_h: int, input_w: int) -> dict[int, np.ndarray]:
"""Generate anchor centers for each stride.
Returns dict mapping stride -> array of shape [H*W*num_anchors, 2],
where each row is (cx, cy) in input pixel coordinates.
"""
anchors = {}
for stride in _STRIDES:
feat_h = input_h // stride
feat_w = input_w // stride
grid_y, grid_x = np.mgrid[:feat_h, :feat_w]
centers = np.stack([grid_x.ravel(), grid_y.ravel()], axis=1).astype(np.float32)
centers = (centers + 0.5) * stride
# Repeat for num_anchors per cell
centers = np.tile(centers, (_NUM_ANCHORS, 1)) # [H*W*2, 2]
# Interleave properly: [anchor0_cell0, anchor1_cell0, anchor0_cell1, ...]
centers = np.repeat(
np.stack([grid_x.ravel(), grid_y.ravel()], axis=1).astype(np.float32),
_NUM_ANCHORS, axis=0
)
centers = (centers + 0.5) * stride
anchors[stride] = centers
return anchors
# -- Main detector class -------------------------------------------------------
class SCRFDDetector:
"""SCRFD face detector with TensorRT FP16 and ONNX fallback.
Args:
engine_path: Path to TensorRT engine file.
onnx_path: Path to ONNX model file (used if engine not available).
conf_thresh: Minimum confidence for detections.
nms_iou: IoU threshold for NMS.
input_size: Model input resolution (square).
"""
def __init__(
self,
engine_path: str = '',
onnx_path: str = '',
conf_thresh: float = 0.5,
nms_iou: float = 0.4,
input_size: int = 640,
):
self._conf_thresh = conf_thresh
self._nms_iou = nms_iou
self._input_size = input_size
self._backend: Optional[_TRTBackend | _ONNXBackend] = None
self._anchors = _generate_anchors(input_size, input_size)
# Try TRT first, then ONNX
if engine_path and os.path.isfile(engine_path):
try:
self._backend = _TRTBackend(engine_path)
logger.info('SCRFD TensorRT backend loaded: %s', engine_path)
return
except Exception as e:
logger.warning('SCRFD TRT load failed (%s), trying ONNX', e)
if onnx_path and os.path.isfile(onnx_path):
try:
self._backend = _ONNXBackend(onnx_path)
logger.info('SCRFD ONNX backend loaded: %s', onnx_path)
return
except Exception as e:
logger.error('SCRFD ONNX load failed: %s', e)
logger.error('No SCRFD model loaded. Detection will be unavailable.')
@property
def is_loaded(self) -> bool:
"""Return True if a backend is loaded and ready."""
return self._backend is not None
def detect(self, bgr: np.ndarray) -> list[dict]:
"""Detect faces in a BGR image.
Args:
bgr: Input image in BGR format, shape (H, W, 3).
Returns:
List of dicts with keys:
bbox: [x1, y1, x2, y2] in original image coordinates
kps: [x0,y0, x1,y1, ..., x4,y4] 10 floats, 5 landmarks
score: detection confidence
"""
if self._backend is None:
return []
orig_h, orig_w = bgr.shape[:2]
tensor, scale, pad_w, pad_h = self._preprocess(bgr)
outputs = self._backend.infer(tensor)
detections = self._decode_outputs(outputs)
detections = self._rescale(detections, scale, pad_w, pad_h, orig_w, orig_h)
return detections
def _preprocess(self, bgr: np.ndarray) -> tuple[np.ndarray, float, int, int]:
"""Resize to input_size x input_size with letterbox padding, normalize."""
h, w = bgr.shape[:2]
size = self._input_size
scale = min(size / h, size / w)
new_w, new_h = int(w * scale), int(h * scale)
pad_w = (size - new_w) // 2
pad_h = (size - new_h) // 2
resized = cv2.resize(bgr, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
canvas = np.full((size, size, 3), 0, dtype=np.uint8)
canvas[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = resized
# Normalize: subtract 127.5, divide 128.0
blob = canvas.astype(np.float32)
blob = (blob - 127.5) / 128.0
# HWC -> NCHW
blob = blob.transpose(2, 0, 1)[np.newaxis]
blob = np.ascontiguousarray(blob)
return blob, scale, pad_w, pad_h
def _decode_outputs(self, outputs: list[np.ndarray]) -> list[dict]:
"""Decode SCRFD 9-output format into face detections.
SCRFD outputs 9 tensors, 3 per stride (score, bbox, kps):
score_8, bbox_8, kps_8, score_16, bbox_16, kps_16, score_32, bbox_32, kps_32
"""
all_scores = []
all_bboxes = []
all_kps = []
for idx, stride in enumerate(_STRIDES):
score_out = outputs[idx * 3].squeeze() # [H*W*num_anchors]
bbox_out = outputs[idx * 3 + 1].squeeze() # [H*W*num_anchors, 4]
kps_out = outputs[idx * 3 + 2].squeeze() # [H*W*num_anchors, 10]
if score_out.ndim == 0:
continue
# Ensure proper shapes
if score_out.ndim == 1:
n = score_out.shape[0]
else:
n = score_out.shape[0]
score_out = score_out.ravel()
if bbox_out.ndim == 1:
bbox_out = bbox_out.reshape(-1, 4)
if kps_out.ndim == 1:
kps_out = kps_out.reshape(-1, 10)
# Filter by confidence
mask = score_out > self._conf_thresh
if not mask.any():
continue
scores = score_out[mask]
bboxes = bbox_out[mask]
kps = kps_out[mask]
anchors = self._anchors[stride]
# Trim or pad anchors to match output count
if anchors.shape[0] > n:
anchors = anchors[:n]
elif anchors.shape[0] < n:
continue
anchors = anchors[mask]
# Decode bboxes: center = anchor + pred[:2]*stride, size = exp(pred[2:])*stride
cx = anchors[:, 0] + bboxes[:, 0] * stride
cy = anchors[:, 1] + bboxes[:, 1] * stride
w = np.exp(bboxes[:, 2]) * stride
h = np.exp(bboxes[:, 3]) * stride
x1 = cx - w / 2.0
y1 = cy - h / 2.0
x2 = cx + w / 2.0
y2 = cy + h / 2.0
decoded_bboxes = np.stack([x1, y1, x2, y2], axis=1)
# Decode keypoints: kp = anchor + pred * stride
decoded_kps = np.zeros_like(kps)
for k in range(5):
decoded_kps[:, k * 2] = anchors[:, 0] + kps[:, k * 2] * stride
decoded_kps[:, k * 2 + 1] = anchors[:, 1] + kps[:, k * 2 + 1] * stride
all_scores.append(scores)
all_bboxes.append(decoded_bboxes)
all_kps.append(decoded_kps)
if not all_scores:
return []
scores = np.concatenate(all_scores)
bboxes = np.concatenate(all_bboxes)
kps = np.concatenate(all_kps)
# NMS
keep = _nms(bboxes, scores, self._nms_iou)
results = []
for i in keep:
results.append({
'bbox': bboxes[i].tolist(),
'kps': kps[i].tolist(),
'score': float(scores[i]),
})
return results
def _rescale(
self,
detections: list[dict],
scale: float,
pad_w: int,
pad_h: int,
orig_w: int,
orig_h: int,
) -> list[dict]:
"""Rescale detections from model input space to original image space."""
for det in detections:
bbox = det['bbox']
bbox[0] = max(0.0, (bbox[0] - pad_w) / scale)
bbox[1] = max(0.0, (bbox[1] - pad_h) / scale)
bbox[2] = min(float(orig_w), (bbox[2] - pad_w) / scale)
bbox[3] = min(float(orig_h), (bbox[3] - pad_h) / scale)
det['bbox'] = bbox
kps = det['kps']
for k in range(5):
kps[k * 2] = (kps[k * 2] - pad_w) / scale
kps[k * 2 + 1] = (kps[k * 2 + 1] - pad_h) / scale
det['kps'] = kps
return detections

View File

@ -0,0 +1,112 @@
#!/usr/bin/env python3
"""
build_face_trt_engines.py -- Build TensorRT FP16 engines for SCRFD and ArcFace.
Converts ONNX model files to optimized TensorRT engines with FP16 precision
for fast inference on Jetson Orin Nano Super.
Usage:
python3 build_face_trt_engines.py \
--scrfd-onnx /path/to/scrfd_2.5g_bnkps.onnx \
--arcface-onnx /path/to/arcface_r50.onnx \
--output-dir /mnt/nvme/saltybot/models \
--fp16 --workspace-mb 1024
"""
import argparse
import os
import time
def build_engine(onnx_path: str, engine_path: str, fp16: bool, workspace_mb: int):
"""Build a TensorRT engine from an ONNX model.
Args:
onnx_path: Path to the source ONNX model file.
engine_path: Output path for the serialized TensorRT engine.
fp16: Enable FP16 precision.
workspace_mb: Maximum workspace size in megabytes.
"""
import tensorrt as trt
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
print(f'Parsing ONNX model: {onnx_path}')
t0 = time.monotonic()
with open(onnx_path, 'rb') as f:
if not parser.parse(f.read()):
for i in range(parser.num_errors):
print(f' ONNX parse error: {parser.get_error(i)}')
raise RuntimeError(f'Failed to parse {onnx_path}')
parse_time = time.monotonic() - t0
print(f' Parsed in {parse_time:.1f}s')
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE,
workspace_mb * (1 << 20))
if fp16:
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
print(' FP16 enabled.')
else:
print(' Warning: FP16 not supported on this platform, using FP32.')
print(f'Building engine (this may take several minutes)...')
t0 = time.monotonic()
serialized = builder.build_serialized_network(network, config)
build_time = time.monotonic() - t0
if serialized is None:
raise RuntimeError('Engine build failed.')
os.makedirs(os.path.dirname(engine_path) or '.', exist_ok=True)
with open(engine_path, 'wb') as f:
f.write(serialized)
size_mb = os.path.getsize(engine_path) / (1 << 20)
print(f' Engine saved: {engine_path} ({size_mb:.1f} MB, built in {build_time:.1f}s)')
def main():
"""Main entry point for TRT engine building."""
parser = argparse.ArgumentParser(
description='Build TensorRT FP16 engines for SCRFD and ArcFace.')
parser.add_argument('--scrfd-onnx', type=str, default='',
help='Path to SCRFD ONNX model.')
parser.add_argument('--arcface-onnx', type=str, default='',
help='Path to ArcFace ONNX model.')
parser.add_argument('--output-dir', type=str,
default='/mnt/nvme/saltybot/models',
help='Output directory for engine files.')
parser.add_argument('--fp16', action='store_true', default=True,
help='Enable FP16 precision (default: True).')
parser.add_argument('--no-fp16', action='store_false', dest='fp16',
help='Disable FP16 (use FP32 only).')
parser.add_argument('--workspace-mb', type=int, default=1024,
help='TRT workspace size in MB (default: 1024).')
args = parser.parse_args()
if not args.scrfd_onnx and not args.arcface_onnx:
parser.error('At least one of --scrfd-onnx or --arcface-onnx is required.')
if args.scrfd_onnx:
engine_path = os.path.join(args.output_dir, 'scrfd_2.5g.engine')
print(f'\n=== Building SCRFD engine ===')
build_engine(args.scrfd_onnx, engine_path, args.fp16, args.workspace_mb)
if args.arcface_onnx:
engine_path = os.path.join(args.output_dir, 'arcface_r50.engine')
print(f'\n=== Building ArcFace engine ===')
build_engine(args.arcface_onnx, engine_path, args.fp16, args.workspace_mb)
print('\nDone.')
if __name__ == '__main__':
main()

View File

@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/saltybot_social_face
[install]
install_scripts=$base/lib/saltybot_social_face

View File

@ -0,0 +1,30 @@
"""Setup for saltybot_social_face package."""
from setuptools import find_packages, setup
package_name = 'saltybot_social_face'
setup(
name=package_name,
version='0.1.0',
packages=find_packages(exclude=['test']),
data_files=[
('share/ament_index/resource_index/packages', ['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
('share/' + package_name + '/launch', ['launch/face_recognition.launch.py']),
('share/' + package_name + '/config', ['config/face_recognition_params.yaml']),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='seb',
maintainer_email='seb@vayrette.com',
description='SCRFD face detection and ArcFace recognition for SaltyBot social interactions',
license='MIT',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'face_recognition = saltybot_social_face.face_recognition_node:main',
'enrollment_cli = saltybot_social_face.enrollment_cli:main',
],
},
)

View File

@ -0,0 +1,24 @@
cmake_minimum_required(VERSION 3.8)
project(saltybot_social_msgs)
find_package(ament_cmake REQUIRED)
find_package(rosidl_default_generators REQUIRED)
find_package(std_msgs REQUIRED)
find_package(geometry_msgs REQUIRED)
find_package(builtin_interfaces REQUIRED)
rosidl_generate_interfaces(${PROJECT_NAME}
"msg/FaceDetection.msg"
"msg/FaceDetectionArray.msg"
"msg/FaceEmbedding.msg"
"msg/FaceEmbeddingArray.msg"
"msg/PersonState.msg"
"msg/PersonStateArray.msg"
"srv/EnrollPerson.srv"
"srv/ListPersons.srv"
"srv/DeletePerson.srv"
"srv/UpdatePerson.srv"
DEPENDENCIES std_msgs geometry_msgs builtin_interfaces
)
ament_package()

View File

@ -0,0 +1,10 @@
std_msgs/Header header
int32 face_id
string person_name
float32 confidence
float32 recognition_score
float32 bbox_x
float32 bbox_y
float32 bbox_w
float32 bbox_h
float32[10] landmarks

View File

@ -0,0 +1,2 @@
std_msgs/Header header
saltybot_social_msgs/FaceDetection[] faces

View File

@ -0,0 +1,5 @@
int32 person_id
string person_name
float32[] embedding
builtin_interfaces/Time enrolled_at
int32 sample_count

View File

@ -0,0 +1,2 @@
std_msgs/Header header
saltybot_social_msgs/FaceEmbedding[] embeddings

View File

@ -0,0 +1,19 @@
std_msgs/Header header
int32 person_id
string person_name
int32 face_id
string speaker_id
string uwb_anchor_id
geometry_msgs/Point position
float32 distance
float32 bearing_deg
uint8 state
uint8 STATE_UNKNOWN=0
uint8 STATE_APPROACHING=1
uint8 STATE_ENGAGED=2
uint8 STATE_TALKING=3
uint8 STATE_LEAVING=4
uint8 STATE_ABSENT=5
float32 engagement_score
builtin_interfaces/Time last_seen
int32 camera_id

View File

@ -0,0 +1,3 @@
std_msgs/Header header
saltybot_social_msgs/PersonState[] persons
int32 primary_attention_id

View File

@ -0,0 +1,19 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social_msgs</name>
<version>0.1.0</version>
<description>Custom ROS2 messages and services for saltybot social capabilities</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<buildtool_depend>ament_cmake</buildtool_depend>
<depend>std_msgs</depend>
<depend>geometry_msgs</depend>
<depend>builtin_interfaces</depend>
<build_depend>rosidl_default_generators</build_depend>
<exec_depend>rosidl_default_runtime</exec_depend>
<member_of_group>rosidl_interface_packages</member_of_group>
<export>
<build_type>ament_cmake</build_type>
</export>
</package>

View File

@ -0,0 +1,4 @@
int32 person_id
---
bool success
string message

View File

@ -0,0 +1,7 @@
string name
string mode
int32 n_samples
---
bool success
string message
int32 person_id

View File

@ -0,0 +1,2 @@
---
saltybot_social_msgs/FaceEmbedding[] persons

View File

@ -0,0 +1,5 @@
int32 person_id
string new_name
---
bool success
string message

View File

@ -0,0 +1,22 @@
social_nav:
ros__parameters:
follow_mode: 'shadow'
follow_distance: 1.2
lead_distance: 2.0
orbit_radius: 1.5
max_linear_speed: 1.0
max_linear_speed_fast: 5.5
max_angular_speed: 1.0
goal_tolerance: 0.3
routes_dir: '/mnt/nvme/saltybot/routes'
home_x: 0.0
home_y: 0.0
map_resolution: 0.05
obstacle_inflation_cells: 3
midas_depth:
ros__parameters:
onnx_path: '/mnt/nvme/saltybot/models/midas_small.onnx'
engine_path: '/mnt/nvme/saltybot/models/midas_small.engine'
process_rate: 5.0
output_scale: 1.0

View File

@ -0,0 +1,57 @@
"""Launch file for saltybot social navigation."""
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
return LaunchDescription([
DeclareLaunchArgument('follow_mode', default_value='shadow',
description='Follow mode: shadow/lead/side/orbit/loose/tight'),
DeclareLaunchArgument('follow_distance', default_value='1.2',
description='Follow distance in meters'),
DeclareLaunchArgument('max_linear_speed', default_value='1.0',
description='Max linear speed (m/s)'),
DeclareLaunchArgument('routes_dir',
default_value='/mnt/nvme/saltybot/routes',
description='Directory for saved routes'),
Node(
package='saltybot_social_nav',
executable='social_nav',
name='social_nav',
output='screen',
parameters=[{
'follow_mode': LaunchConfiguration('follow_mode'),
'follow_distance': LaunchConfiguration('follow_distance'),
'max_linear_speed': LaunchConfiguration('max_linear_speed'),
'routes_dir': LaunchConfiguration('routes_dir'),
}],
),
Node(
package='saltybot_social_nav',
executable='midas_depth',
name='midas_depth',
output='screen',
parameters=[{
'onnx_path': '/mnt/nvme/saltybot/models/midas_small.onnx',
'engine_path': '/mnt/nvme/saltybot/models/midas_small.engine',
'process_rate': 5.0,
'output_scale': 1.0,
}],
),
Node(
package='saltybot_social_nav',
executable='waypoint_teacher',
name='waypoint_teacher',
output='screen',
parameters=[{
'routes_dir': LaunchConfiguration('routes_dir'),
'recording_interval': 0.5,
}],
),
])

View File

@ -0,0 +1,28 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social_nav</name>
<version>0.1.0</version>
<description>Social navigation for saltybot: follow modes, waypoint teaching, A* avoidance, MiDaS depth</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>std_msgs</depend>
<depend>geometry_msgs</depend>
<depend>nav_msgs</depend>
<depend>sensor_msgs</depend>
<depend>cv_bridge</depend>
<depend>tf2_ros</depend>
<depend>tf2_geometry_msgs</depend>
<depend>saltybot_social_msgs</depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,82 @@
"""astar.py -- A* path planner for saltybot social navigation."""
import heapq
import numpy as np
def astar(grid: np.ndarray, start: tuple, goal: tuple,
obstacle_val: int = 100) -> list | None:
"""
A* on a 2D occupancy grid (row, col indexing).
Args:
grid: 2D numpy array, values 0=free, >=obstacle_val=obstacle
start: (row, col) start cell
goal: (row, col) goal cell
obstacle_val: cells with value >= obstacle_val are blocked
Returns:
List of (row, col) tuples from start to goal, or None if no path.
"""
rows, cols = grid.shape
def h(a, b):
return abs(a[0] - b[0]) + abs(a[1] - b[1]) # Manhattan heuristic
open_set = []
heapq.heappush(open_set, (h(start, goal), 0, start))
came_from = {}
g_score = {start: 0}
# 8-directional movement
neighbors_delta = [
(-1, -1), (-1, 0), (-1, 1),
(0, -1), (0, 1),
(1, -1), (1, 0), (1, 1),
]
while open_set:
_, cost, current = heapq.heappop(open_set)
if current == goal:
path = []
while current in came_from:
path.append(current)
current = came_from[current]
path.append(start)
return list(reversed(path))
if cost > g_score.get(current, float('inf')):
continue
for dr, dc in neighbors_delta:
nr, nc = current[0] + dr, current[1] + dc
if not (0 <= nr < rows and 0 <= nc < cols):
continue
if grid[nr, nc] >= obstacle_val:
continue
move_cost = 1.414 if (dr != 0 and dc != 0) else 1.0
new_g = g_score[current] + move_cost
neighbor = (nr, nc)
if new_g < g_score.get(neighbor, float('inf')):
g_score[neighbor] = new_g
f = new_g + h(neighbor, goal)
came_from[neighbor] = current
heapq.heappush(open_set, (f, new_g, neighbor))
return None # No path found
def inflate_obstacles(grid: np.ndarray, inflation_radius_cells: int) -> np.ndarray:
"""Inflate obstacles for robot footprint safety."""
from scipy.ndimage import binary_dilation
obstacle_mask = grid >= 50
kernel = np.ones(
(2 * inflation_radius_cells + 1, 2 * inflation_radius_cells + 1),
dtype=bool,
)
inflated = binary_dilation(obstacle_mask, structure=kernel)
result = grid.copy()
result[inflated] = 100
return result

View File

@ -0,0 +1,82 @@
"""follow_modes.py -- Follow mode geometry for saltybot social navigation."""
import math
from enum import Enum
import numpy as np
class FollowMode(Enum):
SHADOW = 'shadow' # stay directly behind at follow_distance
LEAD = 'lead' # move ahead of person by lead_distance
SIDE = 'side' # stay to the right (or left) at side_offset
ORBIT = 'orbit' # circle around person at orbit_radius
LOOSE = 'loose' # general follow, larger tolerance
TIGHT = 'tight' # close follow, small tolerance
def compute_shadow_target(person_pos, person_bearing_deg, follow_dist=1.2):
"""Target position: behind person along their movement direction."""
bearing_rad = math.radians(person_bearing_deg + 180.0)
tx = person_pos[0] + follow_dist * math.sin(bearing_rad)
ty = person_pos[1] + follow_dist * math.cos(bearing_rad)
return (tx, ty, person_pos[2])
def compute_lead_target(person_pos, person_bearing_deg, lead_dist=2.0):
"""Target position: ahead of person."""
bearing_rad = math.radians(person_bearing_deg)
tx = person_pos[0] + lead_dist * math.sin(bearing_rad)
ty = person_pos[1] + lead_dist * math.cos(bearing_rad)
return (tx, ty, person_pos[2])
def compute_side_target(person_pos, person_bearing_deg, side_dist=1.0, right=True):
"""Target position: to the right (or left) of person."""
sign = 1.0 if right else -1.0
bearing_rad = math.radians(person_bearing_deg + sign * 90.0)
tx = person_pos[0] + side_dist * math.sin(bearing_rad)
ty = person_pos[1] + side_dist * math.cos(bearing_rad)
return (tx, ty, person_pos[2])
def compute_orbit_target(person_pos, orbit_angle_deg, orbit_radius=1.5):
"""Target on circle of radius orbit_radius around person."""
angle_rad = math.radians(orbit_angle_deg)
tx = person_pos[0] + orbit_radius * math.sin(angle_rad)
ty = person_pos[1] + orbit_radius * math.cos(angle_rad)
return (tx, ty, person_pos[2])
def compute_loose_target(person_pos, robot_pos, follow_dist=2.0, tolerance=0.8):
"""Only move if farther than follow_dist + tolerance."""
dx = person_pos[0] - robot_pos[0]
dy = person_pos[1] - robot_pos[1]
dist = math.hypot(dx, dy)
if dist <= follow_dist + tolerance:
return robot_pos
# Target at follow_dist behind person (toward robot)
scale = (dist - follow_dist) / dist
return (robot_pos[0] + dx * scale, robot_pos[1] + dy * scale, person_pos[2])
def compute_tight_target(person_pos, follow_dist=0.6):
"""Close follow: stay very near person."""
return (person_pos[0], person_pos[1] - follow_dist, person_pos[2])
MODE_VOICE_COMMANDS = {
'shadow': FollowMode.SHADOW,
'follow me': FollowMode.SHADOW,
'behind me': FollowMode.SHADOW,
'lead': FollowMode.LEAD,
'go ahead': FollowMode.LEAD,
'lead me': FollowMode.LEAD,
'side': FollowMode.SIDE,
'stay beside': FollowMode.SIDE,
'orbit': FollowMode.ORBIT,
'circle me': FollowMode.ORBIT,
'loose': FollowMode.LOOSE,
'give me space': FollowMode.LOOSE,
'tight': FollowMode.TIGHT,
'stay close': FollowMode.TIGHT,
}

View File

@ -0,0 +1,231 @@
"""
midas_depth_node.py -- MiDaS monocular depth estimation for saltybot.
Uses MiDaS_small via ONNX Runtime or TensorRT FP16.
Provides relative depth estimates for cameras without active depth (CSI cameras).
Publishes /social/depth/cam{i}/image (sensor_msgs/Image, float32, relative depth)
"""
import os
import numpy as np
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from sensor_msgs.msg import Image
from cv_bridge import CvBridge
# MiDaS_small input size
_MIDAS_H = 256
_MIDAS_W = 256
# ImageNet normalization
_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
class _TRTBackend:
"""TensorRT inference backend for MiDaS."""
def __init__(self, engine_path: str, logger):
self._logger = logger
try:
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit # noqa: F401
self._cuda = cuda
rt_logger = trt.Logger(trt.Logger.WARNING)
with open(engine_path, 'rb') as f:
engine = trt.Runtime(rt_logger).deserialize_cuda_engine(f.read())
self._context = engine.create_execution_context()
# Allocate buffers
self._d_input = cuda.mem_alloc(1 * 3 * _MIDAS_H * _MIDAS_W * 4)
self._d_output = cuda.mem_alloc(1 * _MIDAS_H * _MIDAS_W * 4)
self._h_output = np.empty((_MIDAS_H, _MIDAS_W), dtype=np.float32)
self._stream = cuda.Stream()
self._logger.info(f'TRT engine loaded: {engine_path}')
except Exception as e:
raise RuntimeError(f'TRT init failed: {e}')
def infer(self, input_tensor: np.ndarray) -> np.ndarray:
self._cuda.memcpy_htod_async(
self._d_input, input_tensor.ravel(), self._stream)
self._context.execute_async_v2(
bindings=[int(self._d_input), int(self._d_output)],
stream_handle=self._stream.handle)
self._cuda.memcpy_dtoh_async(
self._h_output, self._d_output, self._stream)
self._stream.synchronize()
return self._h_output.copy()
class _ONNXBackend:
"""ONNX Runtime inference backend for MiDaS."""
def __init__(self, onnx_path: str, logger):
self._logger = logger
try:
import onnxruntime as ort
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
self._session = ort.InferenceSession(onnx_path, providers=providers)
self._input_name = self._session.get_inputs()[0].name
self._logger.info(f'ONNX model loaded: {onnx_path}')
except Exception as e:
raise RuntimeError(f'ONNX init failed: {e}')
def infer(self, input_tensor: np.ndarray) -> np.ndarray:
result = self._session.run(None, {self._input_name: input_tensor})
return result[0].squeeze()
class MiDaSDepthNode(Node):
"""MiDaS monocular depth estimation node."""
def __init__(self):
super().__init__('midas_depth')
# Parameters
self.declare_parameter('onnx_path',
'/mnt/nvme/saltybot/models/midas_small.onnx')
self.declare_parameter('engine_path',
'/mnt/nvme/saltybot/models/midas_small.engine')
self.declare_parameter('camera_topics', [
'/surround/cam0/image_raw',
'/surround/cam1/image_raw',
'/surround/cam2/image_raw',
'/surround/cam3/image_raw',
])
self.declare_parameter('output_scale', 1.0)
self.declare_parameter('process_rate', 5.0)
onnx_path = self.get_parameter('onnx_path').value
engine_path = self.get_parameter('engine_path').value
self._camera_topics = self.get_parameter('camera_topics').value
self._output_scale = self.get_parameter('output_scale').value
process_rate = self.get_parameter('process_rate').value
# Initialize inference backend (TRT preferred, ONNX fallback)
self._backend = None
if os.path.exists(engine_path):
try:
self._backend = _TRTBackend(engine_path, self.get_logger())
except RuntimeError:
self.get_logger().warn('TRT failed, trying ONNX fallback')
if self._backend is None and os.path.exists(onnx_path):
try:
self._backend = _ONNXBackend(onnx_path, self.get_logger())
except RuntimeError:
self.get_logger().error('Both TRT and ONNX backends failed')
if self._backend is None:
self.get_logger().error(
'No MiDaS model found. Depth estimation disabled.')
self._bridge = CvBridge()
# Latest frames per camera (round-robin processing)
self._latest_frames = [None] * len(self._camera_topics)
self._current_cam_idx = 0
# QoS for camera subscriptions
cam_qos = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=1,
)
# Subscribe to each camera topic
self._cam_subs = []
for i, topic in enumerate(self._camera_topics):
sub = self.create_subscription(
Image, topic,
lambda msg, idx=i: self._on_image(msg, idx),
cam_qos)
self._cam_subs.append(sub)
# Publishers: one per camera
self._depth_pubs = []
for i in range(len(self._camera_topics)):
pub = self.create_publisher(
Image, f'/social/depth/cam{i}/image', 10)
self._depth_pubs.append(pub)
# Timer: round-robin across cameras
timer_period = 1.0 / process_rate
self._timer = self.create_timer(timer_period, self._timer_callback)
self.get_logger().info(
f'MiDaS depth node started: {len(self._camera_topics)} cameras '
f'@ {process_rate} Hz')
def _on_image(self, msg: Image, cam_idx: int):
"""Cache latest frame for each camera."""
self._latest_frames[cam_idx] = msg
def _preprocess(self, bgr: np.ndarray) -> np.ndarray:
"""Preprocess BGR image to MiDaS input tensor [1,3,256,256]."""
import cv2
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
resized = cv2.resize(rgb, (_MIDAS_W, _MIDAS_H),
interpolation=cv2.INTER_LINEAR)
normalized = (resized.astype(np.float32) / 255.0 - _MEAN) / _STD
# HWC -> CHW -> NCHW
tensor = normalized.transpose(2, 0, 1)[np.newaxis, ...]
return tensor.astype(np.float32)
def _infer(self, tensor: np.ndarray) -> np.ndarray:
"""Run inference, returns [256,256] float32 relative inverse depth."""
if self._backend is None:
return np.zeros((_MIDAS_H, _MIDAS_W), dtype=np.float32)
return self._backend.infer(tensor)
def _postprocess(self, raw: np.ndarray, orig_shape: tuple) -> np.ndarray:
"""Resize depth back to original image shape, apply output_scale."""
import cv2
h, w = orig_shape[:2]
depth = cv2.resize(raw, (w, h), interpolation=cv2.INTER_LINEAR)
depth = depth * self._output_scale
return depth
def _timer_callback(self):
"""Process one camera per tick (round-robin)."""
if not self._camera_topics:
return
idx = self._current_cam_idx
self._current_cam_idx = (idx + 1) % len(self._camera_topics)
msg = self._latest_frames[idx]
if msg is None:
return
try:
bgr = self._bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')
except Exception as e:
self.get_logger().warn(f'cv_bridge error cam{idx}: {e}')
return
tensor = self._preprocess(bgr)
raw_depth = self._infer(tensor)
depth_map = self._postprocess(raw_depth, bgr.shape)
# Publish as float32 Image
depth_msg = self._bridge.cv2_to_imgmsg(depth_map, encoding='32FC1')
depth_msg.header = msg.header
self._depth_pubs[idx].publish(depth_msg)
def main(args=None):
rclpy.init(args=args)
node = MiDaSDepthNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,584 @@
"""
social_nav_node.py -- Social navigation node for saltybot.
Orchestrates person following with multiple modes, voice commands,
waypoint teaching/replay, and A* obstacle avoidance.
Follow modes:
shadow -- stay directly behind at follow_distance
lead -- move ahead of person
side -- stay beside (default right)
orbit -- circle around person
loose -- relaxed follow with deadband
tight -- close follow
Waypoint teaching:
Voice command "teach route <name>" -> record mode ON
Voice command "stop teaching" -> save route
Voice command "replay route <name>" -> playback
Voice commands:
"follow me" / "shadow" -> SHADOW mode
"lead me" / "go ahead" -> LEAD mode
"stay beside" -> SIDE mode
"orbit" -> ORBIT mode
"give me space" -> LOOSE mode
"stay close" -> TIGHT mode
"stop" / "halt" -> STOP
"go home" -> navigate to home waypoint
"teach route <name>" -> start recording
"stop teaching" -> finish recording
"replay route <name>" -> playback recorded route
"""
import math
import time
import re
from collections import deque
import numpy as np
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from geometry_msgs.msg import Twist, PoseStamped
from nav_msgs.msg import OccupancyGrid, Odometry
from std_msgs.msg import String, Int32
from .follow_modes import (
FollowMode, MODE_VOICE_COMMANDS,
compute_shadow_target, compute_lead_target, compute_side_target,
compute_orbit_target, compute_loose_target, compute_tight_target,
)
from .astar import astar, inflate_obstacles
from .waypoint_teacher import WaypointRoute, WaypointReplayer
# Try importing social msgs; fallback gracefully
try:
from saltybot_social_msgs.msg import PersonStateArray
_HAS_SOCIAL_MSGS = True
except ImportError:
_HAS_SOCIAL_MSGS = False
# Proportional controller gains
_K_ANG = 2.0 # angular gain
_K_LIN = 0.8 # linear gain
_HIGH_SPEED_THRESHOLD = 3.0 # m/s person velocity triggers fast mode
_PREDICT_AHEAD_S = 0.3 # seconds to extrapolate position
_TEACH_MIN_DIST = 0.5 # meters between recorded waypoints
class SocialNavNode(Node):
"""Main social navigation orchestrator."""
def __init__(self):
super().__init__('social_nav')
# -- Parameters --
self.declare_parameter('follow_mode', 'shadow')
self.declare_parameter('follow_distance', 1.2)
self.declare_parameter('lead_distance', 2.0)
self.declare_parameter('orbit_radius', 1.5)
self.declare_parameter('max_linear_speed', 1.0)
self.declare_parameter('max_linear_speed_fast', 5.5)
self.declare_parameter('max_angular_speed', 1.0)
self.declare_parameter('goal_tolerance', 0.3)
self.declare_parameter('routes_dir', '/mnt/nvme/saltybot/routes')
self.declare_parameter('home_x', 0.0)
self.declare_parameter('home_y', 0.0)
self.declare_parameter('map_resolution', 0.05)
self.declare_parameter('obstacle_inflation_cells', 3)
self._follow_mode = FollowMode(
self.get_parameter('follow_mode').value)
self._follow_distance = self.get_parameter('follow_distance').value
self._lead_distance = self.get_parameter('lead_distance').value
self._orbit_radius = self.get_parameter('orbit_radius').value
self._max_lin = self.get_parameter('max_linear_speed').value
self._max_lin_fast = self.get_parameter('max_linear_speed_fast').value
self._max_ang = self.get_parameter('max_angular_speed').value
self._goal_tol = self.get_parameter('goal_tolerance').value
self._routes_dir = self.get_parameter('routes_dir').value
self._home_x = self.get_parameter('home_x').value
self._home_y = self.get_parameter('home_y').value
self._map_resolution = self.get_parameter('map_resolution').value
self._inflation_cells = self.get_parameter(
'obstacle_inflation_cells').value
# -- State --
self._robot_x = 0.0
self._robot_y = 0.0
self._robot_yaw = 0.0
self._target_person_pos = None # (x, y, z)
self._target_person_bearing = 0.0
self._target_person_id = -1
self._person_history = deque(maxlen=5) # for velocity estimation
self._stopped = False
self._go_home = False
# Occupancy grid for A*
self._occ_grid = None
self._occ_origin = (0.0, 0.0)
self._occ_resolution = 0.05
self._astar_path = None
# Orbit state
self._orbit_angle = 0.0
# Waypoint teaching / replay
self._teaching = False
self._current_route = None
self._last_teach_x = None
self._last_teach_y = None
self._replayer = None
# -- QoS profiles --
best_effort_qos = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST, depth=1)
reliable_qos = QoSProfile(
reliability=ReliabilityPolicy.RELIABLE,
history=HistoryPolicy.KEEP_LAST, depth=1)
# -- Subscriptions --
if _HAS_SOCIAL_MSGS:
self.create_subscription(
PersonStateArray, '/social/persons',
self._on_persons, best_effort_qos)
else:
self.get_logger().warn(
'saltybot_social_msgs not found; '
'using /person/target fallback')
self.create_subscription(
PoseStamped, '/person/target',
self._on_person_target, best_effort_qos)
self.create_subscription(
String, '/social/speech/command',
self._on_voice_command, 10)
self.create_subscription(
String, '/social/speech/transcript',
self._on_transcript, 10)
self.create_subscription(
OccupancyGrid, '/map',
self._on_map, reliable_qos)
self.create_subscription(
Odometry, '/odom',
self._on_odom, best_effort_qos)
self.create_subscription(
Int32, '/social/attention/target_id',
self._on_target_id, 10)
# -- Publishers --
self._cmd_vel_pub = self.create_publisher(
Twist, '/cmd_vel', best_effort_qos)
self._mode_pub = self.create_publisher(
String, '/social/nav/mode', reliable_qos)
self._target_pub = self.create_publisher(
PoseStamped, '/social/nav/target_pos', 10)
self._status_pub = self.create_publisher(
String, '/social/nav/status', best_effort_qos)
# -- Main loop timer (20 Hz) --
self._timer = self.create_timer(0.05, self._control_loop)
self.get_logger().info(
f'Social nav started: mode={self._follow_mode.value}, '
f'dist={self._follow_distance}m')
# ----------------------------------------------------------------
# Subscriptions
# ----------------------------------------------------------------
def _on_persons(self, msg):
"""Handle PersonStateArray from social perception."""
target_id = msg.primary_attention_id
if self._target_person_id >= 0:
target_id = self._target_person_id
for p in msg.persons:
if p.person_id == target_id or (
target_id < 0 and len(msg.persons) > 0):
pos = (p.position.x, p.position.y, p.position.z)
self._update_person_position(pos, p.bearing_deg)
break
def _on_person_target(self, msg: PoseStamped):
"""Fallback: single person target pose."""
pos = (msg.pose.position.x, msg.pose.position.y,
msg.pose.position.z)
# Estimate bearing from quaternion yaw
q = msg.pose.orientation
yaw = math.atan2(2.0 * (q.w * q.z + q.x * q.y),
1.0 - 2.0 * (q.y * q.y + q.z * q.z))
self._update_person_position(pos, math.degrees(yaw))
def _update_person_position(self, pos, bearing_deg):
"""Update person tracking state and record history."""
now = time.time()
self._target_person_pos = pos
self._target_person_bearing = bearing_deg
self._person_history.append((now, pos[0], pos[1]))
def _on_odom(self, msg: Odometry):
"""Update robot pose from odometry."""
self._robot_x = msg.pose.pose.position.x
self._robot_y = msg.pose.pose.position.y
q = msg.pose.pose.orientation
self._robot_yaw = math.atan2(
2.0 * (q.w * q.z + q.x * q.y),
1.0 - 2.0 * (q.y * q.y + q.z * q.z))
def _on_map(self, msg: OccupancyGrid):
"""Cache occupancy grid for A* planning."""
w, h = msg.info.width, msg.info.height
data = np.array(msg.data, dtype=np.int8).reshape((h, w))
# Convert -1 (unknown) to free (0) for planning
data[data < 0] = 0
self._occ_grid = data.astype(np.int32)
self._occ_origin = (msg.info.origin.position.x,
msg.info.origin.position.y)
self._occ_resolution = msg.info.resolution
def _on_target_id(self, msg: Int32):
"""Switch target person."""
self._target_person_id = msg.data
self.get_logger().info(f'Target person ID set to {msg.data}')
def _on_voice_command(self, msg: String):
"""Handle discrete voice commands for mode switching."""
cmd = msg.data.strip().lower()
if cmd in ('stop', 'halt'):
self._stopped = True
self._replayer = None
self._publish_status('STOPPED')
return
if cmd in ('resume', 'go', 'start'):
self._stopped = False
self._publish_status('RESUMED')
return
matched = MODE_VOICE_COMMANDS.get(cmd)
if matched:
self._follow_mode = matched
self._stopped = False
mode_msg = String()
mode_msg.data = self._follow_mode.value
self._mode_pub.publish(mode_msg)
self._publish_status(f'MODE: {self._follow_mode.value}')
def _on_transcript(self, msg: String):
"""Handle free-form voice transcripts for route teaching."""
text = msg.data.strip().lower()
# "teach route <name>"
m = re.match(r'teach\s+route\s+(\w+)', text)
if m:
name = m.group(1)
self._teaching = True
self._current_route = WaypointRoute(name)
self._last_teach_x = self._robot_x
self._last_teach_y = self._robot_y
self._publish_status(f'TEACHING: {name}')
self.get_logger().info(f'Recording route: {name}')
return
# "stop teaching"
if 'stop teaching' in text:
if self._teaching and self._current_route:
self._current_route.save(self._routes_dir)
self._publish_status(
f'SAVED: {self._current_route.name} '
f'({len(self._current_route.waypoints)} pts)')
self.get_logger().info(
f'Route saved: {self._current_route.name}')
self._teaching = False
self._current_route = None
return
# "replay route <name>"
m = re.match(r'replay\s+route\s+(\w+)', text)
if m:
name = m.group(1)
try:
route = WaypointRoute.load(self._routes_dir, name)
self._replayer = WaypointReplayer(route)
self._stopped = False
self._publish_status(f'REPLAY: {name}')
self.get_logger().info(f'Replaying route: {name}')
except FileNotFoundError:
self._publish_status(f'ROUTE NOT FOUND: {name}')
return
# "go home"
if 'go home' in text:
self._go_home = True
self._stopped = False
self._publish_status('GO HOME')
return
# Also try mode commands from transcript
for phrase, mode in MODE_VOICE_COMMANDS.items():
if phrase in text:
self._follow_mode = mode
self._stopped = False
mode_msg = String()
mode_msg.data = self._follow_mode.value
self._mode_pub.publish(mode_msg)
self._publish_status(f'MODE: {self._follow_mode.value}')
return
# ----------------------------------------------------------------
# Control loop
# ----------------------------------------------------------------
def _control_loop(self):
"""Main 20Hz control loop."""
# Record waypoint if teaching
if self._teaching and self._current_route:
self._maybe_record_waypoint()
# Publish zero velocity if stopped
if self._stopped:
self._publish_cmd_vel(0.0, 0.0)
return
# Determine navigation target
target = self._get_nav_target()
if target is None:
self._publish_cmd_vel(0.0, 0.0)
return
tx, ty, tz = target
# Publish debug target
self._publish_target_pose(tx, ty, tz)
# Check if arrived
dist_to_target = math.hypot(tx - self._robot_x, ty - self._robot_y)
if dist_to_target < self._goal_tol:
self._publish_cmd_vel(0.0, 0.0)
if self._go_home:
self._go_home = False
self._publish_status('HOME REACHED')
return
# Try A* path if map available
if self._occ_grid is not None:
path_target = self._plan_astar(tx, ty)
if path_target:
tx, ty = path_target
# Determine speed limit
max_lin = self._max_lin
person_vel = self._estimate_person_velocity()
if person_vel > _HIGH_SPEED_THRESHOLD:
max_lin = self._max_lin_fast
# Compute and publish cmd_vel
lin, ang = self._compute_cmd_vel(
self._robot_x, self._robot_y, self._robot_yaw,
tx, ty, max_lin)
self._publish_cmd_vel(lin, ang)
def _get_nav_target(self):
"""Determine current navigation target based on mode/state."""
# Route replay takes priority
if self._replayer and not self._replayer.is_done:
self._replayer.check_arrived(self._robot_x, self._robot_y)
wp = self._replayer.current_waypoint()
if wp:
return (wp.x, wp.y, wp.z)
else:
self._replayer = None
self._publish_status('REPLAY DONE')
return None
# Go home
if self._go_home:
return (self._home_x, self._home_y, 0.0)
# Person following
if self._target_person_pos is None:
return None
# Predict person position ahead for high-speed tracking
px, py, pz = self._predict_person_position()
bearing = self._target_person_bearing
robot_pos = (self._robot_x, self._robot_y, 0.0)
if self._follow_mode == FollowMode.SHADOW:
return compute_shadow_target(
(px, py, pz), bearing, self._follow_distance)
elif self._follow_mode == FollowMode.LEAD:
return compute_lead_target(
(px, py, pz), bearing, self._lead_distance)
elif self._follow_mode == FollowMode.SIDE:
return compute_side_target(
(px, py, pz), bearing, self._follow_distance)
elif self._follow_mode == FollowMode.ORBIT:
self._orbit_angle = (self._orbit_angle + 1.0) % 360.0
return compute_orbit_target(
(px, py, pz), self._orbit_angle, self._orbit_radius)
elif self._follow_mode == FollowMode.LOOSE:
return compute_loose_target(
(px, py, pz), robot_pos, self._follow_distance)
elif self._follow_mode == FollowMode.TIGHT:
return compute_tight_target(
(px, py, pz), self._follow_distance)
return (px, py, pz)
def _predict_person_position(self):
"""Extrapolate person position using velocity from recent samples."""
if self._target_person_pos is None:
return (0.0, 0.0, 0.0)
px, py, pz = self._target_person_pos
if len(self._person_history) >= 3:
# Use last 3 samples for velocity estimation
t0, x0, y0 = self._person_history[-3]
t1, x1, y1 = self._person_history[-1]
dt = t1 - t0
if dt > 0.01:
vx = (x1 - x0) / dt
vy = (y1 - y0) / dt
speed = math.hypot(vx, vy)
if speed > _HIGH_SPEED_THRESHOLD:
px += vx * _PREDICT_AHEAD_S
py += vy * _PREDICT_AHEAD_S
return (px, py, pz)
def _estimate_person_velocity(self) -> float:
"""Estimate person speed from recent position history."""
if len(self._person_history) < 2:
return 0.0
t0, x0, y0 = self._person_history[-2]
t1, x1, y1 = self._person_history[-1]
dt = t1 - t0
if dt < 0.01:
return 0.0
return math.hypot(x1 - x0, y1 - y0) / dt
def _plan_astar(self, target_x, target_y):
"""Run A* on occupancy grid, return next waypoint in world coords."""
grid = self._occ_grid
res = self._occ_resolution
ox, oy = self._occ_origin
# World to grid
def w2g(wx, wy):
return (int((wy - oy) / res), int((wx - ox) / res))
start = w2g(self._robot_x, self._robot_y)
goal = w2g(target_x, target_y)
rows, cols = grid.shape
if not (0 <= start[0] < rows and 0 <= start[1] < cols):
return None
if not (0 <= goal[0] < rows and 0 <= goal[1] < cols):
return None
inflated = inflate_obstacles(grid, self._inflation_cells)
path = astar(inflated, start, goal)
if path and len(path) > 1:
# Follow a lookahead point (3 steps ahead or end)
lookahead_idx = min(3, len(path) - 1)
r, c = path[lookahead_idx]
wx = ox + c * res + res / 2.0
wy = oy + r * res + res / 2.0
self._astar_path = path
return (wx, wy)
return None
def _compute_cmd_vel(self, rx, ry, ryaw, tx, ty, max_lin):
"""Proportional controller: compute linear and angular velocity."""
dx = tx - rx
dy = ty - ry
dist = math.hypot(dx, dy)
angle_to_target = math.atan2(dy, dx)
angle_error = angle_to_target - ryaw
# Normalize angle error to [-pi, pi]
while angle_error > math.pi:
angle_error -= 2.0 * math.pi
while angle_error < -math.pi:
angle_error += 2.0 * math.pi
angular_vel = _K_ANG * angle_error
angular_vel = max(-self._max_ang,
min(self._max_ang, angular_vel))
# Reduce linear speed when turning hard
angle_factor = max(0.0, 1.0 - abs(angle_error) / (math.pi / 2.0))
linear_vel = _K_LIN * dist * angle_factor
linear_vel = max(0.0, min(max_lin, linear_vel))
return (linear_vel, angular_vel)
# ----------------------------------------------------------------
# Waypoint teaching
# ----------------------------------------------------------------
def _maybe_record_waypoint(self):
"""Record waypoint if robot moved > _TEACH_MIN_DIST."""
if self._last_teach_x is None:
self._last_teach_x = self._robot_x
self._last_teach_y = self._robot_y
dist = math.hypot(
self._robot_x - self._last_teach_x,
self._robot_y - self._last_teach_y)
if dist >= _TEACH_MIN_DIST:
yaw_deg = math.degrees(self._robot_yaw)
self._current_route.add(
self._robot_x, self._robot_y, 0.0, yaw_deg)
self._last_teach_x = self._robot_x
self._last_teach_y = self._robot_y
# ----------------------------------------------------------------
# Publishers
# ----------------------------------------------------------------
def _publish_cmd_vel(self, linear: float, angular: float):
twist = Twist()
twist.linear.x = linear
twist.angular.z = angular
self._cmd_vel_pub.publish(twist)
def _publish_target_pose(self, x, y, z):
msg = PoseStamped()
msg.header.stamp = self.get_clock().now().to_msg()
msg.header.frame_id = 'map'
msg.pose.position.x = x
msg.pose.position.y = y
msg.pose.position.z = z
self._target_pub.publish(msg)
def _publish_status(self, status: str):
msg = String()
msg.data = status
self._status_pub.publish(msg)
self.get_logger().info(f'Nav status: {status}')
def main(args=None):
rclpy.init(args=args)
node = SocialNavNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,91 @@
"""waypoint_teacher.py -- Record and replay waypoint routes."""
import json
import time
import math
from pathlib import Path
from dataclasses import dataclass, asdict
@dataclass
class Waypoint:
x: float
y: float
z: float
yaw_deg: float
timestamp: float
label: str = ''
class WaypointRoute:
"""A named sequence of waypoints."""
def __init__(self, name: str):
self.name = name
self.waypoints: list[Waypoint] = []
self.created_at = time.time()
def add(self, x, y, z, yaw_deg, label=''):
self.waypoints.append(Waypoint(x, y, z, yaw_deg, time.time(), label))
def to_dict(self):
return {
'name': self.name,
'created_at': self.created_at,
'waypoints': [asdict(w) for w in self.waypoints],
}
@classmethod
def from_dict(cls, d):
route = cls(d['name'])
route.created_at = d.get('created_at', 0)
route.waypoints = [Waypoint(**w) for w in d['waypoints']]
return route
def save(self, routes_dir: str):
Path(routes_dir).mkdir(parents=True, exist_ok=True)
path = Path(routes_dir) / f'{self.name}.json'
path.write_text(json.dumps(self.to_dict(), indent=2))
@classmethod
def load(cls, routes_dir: str, name: str):
path = Path(routes_dir) / f'{name}.json'
return cls.from_dict(json.loads(path.read_text()))
@staticmethod
def list_routes(routes_dir: str) -> list[str]:
d = Path(routes_dir)
if not d.exists():
return []
return [p.stem for p in d.glob('*.json')]
class WaypointReplayer:
"""Iterates through waypoints, returning next target."""
def __init__(self, route: WaypointRoute, arrival_radius: float = 0.3):
self._route = route
self._idx = 0
self._arrival_radius = arrival_radius
def current_waypoint(self) -> Waypoint | None:
if self._idx < len(self._route.waypoints):
return self._route.waypoints[self._idx]
return None
def check_arrived(self, robot_x, robot_y) -> bool:
wp = self.current_waypoint()
if wp is None:
return False
dist = math.hypot(robot_x - wp.x, robot_y - wp.y)
if dist < self._arrival_radius:
self._idx += 1
return True
return False
@property
def is_done(self) -> bool:
return self._idx >= len(self._route.waypoints)
def reset(self):
self._idx = 0

View File

@ -0,0 +1,135 @@
"""
waypoint_teacher_node.py -- Standalone waypoint teacher ROS2 node.
Listens to /social/speech/transcript for "teach route <name>" and "stop teaching".
Records robot pose at configurable intervals. Saves/loads routes via WaypointRoute.
"""
import math
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from nav_msgs.msg import Odometry
from std_msgs.msg import String
from .waypoint_teacher import WaypointRoute
class WaypointTeacherNode(Node):
"""Standalone waypoint teaching node."""
def __init__(self):
super().__init__('waypoint_teacher')
self.declare_parameter('routes_dir', '/mnt/nvme/saltybot/routes')
self.declare_parameter('recording_interval', 0.5) # meters
self._routes_dir = self.get_parameter('routes_dir').value
self._interval = self.get_parameter('recording_interval').value
self._teaching = False
self._route = None
self._last_x = None
self._last_y = None
self._robot_x = 0.0
self._robot_y = 0.0
self._robot_yaw = 0.0
best_effort_qos = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST, depth=1)
self.create_subscription(
Odometry, '/odom', self._on_odom, best_effort_qos)
self.create_subscription(
String, '/social/speech/transcript',
self._on_transcript, 10)
self._status_pub = self.create_publisher(
String, '/social/waypoint/status', 10)
# Record timer at 10Hz (check distance)
self._timer = self.create_timer(0.1, self._record_tick)
self.get_logger().info(
f'Waypoint teacher ready (interval={self._interval}m, '
f'dir={self._routes_dir})')
def _on_odom(self, msg: Odometry):
self._robot_x = msg.pose.pose.position.x
self._robot_y = msg.pose.pose.position.y
q = msg.pose.pose.orientation
self._robot_yaw = math.atan2(
2.0 * (q.w * q.z + q.x * q.y),
1.0 - 2.0 * (q.y * q.y + q.z * q.z))
def _on_transcript(self, msg: String):
text = msg.data.strip().lower()
import re
m = re.match(r'teach\s+route\s+(\w+)', text)
if m:
name = m.group(1)
self._route = WaypointRoute(name)
self._teaching = True
self._last_x = self._robot_x
self._last_y = self._robot_y
self._pub_status(f'RECORDING: {name}')
self.get_logger().info(f'Recording route: {name}')
return
if 'stop teaching' in text:
if self._teaching and self._route:
self._route.save(self._routes_dir)
n = len(self._route.waypoints)
self._pub_status(
f'SAVED: {self._route.name} ({n} waypoints)')
self.get_logger().info(
f'Route saved: {self._route.name} ({n} pts)')
self._teaching = False
self._route = None
return
if 'list routes' in text:
routes = WaypointRoute.list_routes(self._routes_dir)
self._pub_status(f'ROUTES: {", ".join(routes) or "(none)"}')
def _record_tick(self):
if not self._teaching or self._route is None:
return
if self._last_x is None:
self._last_x = self._robot_x
self._last_y = self._robot_y
dist = math.hypot(
self._robot_x - self._last_x,
self._robot_y - self._last_y)
if dist >= self._interval:
yaw_deg = math.degrees(self._robot_yaw)
self._route.add(self._robot_x, self._robot_y, 0.0, yaw_deg)
self._last_x = self._robot_x
self._last_y = self._robot_y
def _pub_status(self, text: str):
msg = String()
msg.data = text
self._status_pub.publish(msg)
def main(args=None):
rclpy.init(args=args)
node = WaypointTeacherNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,80 @@
#!/usr/bin/env python3
"""
build_midas_trt_engine.py -- Build TensorRT FP16 engine for MiDaS_small from ONNX.
Usage:
python3 build_midas_trt_engine.py \
--onnx /mnt/nvme/saltybot/models/midas_small.onnx \
--engine /mnt/nvme/saltybot/models/midas_small.engine \
--fp16
Requires: tensorrt, pycuda
"""
import argparse
import os
import sys
def build_engine(onnx_path: str, engine_path: str, fp16: bool = True):
try:
import tensorrt as trt
except ImportError:
print('ERROR: tensorrt not found. Install TensorRT first.')
sys.exit(1)
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
print(f'Parsing ONNX model: {onnx_path}')
with open(onnx_path, 'rb') as f:
if not parser.parse(f.read()):
for i in range(parser.num_errors):
print(f' ONNX parse error: {parser.get_error(i)}')
sys.exit(1)
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB
if fp16 and builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
print('FP16 mode enabled')
elif fp16:
print('WARNING: FP16 not supported on this platform, using FP32')
print('Building TensorRT engine (this may take several minutes)...')
engine_bytes = builder.build_serialized_network(network, config)
if engine_bytes is None:
print('ERROR: Failed to build engine')
sys.exit(1)
os.makedirs(os.path.dirname(engine_path) or '.', exist_ok=True)
with open(engine_path, 'wb') as f:
f.write(engine_bytes)
size_mb = len(engine_bytes) / (1024 * 1024)
print(f'Engine saved: {engine_path} ({size_mb:.1f} MB)')
def main():
parser = argparse.ArgumentParser(
description='Build TensorRT FP16 engine for MiDaS_small')
parser.add_argument('--onnx', required=True,
help='Path to MiDaS ONNX model')
parser.add_argument('--engine', required=True,
help='Output TRT engine path')
parser.add_argument('--fp16', action='store_true', default=True,
help='Enable FP16 (default: True)')
parser.add_argument('--fp32', action='store_true',
help='Force FP32 (disable FP16)')
args = parser.parse_args()
fp16 = not args.fp32
build_engine(args.onnx, args.engine, fp16=fp16)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/saltybot_social_nav
[install]
install_scripts=$base/lib/saltybot_social_nav

View File

@ -0,0 +1,31 @@
from setuptools import setup
import os
from glob import glob
package_name = 'saltybot_social_nav'
setup(
name=package_name,
version='0.1.0',
packages=[package_name],
data_files=[
('share/ament_index/resource_index/packages', ['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
(os.path.join('share', package_name, 'launch'), glob('launch/*.py')),
(os.path.join('share', package_name, 'config'), glob('config/*.yaml')),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='seb',
maintainer_email='seb@vayrette.com',
description='Social navigation for saltybot: follow modes, waypoint teaching, A* avoidance, MiDaS depth',
license='MIT',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'social_nav = saltybot_social_nav.social_nav_node:main',
'midas_depth = saltybot_social_nav.midas_depth_node:main',
'waypoint_teacher = saltybot_social_nav.waypoint_teacher_node:main',
],
},
)

View File

@ -0,0 +1,12 @@
# Copyright 2026 SaltyLab
# Licensed under MIT
from ament_copyright.main import main
import pytest
@pytest.mark.copyright
@pytest.mark.linter
def test_copyright():
rc = main(argv=['.', 'test'])
assert rc == 0, 'Found errors'

View File

@ -0,0 +1,14 @@
# Copyright 2026 SaltyLab
# Licensed under MIT
from ament_flake8.main import main_with_errors
import pytest
@pytest.mark.flake8
@pytest.mark.linter
def test_flake8():
rc, errors = main_with_errors(argv=[])
assert rc == 0, \
'Found %d code style errors / warnings:\n' % len(errors) + \
'\n'.join(errors)

View File

@ -0,0 +1,12 @@
# Copyright 2026 SaltyLab
# Licensed under MIT
from ament_pep257.main import main
import pytest
@pytest.mark.pep257
@pytest.mark.linter
def test_pep257():
rc = main(argv=['.', 'test'])
assert rc == 0, 'Found code style errors / warnings'

View File

@ -1,57 +1,24 @@
# uwb_config.yaml — MaUWB ESP32-S3 DW3000 UWB follow-me system
#
# Hardware layout:
# Anchor-0 (port side) → USB port_a, y = +anchor_separation/2
# Anchor-1 (starboard side) → USB port_b, y = -anchor_separation/2
# Tag on person → belt clip, ~0.9m above ground
# uwb_config.yaml — MaUWB ESP32-S3 DW3000 UWB integration (Issue #90)
#
# Run with:
# ros2 launch saltybot_uwb uwb.launch.py
# Override at launch:
# ros2 launch saltybot_uwb uwb.launch.py port_a:=/dev/ttyUSB2
# ── Serial ports ──────────────────────────────────────────────────────────────
# Set udev rules to get stable symlinks:
# /dev/uwb-anchor0 → port_a
# /dev/uwb-anchor1 → port_b
# (See jetson/docs/pinout.md for udev setup)
port_a: /dev/uwb-anchor0 # Anchor-0 (port)
port_b: /dev/uwb-anchor1 # Anchor-1 (starboard)
baudrate: 115200 # MaUWB default — do not change
port_a: /dev/uwb-anchor0
port_b: /dev/uwb-anchor1
baudrate: 115200
# ── Anchor geometry ────────────────────────────────────────────────────────────
# anchor_separation: centre-to-centre distance between anchors (metres)
# Must match physical mounting. Larger = more accurate lateral resolution.
anchor_separation: 0.25 # metres (25cm)
anchor_separation: 0.25
anchor_height: 0.80
tag_height: 0.90
# anchor_height: height of anchors above ground (metres)
# Orin stem mount ≈ 0.80m on the saltybot platform
anchor_height: 0.80 # metres
range_timeout_s: 1.0
max_range_m: 8.0
min_range_m: 0.05
# tag_height: height of person's belt-clip tag above ground (metres)
tag_height: 0.90 # metres (adjust per user)
# ── Range validity ─────────────────────────────────────────────────────────────
# range_timeout_s: stale anchor — excluded from triangulation after this gap
range_timeout_s: 1.0 # seconds
# max_range_m: discard ranges beyond this (DW3000 indoor practical limit ≈8m)
max_range_m: 8.0 # metres
# min_range_m: discard ranges below this (likely multipath artefacts)
min_range_m: 0.05 # metres
# ── Kalman filter ──────────────────────────────────────────────────────────────
# kf_process_noise: Q scalar — how dynamic the person's motion is
# Higher → faster response, more jitter
kf_process_noise: 0.1
kf_meas_noise: 0.3
# kf_meas_noise: R scalar — how noisy the UWB ranging is
# DW3000 indoor accuracy ≈ 10cm RMS → 0.1m → R ≈ 0.01
# Use 0.3 to be conservative on first deployment
kf_meas_noise: 0.3
range_rate: 100.0
bearing_rate: 10.0
# ── Publish rate ──────────────────────────────────────────────────────────────
# Should match or exceed the AT+RANGE? poll rate from both anchors.
# 20Hz means 50ms per cycle; each anchor query takes ~10ms → headroom ok.
publish_rate: 20.0 # Hz
enrolled_tag_ids: [""]

View File

@ -1,10 +1,14 @@
"""
uwb.launch.py Launch UWB driver node for MaUWB ESP32-S3 follow-me.
uwb.launch.py Launch UWB driver node for MaUWB ESP32-S3 DW3000 (Issue #90).
Topics:
/uwb/ranges 100 Hz raw TWR ranges
/uwb/bearing 10 Hz Kalman-fused bearing
/uwb/target 10 Hz triangulated PoseStamped (backwards compat)
Usage:
ros2 launch saltybot_uwb uwb.launch.py
ros2 launch saltybot_uwb uwb.launch.py port_a:=/dev/ttyUSB2 port_b:=/dev/ttyUSB3
ros2 launch saltybot_uwb uwb.launch.py anchor_separation:=0.30 publish_rate:=10.0
ros2 launch saltybot_uwb uwb.launch.py enrolled_tag_ids:="['0xDEADBEEF']"
"""
import os
@ -31,7 +35,9 @@ def generate_launch_description():
DeclareLaunchArgument("min_range_m", default_value="0.05"),
DeclareLaunchArgument("kf_process_noise", default_value="0.1"),
DeclareLaunchArgument("kf_meas_noise", default_value="0.3"),
DeclareLaunchArgument("publish_rate", default_value="20.0"),
DeclareLaunchArgument("range_rate", default_value="100.0"),
DeclareLaunchArgument("bearing_rate", default_value="10.0"),
DeclareLaunchArgument("enrolled_tag_ids", default_value="['']"),
Node(
package="saltybot_uwb",
@ -52,7 +58,9 @@ def generate_launch_description():
"min_range_m": LaunchConfiguration("min_range_m"),
"kf_process_noise": LaunchConfiguration("kf_process_noise"),
"kf_meas_noise": LaunchConfiguration("kf_meas_noise"),
"publish_rate": LaunchConfiguration("publish_rate"),
"range_rate": LaunchConfiguration("range_rate"),
"bearing_rate": LaunchConfiguration("bearing_rate"),
"enrolled_tag_ids": LaunchConfiguration("enrolled_tag_ids"),
},
],
),

View File

@ -29,6 +29,26 @@ Returns (x_t, y_t); caller should treat negative x_t as 0.
import math
# ── Bearing helper ────────────────────────────────────────────────────────────
def bearing_from_pos(x: float, y: float) -> float:
"""
Compute horizontal bearing to a point (x, y) in base_link.
Parameters
----------
x : forward distance (metres)
y : lateral offset (metres, positive = left of robot / CCW)
Returns
-------
bearing_rad : bearing in radians, range -π .. +π
positive = target to the left (CCW)
0 = directly ahead
"""
return math.atan2(y, x)
# ── Triangulation ─────────────────────────────────────────────────────────────
def triangulate_2anchor(

View File

@ -1,32 +1,52 @@
"""
uwb_driver_node.py ROS2 node for MaUWB ESP32-S3 DW3000 follow-me system.
uwb_driver_node.py ROS2 node for MaUWB ESP32-S3 DW3000 UWB integration.
Hardware
2× MaUWB ESP32-S3 DW3000 anchors on robot stem (USB Orin Nano)
2× MaUWB ESP32-S3 DW3000 anchors on robot stem (USB Orin)
- Anchor-0: port side (y = +sep/2)
- Anchor-1: starboard (y = -sep/2)
1× MaUWB tag on person (belt clip)
1× MaUWB tag per enrolled person (belt clip)
AT command interface (115200 8N1)
Query: AT+RANGE?\r\n
Response (from anchors):
+RANGE:<anchor_id>,<range_mm>[,<rssi>]\r\n
Query:
AT+RANGE?\r\n
Config:
AT+anchor_tag=ANCHOR\r\n set module as anchor
AT+anchor_tag=TAG\r\n set module as tag
Response (from anchors, TWR protocol):
+RANGE:<anchor_id>,<range_mm>[,<rssi>[,<tag_addr>]]\r\n
Tag pairing (optional targets a specific enrolled tag):
AT+RANGE_ADDR=<tag_addr>\r\n anchor only ranges with that tag
Publishes
/uwb/target (geometry_msgs/PoseStamped) triangulated person position in base_link
/uwb/ranges (saltybot_uwb_msgs/UwbRangeArray) raw ranges from both anchors
/uwb/ranges (saltybot_uwb_msgs/UwbRangeArray) 100 Hz raw anchor ranges
/uwb/bearing (saltybot_uwb_msgs/UwbBearing) 10 Hz Kalman-fused bearing
/uwb/target (geometry_msgs/PoseStamped) 10 Hz triangulated position
(kept for backwards compat)
Safety
If a range is stale (> range_timeout_s), that anchor is excluded from
triangulation. If both anchors are stale, /uwb/target is not published.
Tag pairing
Set enrolled_tag_ids to a list of tag address strings (e.g. ["0x1234ABCD"]).
When non-empty, ranges from unrecognised tags are silently discarded.
The matched tag address is stamped in UwbRange.tag_id and UwbBearing.tag_id.
When enrolled_tag_ids is empty, all ranges are accepted (tag_id = "").
Parameters
port_a, port_b serial ports for anchor-0 / anchor-1
baudrate 115200 (default)
anchor_separation centre-to-centre anchor spacing (m)
anchor_height anchor mounting height (m)
tag_height person tag height (m)
range_timeout_s stale-anchor threshold (s)
max_range_m / min_range_m validity window (m)
kf_process_noise Kalman Q scalar
kf_meas_noise Kalman R scalar
range_rate Hz /uwb/ranges publish rate (default 100)
bearing_rate Hz /uwb/bearing publish rate (default 10)
enrolled_tag_ids list[str] accepted tag addresses; [] = accept all
Usage
@ -44,8 +64,8 @@ from rclpy.node import Node
from geometry_msgs.msg import PoseStamped
from std_msgs.msg import Header
from saltybot_uwb_msgs.msg import UwbRange, UwbRangeArray
from saltybot_uwb.ranging_math import triangulate_2anchor, KalmanFilter2D
from saltybot_uwb_msgs.msg import UwbRange, UwbRangeArray, UwbBearing
from saltybot_uwb.ranging_math import triangulate_2anchor, KalmanFilter2D, bearing_from_pos
try:
import serial
@ -54,26 +74,31 @@ except ImportError:
_SERIAL_AVAILABLE = False
# Regex: +RANGE:<id>,<mm> or +RANGE:<id>,<mm>,<rssi>
# +RANGE:<id>,<mm> or +RANGE:<id>,<mm>,<rssi> or +RANGE:<id>,<mm>,<rssi>,<tag>
_RANGE_RE = re.compile(
r"\+RANGE\s*:\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(-?[\d.]+))?",
r"\+RANGE\s*:\s*(\d+)\s*,\s*(\d+)"
r"(?:\s*,\s*(-?[\d.]+)"
r"(?:\s*,\s*([\w:x]+))?"
r")?",
re.IGNORECASE,
)
class SerialReader(threading.Thread):
"""
Background thread: polls an anchor's UART, fires callback on every
valid +RANGE response.
Background thread: polls one anchor's UART at maximum TWR rate,
fires callback on every valid +RANGE response.
Supports optional tag pairing via AT+RANGE_ADDR command.
"""
def __init__(self, port, baudrate, anchor_id, callback, logger):
def __init__(self, port, baudrate, anchor_id, callback, logger, tag_addr=None):
super().__init__(daemon=True)
self._port = port
self._baudrate = baudrate
self._anchor_id = anchor_id
self._callback = callback
self._logger = logger
self._tag_addr = tag_addr
self._running = False
self._ser = None
@ -86,7 +111,10 @@ class SerialReader(threading.Thread):
)
self._logger.info(
f"Anchor-{self._anchor_id}: opened {self._port}"
+ (f" paired with tag {self._tag_addr}" if self._tag_addr else "")
)
if self._tag_addr:
self._send_pairing_cmd()
self._read_loop()
except Exception as exc:
self._logger.warn(
@ -96,12 +124,24 @@ class SerialReader(threading.Thread):
self._ser.close()
time.sleep(2.0)
def _send_pairing_cmd(self):
"""Configure the anchor to range only with the paired tag."""
try:
cmd = f"AT+RANGE_ADDR={self._tag_addr}\r\n".encode("ascii")
self._ser.write(cmd)
time.sleep(0.1)
self._logger.info(
f"Anchor-{self._anchor_id}: sent tag pairing {self._tag_addr}"
)
except Exception as exc:
self._logger.warn(
f"Anchor-{self._anchor_id}: pairing cmd failed: {exc}"
)
def _read_loop(self):
while self._running:
try:
# Query the anchor
self._ser.write(b"AT+RANGE?\r\n")
# Read up to 10 lines waiting for a +RANGE response
for _ in range(10):
raw = self._ser.readline()
if not raw:
@ -111,13 +151,14 @@ class SerialReader(threading.Thread):
if m:
range_mm = int(m.group(2))
rssi = float(m.group(3)) if m.group(3) else 0.0
self._callback(self._anchor_id, range_mm, rssi)
tag_addr = m.group(4) if m.group(4) else ""
self._callback(self._anchor_id, range_mm, rssi, tag_addr)
break
except Exception as exc:
self._logger.warn(
f"Anchor-{self._anchor_id} read error: {exc}"
)
break # trigger reconnect
break
def stop(self):
self._running = False
@ -130,48 +171,51 @@ class UwbDriverNode(Node):
def __init__(self):
super().__init__("uwb_driver")
# ── Parameters ────────────────────────────────────────────────────────
self.declare_parameter("port_a", "/dev/ttyUSB0")
self.declare_parameter("port_b", "/dev/ttyUSB1")
self.declare_parameter("baudrate", 115200)
self.declare_parameter("anchor_separation", 0.25)
self.declare_parameter("anchor_height", 0.80)
self.declare_parameter("tag_height", 0.90)
self.declare_parameter("range_timeout_s", 1.0)
self.declare_parameter("max_range_m", 8.0)
self.declare_parameter("min_range_m", 0.05)
self.declare_parameter("kf_process_noise", 0.1)
self.declare_parameter("kf_meas_noise", 0.3)
self.declare_parameter("publish_rate", 20.0)
self.declare_parameter("port_a", "/dev/uwb-anchor0")
self.declare_parameter("port_b", "/dev/uwb-anchor1")
self.declare_parameter("baudrate", 115200)
self.declare_parameter("anchor_separation", 0.25)
self.declare_parameter("anchor_height", 0.80)
self.declare_parameter("tag_height", 0.90)
self.declare_parameter("range_timeout_s", 1.0)
self.declare_parameter("max_range_m", 8.0)
self.declare_parameter("min_range_m", 0.05)
self.declare_parameter("kf_process_noise", 0.1)
self.declare_parameter("kf_meas_noise", 0.3)
self.declare_parameter("range_rate", 100.0)
self.declare_parameter("bearing_rate", 10.0)
self.declare_parameter("enrolled_tag_ids", [""])
self._p = self._load_params()
# ── State (protected by lock) ──────────────────────────────────────
self._lock = threading.Lock()
self._ranges = {} # anchor_id → (range_m, rssi, timestamp)
self._kf = KalmanFilter2D(
raw_ids = self.get_parameter("enrolled_tag_ids").value
self._enrolled_tags = [t.strip() for t in raw_ids if t.strip()]
paired_tag = self._enrolled_tags[0] if self._enrolled_tags else None
self._lock = threading.Lock()
self._ranges: dict = {}
self._kf = KalmanFilter2D(
process_noise=self._p["kf_process_noise"],
measurement_noise=self._p["kf_meas_noise"],
dt=1.0 / self._p["publish_rate"],
dt=1.0 / self._p["bearing_rate"],
)
# ── Publishers ────────────────────────────────────────────────────
self._target_pub = self.create_publisher(
PoseStamped, "/uwb/target", 10)
self._ranges_pub = self.create_publisher(
UwbRangeArray, "/uwb/ranges", 10)
self._ranges_pub = self.create_publisher(UwbRangeArray, "/uwb/ranges", 10)
self._bearing_pub = self.create_publisher(UwbBearing, "/uwb/bearing", 10)
self._target_pub = self.create_publisher(PoseStamped, "/uwb/target", 10)
# ── Serial readers ────────────────────────────────────────────────
if _SERIAL_AVAILABLE:
self._reader_a = SerialReader(
self._p["port_a"], self._p["baudrate"],
anchor_id=0, callback=self._range_cb,
logger=self.get_logger(),
tag_addr=paired_tag,
)
self._reader_b = SerialReader(
self._p["port_b"], self._p["baudrate"],
anchor_id=1, callback=self._range_cb,
logger=self.get_logger(),
tag_addr=paired_tag,
)
self._reader_a.start()
self._reader_b.start()
@ -180,19 +224,21 @@ class UwbDriverNode(Node):
"pyserial not installed — running in simulation mode (no serial I/O)"
)
# ── Publish timer ─────────────────────────────────────────────────
self._timer = self.create_timer(
1.0 / self._p["publish_rate"], self._publish_cb
self._range_timer = self.create_timer(
1.0 / self._p["range_rate"], self._range_publish_cb
)
self._bearing_timer = self.create_timer(
1.0 / self._p["bearing_rate"], self._bearing_publish_cb
)
self.get_logger().info(
f"UWB driver ready sep={self._p['anchor_separation']}m "
f"ports={self._p['port_a']},{self._p['port_b']} "
f"rate={self._p['publish_rate']}Hz"
f"range={self._p['range_rate']}Hz "
f"bearing={self._p['bearing_rate']}Hz "
f"enrolled_tags={self._enrolled_tags or ['<any>']}"
)
# ── Helpers ───────────────────────────────────────────────────────────────
def _load_params(self):
return {
"port_a": self.get_parameter("port_a").value,
@ -206,90 +252,106 @@ class UwbDriverNode(Node):
"min_range_m": self.get_parameter("min_range_m").value,
"kf_process_noise": self.get_parameter("kf_process_noise").value,
"kf_meas_noise": self.get_parameter("kf_meas_noise").value,
"publish_rate": self.get_parameter("publish_rate").value,
"range_rate": self.get_parameter("range_rate").value,
"bearing_rate": self.get_parameter("bearing_rate").value,
}
# ── Callbacks ─────────────────────────────────────────────────────────────
def _is_enrolled(self, tag_addr: str) -> bool:
if not self._enrolled_tags:
return True
return tag_addr in self._enrolled_tags
def _range_cb(self, anchor_id: int, range_mm: int, rssi: float):
"""Called from serial reader threads — thread-safe update."""
def _range_cb(self, anchor_id: int, range_mm: int, rssi: float, tag_addr: str):
if not self._is_enrolled(tag_addr):
return
range_m = range_mm / 1000.0
p = self._p
if range_m < p["min_range_m"] or range_m > p["max_range_m"]:
return
with self._lock:
self._ranges[anchor_id] = (range_m, rssi, time.monotonic())
self._ranges[anchor_id] = (range_m, rssi, tag_addr, time.monotonic())
def _publish_cb(self):
now = time.monotonic()
def _range_publish_cb(self):
"""100 Hz: publish current raw ranges as UwbRangeArray."""
now = time.monotonic()
timeout = self._p["range_timeout_s"]
sep = self._p["anchor_separation"]
with self._lock:
# Collect valid (non-stale) ranges
valid = {}
for aid, (r, rssi, t) in self._ranges.items():
if now - t <= timeout:
valid[aid] = (r, rssi, t)
# Build and publish UwbRangeArray regardless (even if partial)
valid = {
aid: entry
for aid, entry in self._ranges.items()
if (now - entry[3]) <= timeout
}
hdr = Header()
hdr.stamp = self.get_clock().now().to_msg()
hdr.frame_id = "base_link"
arr = UwbRangeArray()
arr.header = hdr
for aid, (r, rssi, _) in valid.items():
for aid, (r, rssi, tag_id, _) in valid.items():
entry = UwbRange()
entry.header = hdr
entry.anchor_id = aid
entry.range_m = float(r)
entry.raw_mm = int(round(r * 1000.0))
entry.rssi = float(rssi)
entry.tag_id = tag_id
arr.ranges.append(entry)
self._ranges_pub.publish(arr)
# Need both anchors to triangulate
if 0 not in valid or 1 not in valid:
def _bearing_publish_cb(self):
"""10 Hz: Kalman predict+update, publish fused bearing."""
now = time.monotonic()
timeout = self._p["range_timeout_s"]
sep = self._p["anchor_separation"]
with self._lock:
valid = {
aid: entry
for aid, entry in self._ranges.items()
if (now - entry[3]) <= timeout
}
if not valid:
return
r0 = valid[0][0]
r1 = valid[1][0]
try:
x_t, y_t = triangulate_2anchor(
r0=r0,
r1=r1,
sep=sep,
anchor_z=self._p["anchor_height"],
tag_z=self._p["tag_height"],
)
except (ValueError, ZeroDivisionError) as exc:
self.get_logger().warn(f"Triangulation error: {exc}")
return
# Kalman filter update
dt = 1.0 / self._p["publish_rate"]
both_fresh = 0 in valid and 1 in valid
confidence = 1.0 if both_fresh else 0.5
active_tag = valid[next(iter(valid))][2]
dt = 1.0 / self._p["bearing_rate"]
self._kf.predict(dt=dt)
self._kf.update(x_t, y_t)
if both_fresh:
r0 = valid[0][0]
r1 = valid[1][0]
try:
x_t, y_t = triangulate_2anchor(
r0=r0, r1=r1, sep=sep,
anchor_z=self._p["anchor_height"],
tag_z=self._p["tag_height"],
)
except (ValueError, ZeroDivisionError) as exc:
self.get_logger().warn(f"Triangulation error: {exc}")
return
self._kf.update(x_t, y_t)
kx, ky = self._kf.position()
# Publish PoseStamped in base_link
bearing = bearing_from_pos(kx, ky)
range_m = math.sqrt(kx * kx + ky * ky)
hdr = Header()
hdr.stamp = self.get_clock().now().to_msg()
hdr.frame_id = "base_link"
brg_msg = UwbBearing()
brg_msg.header = hdr
brg_msg.bearing_rad = float(bearing)
brg_msg.range_m = float(range_m)
brg_msg.confidence = float(confidence)
brg_msg.tag_id = active_tag
self._bearing_pub.publish(brg_msg)
pose = PoseStamped()
pose.header = hdr
pose.pose.position.x = kx
pose.pose.position.y = ky
pose.pose.position.z = 0.0
# Orientation: face the person (yaw = atan2(y, x))
yaw = math.atan2(ky, kx)
yaw = bearing
pose.pose.orientation.z = math.sin(yaw / 2.0)
pose.pose.orientation.w = math.cos(yaw / 2.0)
self._target_pub.publish(pose)
# ── Entry point ───────────────────────────────────────────────────────────────
def main(args=None):
rclpy.init(args=args)
node = UwbDriverNode()

View File

@ -7,7 +7,7 @@ No ROS2 / serial / GPU dependencies — runs with plain pytest.
import math
import pytest
from saltybot_uwb.ranging_math import triangulate_2anchor, KalmanFilter2D
from saltybot_uwb.ranging_math import triangulate_2anchor, KalmanFilter2D, bearing_from_pos
# ── triangulate_2anchor ───────────────────────────────────────────────────────
@ -172,3 +172,47 @@ class TestKalmanFilter2D:
x, y = kf.position()
assert not math.isnan(x)
assert not math.isnan(y)
# ── bearing_from_pos ──────────────────────────────────────────────────────────
class TestBearingFromPos:
def test_directly_ahead_zero_bearing(self):
"""Person directly ahead: x=2, y=0 → bearing ≈ 0."""
b = bearing_from_pos(2.0, 0.0)
assert abs(b) < 0.001
def test_left_gives_positive_bearing(self):
"""Person to the left (y>0): bearing should be positive."""
b = bearing_from_pos(1.0, 1.0)
assert b > 0.0
assert abs(b - math.pi / 4.0) < 0.001
def test_right_gives_negative_bearing(self):
"""Person to the right (y<0): bearing should be negative."""
b = bearing_from_pos(1.0, -1.0)
assert b < 0.0
assert abs(b + math.pi / 4.0) < 0.001
def test_directly_left_ninety_degrees(self):
"""Person directly to the left: x=0, y=1 → bearing = π/2."""
b = bearing_from_pos(0.0, 1.0)
assert abs(b - math.pi / 2.0) < 0.001
def test_directly_right_minus_ninety_degrees(self):
"""Person directly to the right: x=0, y=-1 → bearing = -π/2."""
b = bearing_from_pos(0.0, -1.0)
assert abs(b + math.pi / 2.0) < 0.001
def test_range_pi_to_minus_pi(self):
"""Bearing is always in -π..+π."""
for x in [-2.0, -0.1, 0.1, 2.0]:
for y in [-2.0, -0.1, 0.1, 2.0]:
b = bearing_from_pos(x, y)
assert -math.pi <= b <= math.pi
def test_no_nan_for_tiny_distance(self):
"""Very close target should not produce NaN."""
b = bearing_from_pos(0.001, 0.001)
assert not math.isnan(b)

View File

@ -8,6 +8,7 @@ find_package(rosidl_default_generators REQUIRED)
rosidl_generate_interfaces(${PROJECT_NAME}
"msg/UwbRange.msg"
"msg/UwbRangeArray.msg"
"msg/UwbBearing.msg"
DEPENDENCIES std_msgs
)

View File

@ -0,0 +1,18 @@
# UwbBearing.msg — 10 Hz Kalman-fused bearing estimate from dual-anchor UWB.
#
# bearing_rad : horizontal bearing to tag in base_link frame (radians)
# positive = tag to the left (CCW), negative = right (CW)
# range: -π .. +π
# range_m : Kalman-filtered horizontal range to tag (metres)
# confidence : quality indicator
# 1.0 = both anchors fresh
# 0.5 = single anchor only (bearing less reliable)
# 0.0 = stale / no data (message not published in this state)
# tag_id : enrolled tag identifier that produced this estimate
std_msgs/Header header
float32 bearing_rad
float32 range_m
float32 confidence
string tag_id

View File

@ -4,6 +4,7 @@
# range_m : measured horizontal range in metres (after height correction)
# raw_mm : raw TWR range from AT+RANGE? response, millimetres
# rssi : received signal strength (dBm), 0 if not reported by module
# tag_id : enrolled tag identifier; empty string if tag pairing is disabled
std_msgs/Header header
@ -11,3 +12,4 @@ uint8 anchor_id
float32 range_m
uint32 raw_mm
float32 rssi
string tag_id