Compare commits
7 Commits
0612eedbcd
...
f61a03b3c5
| Author | SHA1 | Date | |
|---|---|---|---|
| f61a03b3c5 | |||
| d9c983f666 | |||
| 54e9274405 | |||
| b432492785 | |||
| 9a68dfdb2e | |||
| d872ea5e34 | |||
| 84790412d6 |
@ -0,0 +1,8 @@
|
|||||||
|
person_state_tracker:
|
||||||
|
ros__parameters:
|
||||||
|
engagement_distance: 2.0
|
||||||
|
absent_timeout: 5.0
|
||||||
|
prune_timeout: 30.0
|
||||||
|
update_rate: 10.0
|
||||||
|
n_cameras: 4
|
||||||
|
uwb_enabled: false
|
||||||
35
jetson/ros2_ws/src/saltybot_social/launch/social.launch.py
Normal file
35
jetson/ros2_ws/src/saltybot_social/launch/social.launch.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from launch import LaunchDescription
|
||||||
|
from launch.actions import DeclareLaunchArgument
|
||||||
|
from launch.substitutions import LaunchConfiguration
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launch_description():
|
||||||
|
return LaunchDescription([
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
'engagement_distance',
|
||||||
|
default_value='2.0',
|
||||||
|
description='Distance in metres below which a person is considered engaged'
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
'absent_timeout',
|
||||||
|
default_value='5.0',
|
||||||
|
description='Seconds without detection before marking person as ABSENT'
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
'uwb_enabled',
|
||||||
|
default_value='false',
|
||||||
|
description='Whether UWB anchor data is available'
|
||||||
|
),
|
||||||
|
Node(
|
||||||
|
package='saltybot_social',
|
||||||
|
executable='person_state_tracker',
|
||||||
|
name='person_state_tracker',
|
||||||
|
output='screen',
|
||||||
|
parameters=[{
|
||||||
|
'engagement_distance': LaunchConfiguration('engagement_distance'),
|
||||||
|
'absent_timeout': LaunchConfiguration('absent_timeout'),
|
||||||
|
'uwb_enabled': LaunchConfiguration('uwb_enabled'),
|
||||||
|
}],
|
||||||
|
),
|
||||||
|
])
|
||||||
25
jetson/ros2_ws/src/saltybot_social/package.xml
Normal file
25
jetson/ros2_ws/src/saltybot_social/package.xml
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
<?xml version="1.0"?>
|
||||||
|
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||||
|
<package format="3">
|
||||||
|
<name>saltybot_social</name>
|
||||||
|
<version>0.1.0</version>
|
||||||
|
<description>Multi-modal person identity fusion and state tracking for saltybot</description>
|
||||||
|
<maintainer email="seb@vayrette.com">seb</maintainer>
|
||||||
|
<license>MIT</license>
|
||||||
|
<depend>rclpy</depend>
|
||||||
|
<depend>std_msgs</depend>
|
||||||
|
<depend>geometry_msgs</depend>
|
||||||
|
<depend>sensor_msgs</depend>
|
||||||
|
<depend>vision_msgs</depend>
|
||||||
|
<depend>saltybot_social_msgs</depend>
|
||||||
|
<depend>tf2_ros</depend>
|
||||||
|
<depend>tf2_geometry_msgs</depend>
|
||||||
|
<depend>cv_bridge</depend>
|
||||||
|
<test_depend>ament_copyright</test_depend>
|
||||||
|
<test_depend>ament_flake8</test_depend>
|
||||||
|
<test_depend>ament_pep257</test_depend>
|
||||||
|
<test_depend>python3-pytest</test_depend>
|
||||||
|
<export>
|
||||||
|
<build_type>ament_python</build_type>
|
||||||
|
</export>
|
||||||
|
</package>
|
||||||
@ -0,0 +1,201 @@
|
|||||||
|
import math
|
||||||
|
import time
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from saltybot_social.person_state import PersonTrack, State
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityFuser:
|
||||||
|
"""Multi-modal identity fusion: face_id + speaker_id + UWB anchor -> unified person_id."""
|
||||||
|
|
||||||
|
def __init__(self, position_window: float = 3.0):
|
||||||
|
self._tracks: dict[int, PersonTrack] = {}
|
||||||
|
self._next_id: int = 1
|
||||||
|
self._position_window = position_window
|
||||||
|
self._max_history = 20
|
||||||
|
|
||||||
|
def _distance_3d(self, a: Optional[tuple], b: Optional[tuple]) -> float:
|
||||||
|
if a is None or b is None:
|
||||||
|
return float('inf')
|
||||||
|
return math.sqrt(sum((ai - bi) ** 2 for ai, bi in zip(a, b)))
|
||||||
|
|
||||||
|
def _compute_distance_and_bearing(self, position: tuple) -> Tuple[float, float]:
|
||||||
|
x, y, z = position
|
||||||
|
distance = math.sqrt(x * x + y * y + z * z)
|
||||||
|
bearing = math.degrees(math.atan2(y, x))
|
||||||
|
return distance, bearing
|
||||||
|
|
||||||
|
def _allocate_id(self) -> int:
|
||||||
|
pid = self._next_id
|
||||||
|
self._next_id += 1
|
||||||
|
return pid
|
||||||
|
|
||||||
|
def update_face(self, face_id: int, name: str,
|
||||||
|
position_3d: Optional[tuple], camera_id: int,
|
||||||
|
now: float) -> int:
|
||||||
|
"""Match by face_id first. If face_id >= 0, find existing track or create new."""
|
||||||
|
if face_id >= 0:
|
||||||
|
for track in self._tracks.values():
|
||||||
|
if track.face_id == face_id:
|
||||||
|
track.name = name if name and name != 'unknown' else track.name
|
||||||
|
if position_3d is not None:
|
||||||
|
track.position = position_3d
|
||||||
|
dist, bearing = self._compute_distance_and_bearing(position_3d)
|
||||||
|
track.distance = dist
|
||||||
|
track.bearing_deg = bearing
|
||||||
|
track.history_distances.append(dist)
|
||||||
|
if len(track.history_distances) > self._max_history:
|
||||||
|
track.history_distances = track.history_distances[-self._max_history:]
|
||||||
|
track.camera_id = camera_id
|
||||||
|
track.last_seen = now
|
||||||
|
track.last_face_seen = now
|
||||||
|
return track.person_id
|
||||||
|
|
||||||
|
# No existing track with this face_id; try proximity match for face_id < 0
|
||||||
|
if face_id < 0 and position_3d is not None:
|
||||||
|
best_track = None
|
||||||
|
best_dist = 0.5 # 0.5m proximity threshold
|
||||||
|
for track in self._tracks.values():
|
||||||
|
d = self._distance_3d(track.position, position_3d)
|
||||||
|
if d < best_dist and (now - track.last_seen) < self._position_window:
|
||||||
|
best_dist = d
|
||||||
|
best_track = track
|
||||||
|
if best_track is not None:
|
||||||
|
best_track.position = position_3d
|
||||||
|
dist, bearing = self._compute_distance_and_bearing(position_3d)
|
||||||
|
best_track.distance = dist
|
||||||
|
best_track.bearing_deg = bearing
|
||||||
|
best_track.history_distances.append(dist)
|
||||||
|
if len(best_track.history_distances) > self._max_history:
|
||||||
|
best_track.history_distances = best_track.history_distances[-self._max_history:]
|
||||||
|
best_track.camera_id = camera_id
|
||||||
|
best_track.last_seen = now
|
||||||
|
return best_track.person_id
|
||||||
|
|
||||||
|
# Create new track
|
||||||
|
pid = self._allocate_id()
|
||||||
|
track = PersonTrack(person_id=pid, face_id=face_id, name=name,
|
||||||
|
camera_id=camera_id, last_seen=now, last_face_seen=now)
|
||||||
|
if position_3d is not None:
|
||||||
|
track.position = position_3d
|
||||||
|
dist, bearing = self._compute_distance_and_bearing(position_3d)
|
||||||
|
track.distance = dist
|
||||||
|
track.bearing_deg = bearing
|
||||||
|
track.history_distances.append(dist)
|
||||||
|
self._tracks[pid] = track
|
||||||
|
return pid
|
||||||
|
|
||||||
|
def update_speaker(self, speaker_id: str, now: float) -> int:
|
||||||
|
"""Find track with nearest recently-seen position, assign speaker_id."""
|
||||||
|
# First check if any track already has this speaker_id
|
||||||
|
for track in self._tracks.values():
|
||||||
|
if track.speaker_id == speaker_id:
|
||||||
|
track.last_seen = now
|
||||||
|
return track.person_id
|
||||||
|
|
||||||
|
# Assign to nearest recently-seen track
|
||||||
|
best_track = None
|
||||||
|
best_dist = float('inf')
|
||||||
|
for track in self._tracks.values():
|
||||||
|
if (now - track.last_seen) < self._position_window and track.distance > 0:
|
||||||
|
if track.distance < best_dist:
|
||||||
|
best_dist = track.distance
|
||||||
|
best_track = track
|
||||||
|
|
||||||
|
if best_track is not None:
|
||||||
|
best_track.speaker_id = speaker_id
|
||||||
|
best_track.last_seen = now
|
||||||
|
return best_track.person_id
|
||||||
|
|
||||||
|
# No suitable track found; create a new one
|
||||||
|
pid = self._allocate_id()
|
||||||
|
track = PersonTrack(person_id=pid, speaker_id=speaker_id,
|
||||||
|
last_seen=now)
|
||||||
|
self._tracks[pid] = track
|
||||||
|
return pid
|
||||||
|
|
||||||
|
def update_uwb(self, uwb_anchor_id: str, position_3d: tuple,
|
||||||
|
now: float) -> int:
|
||||||
|
"""Match by proximity (nearest track within 0.5m)."""
|
||||||
|
# Check existing UWB assignment
|
||||||
|
for track in self._tracks.values():
|
||||||
|
if track.uwb_anchor_id == uwb_anchor_id:
|
||||||
|
track.position = position_3d
|
||||||
|
dist, bearing = self._compute_distance_and_bearing(position_3d)
|
||||||
|
track.distance = dist
|
||||||
|
track.bearing_deg = bearing
|
||||||
|
track.history_distances.append(dist)
|
||||||
|
if len(track.history_distances) > self._max_history:
|
||||||
|
track.history_distances = track.history_distances[-self._max_history:]
|
||||||
|
track.last_seen = now
|
||||||
|
return track.person_id
|
||||||
|
|
||||||
|
# Proximity match
|
||||||
|
best_track = None
|
||||||
|
best_dist = 0.5
|
||||||
|
for track in self._tracks.values():
|
||||||
|
d = self._distance_3d(track.position, position_3d)
|
||||||
|
if d < best_dist and (now - track.last_seen) < self._position_window:
|
||||||
|
best_dist = d
|
||||||
|
best_track = track
|
||||||
|
|
||||||
|
if best_track is not None:
|
||||||
|
best_track.uwb_anchor_id = uwb_anchor_id
|
||||||
|
best_track.position = position_3d
|
||||||
|
dist, bearing = self._compute_distance_and_bearing(position_3d)
|
||||||
|
best_track.distance = dist
|
||||||
|
best_track.bearing_deg = bearing
|
||||||
|
best_track.history_distances.append(dist)
|
||||||
|
if len(best_track.history_distances) > self._max_history:
|
||||||
|
best_track.history_distances = best_track.history_distances[-self._max_history:]
|
||||||
|
best_track.last_seen = now
|
||||||
|
return best_track.person_id
|
||||||
|
|
||||||
|
# Create new track
|
||||||
|
pid = self._allocate_id()
|
||||||
|
track = PersonTrack(person_id=pid, uwb_anchor_id=uwb_anchor_id,
|
||||||
|
position=position_3d, last_seen=now)
|
||||||
|
dist, bearing = self._compute_distance_and_bearing(position_3d)
|
||||||
|
track.distance = dist
|
||||||
|
track.bearing_deg = bearing
|
||||||
|
track.history_distances.append(dist)
|
||||||
|
self._tracks[pid] = track
|
||||||
|
return pid
|
||||||
|
|
||||||
|
def get_all_tracks(self) -> List[PersonTrack]:
|
||||||
|
return list(self._tracks.values())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_attention(tracks: List[PersonTrack]) -> int:
|
||||||
|
"""Focus on nearest engaged/talking person, or -1 if none."""
|
||||||
|
candidates = [t for t in tracks
|
||||||
|
if t.state in (State.ENGAGED, State.TALKING) and t.distance > 0]
|
||||||
|
if not candidates:
|
||||||
|
return -1
|
||||||
|
# Prefer talking, then nearest engaged
|
||||||
|
talking = [t for t in candidates if t.state == State.TALKING]
|
||||||
|
if talking:
|
||||||
|
return min(talking, key=lambda t: t.distance).person_id
|
||||||
|
return min(candidates, key=lambda t: t.distance).person_id
|
||||||
|
|
||||||
|
def prune_absent(self, absent_timeout: float = 30.0):
|
||||||
|
"""Remove tracks absent longer than timeout."""
|
||||||
|
now = time.monotonic()
|
||||||
|
to_remove = [pid for pid, t in self._tracks.items()
|
||||||
|
if t.state == State.ABSENT and (now - t.last_seen) > absent_timeout]
|
||||||
|
for pid in to_remove:
|
||||||
|
del self._tracks[pid]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def detect_group(tracks: List[PersonTrack]) -> bool:
|
||||||
|
"""Returns True if >= 2 persons within 1.5m of each other."""
|
||||||
|
active = [t for t in tracks
|
||||||
|
if t.position is not None and t.state != State.ABSENT]
|
||||||
|
for i, a in enumerate(active):
|
||||||
|
for b in active[i + 1:]:
|
||||||
|
dx = a.position[0] - b.position[0]
|
||||||
|
dy = a.position[1] - b.position[1]
|
||||||
|
dz = a.position[2] - b.position[2]
|
||||||
|
if math.sqrt(dx * dx + dy * dy + dz * dz) < 1.5:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
@ -0,0 +1,54 @@
|
|||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import IntEnum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class State(IntEnum):
|
||||||
|
UNKNOWN = 0
|
||||||
|
APPROACHING = 1
|
||||||
|
ENGAGED = 2
|
||||||
|
TALKING = 3
|
||||||
|
LEAVING = 4
|
||||||
|
ABSENT = 5
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PersonTrack:
|
||||||
|
person_id: int
|
||||||
|
face_id: int = -1
|
||||||
|
speaker_id: str = ''
|
||||||
|
uwb_anchor_id: str = ''
|
||||||
|
name: str = 'unknown'
|
||||||
|
state: State = State.UNKNOWN
|
||||||
|
position: Optional[tuple] = None # (x, y, z) in base_link
|
||||||
|
distance: float = 0.0
|
||||||
|
bearing_deg: float = 0.0
|
||||||
|
engagement_score: float = 0.0
|
||||||
|
camera_id: int = -1
|
||||||
|
last_seen: float = field(default_factory=time.monotonic)
|
||||||
|
last_face_seen: float = 0.0
|
||||||
|
history_distances: list = field(default_factory=list) # last N distances for trend
|
||||||
|
|
||||||
|
def update_state(self, now: float, engagement_distance: float = 2.0,
|
||||||
|
absent_timeout: float = 5.0):
|
||||||
|
"""Transition state machine based on distance trend and time."""
|
||||||
|
age = now - self.last_seen
|
||||||
|
if age > absent_timeout:
|
||||||
|
self.state = State.ABSENT
|
||||||
|
return
|
||||||
|
if self.speaker_id:
|
||||||
|
self.state = State.TALKING
|
||||||
|
return
|
||||||
|
if len(self.history_distances) >= 3:
|
||||||
|
trend = self.history_distances[-1] - self.history_distances[-3]
|
||||||
|
if self.distance < engagement_distance:
|
||||||
|
self.state = State.ENGAGED
|
||||||
|
elif trend < -0.2: # moving closer
|
||||||
|
self.state = State.APPROACHING
|
||||||
|
elif trend > 0.3: # moving away
|
||||||
|
self.state = State.LEAVING
|
||||||
|
else:
|
||||||
|
self.state = State.ENGAGED if self.distance < engagement_distance else State.UNKNOWN
|
||||||
|
elif self.distance > 0:
|
||||||
|
self.state = State.ENGAGED if self.distance < engagement_distance else State.APPROACHING
|
||||||
@ -0,0 +1,273 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||||
|
|
||||||
|
from std_msgs.msg import Int32, String
|
||||||
|
from geometry_msgs.msg import PoseArray, PoseStamped, Point
|
||||||
|
from sensor_msgs.msg import Image
|
||||||
|
from builtin_interfaces.msg import Time as TimeMsg
|
||||||
|
|
||||||
|
from saltybot_social_msgs.msg import (
|
||||||
|
FaceDetection,
|
||||||
|
FaceDetectionArray,
|
||||||
|
PersonState,
|
||||||
|
PersonStateArray,
|
||||||
|
)
|
||||||
|
from saltybot_social.identity_fuser import IdentityFuser
|
||||||
|
from saltybot_social.person_state import State
|
||||||
|
|
||||||
|
|
||||||
|
class PersonStateTrackerNode(Node):
|
||||||
|
"""Main ROS2 node for multi-modal person tracking."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('person_state_tracker')
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
self.declare_parameter('engagement_distance', 2.0)
|
||||||
|
self.declare_parameter('absent_timeout', 5.0)
|
||||||
|
self.declare_parameter('prune_timeout', 30.0)
|
||||||
|
self.declare_parameter('update_rate', 10.0)
|
||||||
|
self.declare_parameter('n_cameras', 4)
|
||||||
|
self.declare_parameter('uwb_enabled', False)
|
||||||
|
|
||||||
|
self._engagement_distance = self.get_parameter('engagement_distance').value
|
||||||
|
self._absent_timeout = self.get_parameter('absent_timeout').value
|
||||||
|
self._prune_timeout = self.get_parameter('prune_timeout').value
|
||||||
|
update_rate = self.get_parameter('update_rate').value
|
||||||
|
n_cameras = self.get_parameter('n_cameras').value
|
||||||
|
uwb_enabled = self.get_parameter('uwb_enabled').value
|
||||||
|
|
||||||
|
self._fuser = IdentityFuser()
|
||||||
|
|
||||||
|
# QoS profiles
|
||||||
|
best_effort_qos = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=5
|
||||||
|
)
|
||||||
|
reliable_qos = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.RELIABLE,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Subscriptions
|
||||||
|
self.create_subscription(
|
||||||
|
FaceDetectionArray,
|
||||||
|
'/social/faces/detections',
|
||||||
|
self._on_face_detections,
|
||||||
|
best_effort_qos
|
||||||
|
)
|
||||||
|
self.create_subscription(
|
||||||
|
String,
|
||||||
|
'/social/speech/speaker_id',
|
||||||
|
self._on_speaker_id,
|
||||||
|
10
|
||||||
|
)
|
||||||
|
if uwb_enabled:
|
||||||
|
self.create_subscription(
|
||||||
|
PoseArray,
|
||||||
|
'/uwb/positions',
|
||||||
|
self._on_uwb,
|
||||||
|
10
|
||||||
|
)
|
||||||
|
|
||||||
|
# Camera subscriptions (monitor topic existence)
|
||||||
|
self.create_subscription(
|
||||||
|
Image, '/camera/color/image_raw',
|
||||||
|
self._on_camera_image, best_effort_qos)
|
||||||
|
for i in range(n_cameras):
|
||||||
|
self.create_subscription(
|
||||||
|
Image, f'/surround/cam{i}/image_raw',
|
||||||
|
self._on_camera_image, best_effort_qos)
|
||||||
|
|
||||||
|
# Existing person tracker position
|
||||||
|
self.create_subscription(
|
||||||
|
PoseStamped,
|
||||||
|
'/person/target',
|
||||||
|
self._on_person_target,
|
||||||
|
10
|
||||||
|
)
|
||||||
|
|
||||||
|
# Publishers
|
||||||
|
self._persons_pub = self.create_publisher(
|
||||||
|
PersonStateArray, '/social/persons', best_effort_qos)
|
||||||
|
self._attention_pub = self.create_publisher(
|
||||||
|
Int32, '/social/attention/target_id', reliable_qos)
|
||||||
|
|
||||||
|
# Timer
|
||||||
|
timer_period = 1.0 / update_rate
|
||||||
|
self.create_timer(timer_period, self._on_timer)
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f'PersonStateTracker started: engagement={self._engagement_distance}m, '
|
||||||
|
f'absent_timeout={self._absent_timeout}s, rate={update_rate}Hz, '
|
||||||
|
f'cameras={n_cameras}, uwb={uwb_enabled}')
|
||||||
|
|
||||||
|
def _on_face_detections(self, msg: FaceDetectionArray):
|
||||||
|
now = time.monotonic()
|
||||||
|
for face in msg.faces:
|
||||||
|
position_3d = None
|
||||||
|
# Use bbox center as rough position estimate if depth is available
|
||||||
|
# Real 3D position would come from depth camera projection
|
||||||
|
if face.bbox_x > 0 or face.bbox_y > 0:
|
||||||
|
# Approximate: use bbox center as bearing proxy, distance from bbox size
|
||||||
|
# This is a placeholder; real impl would use depth image lookup
|
||||||
|
position_3d = (
|
||||||
|
max(0.5, 1.0 / max(face.bbox_w, 0.01)), # rough depth from bbox width
|
||||||
|
face.bbox_x + face.bbox_w / 2.0 - 0.5, # x offset from center
|
||||||
|
face.bbox_y + face.bbox_h / 2.0 - 0.5 # y offset from center
|
||||||
|
)
|
||||||
|
|
||||||
|
camera_id = -1
|
||||||
|
if hasattr(face.header, 'frame_id') and face.header.frame_id:
|
||||||
|
frame = face.header.frame_id
|
||||||
|
if 'cam' in frame:
|
||||||
|
try:
|
||||||
|
camera_id = int(frame.split('cam')[-1].split('_')[0])
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._fuser.update_face(
|
||||||
|
face_id=face.face_id,
|
||||||
|
name=face.person_name,
|
||||||
|
position_3d=position_3d,
|
||||||
|
camera_id=camera_id,
|
||||||
|
now=now
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_speaker_id(self, msg: String):
|
||||||
|
now = time.monotonic()
|
||||||
|
speaker_id = msg.data
|
||||||
|
# Parse "speaker_<id>" format or raw id
|
||||||
|
if speaker_id.startswith('speaker_'):
|
||||||
|
speaker_id = speaker_id[len('speaker_'):]
|
||||||
|
self._fuser.update_speaker(speaker_id, now)
|
||||||
|
|
||||||
|
def _on_uwb(self, msg: PoseArray):
|
||||||
|
now = time.monotonic()
|
||||||
|
for i, pose in enumerate(msg.poses):
|
||||||
|
position = (pose.position.x, pose.position.y, pose.position.z)
|
||||||
|
anchor_id = f'uwb_{i}'
|
||||||
|
if hasattr(msg, 'header') and msg.header.frame_id:
|
||||||
|
anchor_id = f'{msg.header.frame_id}_{i}'
|
||||||
|
self._fuser.update_uwb(anchor_id, position, now)
|
||||||
|
|
||||||
|
def _on_camera_image(self, msg: Image):
|
||||||
|
# Monitor only -- no processing here
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _on_person_target(self, msg: PoseStamped):
|
||||||
|
now = time.monotonic()
|
||||||
|
position_3d = (
|
||||||
|
msg.pose.position.x,
|
||||||
|
msg.pose.position.y,
|
||||||
|
msg.pose.position.z
|
||||||
|
)
|
||||||
|
# Update position for unidentified person (from YOLOv8 tracker)
|
||||||
|
self._fuser.update_face(
|
||||||
|
face_id=-1,
|
||||||
|
name='unknown',
|
||||||
|
position_3d=position_3d,
|
||||||
|
camera_id=-1,
|
||||||
|
now=now
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_timer(self):
|
||||||
|
now = time.monotonic()
|
||||||
|
tracks = self._fuser.get_all_tracks()
|
||||||
|
|
||||||
|
# Update state machine for all tracks
|
||||||
|
for track in tracks:
|
||||||
|
track.update_state(
|
||||||
|
now,
|
||||||
|
engagement_distance=self._engagement_distance,
|
||||||
|
absent_timeout=self._absent_timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute engagement scores
|
||||||
|
for track in tracks:
|
||||||
|
if track.state == State.ABSENT:
|
||||||
|
track.engagement_score = 0.0
|
||||||
|
elif track.state == State.TALKING:
|
||||||
|
track.engagement_score = 1.0
|
||||||
|
elif track.state == State.ENGAGED:
|
||||||
|
track.engagement_score = max(0.0, 1.0 - track.distance / self._engagement_distance)
|
||||||
|
elif track.state == State.APPROACHING:
|
||||||
|
track.engagement_score = 0.3
|
||||||
|
elif track.state == State.LEAVING:
|
||||||
|
track.engagement_score = 0.1
|
||||||
|
else:
|
||||||
|
track.engagement_score = 0.0
|
||||||
|
|
||||||
|
# Prune long-absent tracks
|
||||||
|
self._fuser.prune_absent(self._prune_timeout)
|
||||||
|
|
||||||
|
# Compute attention target
|
||||||
|
attention_id = IdentityFuser.compute_attention(tracks)
|
||||||
|
|
||||||
|
# Build and publish PersonStateArray
|
||||||
|
msg = PersonStateArray()
|
||||||
|
msg.header.stamp = self.get_clock().now().to_msg()
|
||||||
|
msg.header.frame_id = 'base_link'
|
||||||
|
msg.primary_attention_id = attention_id
|
||||||
|
|
||||||
|
for track in tracks:
|
||||||
|
ps = PersonState()
|
||||||
|
ps.header.stamp = msg.header.stamp
|
||||||
|
ps.header.frame_id = 'base_link'
|
||||||
|
ps.person_id = track.person_id
|
||||||
|
ps.person_name = track.name
|
||||||
|
ps.face_id = track.face_id
|
||||||
|
ps.speaker_id = track.speaker_id
|
||||||
|
ps.uwb_anchor_id = track.uwb_anchor_id
|
||||||
|
if track.position is not None:
|
||||||
|
ps.position = Point(
|
||||||
|
x=float(track.position[0]),
|
||||||
|
y=float(track.position[1]),
|
||||||
|
z=float(track.position[2]))
|
||||||
|
ps.distance = float(track.distance)
|
||||||
|
ps.bearing_deg = float(track.bearing_deg)
|
||||||
|
ps.state = int(track.state)
|
||||||
|
ps.engagement_score = float(track.engagement_score)
|
||||||
|
ps.last_seen = self._mono_to_ros_time(track.last_seen)
|
||||||
|
ps.camera_id = track.camera_id
|
||||||
|
msg.persons.append(ps)
|
||||||
|
|
||||||
|
self._persons_pub.publish(msg)
|
||||||
|
|
||||||
|
# Publish attention target
|
||||||
|
att_msg = Int32()
|
||||||
|
att_msg.data = attention_id
|
||||||
|
self._attention_pub.publish(att_msg)
|
||||||
|
|
||||||
|
def _mono_to_ros_time(self, mono: float) -> TimeMsg:
|
||||||
|
"""Convert monotonic timestamp to approximate ROS time."""
|
||||||
|
# Offset from monotonic to wall clock
|
||||||
|
offset = time.time() - time.monotonic()
|
||||||
|
wall = mono + offset
|
||||||
|
sec = int(wall)
|
||||||
|
nsec = int((wall - sec) * 1e9)
|
||||||
|
t = TimeMsg()
|
||||||
|
t.sec = sec
|
||||||
|
t.nanosec = nsec
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = PersonStateTrackerNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
4
jetson/ros2_ws/src/saltybot_social/setup.cfg
Normal file
4
jetson/ros2_ws/src/saltybot_social/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
[develop]
|
||||||
|
script_dir=$base/lib/saltybot_social
|
||||||
|
[install]
|
||||||
|
install_scripts=$base/lib/saltybot_social
|
||||||
32
jetson/ros2_ws/src/saltybot_social/setup.py
Normal file
32
jetson/ros2_ws/src/saltybot_social/setup.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from setuptools import find_packages, setup
|
||||||
|
import os
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
package_name = 'saltybot_social'
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name=package_name,
|
||||||
|
version='0.1.0',
|
||||||
|
packages=find_packages(exclude=['test']),
|
||||||
|
data_files=[
|
||||||
|
('share/ament_index/resource_index/packages',
|
||||||
|
['resource/' + package_name]),
|
||||||
|
('share/' + package_name, ['package.xml']),
|
||||||
|
(os.path.join('share', package_name, 'launch'),
|
||||||
|
glob(os.path.join('launch', '*launch.[pxy][yma]*'))),
|
||||||
|
(os.path.join('share', package_name, 'config'),
|
||||||
|
glob(os.path.join('config', '*.yaml'))),
|
||||||
|
],
|
||||||
|
install_requires=['setuptools'],
|
||||||
|
zip_safe=True,
|
||||||
|
maintainer='seb',
|
||||||
|
maintainer_email='seb@vayrette.com',
|
||||||
|
description='Multi-modal person identity fusion and state tracking for saltybot',
|
||||||
|
license='MIT',
|
||||||
|
tests_require=['pytest'],
|
||||||
|
entry_points={
|
||||||
|
'console_scripts': [
|
||||||
|
'person_state_tracker = saltybot_social.person_state_tracker_node:main',
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
0
jetson/ros2_ws/src/saltybot_social/test/__init__.py
Normal file
0
jetson/ros2_ws/src/saltybot_social/test/__init__.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
face_recognizer:
|
||||||
|
ros__parameters:
|
||||||
|
scrfd_engine_path: '/mnt/nvme/saltybot/models/scrfd_2.5g.engine'
|
||||||
|
scrfd_onnx_path: '/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx'
|
||||||
|
arcface_engine_path: '/mnt/nvme/saltybot/models/arcface_r50.engine'
|
||||||
|
arcface_onnx_path: '/mnt/nvme/saltybot/models/arcface_r50.onnx'
|
||||||
|
gallery_dir: '/mnt/nvme/saltybot/gallery'
|
||||||
|
recognition_threshold: 0.35
|
||||||
|
publish_debug_image: false
|
||||||
|
max_faces: 10
|
||||||
|
scrfd_conf_thresh: 0.5
|
||||||
@ -0,0 +1,80 @@
|
|||||||
|
"""
|
||||||
|
face_recognition.launch.py -- Launch file for the SCRFD + ArcFace face recognition node.
|
||||||
|
|
||||||
|
Launches the face_recognizer node with configurable model paths and parameters.
|
||||||
|
The RealSense camera must be running separately (e.g., via realsense.launch.py).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from launch import LaunchDescription
|
||||||
|
from launch.actions import DeclareLaunchArgument
|
||||||
|
from launch.substitutions import LaunchConfiguration
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launch_description():
|
||||||
|
"""Generate launch description for face recognition pipeline."""
|
||||||
|
return LaunchDescription([
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
'scrfd_engine_path',
|
||||||
|
default_value='/mnt/nvme/saltybot/models/scrfd_2.5g.engine',
|
||||||
|
description='Path to SCRFD TensorRT engine file',
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
'scrfd_onnx_path',
|
||||||
|
default_value='/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx',
|
||||||
|
description='Path to SCRFD ONNX model file (fallback)',
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
'arcface_engine_path',
|
||||||
|
default_value='/mnt/nvme/saltybot/models/arcface_r50.engine',
|
||||||
|
description='Path to ArcFace TensorRT engine file',
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
'arcface_onnx_path',
|
||||||
|
default_value='/mnt/nvme/saltybot/models/arcface_r50.onnx',
|
||||||
|
description='Path to ArcFace ONNX model file (fallback)',
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
'gallery_dir',
|
||||||
|
default_value='/mnt/nvme/saltybot/gallery',
|
||||||
|
description='Directory for persistent face gallery storage',
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
'recognition_threshold',
|
||||||
|
default_value='0.35',
|
||||||
|
description='Cosine similarity threshold for face recognition',
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
'publish_debug_image',
|
||||||
|
default_value='false',
|
||||||
|
description='Publish annotated debug image to /social/faces/debug_image',
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
'max_faces',
|
||||||
|
default_value='10',
|
||||||
|
description='Maximum faces to process per frame',
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
'scrfd_conf_thresh',
|
||||||
|
default_value='0.5',
|
||||||
|
description='SCRFD detection confidence threshold',
|
||||||
|
),
|
||||||
|
|
||||||
|
Node(
|
||||||
|
package='saltybot_social_face',
|
||||||
|
executable='face_recognition',
|
||||||
|
name='face_recognizer',
|
||||||
|
output='screen',
|
||||||
|
parameters=[{
|
||||||
|
'scrfd_engine_path': LaunchConfiguration('scrfd_engine_path'),
|
||||||
|
'scrfd_onnx_path': LaunchConfiguration('scrfd_onnx_path'),
|
||||||
|
'arcface_engine_path': LaunchConfiguration('arcface_engine_path'),
|
||||||
|
'arcface_onnx_path': LaunchConfiguration('arcface_onnx_path'),
|
||||||
|
'gallery_dir': LaunchConfiguration('gallery_dir'),
|
||||||
|
'recognition_threshold': LaunchConfiguration('recognition_threshold'),
|
||||||
|
'publish_debug_image': LaunchConfiguration('publish_debug_image'),
|
||||||
|
'max_faces': LaunchConfiguration('max_faces'),
|
||||||
|
'scrfd_conf_thresh': LaunchConfiguration('scrfd_conf_thresh'),
|
||||||
|
}],
|
||||||
|
),
|
||||||
|
])
|
||||||
27
jetson/ros2_ws/src/saltybot_social_face/package.xml
Normal file
27
jetson/ros2_ws/src/saltybot_social_face/package.xml
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
<?xml version="1.0"?>
|
||||||
|
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||||
|
<package format="3">
|
||||||
|
<name>saltybot_social_face</name>
|
||||||
|
<version>0.1.0</version>
|
||||||
|
<description>SCRFD face detection and ArcFace recognition for SaltyBot social interactions</description>
|
||||||
|
<maintainer email="seb@vayrette.com">seb</maintainer>
|
||||||
|
<license>MIT</license>
|
||||||
|
|
||||||
|
<depend>rclpy</depend>
|
||||||
|
<depend>sensor_msgs</depend>
|
||||||
|
<depend>cv_bridge</depend>
|
||||||
|
<depend>image_transport</depend>
|
||||||
|
<depend>saltybot_social_msgs</depend>
|
||||||
|
|
||||||
|
<exec_depend>python3-numpy</exec_depend>
|
||||||
|
<exec_depend>python3-opencv</exec_depend>
|
||||||
|
|
||||||
|
<test_depend>ament_copyright</test_depend>
|
||||||
|
<test_depend>ament_flake8</test_depend>
|
||||||
|
<test_depend>ament_pep257</test_depend>
|
||||||
|
<test_depend>python3-pytest</test_depend>
|
||||||
|
|
||||||
|
<export>
|
||||||
|
<build_type>ament_python</build_type>
|
||||||
|
</export>
|
||||||
|
</package>
|
||||||
@ -0,0 +1 @@
|
|||||||
|
"""SaltyBot social face detection and recognition package."""
|
||||||
@ -0,0 +1,316 @@
|
|||||||
|
"""
|
||||||
|
arcface_recognizer.py -- ArcFace face embedding extraction and gallery matching.
|
||||||
|
|
||||||
|
Performs face alignment using 5-point landmarks (insightface standard reference),
|
||||||
|
extracts 512-dimensional embeddings via ArcFace (TRT FP16 or ONNX fallback),
|
||||||
|
and matches against a persistent gallery using cosine similarity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# InsightFace standard reference landmarks for 112x112 alignment
|
||||||
|
ARCFACE_SRC = np.array([
|
||||||
|
[38.2946, 51.6963], # left eye
|
||||||
|
[73.5318, 51.5014], # right eye
|
||||||
|
[56.0252, 71.7366], # nose
|
||||||
|
[41.5493, 92.3655], # left mouth
|
||||||
|
[70.7299, 92.2041], # right mouth
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
# -- Inference backends --------------------------------------------------------
|
||||||
|
|
||||||
|
class _TRTBackend:
|
||||||
|
"""TensorRT inference engine for ArcFace."""
|
||||||
|
|
||||||
|
def __init__(self, engine_path: str):
|
||||||
|
import tensorrt as trt
|
||||||
|
import pycuda.driver as cuda
|
||||||
|
import pycuda.autoinit # noqa: F401
|
||||||
|
|
||||||
|
self._cuda = cuda
|
||||||
|
trt_logger = trt.Logger(trt.Logger.WARNING)
|
||||||
|
with open(engine_path, 'rb') as f, trt.Runtime(trt_logger) as runtime:
|
||||||
|
self._engine = runtime.deserialize_cuda_engine(f.read())
|
||||||
|
self._context = self._engine.create_execution_context()
|
||||||
|
|
||||||
|
self._inputs = []
|
||||||
|
self._outputs = []
|
||||||
|
self._bindings = []
|
||||||
|
for i in range(self._engine.num_io_tensors):
|
||||||
|
name = self._engine.get_tensor_name(i)
|
||||||
|
dtype = trt.nptype(self._engine.get_tensor_dtype(name))
|
||||||
|
shape = tuple(self._engine.get_tensor_shape(name))
|
||||||
|
nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize
|
||||||
|
host_mem = cuda.pagelocked_empty(shape, dtype)
|
||||||
|
device_mem = cuda.mem_alloc(nbytes)
|
||||||
|
self._bindings.append(int(device_mem))
|
||||||
|
if self._engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
|
||||||
|
self._inputs.append({'host': host_mem, 'device': device_mem})
|
||||||
|
else:
|
||||||
|
self._outputs.append({'host': host_mem, 'device': device_mem,
|
||||||
|
'shape': shape})
|
||||||
|
self._stream = cuda.Stream()
|
||||||
|
|
||||||
|
def infer(self, input_data: np.ndarray) -> np.ndarray:
|
||||||
|
"""Run inference and return the embedding vector."""
|
||||||
|
np.copyto(self._inputs[0]['host'], input_data.ravel())
|
||||||
|
self._cuda.memcpy_htod_async(
|
||||||
|
self._inputs[0]['device'], self._inputs[0]['host'], self._stream)
|
||||||
|
self._context.execute_async_v2(self._bindings, self._stream.handle)
|
||||||
|
for out in self._outputs:
|
||||||
|
self._cuda.memcpy_dtoh_async(out['host'], out['device'], self._stream)
|
||||||
|
self._stream.synchronize()
|
||||||
|
return self._outputs[0]['host'].reshape(self._outputs[0]['shape']).copy()
|
||||||
|
|
||||||
|
|
||||||
|
class _ONNXBackend:
|
||||||
|
"""ONNX Runtime inference (CUDA EP with CPU fallback)."""
|
||||||
|
|
||||||
|
def __init__(self, onnx_path: str):
|
||||||
|
import onnxruntime as ort
|
||||||
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||||
|
self._session = ort.InferenceSession(onnx_path, providers=providers)
|
||||||
|
self._input_name = self._session.get_inputs()[0].name
|
||||||
|
|
||||||
|
def infer(self, input_data: np.ndarray) -> np.ndarray:
|
||||||
|
"""Run inference and return the embedding vector."""
|
||||||
|
results = self._session.run(None, {self._input_name: input_data})
|
||||||
|
return results[0]
|
||||||
|
|
||||||
|
|
||||||
|
# -- Face alignment ------------------------------------------------------------
|
||||||
|
|
||||||
|
def align_face(bgr: np.ndarray, landmarks_10: list[float]) -> np.ndarray:
|
||||||
|
"""Align a face to 112x112 using 5-point landmarks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bgr: Source BGR image.
|
||||||
|
landmarks_10: Flat list of 10 floats [x0,y0, x1,y1, ..., x4,y4].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Aligned BGR face crop of shape (112, 112, 3).
|
||||||
|
"""
|
||||||
|
src_pts = np.array(landmarks_10, dtype=np.float32).reshape(5, 2)
|
||||||
|
M, _ = cv2.estimateAffinePartial2D(src_pts, ARCFACE_SRC)
|
||||||
|
if M is None:
|
||||||
|
# Fallback: simple crop and resize from bbox-like region
|
||||||
|
cx = np.mean(src_pts[:, 0])
|
||||||
|
cy = np.mean(src_pts[:, 1])
|
||||||
|
spread = max(np.ptp(src_pts[:, 0]), np.ptp(src_pts[:, 1])) * 1.5
|
||||||
|
half = spread / 2
|
||||||
|
x1 = max(0, int(cx - half))
|
||||||
|
y1 = max(0, int(cy - half))
|
||||||
|
x2 = min(bgr.shape[1], int(cx + half))
|
||||||
|
y2 = min(bgr.shape[0], int(cy + half))
|
||||||
|
crop = bgr[y1:y2, x1:x2]
|
||||||
|
return cv2.resize(crop, (112, 112), interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
aligned = cv2.warpAffine(bgr, M, (112, 112), borderMode=cv2.BORDER_REPLICATE)
|
||||||
|
return aligned
|
||||||
|
|
||||||
|
|
||||||
|
# -- Main recognizer class -----------------------------------------------------
|
||||||
|
|
||||||
|
class ArcFaceRecognizer:
|
||||||
|
"""ArcFace face embedding extractor and gallery matcher.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
engine_path: Path to TensorRT engine file.
|
||||||
|
onnx_path: Path to ONNX model file (used if engine not available).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, engine_path: str = '', onnx_path: str = ''):
|
||||||
|
self._backend: Optional[_TRTBackend | _ONNXBackend] = None
|
||||||
|
self.gallery: dict[int, dict] = {}
|
||||||
|
|
||||||
|
# Try TRT first, then ONNX
|
||||||
|
if engine_path and os.path.isfile(engine_path):
|
||||||
|
try:
|
||||||
|
self._backend = _TRTBackend(engine_path)
|
||||||
|
logger.info('ArcFace TensorRT backend loaded: %s', engine_path)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning('ArcFace TRT load failed (%s), trying ONNX', e)
|
||||||
|
|
||||||
|
if onnx_path and os.path.isfile(onnx_path):
|
||||||
|
try:
|
||||||
|
self._backend = _ONNXBackend(onnx_path)
|
||||||
|
logger.info('ArcFace ONNX backend loaded: %s', onnx_path)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.error('ArcFace ONNX load failed: %s', e)
|
||||||
|
|
||||||
|
logger.error('No ArcFace model loaded. Recognition will be unavailable.')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_loaded(self) -> bool:
|
||||||
|
"""Return True if a backend is loaded and ready."""
|
||||||
|
return self._backend is not None
|
||||||
|
|
||||||
|
def embed(self, bgr_face_112x112: np.ndarray) -> np.ndarray:
|
||||||
|
"""Extract 512-dim L2-normalized embedding from a 112x112 aligned face.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bgr_face_112x112: Aligned face crop, BGR, shape (112, 112, 3).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
L2-normalized embedding of shape (512,).
|
||||||
|
"""
|
||||||
|
if self._backend is None:
|
||||||
|
return np.zeros(512, dtype=np.float32)
|
||||||
|
|
||||||
|
# Preprocess: BGR->RGB, /255, subtract 0.5, divide 0.5 -> [1,3,112,112]
|
||||||
|
rgb = cv2.cvtColor(bgr_face_112x112, cv2.COLOR_BGR2RGB).astype(np.float32)
|
||||||
|
rgb = rgb / 255.0
|
||||||
|
rgb = (rgb - 0.5) / 0.5
|
||||||
|
blob = rgb.transpose(2, 0, 1)[np.newaxis] # [1, 3, 112, 112]
|
||||||
|
blob = np.ascontiguousarray(blob)
|
||||||
|
|
||||||
|
output = self._backend.infer(blob)
|
||||||
|
embedding = output.flatten()[:512].astype(np.float32)
|
||||||
|
|
||||||
|
# L2 normalize
|
||||||
|
norm = np.linalg.norm(embedding)
|
||||||
|
if norm > 0:
|
||||||
|
embedding = embedding / norm
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def align_and_embed(self, bgr_image: np.ndarray, landmarks_10: list[float]) -> np.ndarray:
|
||||||
|
"""Align face using landmarks and extract embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bgr_image: Full BGR image.
|
||||||
|
landmarks_10: Flat list of 10 floats from SCRFD detection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
L2-normalized embedding of shape (512,).
|
||||||
|
"""
|
||||||
|
aligned = align_face(bgr_image, landmarks_10)
|
||||||
|
return self.embed(aligned)
|
||||||
|
|
||||||
|
def load_gallery(self, gallery_path: str) -> None:
|
||||||
|
"""Load gallery from .npz file with JSON metadata sidecar.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gallery_path: Path to the .npz gallery file.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
if not os.path.isfile(gallery_path):
|
||||||
|
logger.info('No gallery file at %s, starting empty.', gallery_path)
|
||||||
|
self.gallery = {}
|
||||||
|
return
|
||||||
|
|
||||||
|
data = np.load(gallery_path, allow_pickle=False)
|
||||||
|
meta_path = gallery_path.replace('.npz', '_meta.json')
|
||||||
|
|
||||||
|
if os.path.isfile(meta_path):
|
||||||
|
with open(meta_path, 'r') as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
else:
|
||||||
|
meta = {}
|
||||||
|
|
||||||
|
self.gallery = {}
|
||||||
|
for key in data.files:
|
||||||
|
pid = int(key)
|
||||||
|
embedding = data[key].astype(np.float32)
|
||||||
|
norm = np.linalg.norm(embedding)
|
||||||
|
if norm > 0:
|
||||||
|
embedding = embedding / norm
|
||||||
|
info = meta.get(str(pid), {})
|
||||||
|
self.gallery[pid] = {
|
||||||
|
'name': info.get('name', f'person_{pid}'),
|
||||||
|
'embedding': embedding,
|
||||||
|
'samples': info.get('samples', 1),
|
||||||
|
'enrolled_at': info.get('enrolled_at', 0.0),
|
||||||
|
}
|
||||||
|
logger.info('Gallery loaded: %d persons from %s', len(self.gallery), gallery_path)
|
||||||
|
|
||||||
|
def save_gallery(self, gallery_path: str) -> None:
|
||||||
|
"""Save gallery to .npz file with JSON metadata sidecar.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gallery_path: Path to the .npz gallery file.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
arrays = {}
|
||||||
|
meta = {}
|
||||||
|
for pid, info in self.gallery.items():
|
||||||
|
arrays[str(pid)] = info['embedding']
|
||||||
|
meta[str(pid)] = {
|
||||||
|
'name': info['name'],
|
||||||
|
'samples': info['samples'],
|
||||||
|
'enrolled_at': info['enrolled_at'],
|
||||||
|
}
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(gallery_path) or '.', exist_ok=True)
|
||||||
|
np.savez(gallery_path, **arrays)
|
||||||
|
|
||||||
|
meta_path = gallery_path.replace('.npz', '_meta.json')
|
||||||
|
with open(meta_path, 'w') as f:
|
||||||
|
json.dump(meta, f, indent=2)
|
||||||
|
logger.info('Gallery saved: %d persons to %s', len(self.gallery), gallery_path)
|
||||||
|
|
||||||
|
def match(self, embedding: np.ndarray, threshold: float = 0.35) -> tuple[int, str, float]:
|
||||||
|
"""Match an embedding against the gallery.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding: L2-normalized query embedding of shape (512,).
|
||||||
|
threshold: Minimum cosine similarity for a match.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(person_id, person_name, score) or (-1, '', 0.0) if no match.
|
||||||
|
"""
|
||||||
|
if not self.gallery:
|
||||||
|
return (-1, '', 0.0)
|
||||||
|
|
||||||
|
best_pid = -1
|
||||||
|
best_name = ''
|
||||||
|
best_score = 0.0
|
||||||
|
|
||||||
|
for pid, info in self.gallery.items():
|
||||||
|
score = float(np.dot(embedding, info['embedding']))
|
||||||
|
if score > best_score:
|
||||||
|
best_score = score
|
||||||
|
best_pid = pid
|
||||||
|
best_name = info['name']
|
||||||
|
|
||||||
|
if best_score >= threshold:
|
||||||
|
return (best_pid, best_name, best_score)
|
||||||
|
return (-1, '', 0.0)
|
||||||
|
|
||||||
|
def enroll(self, person_id: int, person_name: str, embeddings_list: list[np.ndarray]) -> None:
|
||||||
|
"""Enroll a person by averaging multiple embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
person_id: Unique integer ID for this person.
|
||||||
|
person_name: Human-readable name.
|
||||||
|
embeddings_list: List of L2-normalized embeddings to average.
|
||||||
|
"""
|
||||||
|
import time as _time
|
||||||
|
|
||||||
|
if not embeddings_list:
|
||||||
|
return
|
||||||
|
|
||||||
|
mean_emb = np.mean(embeddings_list, axis=0).astype(np.float32)
|
||||||
|
norm = np.linalg.norm(mean_emb)
|
||||||
|
if norm > 0:
|
||||||
|
mean_emb = mean_emb / norm
|
||||||
|
|
||||||
|
self.gallery[person_id] = {
|
||||||
|
'name': person_name,
|
||||||
|
'embedding': mean_emb,
|
||||||
|
'samples': len(embeddings_list),
|
||||||
|
'enrolled_at': _time.time(),
|
||||||
|
}
|
||||||
|
logger.info('Enrolled person %d (%s) with %d samples.',
|
||||||
|
person_id, person_name, len(embeddings_list))
|
||||||
@ -0,0 +1,78 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
enrollment_cli.py -- CLI tool for enrolling persons via the /social/enroll service.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ros2 run saltybot_social_face enrollment_cli -- --name Alice --mode face --samples 10
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
|
||||||
|
from saltybot_social_msgs.srv import EnrollPerson
|
||||||
|
|
||||||
|
|
||||||
|
class EnrollmentCLI(Node):
|
||||||
|
"""Simple CLI node that calls the EnrollPerson service."""
|
||||||
|
|
||||||
|
def __init__(self, name: str, mode: str, n_samples: int):
|
||||||
|
super().__init__('enrollment_cli')
|
||||||
|
self._client = self.create_client(EnrollPerson, '/social/enroll')
|
||||||
|
|
||||||
|
self.get_logger().info('Waiting for /social/enroll service...')
|
||||||
|
if not self._client.wait_for_service(timeout_sec=10.0):
|
||||||
|
self.get_logger().error('Service /social/enroll not available.')
|
||||||
|
return
|
||||||
|
|
||||||
|
request = EnrollPerson.Request()
|
||||||
|
request.name = name
|
||||||
|
request.mode = mode
|
||||||
|
request.n_samples = n_samples
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
'Enrolling "%s" (mode=%s, samples=%d)...', name, mode, n_samples)
|
||||||
|
|
||||||
|
future = self._client.call_async(request)
|
||||||
|
rclpy.spin_until_future_complete(self, future, timeout_sec=120.0)
|
||||||
|
|
||||||
|
if future.result() is not None:
|
||||||
|
result = future.result()
|
||||||
|
if result.success:
|
||||||
|
self.get_logger().info(
|
||||||
|
'Enrollment successful: person_id=%d, %s',
|
||||||
|
result.person_id, result.message)
|
||||||
|
else:
|
||||||
|
self.get_logger().error(
|
||||||
|
'Enrollment failed: %s', result.message)
|
||||||
|
else:
|
||||||
|
self.get_logger().error('Enrollment service call timed out or failed.')
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
"""Entry point for enrollment CLI."""
|
||||||
|
parser = argparse.ArgumentParser(description='Enroll a person for face recognition.')
|
||||||
|
parser.add_argument('--name', type=str, required=True,
|
||||||
|
help='Name of the person to enroll.')
|
||||||
|
parser.add_argument('--mode', type=str, default='face',
|
||||||
|
choices=['face', 'voice', 'both'],
|
||||||
|
help='Enrollment mode (default: face).')
|
||||||
|
parser.add_argument('--samples', type=int, default=10,
|
||||||
|
help='Number of face samples to collect (default: 10).')
|
||||||
|
|
||||||
|
# Parse only known args so ROS2 remapping args pass through
|
||||||
|
parsed, remaining = parser.parse_known_args(args=sys.argv[1:])
|
||||||
|
|
||||||
|
rclpy.init(args=remaining)
|
||||||
|
node = EnrollmentCLI(parsed.name, parsed.mode, parsed.samples)
|
||||||
|
try:
|
||||||
|
pass # Node does all work in __init__
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -0,0 +1,206 @@
|
|||||||
|
"""
|
||||||
|
face_gallery.py -- Persistent face embedding gallery backed by numpy .npz + JSON.
|
||||||
|
|
||||||
|
Thread-safe gallery storage for face recognition. Embeddings are stored in a
|
||||||
|
.npz file, with a sidecar metadata.json containing names, sample counts, and
|
||||||
|
enrollment timestamps. Auto-increment IDs start at 1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FaceGallery:
|
||||||
|
"""Persistent, thread-safe face embedding gallery.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gallery_dir: Directory for gallery.npz and metadata.json files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, gallery_dir: str):
|
||||||
|
self._gallery_dir = gallery_dir
|
||||||
|
self._npz_path = os.path.join(gallery_dir, 'gallery.npz')
|
||||||
|
self._meta_path = os.path.join(gallery_dir, 'metadata.json')
|
||||||
|
self._gallery: dict[int, dict] = {}
|
||||||
|
self._next_id = 1
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def load(self) -> None:
|
||||||
|
"""Load gallery from disk. Populates internal gallery dict."""
|
||||||
|
with self._lock:
|
||||||
|
self._gallery = {}
|
||||||
|
self._next_id = 1
|
||||||
|
|
||||||
|
if not os.path.isfile(self._npz_path):
|
||||||
|
logger.info('No gallery file at %s, starting empty.', self._npz_path)
|
||||||
|
return
|
||||||
|
|
||||||
|
data = np.load(self._npz_path, allow_pickle=False)
|
||||||
|
meta: dict = {}
|
||||||
|
if os.path.isfile(self._meta_path):
|
||||||
|
with open(self._meta_path, 'r') as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
for key in data.files:
|
||||||
|
pid = int(key)
|
||||||
|
embedding = data[key].astype(np.float32)
|
||||||
|
norm = np.linalg.norm(embedding)
|
||||||
|
if norm > 0:
|
||||||
|
embedding = embedding / norm
|
||||||
|
info = meta.get(str(pid), {})
|
||||||
|
self._gallery[pid] = {
|
||||||
|
'name': info.get('name', f'person_{pid}'),
|
||||||
|
'embedding': embedding,
|
||||||
|
'samples': info.get('samples', 1),
|
||||||
|
'enrolled_at': info.get('enrolled_at', 0.0),
|
||||||
|
}
|
||||||
|
if pid >= self._next_id:
|
||||||
|
self._next_id = pid + 1
|
||||||
|
|
||||||
|
logger.info('Gallery loaded: %d persons from %s',
|
||||||
|
len(self._gallery), self._npz_path)
|
||||||
|
|
||||||
|
def save(self) -> None:
|
||||||
|
"""Save gallery to disk (npz + JSON sidecar)."""
|
||||||
|
with self._lock:
|
||||||
|
os.makedirs(self._gallery_dir, exist_ok=True)
|
||||||
|
|
||||||
|
arrays = {}
|
||||||
|
meta = {}
|
||||||
|
for pid, info in self._gallery.items():
|
||||||
|
arrays[str(pid)] = info['embedding']
|
||||||
|
meta[str(pid)] = {
|
||||||
|
'name': info['name'],
|
||||||
|
'samples': info['samples'],
|
||||||
|
'enrolled_at': info['enrolled_at'],
|
||||||
|
}
|
||||||
|
|
||||||
|
np.savez(self._npz_path, **arrays)
|
||||||
|
with open(self._meta_path, 'w') as f:
|
||||||
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
|
logger.info('Gallery saved: %d persons to %s',
|
||||||
|
len(self._gallery), self._npz_path)
|
||||||
|
|
||||||
|
def add_person(self, name: str, embedding: np.ndarray, samples: int = 1) -> int:
|
||||||
|
"""Add a new person to the gallery.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Person's name.
|
||||||
|
embedding: L2-normalized 512-dim embedding.
|
||||||
|
samples: Number of samples used to compute the embedding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Assigned person_id (auto-increment integer).
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
pid = self._next_id
|
||||||
|
self._next_id += 1
|
||||||
|
emb = embedding.astype(np.float32)
|
||||||
|
norm = np.linalg.norm(emb)
|
||||||
|
if norm > 0:
|
||||||
|
emb = emb / norm
|
||||||
|
self._gallery[pid] = {
|
||||||
|
'name': name,
|
||||||
|
'embedding': emb,
|
||||||
|
'samples': samples,
|
||||||
|
'enrolled_at': time.time(),
|
||||||
|
}
|
||||||
|
logger.info('Added person %d (%s), %d samples.', pid, name, samples)
|
||||||
|
return pid
|
||||||
|
|
||||||
|
def update_name(self, person_id: int, new_name: str) -> bool:
|
||||||
|
"""Update a person's name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
person_id: The ID of the person to update.
|
||||||
|
new_name: New name string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the person was found and updated.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if person_id not in self._gallery:
|
||||||
|
return False
|
||||||
|
self._gallery[person_id]['name'] = new_name
|
||||||
|
return True
|
||||||
|
|
||||||
|
def delete_person(self, person_id: int) -> bool:
|
||||||
|
"""Remove a person from the gallery.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
person_id: The ID of the person to delete.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the person was found and removed.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if person_id not in self._gallery:
|
||||||
|
return False
|
||||||
|
del self._gallery[person_id]
|
||||||
|
logger.info('Deleted person %d.', person_id)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_all(self) -> list[dict]:
|
||||||
|
"""Get all gallery entries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with keys: person_id, name, embedding, samples, enrolled_at.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
result = []
|
||||||
|
for pid, info in self._gallery.items():
|
||||||
|
result.append({
|
||||||
|
'person_id': pid,
|
||||||
|
'name': info['name'],
|
||||||
|
'embedding': info['embedding'].copy(),
|
||||||
|
'samples': info['samples'],
|
||||||
|
'enrolled_at': info['enrolled_at'],
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
|
||||||
|
def match(self, query_embedding: np.ndarray, threshold: float = 0.35) -> tuple[int, str, float]:
|
||||||
|
"""Match a query embedding against the gallery using cosine similarity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_embedding: L2-normalized 512-dim embedding.
|
||||||
|
threshold: Minimum cosine similarity for a match.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(person_id, name, score) or (-1, '', 0.0) if no match.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if not self._gallery:
|
||||||
|
return (-1, '', 0.0)
|
||||||
|
|
||||||
|
best_pid = -1
|
||||||
|
best_name = ''
|
||||||
|
best_score = 0.0
|
||||||
|
|
||||||
|
query = query_embedding.astype(np.float32)
|
||||||
|
norm = np.linalg.norm(query)
|
||||||
|
if norm > 0:
|
||||||
|
query = query / norm
|
||||||
|
|
||||||
|
for pid, info in self._gallery.items():
|
||||||
|
score = float(np.dot(query, info['embedding']))
|
||||||
|
if score > best_score:
|
||||||
|
best_score = score
|
||||||
|
best_pid = pid
|
||||||
|
best_name = info['name']
|
||||||
|
|
||||||
|
if best_score >= threshold:
|
||||||
|
return (best_pid, best_name, best_score)
|
||||||
|
return (-1, '', 0.0)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
with self._lock:
|
||||||
|
return len(self._gallery)
|
||||||
@ -0,0 +1,431 @@
|
|||||||
|
"""
|
||||||
|
face_recognition_node.py -- ROS2 node for SCRFD face detection + ArcFace recognition.
|
||||||
|
|
||||||
|
Pipeline:
|
||||||
|
1. Subscribe to /camera/color/image_raw (RealSense D435i color stream).
|
||||||
|
2. Run SCRFD face detection (TensorRT FP16 or ONNX fallback).
|
||||||
|
3. For each detected face, align and extract ArcFace embedding.
|
||||||
|
4. Match embedding against persistent gallery.
|
||||||
|
5. Publish FaceDetectionArray with identified faces.
|
||||||
|
|
||||||
|
Services:
|
||||||
|
/social/enroll -- Enroll a new person (collects N face samples).
|
||||||
|
/social/persons/list -- List all enrolled persons.
|
||||||
|
/social/persons/delete -- Delete a person from the gallery.
|
||||||
|
/social/persons/update -- Update a person's name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy, DurabilityPolicy
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
from cv_bridge import CvBridge
|
||||||
|
|
||||||
|
from sensor_msgs.msg import Image
|
||||||
|
from builtin_interfaces.msg import Time
|
||||||
|
|
||||||
|
from saltybot_social_msgs.msg import (
|
||||||
|
FaceDetection,
|
||||||
|
FaceDetectionArray,
|
||||||
|
FaceEmbedding,
|
||||||
|
FaceEmbeddingArray,
|
||||||
|
)
|
||||||
|
from saltybot_social_msgs.srv import (
|
||||||
|
EnrollPerson,
|
||||||
|
ListPersons,
|
||||||
|
DeletePerson,
|
||||||
|
UpdatePerson,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .scrfd_detector import SCRFDDetector
|
||||||
|
from .arcface_recognizer import ArcFaceRecognizer
|
||||||
|
from .face_gallery import FaceGallery
|
||||||
|
|
||||||
|
|
||||||
|
class FaceRecognitionNode(Node):
|
||||||
|
"""ROS2 node: SCRFD face detection + ArcFace gallery matching."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('face_recognizer')
|
||||||
|
self._bridge = CvBridge()
|
||||||
|
self._frame_count = 0
|
||||||
|
self._fps_t0 = time.monotonic()
|
||||||
|
|
||||||
|
# -- Parameters --------------------------------------------------------
|
||||||
|
self.declare_parameter('scrfd_engine_path',
|
||||||
|
'/mnt/nvme/saltybot/models/scrfd_2.5g.engine')
|
||||||
|
self.declare_parameter('scrfd_onnx_path',
|
||||||
|
'/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx')
|
||||||
|
self.declare_parameter('arcface_engine_path',
|
||||||
|
'/mnt/nvme/saltybot/models/arcface_r50.engine')
|
||||||
|
self.declare_parameter('arcface_onnx_path',
|
||||||
|
'/mnt/nvme/saltybot/models/arcface_r50.onnx')
|
||||||
|
self.declare_parameter('gallery_dir', '/mnt/nvme/saltybot/gallery')
|
||||||
|
self.declare_parameter('recognition_threshold', 0.35)
|
||||||
|
self.declare_parameter('publish_debug_image', False)
|
||||||
|
self.declare_parameter('max_faces', 10)
|
||||||
|
self.declare_parameter('scrfd_conf_thresh', 0.5)
|
||||||
|
|
||||||
|
self._recognition_threshold = self.get_parameter('recognition_threshold').value
|
||||||
|
self._pub_debug_flag = self.get_parameter('publish_debug_image').value
|
||||||
|
self._max_faces = self.get_parameter('max_faces').value
|
||||||
|
|
||||||
|
# -- Models ------------------------------------------------------------
|
||||||
|
self._detector = SCRFDDetector(
|
||||||
|
engine_path=self.get_parameter('scrfd_engine_path').value,
|
||||||
|
onnx_path=self.get_parameter('scrfd_onnx_path').value,
|
||||||
|
conf_thresh=self.get_parameter('scrfd_conf_thresh').value,
|
||||||
|
)
|
||||||
|
self._recognizer = ArcFaceRecognizer(
|
||||||
|
engine_path=self.get_parameter('arcface_engine_path').value,
|
||||||
|
onnx_path=self.get_parameter('arcface_onnx_path').value,
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- Gallery -----------------------------------------------------------
|
||||||
|
gallery_dir = self.get_parameter('gallery_dir').value
|
||||||
|
self._gallery = FaceGallery(gallery_dir)
|
||||||
|
self._gallery.load()
|
||||||
|
self.get_logger().info('Gallery loaded: %d persons.', len(self._gallery))
|
||||||
|
|
||||||
|
# -- Enrollment state --------------------------------------------------
|
||||||
|
self._enrolling = None # {name, samples_needed, collected: [embeddings]}
|
||||||
|
|
||||||
|
# -- QoS profiles ------------------------------------------------------
|
||||||
|
best_effort_qos = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=1,
|
||||||
|
)
|
||||||
|
reliable_qos = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.RELIABLE,
|
||||||
|
durability=DurabilityPolicy.TRANSIENT_LOCAL,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- Subscribers -------------------------------------------------------
|
||||||
|
self.create_subscription(
|
||||||
|
Image,
|
||||||
|
'/camera/color/image_raw',
|
||||||
|
self._on_image,
|
||||||
|
best_effort_qos,
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- Publishers --------------------------------------------------------
|
||||||
|
self._pub_detections = self.create_publisher(
|
||||||
|
FaceDetectionArray, '/social/faces/detections', best_effort_qos)
|
||||||
|
self._pub_embeddings = self.create_publisher(
|
||||||
|
FaceEmbeddingArray, '/social/faces/embeddings', reliable_qos)
|
||||||
|
if self._pub_debug_flag:
|
||||||
|
self._pub_debug_img = self.create_publisher(
|
||||||
|
Image, '/social/faces/debug_image', best_effort_qos)
|
||||||
|
|
||||||
|
# -- Services ----------------------------------------------------------
|
||||||
|
self.create_service(EnrollPerson, '/social/enroll', self._handle_enroll)
|
||||||
|
self.create_service(ListPersons, '/social/persons/list', self._handle_list)
|
||||||
|
self.create_service(DeletePerson, '/social/persons/delete', self._handle_delete)
|
||||||
|
self.create_service(UpdatePerson, '/social/persons/update', self._handle_update)
|
||||||
|
|
||||||
|
# Publish initial gallery state
|
||||||
|
self._publish_gallery_embeddings()
|
||||||
|
|
||||||
|
self.get_logger().info('FaceRecognitionNode ready.')
|
||||||
|
|
||||||
|
# -- Image callback --------------------------------------------------------
|
||||||
|
|
||||||
|
def _on_image(self, msg: Image):
|
||||||
|
"""Process incoming camera frame: detect, recognize, publish."""
|
||||||
|
try:
|
||||||
|
bgr = self._bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error('Image decode error: %s', str(e),
|
||||||
|
throttle_duration_sec=5.0)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Detect faces
|
||||||
|
detections = self._detector.detect(bgr)
|
||||||
|
|
||||||
|
# Limit face count
|
||||||
|
if len(detections) > self._max_faces:
|
||||||
|
detections = sorted(detections, key=lambda d: d['score'], reverse=True)
|
||||||
|
detections = detections[:self._max_faces]
|
||||||
|
|
||||||
|
# Build output message
|
||||||
|
det_array = FaceDetectionArray()
|
||||||
|
det_array.header = msg.header
|
||||||
|
|
||||||
|
for det in detections:
|
||||||
|
# Extract embedding and match gallery
|
||||||
|
embedding = self._recognizer.align_and_embed(bgr, det['kps'])
|
||||||
|
pid, pname, score = self._gallery.match(
|
||||||
|
embedding, self._recognition_threshold)
|
||||||
|
|
||||||
|
# Handle enrollment: collect embedding from largest face
|
||||||
|
if self._enrolling is not None:
|
||||||
|
self._enrollment_collect(det, embedding)
|
||||||
|
|
||||||
|
# Build FaceDetection message
|
||||||
|
face_msg = FaceDetection()
|
||||||
|
face_msg.header = msg.header
|
||||||
|
face_msg.face_id = pid
|
||||||
|
face_msg.person_name = pname
|
||||||
|
face_msg.confidence = det['score']
|
||||||
|
face_msg.recognition_score = score
|
||||||
|
|
||||||
|
bbox = det['bbox']
|
||||||
|
face_msg.bbox_x = bbox[0]
|
||||||
|
face_msg.bbox_y = bbox[1]
|
||||||
|
face_msg.bbox_w = bbox[2] - bbox[0]
|
||||||
|
face_msg.bbox_h = bbox[3] - bbox[1]
|
||||||
|
|
||||||
|
kps = det['kps']
|
||||||
|
for i in range(10):
|
||||||
|
face_msg.landmarks[i] = kps[i]
|
||||||
|
|
||||||
|
det_array.faces.append(face_msg)
|
||||||
|
|
||||||
|
self._pub_detections.publish(det_array)
|
||||||
|
|
||||||
|
# Debug image
|
||||||
|
if self._pub_debug_flag and hasattr(self, '_pub_debug_img'):
|
||||||
|
debug_img = self._draw_debug(bgr, detections, det_array.faces)
|
||||||
|
self._pub_debug_img.publish(
|
||||||
|
self._bridge.cv2_to_imgmsg(debug_img, encoding='bgr8'))
|
||||||
|
|
||||||
|
# FPS logging
|
||||||
|
self._frame_count += 1
|
||||||
|
if self._frame_count % 30 == 0:
|
||||||
|
elapsed = time.monotonic() - self._fps_t0
|
||||||
|
fps = 30.0 / elapsed if elapsed > 0 else 0.0
|
||||||
|
self._fps_t0 = time.monotonic()
|
||||||
|
self.get_logger().info(
|
||||||
|
'FPS: %.1f | faces: %d', fps, len(detections))
|
||||||
|
|
||||||
|
# -- Enrollment logic ------------------------------------------------------
|
||||||
|
|
||||||
|
def _enrollment_collect(self, det: dict, embedding: np.ndarray):
|
||||||
|
"""Collect an embedding sample during enrollment (largest face only)."""
|
||||||
|
if self._enrolling is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Only collect from the largest face (by bbox area)
|
||||||
|
bbox = det['bbox']
|
||||||
|
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
||||||
|
|
||||||
|
if not hasattr(self, '_enroll_best_area'):
|
||||||
|
self._enroll_best_area = 0.0
|
||||||
|
self._enroll_best_embedding = None
|
||||||
|
|
||||||
|
if area > self._enroll_best_area:
|
||||||
|
self._enroll_best_area = area
|
||||||
|
self._enroll_best_embedding = embedding
|
||||||
|
|
||||||
|
def _enrollment_frame_end(self):
|
||||||
|
"""Called at end of each frame to finalize enrollment sample collection."""
|
||||||
|
if self._enrolling is None or self._enroll_best_embedding is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._enrolling['collected'].append(self._enroll_best_embedding)
|
||||||
|
self._enroll_best_area = 0.0
|
||||||
|
self._enroll_best_embedding = None
|
||||||
|
|
||||||
|
collected = len(self._enrolling['collected'])
|
||||||
|
needed = self._enrolling['samples_needed']
|
||||||
|
self.get_logger().info('Enrollment: %d/%d samples for "%s".',
|
||||||
|
collected, needed, self._enrolling['name'])
|
||||||
|
|
||||||
|
if collected >= needed:
|
||||||
|
# Finalize enrollment
|
||||||
|
name = self._enrolling['name']
|
||||||
|
embeddings = self._enrolling['collected']
|
||||||
|
mean_emb = np.mean(embeddings, axis=0).astype(np.float32)
|
||||||
|
norm = np.linalg.norm(mean_emb)
|
||||||
|
if norm > 0:
|
||||||
|
mean_emb = mean_emb / norm
|
||||||
|
|
||||||
|
pid = self._gallery.add_person(name, mean_emb, samples=len(embeddings))
|
||||||
|
self._gallery.save()
|
||||||
|
self._publish_gallery_embeddings()
|
||||||
|
|
||||||
|
self.get_logger().info('Enrollment complete: person %d (%s).', pid, name)
|
||||||
|
|
||||||
|
# Store result for the service callback
|
||||||
|
self._enrolling['result_pid'] = pid
|
||||||
|
self._enrolling['done'] = True
|
||||||
|
self._enrolling = None
|
||||||
|
|
||||||
|
# -- Service handlers ------------------------------------------------------
|
||||||
|
|
||||||
|
def _handle_enroll(self, request, response):
|
||||||
|
"""Handle EnrollPerson service: start collecting face samples."""
|
||||||
|
name = request.name.strip()
|
||||||
|
if not name:
|
||||||
|
response.success = False
|
||||||
|
response.message = 'Name cannot be empty.'
|
||||||
|
response.person_id = -1
|
||||||
|
return response
|
||||||
|
|
||||||
|
n_samples = request.n_samples if request.n_samples > 0 else 10
|
||||||
|
|
||||||
|
self.get_logger().info('Starting enrollment for "%s" (%d samples).',
|
||||||
|
name, n_samples)
|
||||||
|
|
||||||
|
# Set enrollment state — frames will collect embeddings
|
||||||
|
self._enrolling = {
|
||||||
|
'name': name,
|
||||||
|
'samples_needed': n_samples,
|
||||||
|
'collected': [],
|
||||||
|
'done': False,
|
||||||
|
'result_pid': -1,
|
||||||
|
}
|
||||||
|
self._enroll_best_area = 0.0
|
||||||
|
self._enroll_best_embedding = None
|
||||||
|
|
||||||
|
# Spin until enrollment is done (blocking service)
|
||||||
|
rate = self.create_rate(10) # 10 Hz check
|
||||||
|
timeout_sec = n_samples * 2.0 + 10.0 # generous timeout
|
||||||
|
t0 = time.monotonic()
|
||||||
|
|
||||||
|
while not self._enrolling.get('done', False):
|
||||||
|
# Finalize any pending frame collection
|
||||||
|
self._enrollment_frame_end()
|
||||||
|
|
||||||
|
if time.monotonic() - t0 > timeout_sec:
|
||||||
|
self._enrolling = None
|
||||||
|
response.success = False
|
||||||
|
response.message = f'Enrollment timed out after {timeout_sec:.0f}s.'
|
||||||
|
response.person_id = -1
|
||||||
|
return response
|
||||||
|
|
||||||
|
rclpy.spin_once(self, timeout_sec=0.1)
|
||||||
|
|
||||||
|
response.success = True
|
||||||
|
response.message = f'Enrolled "{name}" with {n_samples} samples.'
|
||||||
|
response.person_id = self._enrolling.get('result_pid', -1) if self._enrolling else -1
|
||||||
|
|
||||||
|
# Clean up (already set to None in _enrollment_frame_end on success)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _handle_list(self, request, response):
|
||||||
|
"""Handle ListPersons service: return all gallery entries."""
|
||||||
|
entries = self._gallery.get_all()
|
||||||
|
for entry in entries:
|
||||||
|
emb_msg = FaceEmbedding()
|
||||||
|
emb_msg.person_id = entry['person_id']
|
||||||
|
emb_msg.person_name = entry['name']
|
||||||
|
emb_msg.embedding = entry['embedding'].tolist()
|
||||||
|
emb_msg.sample_count = entry['samples']
|
||||||
|
|
||||||
|
secs = int(entry['enrolled_at'])
|
||||||
|
nsecs = int((entry['enrolled_at'] - secs) * 1e9)
|
||||||
|
emb_msg.enrolled_at = Time(sec=secs, nanosec=nsecs)
|
||||||
|
|
||||||
|
response.persons.append(emb_msg)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _handle_delete(self, request, response):
|
||||||
|
"""Handle DeletePerson service: remove a person from the gallery."""
|
||||||
|
if self._gallery.delete_person(request.person_id):
|
||||||
|
self._gallery.save()
|
||||||
|
self._publish_gallery_embeddings()
|
||||||
|
response.success = True
|
||||||
|
response.message = f'Deleted person {request.person_id}.'
|
||||||
|
else:
|
||||||
|
response.success = False
|
||||||
|
response.message = f'Person {request.person_id} not found.'
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _handle_update(self, request, response):
|
||||||
|
"""Handle UpdatePerson service: rename a person."""
|
||||||
|
new_name = request.new_name.strip()
|
||||||
|
if not new_name:
|
||||||
|
response.success = False
|
||||||
|
response.message = 'New name cannot be empty.'
|
||||||
|
return response
|
||||||
|
|
||||||
|
if self._gallery.update_name(request.person_id, new_name):
|
||||||
|
self._gallery.save()
|
||||||
|
self._publish_gallery_embeddings()
|
||||||
|
response.success = True
|
||||||
|
response.message = f'Updated person {request.person_id} to "{new_name}".'
|
||||||
|
else:
|
||||||
|
response.success = False
|
||||||
|
response.message = f'Person {request.person_id} not found.'
|
||||||
|
return response
|
||||||
|
|
||||||
|
# -- Gallery publishing ----------------------------------------------------
|
||||||
|
|
||||||
|
def _publish_gallery_embeddings(self):
|
||||||
|
"""Publish current gallery as FaceEmbeddingArray (latched-like)."""
|
||||||
|
entries = self._gallery.get_all()
|
||||||
|
msg = FaceEmbeddingArray()
|
||||||
|
msg.header.stamp = self.get_clock().now().to_msg()
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
emb_msg = FaceEmbedding()
|
||||||
|
emb_msg.person_id = entry['person_id']
|
||||||
|
emb_msg.person_name = entry['name']
|
||||||
|
emb_msg.embedding = entry['embedding'].tolist()
|
||||||
|
emb_msg.sample_count = entry['samples']
|
||||||
|
|
||||||
|
secs = int(entry['enrolled_at'])
|
||||||
|
nsecs = int((entry['enrolled_at'] - secs) * 1e9)
|
||||||
|
emb_msg.enrolled_at = Time(sec=secs, nanosec=nsecs)
|
||||||
|
|
||||||
|
msg.embeddings.append(emb_msg)
|
||||||
|
|
||||||
|
self._pub_embeddings.publish(msg)
|
||||||
|
|
||||||
|
# -- Debug image -----------------------------------------------------------
|
||||||
|
|
||||||
|
def _draw_debug(self, bgr: np.ndarray, detections: list[dict],
|
||||||
|
face_msgs: list) -> np.ndarray:
|
||||||
|
"""Draw bounding boxes, landmarks, and names on the image."""
|
||||||
|
vis = bgr.copy()
|
||||||
|
for det, face_msg in zip(detections, face_msgs):
|
||||||
|
bbox = det['bbox']
|
||||||
|
x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
|
||||||
|
|
||||||
|
# Color: green if recognized, yellow if unknown
|
||||||
|
if face_msg.face_id >= 0:
|
||||||
|
color = (0, 255, 0)
|
||||||
|
label = f'{face_msg.person_name} ({face_msg.recognition_score:.2f})'
|
||||||
|
else:
|
||||||
|
color = (0, 255, 255)
|
||||||
|
label = f'unknown ({face_msg.confidence:.2f})'
|
||||||
|
|
||||||
|
cv2.rectangle(vis, (x1, y1), (x2, y2), color, 2)
|
||||||
|
cv2.putText(vis, label, (x1, y1 - 8),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
||||||
|
|
||||||
|
# Draw landmarks
|
||||||
|
kps = det['kps']
|
||||||
|
for k in range(5):
|
||||||
|
px, py = int(kps[k * 2]), int(kps[k * 2 + 1])
|
||||||
|
cv2.circle(vis, (px, py), 2, (0, 0, 255), -1)
|
||||||
|
|
||||||
|
return vis
|
||||||
|
|
||||||
|
|
||||||
|
# -- Entry point ---------------------------------------------------------------
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
"""ROS2 entry point for face_recognition node."""
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = FaceRecognitionNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -0,0 +1,350 @@
|
|||||||
|
"""
|
||||||
|
scrfd_detector.py -- SCRFD face detection with TensorRT FP16 + ONNX fallback.
|
||||||
|
|
||||||
|
SCRFD (Sample and Computation Redistribution for Face Detection) produces
|
||||||
|
9 output tensors across 3 strides (8, 16, 32), each with score, bbox, and
|
||||||
|
keypoint branches. This module handles anchor generation, bbox/keypoint
|
||||||
|
decoding, and NMS to produce final face detections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_STRIDES = [8, 16, 32]
|
||||||
|
_NUM_ANCHORS = 2 # anchors per cell per stride
|
||||||
|
|
||||||
|
|
||||||
|
# -- Inference backends --------------------------------------------------------
|
||||||
|
|
||||||
|
class _TRTBackend:
|
||||||
|
"""TensorRT inference engine for SCRFD."""
|
||||||
|
|
||||||
|
def __init__(self, engine_path: str):
|
||||||
|
import tensorrt as trt
|
||||||
|
import pycuda.driver as cuda
|
||||||
|
import pycuda.autoinit # noqa: F401
|
||||||
|
|
||||||
|
self._cuda = cuda
|
||||||
|
trt_logger = trt.Logger(trt.Logger.WARNING)
|
||||||
|
with open(engine_path, 'rb') as f, trt.Runtime(trt_logger) as runtime:
|
||||||
|
self._engine = runtime.deserialize_cuda_engine(f.read())
|
||||||
|
self._context = self._engine.create_execution_context()
|
||||||
|
|
||||||
|
self._inputs = []
|
||||||
|
self._outputs = []
|
||||||
|
self._output_names = []
|
||||||
|
self._bindings = []
|
||||||
|
for i in range(self._engine.num_io_tensors):
|
||||||
|
name = self._engine.get_tensor_name(i)
|
||||||
|
dtype = trt.nptype(self._engine.get_tensor_dtype(name))
|
||||||
|
shape = tuple(self._engine.get_tensor_shape(name))
|
||||||
|
nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize
|
||||||
|
host_mem = cuda.pagelocked_empty(shape, dtype)
|
||||||
|
device_mem = cuda.mem_alloc(nbytes)
|
||||||
|
self._bindings.append(int(device_mem))
|
||||||
|
if self._engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
|
||||||
|
self._inputs.append({'host': host_mem, 'device': device_mem})
|
||||||
|
else:
|
||||||
|
self._outputs.append({'host': host_mem, 'device': device_mem,
|
||||||
|
'shape': shape})
|
||||||
|
self._output_names.append(name)
|
||||||
|
self._stream = cuda.Stream()
|
||||||
|
|
||||||
|
def infer(self, input_data: np.ndarray) -> list[np.ndarray]:
|
||||||
|
"""Run inference and return output tensors."""
|
||||||
|
np.copyto(self._inputs[0]['host'], input_data.ravel())
|
||||||
|
self._cuda.memcpy_htod_async(
|
||||||
|
self._inputs[0]['device'], self._inputs[0]['host'], self._stream)
|
||||||
|
self._context.execute_async_v2(self._bindings, self._stream.handle)
|
||||||
|
results = []
|
||||||
|
for out in self._outputs:
|
||||||
|
self._cuda.memcpy_dtoh_async(out['host'], out['device'], self._stream)
|
||||||
|
self._stream.synchronize()
|
||||||
|
for out in self._outputs:
|
||||||
|
results.append(out['host'].reshape(out['shape']).copy())
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class _ONNXBackend:
|
||||||
|
"""ONNX Runtime inference (CUDA EP with CPU fallback)."""
|
||||||
|
|
||||||
|
def __init__(self, onnx_path: str):
|
||||||
|
import onnxruntime as ort
|
||||||
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||||
|
self._session = ort.InferenceSession(onnx_path, providers=providers)
|
||||||
|
self._input_name = self._session.get_inputs()[0].name
|
||||||
|
self._output_names = [o.name for o in self._session.get_outputs()]
|
||||||
|
|
||||||
|
def infer(self, input_data: np.ndarray) -> list[np.ndarray]:
|
||||||
|
"""Run inference and return output tensors."""
|
||||||
|
return self._session.run(None, {self._input_name: input_data})
|
||||||
|
|
||||||
|
|
||||||
|
# -- NMS ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _nms(boxes: np.ndarray, scores: np.ndarray, iou_thresh: float) -> list[int]:
|
||||||
|
"""Non-maximum suppression. boxes: [N, 4] as x1,y1,x2,y2."""
|
||||||
|
if len(boxes) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
|
||||||
|
areas = (x2 - x1) * (y2 - y1)
|
||||||
|
order = scores.argsort()[::-1]
|
||||||
|
keep = []
|
||||||
|
|
||||||
|
while order.size > 0:
|
||||||
|
i = order[0]
|
||||||
|
keep.append(int(i))
|
||||||
|
if order.size == 1:
|
||||||
|
break
|
||||||
|
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||||
|
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||||
|
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||||
|
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||||
|
inter = np.maximum(0.0, xx2 - xx1) * np.maximum(0.0, yy2 - yy1)
|
||||||
|
iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-6)
|
||||||
|
remaining = np.where(iou <= iou_thresh)[0]
|
||||||
|
order = order[remaining + 1]
|
||||||
|
|
||||||
|
return keep
|
||||||
|
|
||||||
|
|
||||||
|
# -- Anchor generation ---------------------------------------------------------
|
||||||
|
|
||||||
|
def _generate_anchors(input_h: int, input_w: int) -> dict[int, np.ndarray]:
|
||||||
|
"""Generate anchor centers for each stride.
|
||||||
|
|
||||||
|
Returns dict mapping stride -> array of shape [H*W*num_anchors, 2],
|
||||||
|
where each row is (cx, cy) in input pixel coordinates.
|
||||||
|
"""
|
||||||
|
anchors = {}
|
||||||
|
for stride in _STRIDES:
|
||||||
|
feat_h = input_h // stride
|
||||||
|
feat_w = input_w // stride
|
||||||
|
grid_y, grid_x = np.mgrid[:feat_h, :feat_w]
|
||||||
|
centers = np.stack([grid_x.ravel(), grid_y.ravel()], axis=1).astype(np.float32)
|
||||||
|
centers = (centers + 0.5) * stride
|
||||||
|
# Repeat for num_anchors per cell
|
||||||
|
centers = np.tile(centers, (_NUM_ANCHORS, 1)) # [H*W*2, 2]
|
||||||
|
# Interleave properly: [anchor0_cell0, anchor1_cell0, anchor0_cell1, ...]
|
||||||
|
centers = np.repeat(
|
||||||
|
np.stack([grid_x.ravel(), grid_y.ravel()], axis=1).astype(np.float32),
|
||||||
|
_NUM_ANCHORS, axis=0
|
||||||
|
)
|
||||||
|
centers = (centers + 0.5) * stride
|
||||||
|
anchors[stride] = centers
|
||||||
|
return anchors
|
||||||
|
|
||||||
|
|
||||||
|
# -- Main detector class -------------------------------------------------------
|
||||||
|
|
||||||
|
class SCRFDDetector:
|
||||||
|
"""SCRFD face detector with TensorRT FP16 and ONNX fallback.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
engine_path: Path to TensorRT engine file.
|
||||||
|
onnx_path: Path to ONNX model file (used if engine not available).
|
||||||
|
conf_thresh: Minimum confidence for detections.
|
||||||
|
nms_iou: IoU threshold for NMS.
|
||||||
|
input_size: Model input resolution (square).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
engine_path: str = '',
|
||||||
|
onnx_path: str = '',
|
||||||
|
conf_thresh: float = 0.5,
|
||||||
|
nms_iou: float = 0.4,
|
||||||
|
input_size: int = 640,
|
||||||
|
):
|
||||||
|
self._conf_thresh = conf_thresh
|
||||||
|
self._nms_iou = nms_iou
|
||||||
|
self._input_size = input_size
|
||||||
|
self._backend: Optional[_TRTBackend | _ONNXBackend] = None
|
||||||
|
self._anchors = _generate_anchors(input_size, input_size)
|
||||||
|
|
||||||
|
# Try TRT first, then ONNX
|
||||||
|
if engine_path and os.path.isfile(engine_path):
|
||||||
|
try:
|
||||||
|
self._backend = _TRTBackend(engine_path)
|
||||||
|
logger.info('SCRFD TensorRT backend loaded: %s', engine_path)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning('SCRFD TRT load failed (%s), trying ONNX', e)
|
||||||
|
|
||||||
|
if onnx_path and os.path.isfile(onnx_path):
|
||||||
|
try:
|
||||||
|
self._backend = _ONNXBackend(onnx_path)
|
||||||
|
logger.info('SCRFD ONNX backend loaded: %s', onnx_path)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.error('SCRFD ONNX load failed: %s', e)
|
||||||
|
|
||||||
|
logger.error('No SCRFD model loaded. Detection will be unavailable.')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_loaded(self) -> bool:
|
||||||
|
"""Return True if a backend is loaded and ready."""
|
||||||
|
return self._backend is not None
|
||||||
|
|
||||||
|
def detect(self, bgr: np.ndarray) -> list[dict]:
|
||||||
|
"""Detect faces in a BGR image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bgr: Input image in BGR format, shape (H, W, 3).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with keys:
|
||||||
|
bbox: [x1, y1, x2, y2] in original image coordinates
|
||||||
|
kps: [x0,y0, x1,y1, ..., x4,y4] — 10 floats, 5 landmarks
|
||||||
|
score: detection confidence
|
||||||
|
"""
|
||||||
|
if self._backend is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
orig_h, orig_w = bgr.shape[:2]
|
||||||
|
tensor, scale, pad_w, pad_h = self._preprocess(bgr)
|
||||||
|
outputs = self._backend.infer(tensor)
|
||||||
|
detections = self._decode_outputs(outputs)
|
||||||
|
detections = self._rescale(detections, scale, pad_w, pad_h, orig_w, orig_h)
|
||||||
|
return detections
|
||||||
|
|
||||||
|
def _preprocess(self, bgr: np.ndarray) -> tuple[np.ndarray, float, int, int]:
|
||||||
|
"""Resize to input_size x input_size with letterbox padding, normalize."""
|
||||||
|
h, w = bgr.shape[:2]
|
||||||
|
size = self._input_size
|
||||||
|
scale = min(size / h, size / w)
|
||||||
|
new_w, new_h = int(w * scale), int(h * scale)
|
||||||
|
pad_w = (size - new_w) // 2
|
||||||
|
pad_h = (size - new_h) // 2
|
||||||
|
|
||||||
|
resized = cv2.resize(bgr, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
||||||
|
canvas = np.full((size, size, 3), 0, dtype=np.uint8)
|
||||||
|
canvas[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = resized
|
||||||
|
|
||||||
|
# Normalize: subtract 127.5, divide 128.0
|
||||||
|
blob = canvas.astype(np.float32)
|
||||||
|
blob = (blob - 127.5) / 128.0
|
||||||
|
# HWC -> NCHW
|
||||||
|
blob = blob.transpose(2, 0, 1)[np.newaxis]
|
||||||
|
blob = np.ascontiguousarray(blob)
|
||||||
|
return blob, scale, pad_w, pad_h
|
||||||
|
|
||||||
|
def _decode_outputs(self, outputs: list[np.ndarray]) -> list[dict]:
|
||||||
|
"""Decode SCRFD 9-output format into face detections.
|
||||||
|
|
||||||
|
SCRFD outputs 9 tensors, 3 per stride (score, bbox, kps):
|
||||||
|
score_8, bbox_8, kps_8, score_16, bbox_16, kps_16, score_32, bbox_32, kps_32
|
||||||
|
"""
|
||||||
|
all_scores = []
|
||||||
|
all_bboxes = []
|
||||||
|
all_kps = []
|
||||||
|
|
||||||
|
for idx, stride in enumerate(_STRIDES):
|
||||||
|
score_out = outputs[idx * 3].squeeze() # [H*W*num_anchors]
|
||||||
|
bbox_out = outputs[idx * 3 + 1].squeeze() # [H*W*num_anchors, 4]
|
||||||
|
kps_out = outputs[idx * 3 + 2].squeeze() # [H*W*num_anchors, 10]
|
||||||
|
|
||||||
|
if score_out.ndim == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Ensure proper shapes
|
||||||
|
if score_out.ndim == 1:
|
||||||
|
n = score_out.shape[0]
|
||||||
|
else:
|
||||||
|
n = score_out.shape[0]
|
||||||
|
score_out = score_out.ravel()
|
||||||
|
|
||||||
|
if bbox_out.ndim == 1:
|
||||||
|
bbox_out = bbox_out.reshape(-1, 4)
|
||||||
|
if kps_out.ndim == 1:
|
||||||
|
kps_out = kps_out.reshape(-1, 10)
|
||||||
|
|
||||||
|
# Filter by confidence
|
||||||
|
mask = score_out > self._conf_thresh
|
||||||
|
if not mask.any():
|
||||||
|
continue
|
||||||
|
|
||||||
|
scores = score_out[mask]
|
||||||
|
bboxes = bbox_out[mask]
|
||||||
|
kps = kps_out[mask]
|
||||||
|
|
||||||
|
anchors = self._anchors[stride]
|
||||||
|
# Trim or pad anchors to match output count
|
||||||
|
if anchors.shape[0] > n:
|
||||||
|
anchors = anchors[:n]
|
||||||
|
elif anchors.shape[0] < n:
|
||||||
|
continue
|
||||||
|
|
||||||
|
anchors = anchors[mask]
|
||||||
|
|
||||||
|
# Decode bboxes: center = anchor + pred[:2]*stride, size = exp(pred[2:])*stride
|
||||||
|
cx = anchors[:, 0] + bboxes[:, 0] * stride
|
||||||
|
cy = anchors[:, 1] + bboxes[:, 1] * stride
|
||||||
|
w = np.exp(bboxes[:, 2]) * stride
|
||||||
|
h = np.exp(bboxes[:, 3]) * stride
|
||||||
|
x1 = cx - w / 2.0
|
||||||
|
y1 = cy - h / 2.0
|
||||||
|
x2 = cx + w / 2.0
|
||||||
|
y2 = cy + h / 2.0
|
||||||
|
decoded_bboxes = np.stack([x1, y1, x2, y2], axis=1)
|
||||||
|
|
||||||
|
# Decode keypoints: kp = anchor + pred * stride
|
||||||
|
decoded_kps = np.zeros_like(kps)
|
||||||
|
for k in range(5):
|
||||||
|
decoded_kps[:, k * 2] = anchors[:, 0] + kps[:, k * 2] * stride
|
||||||
|
decoded_kps[:, k * 2 + 1] = anchors[:, 1] + kps[:, k * 2 + 1] * stride
|
||||||
|
|
||||||
|
all_scores.append(scores)
|
||||||
|
all_bboxes.append(decoded_bboxes)
|
||||||
|
all_kps.append(decoded_kps)
|
||||||
|
|
||||||
|
if not all_scores:
|
||||||
|
return []
|
||||||
|
|
||||||
|
scores = np.concatenate(all_scores)
|
||||||
|
bboxes = np.concatenate(all_bboxes)
|
||||||
|
kps = np.concatenate(all_kps)
|
||||||
|
|
||||||
|
# NMS
|
||||||
|
keep = _nms(bboxes, scores, self._nms_iou)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i in keep:
|
||||||
|
results.append({
|
||||||
|
'bbox': bboxes[i].tolist(),
|
||||||
|
'kps': kps[i].tolist(),
|
||||||
|
'score': float(scores[i]),
|
||||||
|
})
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _rescale(
|
||||||
|
self,
|
||||||
|
detections: list[dict],
|
||||||
|
scale: float,
|
||||||
|
pad_w: int,
|
||||||
|
pad_h: int,
|
||||||
|
orig_w: int,
|
||||||
|
orig_h: int,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Rescale detections from model input space to original image space."""
|
||||||
|
for det in detections:
|
||||||
|
bbox = det['bbox']
|
||||||
|
bbox[0] = max(0.0, (bbox[0] - pad_w) / scale)
|
||||||
|
bbox[1] = max(0.0, (bbox[1] - pad_h) / scale)
|
||||||
|
bbox[2] = min(float(orig_w), (bbox[2] - pad_w) / scale)
|
||||||
|
bbox[3] = min(float(orig_h), (bbox[3] - pad_h) / scale)
|
||||||
|
det['bbox'] = bbox
|
||||||
|
|
||||||
|
kps = det['kps']
|
||||||
|
for k in range(5):
|
||||||
|
kps[k * 2] = (kps[k * 2] - pad_w) / scale
|
||||||
|
kps[k * 2 + 1] = (kps[k * 2 + 1] - pad_h) / scale
|
||||||
|
det['kps'] = kps
|
||||||
|
return detections
|
||||||
@ -0,0 +1,112 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
build_face_trt_engines.py -- Build TensorRT FP16 engines for SCRFD and ArcFace.
|
||||||
|
|
||||||
|
Converts ONNX model files to optimized TensorRT engines with FP16 precision
|
||||||
|
for fast inference on Jetson Orin Nano Super.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 build_face_trt_engines.py \
|
||||||
|
--scrfd-onnx /path/to/scrfd_2.5g_bnkps.onnx \
|
||||||
|
--arcface-onnx /path/to/arcface_r50.onnx \
|
||||||
|
--output-dir /mnt/nvme/saltybot/models \
|
||||||
|
--fp16 --workspace-mb 1024
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def build_engine(onnx_path: str, engine_path: str, fp16: bool, workspace_mb: int):
|
||||||
|
"""Build a TensorRT engine from an ONNX model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
onnx_path: Path to the source ONNX model file.
|
||||||
|
engine_path: Output path for the serialized TensorRT engine.
|
||||||
|
fp16: Enable FP16 precision.
|
||||||
|
workspace_mb: Maximum workspace size in megabytes.
|
||||||
|
"""
|
||||||
|
import tensorrt as trt
|
||||||
|
|
||||||
|
logger = trt.Logger(trt.Logger.INFO)
|
||||||
|
builder = trt.Builder(logger)
|
||||||
|
network = builder.create_network(
|
||||||
|
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
||||||
|
parser = trt.OnnxParser(network, logger)
|
||||||
|
|
||||||
|
print(f'Parsing ONNX model: {onnx_path}')
|
||||||
|
t0 = time.monotonic()
|
||||||
|
|
||||||
|
with open(onnx_path, 'rb') as f:
|
||||||
|
if not parser.parse(f.read()):
|
||||||
|
for i in range(parser.num_errors):
|
||||||
|
print(f' ONNX parse error: {parser.get_error(i)}')
|
||||||
|
raise RuntimeError(f'Failed to parse {onnx_path}')
|
||||||
|
|
||||||
|
parse_time = time.monotonic() - t0
|
||||||
|
print(f' Parsed in {parse_time:.1f}s')
|
||||||
|
|
||||||
|
config = builder.create_builder_config()
|
||||||
|
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE,
|
||||||
|
workspace_mb * (1 << 20))
|
||||||
|
if fp16:
|
||||||
|
if builder.platform_has_fast_fp16:
|
||||||
|
config.set_flag(trt.BuilderFlag.FP16)
|
||||||
|
print(' FP16 enabled.')
|
||||||
|
else:
|
||||||
|
print(' Warning: FP16 not supported on this platform, using FP32.')
|
||||||
|
|
||||||
|
print(f'Building engine (this may take several minutes)...')
|
||||||
|
t0 = time.monotonic()
|
||||||
|
serialized = builder.build_serialized_network(network, config)
|
||||||
|
build_time = time.monotonic() - t0
|
||||||
|
|
||||||
|
if serialized is None:
|
||||||
|
raise RuntimeError('Engine build failed.')
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(engine_path) or '.', exist_ok=True)
|
||||||
|
with open(engine_path, 'wb') as f:
|
||||||
|
f.write(serialized)
|
||||||
|
|
||||||
|
size_mb = os.path.getsize(engine_path) / (1 << 20)
|
||||||
|
print(f' Engine saved: {engine_path} ({size_mb:.1f} MB, built in {build_time:.1f}s)')
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point for TRT engine building."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Build TensorRT FP16 engines for SCRFD and ArcFace.')
|
||||||
|
parser.add_argument('--scrfd-onnx', type=str, default='',
|
||||||
|
help='Path to SCRFD ONNX model.')
|
||||||
|
parser.add_argument('--arcface-onnx', type=str, default='',
|
||||||
|
help='Path to ArcFace ONNX model.')
|
||||||
|
parser.add_argument('--output-dir', type=str,
|
||||||
|
default='/mnt/nvme/saltybot/models',
|
||||||
|
help='Output directory for engine files.')
|
||||||
|
parser.add_argument('--fp16', action='store_true', default=True,
|
||||||
|
help='Enable FP16 precision (default: True).')
|
||||||
|
parser.add_argument('--no-fp16', action='store_false', dest='fp16',
|
||||||
|
help='Disable FP16 (use FP32 only).')
|
||||||
|
parser.add_argument('--workspace-mb', type=int, default=1024,
|
||||||
|
help='TRT workspace size in MB (default: 1024).')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not args.scrfd_onnx and not args.arcface_onnx:
|
||||||
|
parser.error('At least one of --scrfd-onnx or --arcface-onnx is required.')
|
||||||
|
|
||||||
|
if args.scrfd_onnx:
|
||||||
|
engine_path = os.path.join(args.output_dir, 'scrfd_2.5g.engine')
|
||||||
|
print(f'\n=== Building SCRFD engine ===')
|
||||||
|
build_engine(args.scrfd_onnx, engine_path, args.fp16, args.workspace_mb)
|
||||||
|
|
||||||
|
if args.arcface_onnx:
|
||||||
|
engine_path = os.path.join(args.output_dir, 'arcface_r50.engine')
|
||||||
|
print(f'\n=== Building ArcFace engine ===')
|
||||||
|
build_engine(args.arcface_onnx, engine_path, args.fp16, args.workspace_mb)
|
||||||
|
|
||||||
|
print('\nDone.')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
4
jetson/ros2_ws/src/saltybot_social_face/setup.cfg
Normal file
4
jetson/ros2_ws/src/saltybot_social_face/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
[develop]
|
||||||
|
script_dir=$base/lib/saltybot_social_face
|
||||||
|
[install]
|
||||||
|
install_scripts=$base/lib/saltybot_social_face
|
||||||
30
jetson/ros2_ws/src/saltybot_social_face/setup.py
Normal file
30
jetson/ros2_ws/src/saltybot_social_face/setup.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
"""Setup for saltybot_social_face package."""
|
||||||
|
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
package_name = 'saltybot_social_face'
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name=package_name,
|
||||||
|
version='0.1.0',
|
||||||
|
packages=find_packages(exclude=['test']),
|
||||||
|
data_files=[
|
||||||
|
('share/ament_index/resource_index/packages', ['resource/' + package_name]),
|
||||||
|
('share/' + package_name, ['package.xml']),
|
||||||
|
('share/' + package_name + '/launch', ['launch/face_recognition.launch.py']),
|
||||||
|
('share/' + package_name + '/config', ['config/face_recognition_params.yaml']),
|
||||||
|
],
|
||||||
|
install_requires=['setuptools'],
|
||||||
|
zip_safe=True,
|
||||||
|
maintainer='seb',
|
||||||
|
maintainer_email='seb@vayrette.com',
|
||||||
|
description='SCRFD face detection and ArcFace recognition for SaltyBot social interactions',
|
||||||
|
license='MIT',
|
||||||
|
tests_require=['pytest'],
|
||||||
|
entry_points={
|
||||||
|
'console_scripts': [
|
||||||
|
'face_recognition = saltybot_social_face.face_recognition_node:main',
|
||||||
|
'enrollment_cli = saltybot_social_face.enrollment_cli:main',
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
24
jetson/ros2_ws/src/saltybot_social_msgs/CMakeLists.txt
Normal file
24
jetson/ros2_ws/src/saltybot_social_msgs/CMakeLists.txt
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.8)
|
||||||
|
project(saltybot_social_msgs)
|
||||||
|
|
||||||
|
find_package(ament_cmake REQUIRED)
|
||||||
|
find_package(rosidl_default_generators REQUIRED)
|
||||||
|
find_package(std_msgs REQUIRED)
|
||||||
|
find_package(geometry_msgs REQUIRED)
|
||||||
|
find_package(builtin_interfaces REQUIRED)
|
||||||
|
|
||||||
|
rosidl_generate_interfaces(${PROJECT_NAME}
|
||||||
|
"msg/FaceDetection.msg"
|
||||||
|
"msg/FaceDetectionArray.msg"
|
||||||
|
"msg/FaceEmbedding.msg"
|
||||||
|
"msg/FaceEmbeddingArray.msg"
|
||||||
|
"msg/PersonState.msg"
|
||||||
|
"msg/PersonStateArray.msg"
|
||||||
|
"srv/EnrollPerson.srv"
|
||||||
|
"srv/ListPersons.srv"
|
||||||
|
"srv/DeletePerson.srv"
|
||||||
|
"srv/UpdatePerson.srv"
|
||||||
|
DEPENDENCIES std_msgs geometry_msgs builtin_interfaces
|
||||||
|
)
|
||||||
|
|
||||||
|
ament_package()
|
||||||
@ -0,0 +1,10 @@
|
|||||||
|
std_msgs/Header header
|
||||||
|
int32 face_id
|
||||||
|
string person_name
|
||||||
|
float32 confidence
|
||||||
|
float32 recognition_score
|
||||||
|
float32 bbox_x
|
||||||
|
float32 bbox_y
|
||||||
|
float32 bbox_w
|
||||||
|
float32 bbox_h
|
||||||
|
float32[10] landmarks
|
||||||
@ -0,0 +1,2 @@
|
|||||||
|
std_msgs/Header header
|
||||||
|
saltybot_social_msgs/FaceDetection[] faces
|
||||||
@ -0,0 +1,5 @@
|
|||||||
|
int32 person_id
|
||||||
|
string person_name
|
||||||
|
float32[] embedding
|
||||||
|
builtin_interfaces/Time enrolled_at
|
||||||
|
int32 sample_count
|
||||||
@ -0,0 +1,2 @@
|
|||||||
|
std_msgs/Header header
|
||||||
|
saltybot_social_msgs/FaceEmbedding[] embeddings
|
||||||
19
jetson/ros2_ws/src/saltybot_social_msgs/msg/PersonState.msg
Normal file
19
jetson/ros2_ws/src/saltybot_social_msgs/msg/PersonState.msg
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
std_msgs/Header header
|
||||||
|
int32 person_id
|
||||||
|
string person_name
|
||||||
|
int32 face_id
|
||||||
|
string speaker_id
|
||||||
|
string uwb_anchor_id
|
||||||
|
geometry_msgs/Point position
|
||||||
|
float32 distance
|
||||||
|
float32 bearing_deg
|
||||||
|
uint8 state
|
||||||
|
uint8 STATE_UNKNOWN=0
|
||||||
|
uint8 STATE_APPROACHING=1
|
||||||
|
uint8 STATE_ENGAGED=2
|
||||||
|
uint8 STATE_TALKING=3
|
||||||
|
uint8 STATE_LEAVING=4
|
||||||
|
uint8 STATE_ABSENT=5
|
||||||
|
float32 engagement_score
|
||||||
|
builtin_interfaces/Time last_seen
|
||||||
|
int32 camera_id
|
||||||
@ -0,0 +1,3 @@
|
|||||||
|
std_msgs/Header header
|
||||||
|
saltybot_social_msgs/PersonState[] persons
|
||||||
|
int32 primary_attention_id
|
||||||
19
jetson/ros2_ws/src/saltybot_social_msgs/package.xml
Normal file
19
jetson/ros2_ws/src/saltybot_social_msgs/package.xml
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
<?xml version="1.0"?>
|
||||||
|
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||||
|
<package format="3">
|
||||||
|
<name>saltybot_social_msgs</name>
|
||||||
|
<version>0.1.0</version>
|
||||||
|
<description>Custom ROS2 messages and services for saltybot social capabilities</description>
|
||||||
|
<maintainer email="seb@vayrette.com">seb</maintainer>
|
||||||
|
<license>MIT</license>
|
||||||
|
<buildtool_depend>ament_cmake</buildtool_depend>
|
||||||
|
<depend>std_msgs</depend>
|
||||||
|
<depend>geometry_msgs</depend>
|
||||||
|
<depend>builtin_interfaces</depend>
|
||||||
|
<build_depend>rosidl_default_generators</build_depend>
|
||||||
|
<exec_depend>rosidl_default_runtime</exec_depend>
|
||||||
|
<member_of_group>rosidl_interface_packages</member_of_group>
|
||||||
|
<export>
|
||||||
|
<build_type>ament_cmake</build_type>
|
||||||
|
</export>
|
||||||
|
</package>
|
||||||
@ -0,0 +1,4 @@
|
|||||||
|
int32 person_id
|
||||||
|
---
|
||||||
|
bool success
|
||||||
|
string message
|
||||||
@ -0,0 +1,7 @@
|
|||||||
|
string name
|
||||||
|
string mode
|
||||||
|
int32 n_samples
|
||||||
|
---
|
||||||
|
bool success
|
||||||
|
string message
|
||||||
|
int32 person_id
|
||||||
@ -0,0 +1,2 @@
|
|||||||
|
---
|
||||||
|
saltybot_social_msgs/FaceEmbedding[] persons
|
||||||
@ -0,0 +1,5 @@
|
|||||||
|
int32 person_id
|
||||||
|
string new_name
|
||||||
|
---
|
||||||
|
bool success
|
||||||
|
string message
|
||||||
@ -0,0 +1,22 @@
|
|||||||
|
social_nav:
|
||||||
|
ros__parameters:
|
||||||
|
follow_mode: 'shadow'
|
||||||
|
follow_distance: 1.2
|
||||||
|
lead_distance: 2.0
|
||||||
|
orbit_radius: 1.5
|
||||||
|
max_linear_speed: 1.0
|
||||||
|
max_linear_speed_fast: 5.5
|
||||||
|
max_angular_speed: 1.0
|
||||||
|
goal_tolerance: 0.3
|
||||||
|
routes_dir: '/mnt/nvme/saltybot/routes'
|
||||||
|
home_x: 0.0
|
||||||
|
home_y: 0.0
|
||||||
|
map_resolution: 0.05
|
||||||
|
obstacle_inflation_cells: 3
|
||||||
|
|
||||||
|
midas_depth:
|
||||||
|
ros__parameters:
|
||||||
|
onnx_path: '/mnt/nvme/saltybot/models/midas_small.onnx'
|
||||||
|
engine_path: '/mnt/nvme/saltybot/models/midas_small.engine'
|
||||||
|
process_rate: 5.0
|
||||||
|
output_scale: 1.0
|
||||||
@ -0,0 +1,57 @@
|
|||||||
|
"""Launch file for saltybot social navigation."""
|
||||||
|
|
||||||
|
from launch import LaunchDescription
|
||||||
|
from launch.actions import DeclareLaunchArgument
|
||||||
|
from launch.substitutions import LaunchConfiguration
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launch_description():
|
||||||
|
return LaunchDescription([
|
||||||
|
DeclareLaunchArgument('follow_mode', default_value='shadow',
|
||||||
|
description='Follow mode: shadow/lead/side/orbit/loose/tight'),
|
||||||
|
DeclareLaunchArgument('follow_distance', default_value='1.2',
|
||||||
|
description='Follow distance in meters'),
|
||||||
|
DeclareLaunchArgument('max_linear_speed', default_value='1.0',
|
||||||
|
description='Max linear speed (m/s)'),
|
||||||
|
DeclareLaunchArgument('routes_dir',
|
||||||
|
default_value='/mnt/nvme/saltybot/routes',
|
||||||
|
description='Directory for saved routes'),
|
||||||
|
|
||||||
|
Node(
|
||||||
|
package='saltybot_social_nav',
|
||||||
|
executable='social_nav',
|
||||||
|
name='social_nav',
|
||||||
|
output='screen',
|
||||||
|
parameters=[{
|
||||||
|
'follow_mode': LaunchConfiguration('follow_mode'),
|
||||||
|
'follow_distance': LaunchConfiguration('follow_distance'),
|
||||||
|
'max_linear_speed': LaunchConfiguration('max_linear_speed'),
|
||||||
|
'routes_dir': LaunchConfiguration('routes_dir'),
|
||||||
|
}],
|
||||||
|
),
|
||||||
|
|
||||||
|
Node(
|
||||||
|
package='saltybot_social_nav',
|
||||||
|
executable='midas_depth',
|
||||||
|
name='midas_depth',
|
||||||
|
output='screen',
|
||||||
|
parameters=[{
|
||||||
|
'onnx_path': '/mnt/nvme/saltybot/models/midas_small.onnx',
|
||||||
|
'engine_path': '/mnt/nvme/saltybot/models/midas_small.engine',
|
||||||
|
'process_rate': 5.0,
|
||||||
|
'output_scale': 1.0,
|
||||||
|
}],
|
||||||
|
),
|
||||||
|
|
||||||
|
Node(
|
||||||
|
package='saltybot_social_nav',
|
||||||
|
executable='waypoint_teacher',
|
||||||
|
name='waypoint_teacher',
|
||||||
|
output='screen',
|
||||||
|
parameters=[{
|
||||||
|
'routes_dir': LaunchConfiguration('routes_dir'),
|
||||||
|
'recording_interval': 0.5,
|
||||||
|
}],
|
||||||
|
),
|
||||||
|
])
|
||||||
28
jetson/ros2_ws/src/saltybot_social_nav/package.xml
Normal file
28
jetson/ros2_ws/src/saltybot_social_nav/package.xml
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
<?xml version="1.0"?>
|
||||||
|
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||||
|
<package format="3">
|
||||||
|
<name>saltybot_social_nav</name>
|
||||||
|
<version>0.1.0</version>
|
||||||
|
<description>Social navigation for saltybot: follow modes, waypoint teaching, A* avoidance, MiDaS depth</description>
|
||||||
|
<maintainer email="seb@vayrette.com">seb</maintainer>
|
||||||
|
<license>MIT</license>
|
||||||
|
|
||||||
|
<depend>rclpy</depend>
|
||||||
|
<depend>std_msgs</depend>
|
||||||
|
<depend>geometry_msgs</depend>
|
||||||
|
<depend>nav_msgs</depend>
|
||||||
|
<depend>sensor_msgs</depend>
|
||||||
|
<depend>cv_bridge</depend>
|
||||||
|
<depend>tf2_ros</depend>
|
||||||
|
<depend>tf2_geometry_msgs</depend>
|
||||||
|
<depend>saltybot_social_msgs</depend>
|
||||||
|
|
||||||
|
<test_depend>ament_copyright</test_depend>
|
||||||
|
<test_depend>ament_flake8</test_depend>
|
||||||
|
<test_depend>ament_pep257</test_depend>
|
||||||
|
<test_depend>python3-pytest</test_depend>
|
||||||
|
|
||||||
|
<export>
|
||||||
|
<build_type>ament_python</build_type>
|
||||||
|
</export>
|
||||||
|
</package>
|
||||||
@ -0,0 +1,82 @@
|
|||||||
|
"""astar.py -- A* path planner for saltybot social navigation."""
|
||||||
|
|
||||||
|
import heapq
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def astar(grid: np.ndarray, start: tuple, goal: tuple,
|
||||||
|
obstacle_val: int = 100) -> list | None:
|
||||||
|
"""
|
||||||
|
A* on a 2D occupancy grid (row, col indexing).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid: 2D numpy array, values 0=free, >=obstacle_val=obstacle
|
||||||
|
start: (row, col) start cell
|
||||||
|
goal: (row, col) goal cell
|
||||||
|
obstacle_val: cells with value >= obstacle_val are blocked
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (row, col) tuples from start to goal, or None if no path.
|
||||||
|
"""
|
||||||
|
rows, cols = grid.shape
|
||||||
|
|
||||||
|
def h(a, b):
|
||||||
|
return abs(a[0] - b[0]) + abs(a[1] - b[1]) # Manhattan heuristic
|
||||||
|
|
||||||
|
open_set = []
|
||||||
|
heapq.heappush(open_set, (h(start, goal), 0, start))
|
||||||
|
came_from = {}
|
||||||
|
g_score = {start: 0}
|
||||||
|
|
||||||
|
# 8-directional movement
|
||||||
|
neighbors_delta = [
|
||||||
|
(-1, -1), (-1, 0), (-1, 1),
|
||||||
|
(0, -1), (0, 1),
|
||||||
|
(1, -1), (1, 0), (1, 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
while open_set:
|
||||||
|
_, cost, current = heapq.heappop(open_set)
|
||||||
|
|
||||||
|
if current == goal:
|
||||||
|
path = []
|
||||||
|
while current in came_from:
|
||||||
|
path.append(current)
|
||||||
|
current = came_from[current]
|
||||||
|
path.append(start)
|
||||||
|
return list(reversed(path))
|
||||||
|
|
||||||
|
if cost > g_score.get(current, float('inf')):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for dr, dc in neighbors_delta:
|
||||||
|
nr, nc = current[0] + dr, current[1] + dc
|
||||||
|
if not (0 <= nr < rows and 0 <= nc < cols):
|
||||||
|
continue
|
||||||
|
if grid[nr, nc] >= obstacle_val:
|
||||||
|
continue
|
||||||
|
move_cost = 1.414 if (dr != 0 and dc != 0) else 1.0
|
||||||
|
new_g = g_score[current] + move_cost
|
||||||
|
neighbor = (nr, nc)
|
||||||
|
if new_g < g_score.get(neighbor, float('inf')):
|
||||||
|
g_score[neighbor] = new_g
|
||||||
|
f = new_g + h(neighbor, goal)
|
||||||
|
came_from[neighbor] = current
|
||||||
|
heapq.heappush(open_set, (f, new_g, neighbor))
|
||||||
|
|
||||||
|
return None # No path found
|
||||||
|
|
||||||
|
|
||||||
|
def inflate_obstacles(grid: np.ndarray, inflation_radius_cells: int) -> np.ndarray:
|
||||||
|
"""Inflate obstacles for robot footprint safety."""
|
||||||
|
from scipy.ndimage import binary_dilation
|
||||||
|
|
||||||
|
obstacle_mask = grid >= 50
|
||||||
|
kernel = np.ones(
|
||||||
|
(2 * inflation_radius_cells + 1, 2 * inflation_radius_cells + 1),
|
||||||
|
dtype=bool,
|
||||||
|
)
|
||||||
|
inflated = binary_dilation(obstacle_mask, structure=kernel)
|
||||||
|
result = grid.copy()
|
||||||
|
result[inflated] = 100
|
||||||
|
return result
|
||||||
@ -0,0 +1,82 @@
|
|||||||
|
"""follow_modes.py -- Follow mode geometry for saltybot social navigation."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from enum import Enum
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class FollowMode(Enum):
|
||||||
|
SHADOW = 'shadow' # stay directly behind at follow_distance
|
||||||
|
LEAD = 'lead' # move ahead of person by lead_distance
|
||||||
|
SIDE = 'side' # stay to the right (or left) at side_offset
|
||||||
|
ORBIT = 'orbit' # circle around person at orbit_radius
|
||||||
|
LOOSE = 'loose' # general follow, larger tolerance
|
||||||
|
TIGHT = 'tight' # close follow, small tolerance
|
||||||
|
|
||||||
|
|
||||||
|
def compute_shadow_target(person_pos, person_bearing_deg, follow_dist=1.2):
|
||||||
|
"""Target position: behind person along their movement direction."""
|
||||||
|
bearing_rad = math.radians(person_bearing_deg + 180.0)
|
||||||
|
tx = person_pos[0] + follow_dist * math.sin(bearing_rad)
|
||||||
|
ty = person_pos[1] + follow_dist * math.cos(bearing_rad)
|
||||||
|
return (tx, ty, person_pos[2])
|
||||||
|
|
||||||
|
|
||||||
|
def compute_lead_target(person_pos, person_bearing_deg, lead_dist=2.0):
|
||||||
|
"""Target position: ahead of person."""
|
||||||
|
bearing_rad = math.radians(person_bearing_deg)
|
||||||
|
tx = person_pos[0] + lead_dist * math.sin(bearing_rad)
|
||||||
|
ty = person_pos[1] + lead_dist * math.cos(bearing_rad)
|
||||||
|
return (tx, ty, person_pos[2])
|
||||||
|
|
||||||
|
|
||||||
|
def compute_side_target(person_pos, person_bearing_deg, side_dist=1.0, right=True):
|
||||||
|
"""Target position: to the right (or left) of person."""
|
||||||
|
sign = 1.0 if right else -1.0
|
||||||
|
bearing_rad = math.radians(person_bearing_deg + sign * 90.0)
|
||||||
|
tx = person_pos[0] + side_dist * math.sin(bearing_rad)
|
||||||
|
ty = person_pos[1] + side_dist * math.cos(bearing_rad)
|
||||||
|
return (tx, ty, person_pos[2])
|
||||||
|
|
||||||
|
|
||||||
|
def compute_orbit_target(person_pos, orbit_angle_deg, orbit_radius=1.5):
|
||||||
|
"""Target on circle of radius orbit_radius around person."""
|
||||||
|
angle_rad = math.radians(orbit_angle_deg)
|
||||||
|
tx = person_pos[0] + orbit_radius * math.sin(angle_rad)
|
||||||
|
ty = person_pos[1] + orbit_radius * math.cos(angle_rad)
|
||||||
|
return (tx, ty, person_pos[2])
|
||||||
|
|
||||||
|
|
||||||
|
def compute_loose_target(person_pos, robot_pos, follow_dist=2.0, tolerance=0.8):
|
||||||
|
"""Only move if farther than follow_dist + tolerance."""
|
||||||
|
dx = person_pos[0] - robot_pos[0]
|
||||||
|
dy = person_pos[1] - robot_pos[1]
|
||||||
|
dist = math.hypot(dx, dy)
|
||||||
|
if dist <= follow_dist + tolerance:
|
||||||
|
return robot_pos
|
||||||
|
# Target at follow_dist behind person (toward robot)
|
||||||
|
scale = (dist - follow_dist) / dist
|
||||||
|
return (robot_pos[0] + dx * scale, robot_pos[1] + dy * scale, person_pos[2])
|
||||||
|
|
||||||
|
|
||||||
|
def compute_tight_target(person_pos, follow_dist=0.6):
|
||||||
|
"""Close follow: stay very near person."""
|
||||||
|
return (person_pos[0], person_pos[1] - follow_dist, person_pos[2])
|
||||||
|
|
||||||
|
|
||||||
|
MODE_VOICE_COMMANDS = {
|
||||||
|
'shadow': FollowMode.SHADOW,
|
||||||
|
'follow me': FollowMode.SHADOW,
|
||||||
|
'behind me': FollowMode.SHADOW,
|
||||||
|
'lead': FollowMode.LEAD,
|
||||||
|
'go ahead': FollowMode.LEAD,
|
||||||
|
'lead me': FollowMode.LEAD,
|
||||||
|
'side': FollowMode.SIDE,
|
||||||
|
'stay beside': FollowMode.SIDE,
|
||||||
|
'orbit': FollowMode.ORBIT,
|
||||||
|
'circle me': FollowMode.ORBIT,
|
||||||
|
'loose': FollowMode.LOOSE,
|
||||||
|
'give me space': FollowMode.LOOSE,
|
||||||
|
'tight': FollowMode.TIGHT,
|
||||||
|
'stay close': FollowMode.TIGHT,
|
||||||
|
}
|
||||||
@ -0,0 +1,231 @@
|
|||||||
|
"""
|
||||||
|
midas_depth_node.py -- MiDaS monocular depth estimation for saltybot.
|
||||||
|
|
||||||
|
Uses MiDaS_small via ONNX Runtime or TensorRT FP16.
|
||||||
|
Provides relative depth estimates for cameras without active depth (CSI cameras).
|
||||||
|
|
||||||
|
Publishes /social/depth/cam{i}/image (sensor_msgs/Image, float32, relative depth)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||||
|
from sensor_msgs.msg import Image
|
||||||
|
from cv_bridge import CvBridge
|
||||||
|
|
||||||
|
# MiDaS_small input size
|
||||||
|
_MIDAS_H = 256
|
||||||
|
_MIDAS_W = 256
|
||||||
|
# ImageNet normalization
|
||||||
|
_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
||||||
|
_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
class _TRTBackend:
|
||||||
|
"""TensorRT inference backend for MiDaS."""
|
||||||
|
|
||||||
|
def __init__(self, engine_path: str, logger):
|
||||||
|
self._logger = logger
|
||||||
|
try:
|
||||||
|
import tensorrt as trt
|
||||||
|
import pycuda.driver as cuda
|
||||||
|
import pycuda.autoinit # noqa: F401
|
||||||
|
|
||||||
|
self._cuda = cuda
|
||||||
|
rt_logger = trt.Logger(trt.Logger.WARNING)
|
||||||
|
with open(engine_path, 'rb') as f:
|
||||||
|
engine = trt.Runtime(rt_logger).deserialize_cuda_engine(f.read())
|
||||||
|
self._context = engine.create_execution_context()
|
||||||
|
|
||||||
|
# Allocate buffers
|
||||||
|
self._d_input = cuda.mem_alloc(1 * 3 * _MIDAS_H * _MIDAS_W * 4)
|
||||||
|
self._d_output = cuda.mem_alloc(1 * _MIDAS_H * _MIDAS_W * 4)
|
||||||
|
self._h_output = np.empty((_MIDAS_H, _MIDAS_W), dtype=np.float32)
|
||||||
|
self._stream = cuda.Stream()
|
||||||
|
self._logger.info(f'TRT engine loaded: {engine_path}')
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f'TRT init failed: {e}')
|
||||||
|
|
||||||
|
def infer(self, input_tensor: np.ndarray) -> np.ndarray:
|
||||||
|
self._cuda.memcpy_htod_async(
|
||||||
|
self._d_input, input_tensor.ravel(), self._stream)
|
||||||
|
self._context.execute_async_v2(
|
||||||
|
bindings=[int(self._d_input), int(self._d_output)],
|
||||||
|
stream_handle=self._stream.handle)
|
||||||
|
self._cuda.memcpy_dtoh_async(
|
||||||
|
self._h_output, self._d_output, self._stream)
|
||||||
|
self._stream.synchronize()
|
||||||
|
return self._h_output.copy()
|
||||||
|
|
||||||
|
|
||||||
|
class _ONNXBackend:
|
||||||
|
"""ONNX Runtime inference backend for MiDaS."""
|
||||||
|
|
||||||
|
def __init__(self, onnx_path: str, logger):
|
||||||
|
self._logger = logger
|
||||||
|
try:
|
||||||
|
import onnxruntime as ort
|
||||||
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||||
|
self._session = ort.InferenceSession(onnx_path, providers=providers)
|
||||||
|
self._input_name = self._session.get_inputs()[0].name
|
||||||
|
self._logger.info(f'ONNX model loaded: {onnx_path}')
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f'ONNX init failed: {e}')
|
||||||
|
|
||||||
|
def infer(self, input_tensor: np.ndarray) -> np.ndarray:
|
||||||
|
result = self._session.run(None, {self._input_name: input_tensor})
|
||||||
|
return result[0].squeeze()
|
||||||
|
|
||||||
|
|
||||||
|
class MiDaSDepthNode(Node):
|
||||||
|
"""MiDaS monocular depth estimation node."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('midas_depth')
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
self.declare_parameter('onnx_path',
|
||||||
|
'/mnt/nvme/saltybot/models/midas_small.onnx')
|
||||||
|
self.declare_parameter('engine_path',
|
||||||
|
'/mnt/nvme/saltybot/models/midas_small.engine')
|
||||||
|
self.declare_parameter('camera_topics', [
|
||||||
|
'/surround/cam0/image_raw',
|
||||||
|
'/surround/cam1/image_raw',
|
||||||
|
'/surround/cam2/image_raw',
|
||||||
|
'/surround/cam3/image_raw',
|
||||||
|
])
|
||||||
|
self.declare_parameter('output_scale', 1.0)
|
||||||
|
self.declare_parameter('process_rate', 5.0)
|
||||||
|
|
||||||
|
onnx_path = self.get_parameter('onnx_path').value
|
||||||
|
engine_path = self.get_parameter('engine_path').value
|
||||||
|
self._camera_topics = self.get_parameter('camera_topics').value
|
||||||
|
self._output_scale = self.get_parameter('output_scale').value
|
||||||
|
process_rate = self.get_parameter('process_rate').value
|
||||||
|
|
||||||
|
# Initialize inference backend (TRT preferred, ONNX fallback)
|
||||||
|
self._backend = None
|
||||||
|
if os.path.exists(engine_path):
|
||||||
|
try:
|
||||||
|
self._backend = _TRTBackend(engine_path, self.get_logger())
|
||||||
|
except RuntimeError:
|
||||||
|
self.get_logger().warn('TRT failed, trying ONNX fallback')
|
||||||
|
if self._backend is None and os.path.exists(onnx_path):
|
||||||
|
try:
|
||||||
|
self._backend = _ONNXBackend(onnx_path, self.get_logger())
|
||||||
|
except RuntimeError:
|
||||||
|
self.get_logger().error('Both TRT and ONNX backends failed')
|
||||||
|
if self._backend is None:
|
||||||
|
self.get_logger().error(
|
||||||
|
'No MiDaS model found. Depth estimation disabled.')
|
||||||
|
|
||||||
|
self._bridge = CvBridge()
|
||||||
|
|
||||||
|
# Latest frames per camera (round-robin processing)
|
||||||
|
self._latest_frames = [None] * len(self._camera_topics)
|
||||||
|
self._current_cam_idx = 0
|
||||||
|
|
||||||
|
# QoS for camera subscriptions
|
||||||
|
cam_qos = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Subscribe to each camera topic
|
||||||
|
self._cam_subs = []
|
||||||
|
for i, topic in enumerate(self._camera_topics):
|
||||||
|
sub = self.create_subscription(
|
||||||
|
Image, topic,
|
||||||
|
lambda msg, idx=i: self._on_image(msg, idx),
|
||||||
|
cam_qos)
|
||||||
|
self._cam_subs.append(sub)
|
||||||
|
|
||||||
|
# Publishers: one per camera
|
||||||
|
self._depth_pubs = []
|
||||||
|
for i in range(len(self._camera_topics)):
|
||||||
|
pub = self.create_publisher(
|
||||||
|
Image, f'/social/depth/cam{i}/image', 10)
|
||||||
|
self._depth_pubs.append(pub)
|
||||||
|
|
||||||
|
# Timer: round-robin across cameras
|
||||||
|
timer_period = 1.0 / process_rate
|
||||||
|
self._timer = self.create_timer(timer_period, self._timer_callback)
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f'MiDaS depth node started: {len(self._camera_topics)} cameras '
|
||||||
|
f'@ {process_rate} Hz')
|
||||||
|
|
||||||
|
def _on_image(self, msg: Image, cam_idx: int):
|
||||||
|
"""Cache latest frame for each camera."""
|
||||||
|
self._latest_frames[cam_idx] = msg
|
||||||
|
|
||||||
|
def _preprocess(self, bgr: np.ndarray) -> np.ndarray:
|
||||||
|
"""Preprocess BGR image to MiDaS input tensor [1,3,256,256]."""
|
||||||
|
import cv2
|
||||||
|
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
||||||
|
resized = cv2.resize(rgb, (_MIDAS_W, _MIDAS_H),
|
||||||
|
interpolation=cv2.INTER_LINEAR)
|
||||||
|
normalized = (resized.astype(np.float32) / 255.0 - _MEAN) / _STD
|
||||||
|
# HWC -> CHW -> NCHW
|
||||||
|
tensor = normalized.transpose(2, 0, 1)[np.newaxis, ...]
|
||||||
|
return tensor.astype(np.float32)
|
||||||
|
|
||||||
|
def _infer(self, tensor: np.ndarray) -> np.ndarray:
|
||||||
|
"""Run inference, returns [256,256] float32 relative inverse depth."""
|
||||||
|
if self._backend is None:
|
||||||
|
return np.zeros((_MIDAS_H, _MIDAS_W), dtype=np.float32)
|
||||||
|
return self._backend.infer(tensor)
|
||||||
|
|
||||||
|
def _postprocess(self, raw: np.ndarray, orig_shape: tuple) -> np.ndarray:
|
||||||
|
"""Resize depth back to original image shape, apply output_scale."""
|
||||||
|
import cv2
|
||||||
|
h, w = orig_shape[:2]
|
||||||
|
depth = cv2.resize(raw, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||||
|
depth = depth * self._output_scale
|
||||||
|
return depth
|
||||||
|
|
||||||
|
def _timer_callback(self):
|
||||||
|
"""Process one camera per tick (round-robin)."""
|
||||||
|
if not self._camera_topics:
|
||||||
|
return
|
||||||
|
|
||||||
|
idx = self._current_cam_idx
|
||||||
|
self._current_cam_idx = (idx + 1) % len(self._camera_topics)
|
||||||
|
|
||||||
|
msg = self._latest_frames[idx]
|
||||||
|
if msg is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
bgr = self._bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().warn(f'cv_bridge error cam{idx}: {e}')
|
||||||
|
return
|
||||||
|
|
||||||
|
tensor = self._preprocess(bgr)
|
||||||
|
raw_depth = self._infer(tensor)
|
||||||
|
depth_map = self._postprocess(raw_depth, bgr.shape)
|
||||||
|
|
||||||
|
# Publish as float32 Image
|
||||||
|
depth_msg = self._bridge.cv2_to_imgmsg(depth_map, encoding='32FC1')
|
||||||
|
depth_msg.header = msg.header
|
||||||
|
self._depth_pubs[idx].publish(depth_msg)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = MiDaSDepthNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -0,0 +1,584 @@
|
|||||||
|
"""
|
||||||
|
social_nav_node.py -- Social navigation node for saltybot.
|
||||||
|
|
||||||
|
Orchestrates person following with multiple modes, voice commands,
|
||||||
|
waypoint teaching/replay, and A* obstacle avoidance.
|
||||||
|
|
||||||
|
Follow modes:
|
||||||
|
shadow -- stay directly behind at follow_distance
|
||||||
|
lead -- move ahead of person
|
||||||
|
side -- stay beside (default right)
|
||||||
|
orbit -- circle around person
|
||||||
|
loose -- relaxed follow with deadband
|
||||||
|
tight -- close follow
|
||||||
|
|
||||||
|
Waypoint teaching:
|
||||||
|
Voice command "teach route <name>" -> record mode ON
|
||||||
|
Voice command "stop teaching" -> save route
|
||||||
|
Voice command "replay route <name>" -> playback
|
||||||
|
|
||||||
|
Voice commands:
|
||||||
|
"follow me" / "shadow" -> SHADOW mode
|
||||||
|
"lead me" / "go ahead" -> LEAD mode
|
||||||
|
"stay beside" -> SIDE mode
|
||||||
|
"orbit" -> ORBIT mode
|
||||||
|
"give me space" -> LOOSE mode
|
||||||
|
"stay close" -> TIGHT mode
|
||||||
|
"stop" / "halt" -> STOP
|
||||||
|
"go home" -> navigate to home waypoint
|
||||||
|
"teach route <name>" -> start recording
|
||||||
|
"stop teaching" -> finish recording
|
||||||
|
"replay route <name>" -> playback recorded route
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
import re
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||||
|
from geometry_msgs.msg import Twist, PoseStamped
|
||||||
|
from nav_msgs.msg import OccupancyGrid, Odometry
|
||||||
|
from std_msgs.msg import String, Int32
|
||||||
|
|
||||||
|
from .follow_modes import (
|
||||||
|
FollowMode, MODE_VOICE_COMMANDS,
|
||||||
|
compute_shadow_target, compute_lead_target, compute_side_target,
|
||||||
|
compute_orbit_target, compute_loose_target, compute_tight_target,
|
||||||
|
)
|
||||||
|
from .astar import astar, inflate_obstacles
|
||||||
|
from .waypoint_teacher import WaypointRoute, WaypointReplayer
|
||||||
|
|
||||||
|
# Try importing social msgs; fallback gracefully
|
||||||
|
try:
|
||||||
|
from saltybot_social_msgs.msg import PersonStateArray
|
||||||
|
_HAS_SOCIAL_MSGS = True
|
||||||
|
except ImportError:
|
||||||
|
_HAS_SOCIAL_MSGS = False
|
||||||
|
|
||||||
|
# Proportional controller gains
|
||||||
|
_K_ANG = 2.0 # angular gain
|
||||||
|
_K_LIN = 0.8 # linear gain
|
||||||
|
_HIGH_SPEED_THRESHOLD = 3.0 # m/s person velocity triggers fast mode
|
||||||
|
_PREDICT_AHEAD_S = 0.3 # seconds to extrapolate position
|
||||||
|
_TEACH_MIN_DIST = 0.5 # meters between recorded waypoints
|
||||||
|
|
||||||
|
|
||||||
|
class SocialNavNode(Node):
|
||||||
|
"""Main social navigation orchestrator."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('social_nav')
|
||||||
|
|
||||||
|
# -- Parameters --
|
||||||
|
self.declare_parameter('follow_mode', 'shadow')
|
||||||
|
self.declare_parameter('follow_distance', 1.2)
|
||||||
|
self.declare_parameter('lead_distance', 2.0)
|
||||||
|
self.declare_parameter('orbit_radius', 1.5)
|
||||||
|
self.declare_parameter('max_linear_speed', 1.0)
|
||||||
|
self.declare_parameter('max_linear_speed_fast', 5.5)
|
||||||
|
self.declare_parameter('max_angular_speed', 1.0)
|
||||||
|
self.declare_parameter('goal_tolerance', 0.3)
|
||||||
|
self.declare_parameter('routes_dir', '/mnt/nvme/saltybot/routes')
|
||||||
|
self.declare_parameter('home_x', 0.0)
|
||||||
|
self.declare_parameter('home_y', 0.0)
|
||||||
|
self.declare_parameter('map_resolution', 0.05)
|
||||||
|
self.declare_parameter('obstacle_inflation_cells', 3)
|
||||||
|
|
||||||
|
self._follow_mode = FollowMode(
|
||||||
|
self.get_parameter('follow_mode').value)
|
||||||
|
self._follow_distance = self.get_parameter('follow_distance').value
|
||||||
|
self._lead_distance = self.get_parameter('lead_distance').value
|
||||||
|
self._orbit_radius = self.get_parameter('orbit_radius').value
|
||||||
|
self._max_lin = self.get_parameter('max_linear_speed').value
|
||||||
|
self._max_lin_fast = self.get_parameter('max_linear_speed_fast').value
|
||||||
|
self._max_ang = self.get_parameter('max_angular_speed').value
|
||||||
|
self._goal_tol = self.get_parameter('goal_tolerance').value
|
||||||
|
self._routes_dir = self.get_parameter('routes_dir').value
|
||||||
|
self._home_x = self.get_parameter('home_x').value
|
||||||
|
self._home_y = self.get_parameter('home_y').value
|
||||||
|
self._map_resolution = self.get_parameter('map_resolution').value
|
||||||
|
self._inflation_cells = self.get_parameter(
|
||||||
|
'obstacle_inflation_cells').value
|
||||||
|
|
||||||
|
# -- State --
|
||||||
|
self._robot_x = 0.0
|
||||||
|
self._robot_y = 0.0
|
||||||
|
self._robot_yaw = 0.0
|
||||||
|
self._target_person_pos = None # (x, y, z)
|
||||||
|
self._target_person_bearing = 0.0
|
||||||
|
self._target_person_id = -1
|
||||||
|
self._person_history = deque(maxlen=5) # for velocity estimation
|
||||||
|
self._stopped = False
|
||||||
|
self._go_home = False
|
||||||
|
|
||||||
|
# Occupancy grid for A*
|
||||||
|
self._occ_grid = None
|
||||||
|
self._occ_origin = (0.0, 0.0)
|
||||||
|
self._occ_resolution = 0.05
|
||||||
|
self._astar_path = None
|
||||||
|
|
||||||
|
# Orbit state
|
||||||
|
self._orbit_angle = 0.0
|
||||||
|
|
||||||
|
# Waypoint teaching / replay
|
||||||
|
self._teaching = False
|
||||||
|
self._current_route = None
|
||||||
|
self._last_teach_x = None
|
||||||
|
self._last_teach_y = None
|
||||||
|
self._replayer = None
|
||||||
|
|
||||||
|
# -- QoS profiles --
|
||||||
|
best_effort_qos = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
history=HistoryPolicy.KEEP_LAST, depth=1)
|
||||||
|
reliable_qos = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.RELIABLE,
|
||||||
|
history=HistoryPolicy.KEEP_LAST, depth=1)
|
||||||
|
|
||||||
|
# -- Subscriptions --
|
||||||
|
if _HAS_SOCIAL_MSGS:
|
||||||
|
self.create_subscription(
|
||||||
|
PersonStateArray, '/social/persons',
|
||||||
|
self._on_persons, best_effort_qos)
|
||||||
|
else:
|
||||||
|
self.get_logger().warn(
|
||||||
|
'saltybot_social_msgs not found; '
|
||||||
|
'using /person/target fallback')
|
||||||
|
|
||||||
|
self.create_subscription(
|
||||||
|
PoseStamped, '/person/target',
|
||||||
|
self._on_person_target, best_effort_qos)
|
||||||
|
self.create_subscription(
|
||||||
|
String, '/social/speech/command',
|
||||||
|
self._on_voice_command, 10)
|
||||||
|
self.create_subscription(
|
||||||
|
String, '/social/speech/transcript',
|
||||||
|
self._on_transcript, 10)
|
||||||
|
self.create_subscription(
|
||||||
|
OccupancyGrid, '/map',
|
||||||
|
self._on_map, reliable_qos)
|
||||||
|
self.create_subscription(
|
||||||
|
Odometry, '/odom',
|
||||||
|
self._on_odom, best_effort_qos)
|
||||||
|
self.create_subscription(
|
||||||
|
Int32, '/social/attention/target_id',
|
||||||
|
self._on_target_id, 10)
|
||||||
|
|
||||||
|
# -- Publishers --
|
||||||
|
self._cmd_vel_pub = self.create_publisher(
|
||||||
|
Twist, '/cmd_vel', best_effort_qos)
|
||||||
|
self._mode_pub = self.create_publisher(
|
||||||
|
String, '/social/nav/mode', reliable_qos)
|
||||||
|
self._target_pub = self.create_publisher(
|
||||||
|
PoseStamped, '/social/nav/target_pos', 10)
|
||||||
|
self._status_pub = self.create_publisher(
|
||||||
|
String, '/social/nav/status', best_effort_qos)
|
||||||
|
|
||||||
|
# -- Main loop timer (20 Hz) --
|
||||||
|
self._timer = self.create_timer(0.05, self._control_loop)
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f'Social nav started: mode={self._follow_mode.value}, '
|
||||||
|
f'dist={self._follow_distance}m')
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------
|
||||||
|
# Subscriptions
|
||||||
|
# ----------------------------------------------------------------
|
||||||
|
|
||||||
|
def _on_persons(self, msg):
|
||||||
|
"""Handle PersonStateArray from social perception."""
|
||||||
|
target_id = msg.primary_attention_id
|
||||||
|
if self._target_person_id >= 0:
|
||||||
|
target_id = self._target_person_id
|
||||||
|
|
||||||
|
for p in msg.persons:
|
||||||
|
if p.person_id == target_id or (
|
||||||
|
target_id < 0 and len(msg.persons) > 0):
|
||||||
|
pos = (p.position.x, p.position.y, p.position.z)
|
||||||
|
self._update_person_position(pos, p.bearing_deg)
|
||||||
|
break
|
||||||
|
|
||||||
|
def _on_person_target(self, msg: PoseStamped):
|
||||||
|
"""Fallback: single person target pose."""
|
||||||
|
pos = (msg.pose.position.x, msg.pose.position.y,
|
||||||
|
msg.pose.position.z)
|
||||||
|
# Estimate bearing from quaternion yaw
|
||||||
|
q = msg.pose.orientation
|
||||||
|
yaw = math.atan2(2.0 * (q.w * q.z + q.x * q.y),
|
||||||
|
1.0 - 2.0 * (q.y * q.y + q.z * q.z))
|
||||||
|
self._update_person_position(pos, math.degrees(yaw))
|
||||||
|
|
||||||
|
def _update_person_position(self, pos, bearing_deg):
|
||||||
|
"""Update person tracking state and record history."""
|
||||||
|
now = time.time()
|
||||||
|
self._target_person_pos = pos
|
||||||
|
self._target_person_bearing = bearing_deg
|
||||||
|
self._person_history.append((now, pos[0], pos[1]))
|
||||||
|
|
||||||
|
def _on_odom(self, msg: Odometry):
|
||||||
|
"""Update robot pose from odometry."""
|
||||||
|
self._robot_x = msg.pose.pose.position.x
|
||||||
|
self._robot_y = msg.pose.pose.position.y
|
||||||
|
q = msg.pose.pose.orientation
|
||||||
|
self._robot_yaw = math.atan2(
|
||||||
|
2.0 * (q.w * q.z + q.x * q.y),
|
||||||
|
1.0 - 2.0 * (q.y * q.y + q.z * q.z))
|
||||||
|
|
||||||
|
def _on_map(self, msg: OccupancyGrid):
|
||||||
|
"""Cache occupancy grid for A* planning."""
|
||||||
|
w, h = msg.info.width, msg.info.height
|
||||||
|
data = np.array(msg.data, dtype=np.int8).reshape((h, w))
|
||||||
|
# Convert -1 (unknown) to free (0) for planning
|
||||||
|
data[data < 0] = 0
|
||||||
|
self._occ_grid = data.astype(np.int32)
|
||||||
|
self._occ_origin = (msg.info.origin.position.x,
|
||||||
|
msg.info.origin.position.y)
|
||||||
|
self._occ_resolution = msg.info.resolution
|
||||||
|
|
||||||
|
def _on_target_id(self, msg: Int32):
|
||||||
|
"""Switch target person."""
|
||||||
|
self._target_person_id = msg.data
|
||||||
|
self.get_logger().info(f'Target person ID set to {msg.data}')
|
||||||
|
|
||||||
|
def _on_voice_command(self, msg: String):
|
||||||
|
"""Handle discrete voice commands for mode switching."""
|
||||||
|
cmd = msg.data.strip().lower()
|
||||||
|
|
||||||
|
if cmd in ('stop', 'halt'):
|
||||||
|
self._stopped = True
|
||||||
|
self._replayer = None
|
||||||
|
self._publish_status('STOPPED')
|
||||||
|
return
|
||||||
|
|
||||||
|
if cmd in ('resume', 'go', 'start'):
|
||||||
|
self._stopped = False
|
||||||
|
self._publish_status('RESUMED')
|
||||||
|
return
|
||||||
|
|
||||||
|
matched = MODE_VOICE_COMMANDS.get(cmd)
|
||||||
|
if matched:
|
||||||
|
self._follow_mode = matched
|
||||||
|
self._stopped = False
|
||||||
|
mode_msg = String()
|
||||||
|
mode_msg.data = self._follow_mode.value
|
||||||
|
self._mode_pub.publish(mode_msg)
|
||||||
|
self._publish_status(f'MODE: {self._follow_mode.value}')
|
||||||
|
|
||||||
|
def _on_transcript(self, msg: String):
|
||||||
|
"""Handle free-form voice transcripts for route teaching."""
|
||||||
|
text = msg.data.strip().lower()
|
||||||
|
|
||||||
|
# "teach route <name>"
|
||||||
|
m = re.match(r'teach\s+route\s+(\w+)', text)
|
||||||
|
if m:
|
||||||
|
name = m.group(1)
|
||||||
|
self._teaching = True
|
||||||
|
self._current_route = WaypointRoute(name)
|
||||||
|
self._last_teach_x = self._robot_x
|
||||||
|
self._last_teach_y = self._robot_y
|
||||||
|
self._publish_status(f'TEACHING: {name}')
|
||||||
|
self.get_logger().info(f'Recording route: {name}')
|
||||||
|
return
|
||||||
|
|
||||||
|
# "stop teaching"
|
||||||
|
if 'stop teaching' in text:
|
||||||
|
if self._teaching and self._current_route:
|
||||||
|
self._current_route.save(self._routes_dir)
|
||||||
|
self._publish_status(
|
||||||
|
f'SAVED: {self._current_route.name} '
|
||||||
|
f'({len(self._current_route.waypoints)} pts)')
|
||||||
|
self.get_logger().info(
|
||||||
|
f'Route saved: {self._current_route.name}')
|
||||||
|
self._teaching = False
|
||||||
|
self._current_route = None
|
||||||
|
return
|
||||||
|
|
||||||
|
# "replay route <name>"
|
||||||
|
m = re.match(r'replay\s+route\s+(\w+)', text)
|
||||||
|
if m:
|
||||||
|
name = m.group(1)
|
||||||
|
try:
|
||||||
|
route = WaypointRoute.load(self._routes_dir, name)
|
||||||
|
self._replayer = WaypointReplayer(route)
|
||||||
|
self._stopped = False
|
||||||
|
self._publish_status(f'REPLAY: {name}')
|
||||||
|
self.get_logger().info(f'Replaying route: {name}')
|
||||||
|
except FileNotFoundError:
|
||||||
|
self._publish_status(f'ROUTE NOT FOUND: {name}')
|
||||||
|
return
|
||||||
|
|
||||||
|
# "go home"
|
||||||
|
if 'go home' in text:
|
||||||
|
self._go_home = True
|
||||||
|
self._stopped = False
|
||||||
|
self._publish_status('GO HOME')
|
||||||
|
return
|
||||||
|
|
||||||
|
# Also try mode commands from transcript
|
||||||
|
for phrase, mode in MODE_VOICE_COMMANDS.items():
|
||||||
|
if phrase in text:
|
||||||
|
self._follow_mode = mode
|
||||||
|
self._stopped = False
|
||||||
|
mode_msg = String()
|
||||||
|
mode_msg.data = self._follow_mode.value
|
||||||
|
self._mode_pub.publish(mode_msg)
|
||||||
|
self._publish_status(f'MODE: {self._follow_mode.value}')
|
||||||
|
return
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------
|
||||||
|
# Control loop
|
||||||
|
# ----------------------------------------------------------------
|
||||||
|
|
||||||
|
def _control_loop(self):
|
||||||
|
"""Main 20Hz control loop."""
|
||||||
|
# Record waypoint if teaching
|
||||||
|
if self._teaching and self._current_route:
|
||||||
|
self._maybe_record_waypoint()
|
||||||
|
|
||||||
|
# Publish zero velocity if stopped
|
||||||
|
if self._stopped:
|
||||||
|
self._publish_cmd_vel(0.0, 0.0)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Determine navigation target
|
||||||
|
target = self._get_nav_target()
|
||||||
|
if target is None:
|
||||||
|
self._publish_cmd_vel(0.0, 0.0)
|
||||||
|
return
|
||||||
|
|
||||||
|
tx, ty, tz = target
|
||||||
|
|
||||||
|
# Publish debug target
|
||||||
|
self._publish_target_pose(tx, ty, tz)
|
||||||
|
|
||||||
|
# Check if arrived
|
||||||
|
dist_to_target = math.hypot(tx - self._robot_x, ty - self._robot_y)
|
||||||
|
if dist_to_target < self._goal_tol:
|
||||||
|
self._publish_cmd_vel(0.0, 0.0)
|
||||||
|
if self._go_home:
|
||||||
|
self._go_home = False
|
||||||
|
self._publish_status('HOME REACHED')
|
||||||
|
return
|
||||||
|
|
||||||
|
# Try A* path if map available
|
||||||
|
if self._occ_grid is not None:
|
||||||
|
path_target = self._plan_astar(tx, ty)
|
||||||
|
if path_target:
|
||||||
|
tx, ty = path_target
|
||||||
|
|
||||||
|
# Determine speed limit
|
||||||
|
max_lin = self._max_lin
|
||||||
|
person_vel = self._estimate_person_velocity()
|
||||||
|
if person_vel > _HIGH_SPEED_THRESHOLD:
|
||||||
|
max_lin = self._max_lin_fast
|
||||||
|
|
||||||
|
# Compute and publish cmd_vel
|
||||||
|
lin, ang = self._compute_cmd_vel(
|
||||||
|
self._robot_x, self._robot_y, self._robot_yaw,
|
||||||
|
tx, ty, max_lin)
|
||||||
|
self._publish_cmd_vel(lin, ang)
|
||||||
|
|
||||||
|
def _get_nav_target(self):
|
||||||
|
"""Determine current navigation target based on mode/state."""
|
||||||
|
# Route replay takes priority
|
||||||
|
if self._replayer and not self._replayer.is_done:
|
||||||
|
self._replayer.check_arrived(self._robot_x, self._robot_y)
|
||||||
|
wp = self._replayer.current_waypoint()
|
||||||
|
if wp:
|
||||||
|
return (wp.x, wp.y, wp.z)
|
||||||
|
else:
|
||||||
|
self._replayer = None
|
||||||
|
self._publish_status('REPLAY DONE')
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Go home
|
||||||
|
if self._go_home:
|
||||||
|
return (self._home_x, self._home_y, 0.0)
|
||||||
|
|
||||||
|
# Person following
|
||||||
|
if self._target_person_pos is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Predict person position ahead for high-speed tracking
|
||||||
|
px, py, pz = self._predict_person_position()
|
||||||
|
bearing = self._target_person_bearing
|
||||||
|
robot_pos = (self._robot_x, self._robot_y, 0.0)
|
||||||
|
|
||||||
|
if self._follow_mode == FollowMode.SHADOW:
|
||||||
|
return compute_shadow_target(
|
||||||
|
(px, py, pz), bearing, self._follow_distance)
|
||||||
|
elif self._follow_mode == FollowMode.LEAD:
|
||||||
|
return compute_lead_target(
|
||||||
|
(px, py, pz), bearing, self._lead_distance)
|
||||||
|
elif self._follow_mode == FollowMode.SIDE:
|
||||||
|
return compute_side_target(
|
||||||
|
(px, py, pz), bearing, self._follow_distance)
|
||||||
|
elif self._follow_mode == FollowMode.ORBIT:
|
||||||
|
self._orbit_angle = (self._orbit_angle + 1.0) % 360.0
|
||||||
|
return compute_orbit_target(
|
||||||
|
(px, py, pz), self._orbit_angle, self._orbit_radius)
|
||||||
|
elif self._follow_mode == FollowMode.LOOSE:
|
||||||
|
return compute_loose_target(
|
||||||
|
(px, py, pz), robot_pos, self._follow_distance)
|
||||||
|
elif self._follow_mode == FollowMode.TIGHT:
|
||||||
|
return compute_tight_target(
|
||||||
|
(px, py, pz), self._follow_distance)
|
||||||
|
|
||||||
|
return (px, py, pz)
|
||||||
|
|
||||||
|
def _predict_person_position(self):
|
||||||
|
"""Extrapolate person position using velocity from recent samples."""
|
||||||
|
if self._target_person_pos is None:
|
||||||
|
return (0.0, 0.0, 0.0)
|
||||||
|
|
||||||
|
px, py, pz = self._target_person_pos
|
||||||
|
|
||||||
|
if len(self._person_history) >= 3:
|
||||||
|
# Use last 3 samples for velocity estimation
|
||||||
|
t0, x0, y0 = self._person_history[-3]
|
||||||
|
t1, x1, y1 = self._person_history[-1]
|
||||||
|
dt = t1 - t0
|
||||||
|
if dt > 0.01:
|
||||||
|
vx = (x1 - x0) / dt
|
||||||
|
vy = (y1 - y0) / dt
|
||||||
|
speed = math.hypot(vx, vy)
|
||||||
|
if speed > _HIGH_SPEED_THRESHOLD:
|
||||||
|
px += vx * _PREDICT_AHEAD_S
|
||||||
|
py += vy * _PREDICT_AHEAD_S
|
||||||
|
|
||||||
|
return (px, py, pz)
|
||||||
|
|
||||||
|
def _estimate_person_velocity(self) -> float:
|
||||||
|
"""Estimate person speed from recent position history."""
|
||||||
|
if len(self._person_history) < 2:
|
||||||
|
return 0.0
|
||||||
|
t0, x0, y0 = self._person_history[-2]
|
||||||
|
t1, x1, y1 = self._person_history[-1]
|
||||||
|
dt = t1 - t0
|
||||||
|
if dt < 0.01:
|
||||||
|
return 0.0
|
||||||
|
return math.hypot(x1 - x0, y1 - y0) / dt
|
||||||
|
|
||||||
|
def _plan_astar(self, target_x, target_y):
|
||||||
|
"""Run A* on occupancy grid, return next waypoint in world coords."""
|
||||||
|
grid = self._occ_grid
|
||||||
|
res = self._occ_resolution
|
||||||
|
ox, oy = self._occ_origin
|
||||||
|
|
||||||
|
# World to grid
|
||||||
|
def w2g(wx, wy):
|
||||||
|
return (int((wy - oy) / res), int((wx - ox) / res))
|
||||||
|
|
||||||
|
start = w2g(self._robot_x, self._robot_y)
|
||||||
|
goal = w2g(target_x, target_y)
|
||||||
|
|
||||||
|
rows, cols = grid.shape
|
||||||
|
if not (0 <= start[0] < rows and 0 <= start[1] < cols):
|
||||||
|
return None
|
||||||
|
if not (0 <= goal[0] < rows and 0 <= goal[1] < cols):
|
||||||
|
return None
|
||||||
|
|
||||||
|
inflated = inflate_obstacles(grid, self._inflation_cells)
|
||||||
|
path = astar(inflated, start, goal)
|
||||||
|
|
||||||
|
if path and len(path) > 1:
|
||||||
|
# Follow a lookahead point (3 steps ahead or end)
|
||||||
|
lookahead_idx = min(3, len(path) - 1)
|
||||||
|
r, c = path[lookahead_idx]
|
||||||
|
wx = ox + c * res + res / 2.0
|
||||||
|
wy = oy + r * res + res / 2.0
|
||||||
|
self._astar_path = path
|
||||||
|
return (wx, wy)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _compute_cmd_vel(self, rx, ry, ryaw, tx, ty, max_lin):
|
||||||
|
"""Proportional controller: compute linear and angular velocity."""
|
||||||
|
dx = tx - rx
|
||||||
|
dy = ty - ry
|
||||||
|
dist = math.hypot(dx, dy)
|
||||||
|
angle_to_target = math.atan2(dy, dx)
|
||||||
|
angle_error = angle_to_target - ryaw
|
||||||
|
|
||||||
|
# Normalize angle error to [-pi, pi]
|
||||||
|
while angle_error > math.pi:
|
||||||
|
angle_error -= 2.0 * math.pi
|
||||||
|
while angle_error < -math.pi:
|
||||||
|
angle_error += 2.0 * math.pi
|
||||||
|
|
||||||
|
angular_vel = _K_ANG * angle_error
|
||||||
|
angular_vel = max(-self._max_ang,
|
||||||
|
min(self._max_ang, angular_vel))
|
||||||
|
|
||||||
|
# Reduce linear speed when turning hard
|
||||||
|
angle_factor = max(0.0, 1.0 - abs(angle_error) / (math.pi / 2.0))
|
||||||
|
linear_vel = _K_LIN * dist * angle_factor
|
||||||
|
linear_vel = max(0.0, min(max_lin, linear_vel))
|
||||||
|
|
||||||
|
return (linear_vel, angular_vel)
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------
|
||||||
|
# Waypoint teaching
|
||||||
|
# ----------------------------------------------------------------
|
||||||
|
|
||||||
|
def _maybe_record_waypoint(self):
|
||||||
|
"""Record waypoint if robot moved > _TEACH_MIN_DIST."""
|
||||||
|
if self._last_teach_x is None:
|
||||||
|
self._last_teach_x = self._robot_x
|
||||||
|
self._last_teach_y = self._robot_y
|
||||||
|
|
||||||
|
dist = math.hypot(
|
||||||
|
self._robot_x - self._last_teach_x,
|
||||||
|
self._robot_y - self._last_teach_y)
|
||||||
|
|
||||||
|
if dist >= _TEACH_MIN_DIST:
|
||||||
|
yaw_deg = math.degrees(self._robot_yaw)
|
||||||
|
self._current_route.add(
|
||||||
|
self._robot_x, self._robot_y, 0.0, yaw_deg)
|
||||||
|
self._last_teach_x = self._robot_x
|
||||||
|
self._last_teach_y = self._robot_y
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------
|
||||||
|
# Publishers
|
||||||
|
# ----------------------------------------------------------------
|
||||||
|
|
||||||
|
def _publish_cmd_vel(self, linear: float, angular: float):
|
||||||
|
twist = Twist()
|
||||||
|
twist.linear.x = linear
|
||||||
|
twist.angular.z = angular
|
||||||
|
self._cmd_vel_pub.publish(twist)
|
||||||
|
|
||||||
|
def _publish_target_pose(self, x, y, z):
|
||||||
|
msg = PoseStamped()
|
||||||
|
msg.header.stamp = self.get_clock().now().to_msg()
|
||||||
|
msg.header.frame_id = 'map'
|
||||||
|
msg.pose.position.x = x
|
||||||
|
msg.pose.position.y = y
|
||||||
|
msg.pose.position.z = z
|
||||||
|
self._target_pub.publish(msg)
|
||||||
|
|
||||||
|
def _publish_status(self, status: str):
|
||||||
|
msg = String()
|
||||||
|
msg.data = status
|
||||||
|
self._status_pub.publish(msg)
|
||||||
|
self.get_logger().info(f'Nav status: {status}')
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = SocialNavNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -0,0 +1,91 @@
|
|||||||
|
"""waypoint_teacher.py -- Record and replay waypoint routes."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Waypoint:
|
||||||
|
x: float
|
||||||
|
y: float
|
||||||
|
z: float
|
||||||
|
yaw_deg: float
|
||||||
|
timestamp: float
|
||||||
|
label: str = ''
|
||||||
|
|
||||||
|
|
||||||
|
class WaypointRoute:
|
||||||
|
"""A named sequence of waypoints."""
|
||||||
|
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.name = name
|
||||||
|
self.waypoints: list[Waypoint] = []
|
||||||
|
self.created_at = time.time()
|
||||||
|
|
||||||
|
def add(self, x, y, z, yaw_deg, label=''):
|
||||||
|
self.waypoints.append(Waypoint(x, y, z, yaw_deg, time.time(), label))
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
'name': self.name,
|
||||||
|
'created_at': self.created_at,
|
||||||
|
'waypoints': [asdict(w) for w in self.waypoints],
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d):
|
||||||
|
route = cls(d['name'])
|
||||||
|
route.created_at = d.get('created_at', 0)
|
||||||
|
route.waypoints = [Waypoint(**w) for w in d['waypoints']]
|
||||||
|
return route
|
||||||
|
|
||||||
|
def save(self, routes_dir: str):
|
||||||
|
Path(routes_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
path = Path(routes_dir) / f'{self.name}.json'
|
||||||
|
path.write_text(json.dumps(self.to_dict(), indent=2))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, routes_dir: str, name: str):
|
||||||
|
path = Path(routes_dir) / f'{name}.json'
|
||||||
|
return cls.from_dict(json.loads(path.read_text()))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def list_routes(routes_dir: str) -> list[str]:
|
||||||
|
d = Path(routes_dir)
|
||||||
|
if not d.exists():
|
||||||
|
return []
|
||||||
|
return [p.stem for p in d.glob('*.json')]
|
||||||
|
|
||||||
|
|
||||||
|
class WaypointReplayer:
|
||||||
|
"""Iterates through waypoints, returning next target."""
|
||||||
|
|
||||||
|
def __init__(self, route: WaypointRoute, arrival_radius: float = 0.3):
|
||||||
|
self._route = route
|
||||||
|
self._idx = 0
|
||||||
|
self._arrival_radius = arrival_radius
|
||||||
|
|
||||||
|
def current_waypoint(self) -> Waypoint | None:
|
||||||
|
if self._idx < len(self._route.waypoints):
|
||||||
|
return self._route.waypoints[self._idx]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def check_arrived(self, robot_x, robot_y) -> bool:
|
||||||
|
wp = self.current_waypoint()
|
||||||
|
if wp is None:
|
||||||
|
return False
|
||||||
|
dist = math.hypot(robot_x - wp.x, robot_y - wp.y)
|
||||||
|
if dist < self._arrival_radius:
|
||||||
|
self._idx += 1
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_done(self) -> bool:
|
||||||
|
return self._idx >= len(self._route.waypoints)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._idx = 0
|
||||||
@ -0,0 +1,135 @@
|
|||||||
|
"""
|
||||||
|
waypoint_teacher_node.py -- Standalone waypoint teacher ROS2 node.
|
||||||
|
|
||||||
|
Listens to /social/speech/transcript for "teach route <name>" and "stop teaching".
|
||||||
|
Records robot pose at configurable intervals. Saves/loads routes via WaypointRoute.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||||
|
from nav_msgs.msg import Odometry
|
||||||
|
from std_msgs.msg import String
|
||||||
|
|
||||||
|
from .waypoint_teacher import WaypointRoute
|
||||||
|
|
||||||
|
|
||||||
|
class WaypointTeacherNode(Node):
|
||||||
|
"""Standalone waypoint teaching node."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('waypoint_teacher')
|
||||||
|
|
||||||
|
self.declare_parameter('routes_dir', '/mnt/nvme/saltybot/routes')
|
||||||
|
self.declare_parameter('recording_interval', 0.5) # meters
|
||||||
|
|
||||||
|
self._routes_dir = self.get_parameter('routes_dir').value
|
||||||
|
self._interval = self.get_parameter('recording_interval').value
|
||||||
|
|
||||||
|
self._teaching = False
|
||||||
|
self._route = None
|
||||||
|
self._last_x = None
|
||||||
|
self._last_y = None
|
||||||
|
self._robot_x = 0.0
|
||||||
|
self._robot_y = 0.0
|
||||||
|
self._robot_yaw = 0.0
|
||||||
|
|
||||||
|
best_effort_qos = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
history=HistoryPolicy.KEEP_LAST, depth=1)
|
||||||
|
|
||||||
|
self.create_subscription(
|
||||||
|
Odometry, '/odom', self._on_odom, best_effort_qos)
|
||||||
|
self.create_subscription(
|
||||||
|
String, '/social/speech/transcript',
|
||||||
|
self._on_transcript, 10)
|
||||||
|
|
||||||
|
self._status_pub = self.create_publisher(
|
||||||
|
String, '/social/waypoint/status', 10)
|
||||||
|
|
||||||
|
# Record timer at 10Hz (check distance)
|
||||||
|
self._timer = self.create_timer(0.1, self._record_tick)
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f'Waypoint teacher ready (interval={self._interval}m, '
|
||||||
|
f'dir={self._routes_dir})')
|
||||||
|
|
||||||
|
def _on_odom(self, msg: Odometry):
|
||||||
|
self._robot_x = msg.pose.pose.position.x
|
||||||
|
self._robot_y = msg.pose.pose.position.y
|
||||||
|
q = msg.pose.pose.orientation
|
||||||
|
self._robot_yaw = math.atan2(
|
||||||
|
2.0 * (q.w * q.z + q.x * q.y),
|
||||||
|
1.0 - 2.0 * (q.y * q.y + q.z * q.z))
|
||||||
|
|
||||||
|
def _on_transcript(self, msg: String):
|
||||||
|
text = msg.data.strip().lower()
|
||||||
|
|
||||||
|
import re
|
||||||
|
m = re.match(r'teach\s+route\s+(\w+)', text)
|
||||||
|
if m:
|
||||||
|
name = m.group(1)
|
||||||
|
self._route = WaypointRoute(name)
|
||||||
|
self._teaching = True
|
||||||
|
self._last_x = self._robot_x
|
||||||
|
self._last_y = self._robot_y
|
||||||
|
self._pub_status(f'RECORDING: {name}')
|
||||||
|
self.get_logger().info(f'Recording route: {name}')
|
||||||
|
return
|
||||||
|
|
||||||
|
if 'stop teaching' in text:
|
||||||
|
if self._teaching and self._route:
|
||||||
|
self._route.save(self._routes_dir)
|
||||||
|
n = len(self._route.waypoints)
|
||||||
|
self._pub_status(
|
||||||
|
f'SAVED: {self._route.name} ({n} waypoints)')
|
||||||
|
self.get_logger().info(
|
||||||
|
f'Route saved: {self._route.name} ({n} pts)')
|
||||||
|
self._teaching = False
|
||||||
|
self._route = None
|
||||||
|
return
|
||||||
|
|
||||||
|
if 'list routes' in text:
|
||||||
|
routes = WaypointRoute.list_routes(self._routes_dir)
|
||||||
|
self._pub_status(f'ROUTES: {", ".join(routes) or "(none)"}')
|
||||||
|
|
||||||
|
def _record_tick(self):
|
||||||
|
if not self._teaching or self._route is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._last_x is None:
|
||||||
|
self._last_x = self._robot_x
|
||||||
|
self._last_y = self._robot_y
|
||||||
|
|
||||||
|
dist = math.hypot(
|
||||||
|
self._robot_x - self._last_x,
|
||||||
|
self._robot_y - self._last_y)
|
||||||
|
|
||||||
|
if dist >= self._interval:
|
||||||
|
yaw_deg = math.degrees(self._robot_yaw)
|
||||||
|
self._route.add(self._robot_x, self._robot_y, 0.0, yaw_deg)
|
||||||
|
self._last_x = self._robot_x
|
||||||
|
self._last_y = self._robot_y
|
||||||
|
|
||||||
|
def _pub_status(self, text: str):
|
||||||
|
msg = String()
|
||||||
|
msg.data = text
|
||||||
|
self._status_pub.publish(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = WaypointTeacherNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -0,0 +1,80 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
build_midas_trt_engine.py -- Build TensorRT FP16 engine for MiDaS_small from ONNX.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 build_midas_trt_engine.py \
|
||||||
|
--onnx /mnt/nvme/saltybot/models/midas_small.onnx \
|
||||||
|
--engine /mnt/nvme/saltybot/models/midas_small.engine \
|
||||||
|
--fp16
|
||||||
|
|
||||||
|
Requires: tensorrt, pycuda
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def build_engine(onnx_path: str, engine_path: str, fp16: bool = True):
|
||||||
|
try:
|
||||||
|
import tensorrt as trt
|
||||||
|
except ImportError:
|
||||||
|
print('ERROR: tensorrt not found. Install TensorRT first.')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
logger = trt.Logger(trt.Logger.INFO)
|
||||||
|
builder = trt.Builder(logger)
|
||||||
|
network = builder.create_network(
|
||||||
|
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
||||||
|
parser = trt.OnnxParser(network, logger)
|
||||||
|
|
||||||
|
print(f'Parsing ONNX model: {onnx_path}')
|
||||||
|
with open(onnx_path, 'rb') as f:
|
||||||
|
if not parser.parse(f.read()):
|
||||||
|
for i in range(parser.num_errors):
|
||||||
|
print(f' ONNX parse error: {parser.get_error(i)}')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
config = builder.create_builder_config()
|
||||||
|
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB
|
||||||
|
|
||||||
|
if fp16 and builder.platform_has_fast_fp16:
|
||||||
|
config.set_flag(trt.BuilderFlag.FP16)
|
||||||
|
print('FP16 mode enabled')
|
||||||
|
elif fp16:
|
||||||
|
print('WARNING: FP16 not supported on this platform, using FP32')
|
||||||
|
|
||||||
|
print('Building TensorRT engine (this may take several minutes)...')
|
||||||
|
engine_bytes = builder.build_serialized_network(network, config)
|
||||||
|
if engine_bytes is None:
|
||||||
|
print('ERROR: Failed to build engine')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(engine_path) or '.', exist_ok=True)
|
||||||
|
with open(engine_path, 'wb') as f:
|
||||||
|
f.write(engine_bytes)
|
||||||
|
|
||||||
|
size_mb = len(engine_bytes) / (1024 * 1024)
|
||||||
|
print(f'Engine saved: {engine_path} ({size_mb:.1f} MB)')
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Build TensorRT FP16 engine for MiDaS_small')
|
||||||
|
parser.add_argument('--onnx', required=True,
|
||||||
|
help='Path to MiDaS ONNX model')
|
||||||
|
parser.add_argument('--engine', required=True,
|
||||||
|
help='Output TRT engine path')
|
||||||
|
parser.add_argument('--fp16', action='store_true', default=True,
|
||||||
|
help='Enable FP16 (default: True)')
|
||||||
|
parser.add_argument('--fp32', action='store_true',
|
||||||
|
help='Force FP32 (disable FP16)')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
fp16 = not args.fp32
|
||||||
|
build_engine(args.onnx, args.engine, fp16=fp16)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
4
jetson/ros2_ws/src/saltybot_social_nav/setup.cfg
Normal file
4
jetson/ros2_ws/src/saltybot_social_nav/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
[develop]
|
||||||
|
script_dir=$base/lib/saltybot_social_nav
|
||||||
|
[install]
|
||||||
|
install_scripts=$base/lib/saltybot_social_nav
|
||||||
31
jetson/ros2_ws/src/saltybot_social_nav/setup.py
Normal file
31
jetson/ros2_ws/src/saltybot_social_nav/setup.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from setuptools import setup
|
||||||
|
import os
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
package_name = 'saltybot_social_nav'
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name=package_name,
|
||||||
|
version='0.1.0',
|
||||||
|
packages=[package_name],
|
||||||
|
data_files=[
|
||||||
|
('share/ament_index/resource_index/packages', ['resource/' + package_name]),
|
||||||
|
('share/' + package_name, ['package.xml']),
|
||||||
|
(os.path.join('share', package_name, 'launch'), glob('launch/*.py')),
|
||||||
|
(os.path.join('share', package_name, 'config'), glob('config/*.yaml')),
|
||||||
|
],
|
||||||
|
install_requires=['setuptools'],
|
||||||
|
zip_safe=True,
|
||||||
|
maintainer='seb',
|
||||||
|
maintainer_email='seb@vayrette.com',
|
||||||
|
description='Social navigation for saltybot: follow modes, waypoint teaching, A* avoidance, MiDaS depth',
|
||||||
|
license='MIT',
|
||||||
|
tests_require=['pytest'],
|
||||||
|
entry_points={
|
||||||
|
'console_scripts': [
|
||||||
|
'social_nav = saltybot_social_nav.social_nav_node:main',
|
||||||
|
'midas_depth = saltybot_social_nav.midas_depth_node:main',
|
||||||
|
'waypoint_teacher = saltybot_social_nav.waypoint_teacher_node:main',
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
@ -0,0 +1,12 @@
|
|||||||
|
# Copyright 2026 SaltyLab
|
||||||
|
# Licensed under MIT
|
||||||
|
|
||||||
|
from ament_copyright.main import main
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.copyright
|
||||||
|
@pytest.mark.linter
|
||||||
|
def test_copyright():
|
||||||
|
rc = main(argv=['.', 'test'])
|
||||||
|
assert rc == 0, 'Found errors'
|
||||||
14
jetson/ros2_ws/src/saltybot_social_nav/test/test_flake8.py
Normal file
14
jetson/ros2_ws/src/saltybot_social_nav/test/test_flake8.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# Copyright 2026 SaltyLab
|
||||||
|
# Licensed under MIT
|
||||||
|
|
||||||
|
from ament_flake8.main import main_with_errors
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.flake8
|
||||||
|
@pytest.mark.linter
|
||||||
|
def test_flake8():
|
||||||
|
rc, errors = main_with_errors(argv=[])
|
||||||
|
assert rc == 0, \
|
||||||
|
'Found %d code style errors / warnings:\n' % len(errors) + \
|
||||||
|
'\n'.join(errors)
|
||||||
12
jetson/ros2_ws/src/saltybot_social_nav/test/test_pep257.py
Normal file
12
jetson/ros2_ws/src/saltybot_social_nav/test/test_pep257.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# Copyright 2026 SaltyLab
|
||||||
|
# Licensed under MIT
|
||||||
|
|
||||||
|
from ament_pep257.main import main
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.pep257
|
||||||
|
@pytest.mark.linter
|
||||||
|
def test_pep257():
|
||||||
|
rc = main(argv=['.', 'test'])
|
||||||
|
assert rc == 0, 'Found code style errors / warnings'
|
||||||
@ -1,57 +1,24 @@
|
|||||||
# uwb_config.yaml — MaUWB ESP32-S3 DW3000 UWB follow-me system
|
# uwb_config.yaml — MaUWB ESP32-S3 DW3000 UWB integration (Issue #90)
|
||||||
#
|
|
||||||
# Hardware layout:
|
|
||||||
# Anchor-0 (port side) → USB port_a, y = +anchor_separation/2
|
|
||||||
# Anchor-1 (starboard side) → USB port_b, y = -anchor_separation/2
|
|
||||||
# Tag on person → belt clip, ~0.9m above ground
|
|
||||||
#
|
#
|
||||||
# Run with:
|
# Run with:
|
||||||
# ros2 launch saltybot_uwb uwb.launch.py
|
# ros2 launch saltybot_uwb uwb.launch.py
|
||||||
# Override at launch:
|
|
||||||
# ros2 launch saltybot_uwb uwb.launch.py port_a:=/dev/ttyUSB2
|
|
||||||
|
|
||||||
# ── Serial ports ──────────────────────────────────────────────────────────────
|
port_a: /dev/uwb-anchor0
|
||||||
# Set udev rules to get stable symlinks:
|
port_b: /dev/uwb-anchor1
|
||||||
# /dev/uwb-anchor0 → port_a
|
baudrate: 115200
|
||||||
# /dev/uwb-anchor1 → port_b
|
|
||||||
# (See jetson/docs/pinout.md for udev setup)
|
|
||||||
port_a: /dev/uwb-anchor0 # Anchor-0 (port)
|
|
||||||
port_b: /dev/uwb-anchor1 # Anchor-1 (starboard)
|
|
||||||
baudrate: 115200 # MaUWB default — do not change
|
|
||||||
|
|
||||||
# ── Anchor geometry ────────────────────────────────────────────────────────────
|
anchor_separation: 0.25
|
||||||
# anchor_separation: centre-to-centre distance between anchors (metres)
|
anchor_height: 0.80
|
||||||
# Must match physical mounting. Larger = more accurate lateral resolution.
|
tag_height: 0.90
|
||||||
anchor_separation: 0.25 # metres (25cm)
|
|
||||||
|
|
||||||
# anchor_height: height of anchors above ground (metres)
|
range_timeout_s: 1.0
|
||||||
# Orin stem mount ≈ 0.80m on the saltybot platform
|
max_range_m: 8.0
|
||||||
anchor_height: 0.80 # metres
|
min_range_m: 0.05
|
||||||
|
|
||||||
# tag_height: height of person's belt-clip tag above ground (metres)
|
|
||||||
tag_height: 0.90 # metres (adjust per user)
|
|
||||||
|
|
||||||
# ── Range validity ─────────────────────────────────────────────────────────────
|
|
||||||
# range_timeout_s: stale anchor — excluded from triangulation after this gap
|
|
||||||
range_timeout_s: 1.0 # seconds
|
|
||||||
|
|
||||||
# max_range_m: discard ranges beyond this (DW3000 indoor practical limit ≈8m)
|
|
||||||
max_range_m: 8.0 # metres
|
|
||||||
|
|
||||||
# min_range_m: discard ranges below this (likely multipath artefacts)
|
|
||||||
min_range_m: 0.05 # metres
|
|
||||||
|
|
||||||
# ── Kalman filter ──────────────────────────────────────────────────────────────
|
|
||||||
# kf_process_noise: Q scalar — how dynamic the person's motion is
|
|
||||||
# Higher → faster response, more jitter
|
|
||||||
kf_process_noise: 0.1
|
kf_process_noise: 0.1
|
||||||
|
|
||||||
# kf_meas_noise: R scalar — how noisy the UWB ranging is
|
|
||||||
# DW3000 indoor accuracy ≈ 10cm RMS → 0.1m → R ≈ 0.01
|
|
||||||
# Use 0.3 to be conservative on first deployment
|
|
||||||
kf_meas_noise: 0.3
|
kf_meas_noise: 0.3
|
||||||
|
|
||||||
# ── Publish rate ──────────────────────────────────────────────────────────────
|
range_rate: 100.0
|
||||||
# Should match or exceed the AT+RANGE? poll rate from both anchors.
|
bearing_rate: 10.0
|
||||||
# 20Hz means 50ms per cycle; each anchor query takes ~10ms → headroom ok.
|
|
||||||
publish_rate: 20.0 # Hz
|
enrolled_tag_ids: [""]
|
||||||
|
|||||||
@ -1,10 +1,14 @@
|
|||||||
"""
|
"""
|
||||||
uwb.launch.py — Launch UWB driver node for MaUWB ESP32-S3 follow-me.
|
uwb.launch.py — Launch UWB driver node for MaUWB ESP32-S3 DW3000 (Issue #90).
|
||||||
|
|
||||||
|
Topics:
|
||||||
|
/uwb/ranges 100 Hz raw TWR ranges
|
||||||
|
/uwb/bearing 10 Hz Kalman-fused bearing
|
||||||
|
/uwb/target 10 Hz triangulated PoseStamped (backwards compat)
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
ros2 launch saltybot_uwb uwb.launch.py
|
ros2 launch saltybot_uwb uwb.launch.py
|
||||||
ros2 launch saltybot_uwb uwb.launch.py port_a:=/dev/ttyUSB2 port_b:=/dev/ttyUSB3
|
ros2 launch saltybot_uwb uwb.launch.py enrolled_tag_ids:="['0xDEADBEEF']"
|
||||||
ros2 launch saltybot_uwb uwb.launch.py anchor_separation:=0.30 publish_rate:=10.0
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -31,7 +35,9 @@ def generate_launch_description():
|
|||||||
DeclareLaunchArgument("min_range_m", default_value="0.05"),
|
DeclareLaunchArgument("min_range_m", default_value="0.05"),
|
||||||
DeclareLaunchArgument("kf_process_noise", default_value="0.1"),
|
DeclareLaunchArgument("kf_process_noise", default_value="0.1"),
|
||||||
DeclareLaunchArgument("kf_meas_noise", default_value="0.3"),
|
DeclareLaunchArgument("kf_meas_noise", default_value="0.3"),
|
||||||
DeclareLaunchArgument("publish_rate", default_value="20.0"),
|
DeclareLaunchArgument("range_rate", default_value="100.0"),
|
||||||
|
DeclareLaunchArgument("bearing_rate", default_value="10.0"),
|
||||||
|
DeclareLaunchArgument("enrolled_tag_ids", default_value="['']"),
|
||||||
|
|
||||||
Node(
|
Node(
|
||||||
package="saltybot_uwb",
|
package="saltybot_uwb",
|
||||||
@ -52,7 +58,9 @@ def generate_launch_description():
|
|||||||
"min_range_m": LaunchConfiguration("min_range_m"),
|
"min_range_m": LaunchConfiguration("min_range_m"),
|
||||||
"kf_process_noise": LaunchConfiguration("kf_process_noise"),
|
"kf_process_noise": LaunchConfiguration("kf_process_noise"),
|
||||||
"kf_meas_noise": LaunchConfiguration("kf_meas_noise"),
|
"kf_meas_noise": LaunchConfiguration("kf_meas_noise"),
|
||||||
"publish_rate": LaunchConfiguration("publish_rate"),
|
"range_rate": LaunchConfiguration("range_rate"),
|
||||||
|
"bearing_rate": LaunchConfiguration("bearing_rate"),
|
||||||
|
"enrolled_tag_ids": LaunchConfiguration("enrolled_tag_ids"),
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
|||||||
@ -29,6 +29,26 @@ Returns (x_t, y_t); caller should treat negative x_t as 0.
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
# ── Bearing helper ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def bearing_from_pos(x: float, y: float) -> float:
|
||||||
|
"""
|
||||||
|
Compute horizontal bearing to a point (x, y) in base_link.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : forward distance (metres)
|
||||||
|
y : lateral offset (metres, positive = left of robot / CCW)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bearing_rad : bearing in radians, range -π .. +π
|
||||||
|
positive = target to the left (CCW)
|
||||||
|
0 = directly ahead
|
||||||
|
"""
|
||||||
|
return math.atan2(y, x)
|
||||||
|
|
||||||
|
|
||||||
# ── Triangulation ─────────────────────────────────────────────────────────────
|
# ── Triangulation ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def triangulate_2anchor(
|
def triangulate_2anchor(
|
||||||
|
|||||||
@ -1,32 +1,52 @@
|
|||||||
"""
|
"""
|
||||||
uwb_driver_node.py — ROS2 node for MaUWB ESP32-S3 DW3000 follow-me system.
|
uwb_driver_node.py — ROS2 node for MaUWB ESP32-S3 DW3000 UWB integration.
|
||||||
|
|
||||||
Hardware
|
Hardware
|
||||||
────────
|
────────
|
||||||
• 2× MaUWB ESP32-S3 DW3000 anchors on robot stem (USB → Orin Nano)
|
• 2× MaUWB ESP32-S3 DW3000 anchors on robot stem (USB → Orin)
|
||||||
- Anchor-0: port side (y = +sep/2)
|
- Anchor-0: port side (y = +sep/2)
|
||||||
- Anchor-1: starboard (y = -sep/2)
|
- Anchor-1: starboard (y = -sep/2)
|
||||||
• 1× MaUWB tag on person (belt clip)
|
• 1× MaUWB tag per enrolled person (belt clip)
|
||||||
|
|
||||||
AT command interface (115200 8N1)
|
AT command interface (115200 8N1)
|
||||||
──────────────────────────────────
|
──────────────────────────────────
|
||||||
Query: AT+RANGE?\r\n
|
Query:
|
||||||
Response (from anchors):
|
AT+RANGE?\r\n
|
||||||
+RANGE:<anchor_id>,<range_mm>[,<rssi>]\r\n
|
|
||||||
|
|
||||||
Config:
|
Response (from anchors, TWR protocol):
|
||||||
AT+anchor_tag=ANCHOR\r\n — set module as anchor
|
+RANGE:<anchor_id>,<range_mm>[,<rssi>[,<tag_addr>]]\r\n
|
||||||
AT+anchor_tag=TAG\r\n — set module as tag
|
|
||||||
|
Tag pairing (optional — targets a specific enrolled tag):
|
||||||
|
AT+RANGE_ADDR=<tag_addr>\r\n → anchor only ranges with that tag
|
||||||
|
|
||||||
Publishes
|
Publishes
|
||||||
─────────
|
─────────
|
||||||
/uwb/target (geometry_msgs/PoseStamped) — triangulated person position in base_link
|
/uwb/ranges (saltybot_uwb_msgs/UwbRangeArray) 100 Hz — raw anchor ranges
|
||||||
/uwb/ranges (saltybot_uwb_msgs/UwbRangeArray) — raw ranges from both anchors
|
/uwb/bearing (saltybot_uwb_msgs/UwbBearing) 10 Hz — Kalman-fused bearing
|
||||||
|
/uwb/target (geometry_msgs/PoseStamped) 10 Hz — triangulated position
|
||||||
|
(kept for backwards compat)
|
||||||
|
|
||||||
Safety
|
Tag pairing
|
||||||
──────
|
───────────
|
||||||
If a range is stale (> range_timeout_s), that anchor is excluded from
|
Set enrolled_tag_ids to a list of tag address strings (e.g. ["0x1234ABCD"]).
|
||||||
triangulation. If both anchors are stale, /uwb/target is not published.
|
When non-empty, ranges from unrecognised tags are silently discarded.
|
||||||
|
The matched tag address is stamped in UwbRange.tag_id and UwbBearing.tag_id.
|
||||||
|
When enrolled_tag_ids is empty, all ranges are accepted (tag_id = "").
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
──────────
|
||||||
|
port_a, port_b serial ports for anchor-0 / anchor-1
|
||||||
|
baudrate 115200 (default)
|
||||||
|
anchor_separation centre-to-centre anchor spacing (m)
|
||||||
|
anchor_height anchor mounting height (m)
|
||||||
|
tag_height person tag height (m)
|
||||||
|
range_timeout_s stale-anchor threshold (s)
|
||||||
|
max_range_m / min_range_m validity window (m)
|
||||||
|
kf_process_noise Kalman Q scalar
|
||||||
|
kf_meas_noise Kalman R scalar
|
||||||
|
range_rate Hz — /uwb/ranges publish rate (default 100)
|
||||||
|
bearing_rate Hz — /uwb/bearing publish rate (default 10)
|
||||||
|
enrolled_tag_ids list[str] — accepted tag addresses; [] = accept all
|
||||||
|
|
||||||
Usage
|
Usage
|
||||||
─────
|
─────
|
||||||
@ -44,8 +64,8 @@ from rclpy.node import Node
|
|||||||
from geometry_msgs.msg import PoseStamped
|
from geometry_msgs.msg import PoseStamped
|
||||||
from std_msgs.msg import Header
|
from std_msgs.msg import Header
|
||||||
|
|
||||||
from saltybot_uwb_msgs.msg import UwbRange, UwbRangeArray
|
from saltybot_uwb_msgs.msg import UwbRange, UwbRangeArray, UwbBearing
|
||||||
from saltybot_uwb.ranging_math import triangulate_2anchor, KalmanFilter2D
|
from saltybot_uwb.ranging_math import triangulate_2anchor, KalmanFilter2D, bearing_from_pos
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import serial
|
import serial
|
||||||
@ -54,26 +74,31 @@ except ImportError:
|
|||||||
_SERIAL_AVAILABLE = False
|
_SERIAL_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
# Regex: +RANGE:<id>,<mm> or +RANGE:<id>,<mm>,<rssi>
|
# +RANGE:<id>,<mm> or +RANGE:<id>,<mm>,<rssi> or +RANGE:<id>,<mm>,<rssi>,<tag>
|
||||||
_RANGE_RE = re.compile(
|
_RANGE_RE = re.compile(
|
||||||
r"\+RANGE\s*:\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(-?[\d.]+))?",
|
r"\+RANGE\s*:\s*(\d+)\s*,\s*(\d+)"
|
||||||
|
r"(?:\s*,\s*(-?[\d.]+)"
|
||||||
|
r"(?:\s*,\s*([\w:x]+))?"
|
||||||
|
r")?",
|
||||||
re.IGNORECASE,
|
re.IGNORECASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SerialReader(threading.Thread):
|
class SerialReader(threading.Thread):
|
||||||
"""
|
"""
|
||||||
Background thread: polls an anchor's UART, fires callback on every
|
Background thread: polls one anchor's UART at maximum TWR rate,
|
||||||
valid +RANGE response.
|
fires callback on every valid +RANGE response.
|
||||||
|
Supports optional tag pairing via AT+RANGE_ADDR command.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, port, baudrate, anchor_id, callback, logger):
|
def __init__(self, port, baudrate, anchor_id, callback, logger, tag_addr=None):
|
||||||
super().__init__(daemon=True)
|
super().__init__(daemon=True)
|
||||||
self._port = port
|
self._port = port
|
||||||
self._baudrate = baudrate
|
self._baudrate = baudrate
|
||||||
self._anchor_id = anchor_id
|
self._anchor_id = anchor_id
|
||||||
self._callback = callback
|
self._callback = callback
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
|
self._tag_addr = tag_addr
|
||||||
self._running = False
|
self._running = False
|
||||||
self._ser = None
|
self._ser = None
|
||||||
|
|
||||||
@ -86,7 +111,10 @@ class SerialReader(threading.Thread):
|
|||||||
)
|
)
|
||||||
self._logger.info(
|
self._logger.info(
|
||||||
f"Anchor-{self._anchor_id}: opened {self._port}"
|
f"Anchor-{self._anchor_id}: opened {self._port}"
|
||||||
|
+ (f" paired with tag {self._tag_addr}" if self._tag_addr else "")
|
||||||
)
|
)
|
||||||
|
if self._tag_addr:
|
||||||
|
self._send_pairing_cmd()
|
||||||
self._read_loop()
|
self._read_loop()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
self._logger.warn(
|
self._logger.warn(
|
||||||
@ -96,12 +124,24 @@ class SerialReader(threading.Thread):
|
|||||||
self._ser.close()
|
self._ser.close()
|
||||||
time.sleep(2.0)
|
time.sleep(2.0)
|
||||||
|
|
||||||
|
def _send_pairing_cmd(self):
|
||||||
|
"""Configure the anchor to range only with the paired tag."""
|
||||||
|
try:
|
||||||
|
cmd = f"AT+RANGE_ADDR={self._tag_addr}\r\n".encode("ascii")
|
||||||
|
self._ser.write(cmd)
|
||||||
|
time.sleep(0.1)
|
||||||
|
self._logger.info(
|
||||||
|
f"Anchor-{self._anchor_id}: sent tag pairing {self._tag_addr}"
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.warn(
|
||||||
|
f"Anchor-{self._anchor_id}: pairing cmd failed: {exc}"
|
||||||
|
)
|
||||||
|
|
||||||
def _read_loop(self):
|
def _read_loop(self):
|
||||||
while self._running:
|
while self._running:
|
||||||
try:
|
try:
|
||||||
# Query the anchor
|
|
||||||
self._ser.write(b"AT+RANGE?\r\n")
|
self._ser.write(b"AT+RANGE?\r\n")
|
||||||
# Read up to 10 lines waiting for a +RANGE response
|
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
raw = self._ser.readline()
|
raw = self._ser.readline()
|
||||||
if not raw:
|
if not raw:
|
||||||
@ -111,13 +151,14 @@ class SerialReader(threading.Thread):
|
|||||||
if m:
|
if m:
|
||||||
range_mm = int(m.group(2))
|
range_mm = int(m.group(2))
|
||||||
rssi = float(m.group(3)) if m.group(3) else 0.0
|
rssi = float(m.group(3)) if m.group(3) else 0.0
|
||||||
self._callback(self._anchor_id, range_mm, rssi)
|
tag_addr = m.group(4) if m.group(4) else ""
|
||||||
|
self._callback(self._anchor_id, range_mm, rssi, tag_addr)
|
||||||
break
|
break
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
self._logger.warn(
|
self._logger.warn(
|
||||||
f"Anchor-{self._anchor_id} read error: {exc}"
|
f"Anchor-{self._anchor_id} read error: {exc}"
|
||||||
)
|
)
|
||||||
break # trigger reconnect
|
break
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self._running = False
|
self._running = False
|
||||||
@ -130,9 +171,8 @@ class UwbDriverNode(Node):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("uwb_driver")
|
super().__init__("uwb_driver")
|
||||||
|
|
||||||
# ── Parameters ────────────────────────────────────────────────────────
|
self.declare_parameter("port_a", "/dev/uwb-anchor0")
|
||||||
self.declare_parameter("port_a", "/dev/ttyUSB0")
|
self.declare_parameter("port_b", "/dev/uwb-anchor1")
|
||||||
self.declare_parameter("port_b", "/dev/ttyUSB1")
|
|
||||||
self.declare_parameter("baudrate", 115200)
|
self.declare_parameter("baudrate", 115200)
|
||||||
self.declare_parameter("anchor_separation", 0.25)
|
self.declare_parameter("anchor_separation", 0.25)
|
||||||
self.declare_parameter("anchor_height", 0.80)
|
self.declare_parameter("anchor_height", 0.80)
|
||||||
@ -142,36 +182,40 @@ class UwbDriverNode(Node):
|
|||||||
self.declare_parameter("min_range_m", 0.05)
|
self.declare_parameter("min_range_m", 0.05)
|
||||||
self.declare_parameter("kf_process_noise", 0.1)
|
self.declare_parameter("kf_process_noise", 0.1)
|
||||||
self.declare_parameter("kf_meas_noise", 0.3)
|
self.declare_parameter("kf_meas_noise", 0.3)
|
||||||
self.declare_parameter("publish_rate", 20.0)
|
self.declare_parameter("range_rate", 100.0)
|
||||||
|
self.declare_parameter("bearing_rate", 10.0)
|
||||||
|
self.declare_parameter("enrolled_tag_ids", [""])
|
||||||
|
|
||||||
self._p = self._load_params()
|
self._p = self._load_params()
|
||||||
|
|
||||||
# ── State (protected by lock) ──────────────────────────────────────
|
raw_ids = self.get_parameter("enrolled_tag_ids").value
|
||||||
|
self._enrolled_tags = [t.strip() for t in raw_ids if t.strip()]
|
||||||
|
paired_tag = self._enrolled_tags[0] if self._enrolled_tags else None
|
||||||
|
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self._ranges = {} # anchor_id → (range_m, rssi, timestamp)
|
self._ranges: dict = {}
|
||||||
self._kf = KalmanFilter2D(
|
self._kf = KalmanFilter2D(
|
||||||
process_noise=self._p["kf_process_noise"],
|
process_noise=self._p["kf_process_noise"],
|
||||||
measurement_noise=self._p["kf_meas_noise"],
|
measurement_noise=self._p["kf_meas_noise"],
|
||||||
dt=1.0 / self._p["publish_rate"],
|
dt=1.0 / self._p["bearing_rate"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── Publishers ────────────────────────────────────────────────────
|
self._ranges_pub = self.create_publisher(UwbRangeArray, "/uwb/ranges", 10)
|
||||||
self._target_pub = self.create_publisher(
|
self._bearing_pub = self.create_publisher(UwbBearing, "/uwb/bearing", 10)
|
||||||
PoseStamped, "/uwb/target", 10)
|
self._target_pub = self.create_publisher(PoseStamped, "/uwb/target", 10)
|
||||||
self._ranges_pub = self.create_publisher(
|
|
||||||
UwbRangeArray, "/uwb/ranges", 10)
|
|
||||||
|
|
||||||
# ── Serial readers ────────────────────────────────────────────────
|
|
||||||
if _SERIAL_AVAILABLE:
|
if _SERIAL_AVAILABLE:
|
||||||
self._reader_a = SerialReader(
|
self._reader_a = SerialReader(
|
||||||
self._p["port_a"], self._p["baudrate"],
|
self._p["port_a"], self._p["baudrate"],
|
||||||
anchor_id=0, callback=self._range_cb,
|
anchor_id=0, callback=self._range_cb,
|
||||||
logger=self.get_logger(),
|
logger=self.get_logger(),
|
||||||
|
tag_addr=paired_tag,
|
||||||
)
|
)
|
||||||
self._reader_b = SerialReader(
|
self._reader_b = SerialReader(
|
||||||
self._p["port_b"], self._p["baudrate"],
|
self._p["port_b"], self._p["baudrate"],
|
||||||
anchor_id=1, callback=self._range_cb,
|
anchor_id=1, callback=self._range_cb,
|
||||||
logger=self.get_logger(),
|
logger=self.get_logger(),
|
||||||
|
tag_addr=paired_tag,
|
||||||
)
|
)
|
||||||
self._reader_a.start()
|
self._reader_a.start()
|
||||||
self._reader_b.start()
|
self._reader_b.start()
|
||||||
@ -180,19 +224,21 @@ class UwbDriverNode(Node):
|
|||||||
"pyserial not installed — running in simulation mode (no serial I/O)"
|
"pyserial not installed — running in simulation mode (no serial I/O)"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── Publish timer ─────────────────────────────────────────────────
|
self._range_timer = self.create_timer(
|
||||||
self._timer = self.create_timer(
|
1.0 / self._p["range_rate"], self._range_publish_cb
|
||||||
1.0 / self._p["publish_rate"], self._publish_cb
|
)
|
||||||
|
self._bearing_timer = self.create_timer(
|
||||||
|
1.0 / self._p["bearing_rate"], self._bearing_publish_cb
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_logger().info(
|
self.get_logger().info(
|
||||||
f"UWB driver ready sep={self._p['anchor_separation']}m "
|
f"UWB driver ready sep={self._p['anchor_separation']}m "
|
||||||
f"ports={self._p['port_a']},{self._p['port_b']} "
|
f"ports={self._p['port_a']},{self._p['port_b']} "
|
||||||
f"rate={self._p['publish_rate']}Hz"
|
f"range={self._p['range_rate']}Hz "
|
||||||
|
f"bearing={self._p['bearing_rate']}Hz "
|
||||||
|
f"enrolled_tags={self._enrolled_tags or ['<any>']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── Helpers ───────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _load_params(self):
|
def _load_params(self):
|
||||||
return {
|
return {
|
||||||
"port_a": self.get_parameter("port_a").value,
|
"port_a": self.get_parameter("port_a").value,
|
||||||
@ -206,90 +252,106 @@ class UwbDriverNode(Node):
|
|||||||
"min_range_m": self.get_parameter("min_range_m").value,
|
"min_range_m": self.get_parameter("min_range_m").value,
|
||||||
"kf_process_noise": self.get_parameter("kf_process_noise").value,
|
"kf_process_noise": self.get_parameter("kf_process_noise").value,
|
||||||
"kf_meas_noise": self.get_parameter("kf_meas_noise").value,
|
"kf_meas_noise": self.get_parameter("kf_meas_noise").value,
|
||||||
"publish_rate": self.get_parameter("publish_rate").value,
|
"range_rate": self.get_parameter("range_rate").value,
|
||||||
|
"bearing_rate": self.get_parameter("bearing_rate").value,
|
||||||
}
|
}
|
||||||
|
|
||||||
# ── Callbacks ─────────────────────────────────────────────────────────────
|
def _is_enrolled(self, tag_addr: str) -> bool:
|
||||||
|
if not self._enrolled_tags:
|
||||||
|
return True
|
||||||
|
return tag_addr in self._enrolled_tags
|
||||||
|
|
||||||
def _range_cb(self, anchor_id: int, range_mm: int, rssi: float):
|
def _range_cb(self, anchor_id: int, range_mm: int, rssi: float, tag_addr: str):
|
||||||
"""Called from serial reader threads — thread-safe update."""
|
if not self._is_enrolled(tag_addr):
|
||||||
|
return
|
||||||
range_m = range_mm / 1000.0
|
range_m = range_mm / 1000.0
|
||||||
p = self._p
|
p = self._p
|
||||||
if range_m < p["min_range_m"] or range_m > p["max_range_m"]:
|
if range_m < p["min_range_m"] or range_m > p["max_range_m"]:
|
||||||
return
|
return
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._ranges[anchor_id] = (range_m, rssi, time.monotonic())
|
self._ranges[anchor_id] = (range_m, rssi, tag_addr, time.monotonic())
|
||||||
|
|
||||||
def _publish_cb(self):
|
def _range_publish_cb(self):
|
||||||
|
"""100 Hz: publish current raw ranges as UwbRangeArray."""
|
||||||
now = time.monotonic()
|
now = time.monotonic()
|
||||||
timeout = self._p["range_timeout_s"]
|
timeout = self._p["range_timeout_s"]
|
||||||
sep = self._p["anchor_separation"]
|
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
# Collect valid (non-stale) ranges
|
valid = {
|
||||||
valid = {}
|
aid: entry
|
||||||
for aid, (r, rssi, t) in self._ranges.items():
|
for aid, entry in self._ranges.items()
|
||||||
if now - t <= timeout:
|
if (now - entry[3]) <= timeout
|
||||||
valid[aid] = (r, rssi, t)
|
}
|
||||||
|
|
||||||
# Build and publish UwbRangeArray regardless (even if partial)
|
|
||||||
hdr = Header()
|
hdr = Header()
|
||||||
hdr.stamp = self.get_clock().now().to_msg()
|
hdr.stamp = self.get_clock().now().to_msg()
|
||||||
hdr.frame_id = "base_link"
|
hdr.frame_id = "base_link"
|
||||||
|
|
||||||
arr = UwbRangeArray()
|
arr = UwbRangeArray()
|
||||||
arr.header = hdr
|
arr.header = hdr
|
||||||
for aid, (r, rssi, _) in valid.items():
|
for aid, (r, rssi, tag_id, _) in valid.items():
|
||||||
entry = UwbRange()
|
entry = UwbRange()
|
||||||
entry.header = hdr
|
entry.header = hdr
|
||||||
entry.anchor_id = aid
|
entry.anchor_id = aid
|
||||||
entry.range_m = float(r)
|
entry.range_m = float(r)
|
||||||
entry.raw_mm = int(round(r * 1000.0))
|
entry.raw_mm = int(round(r * 1000.0))
|
||||||
entry.rssi = float(rssi)
|
entry.rssi = float(rssi)
|
||||||
|
entry.tag_id = tag_id
|
||||||
arr.ranges.append(entry)
|
arr.ranges.append(entry)
|
||||||
self._ranges_pub.publish(arr)
|
self._ranges_pub.publish(arr)
|
||||||
|
|
||||||
# Need both anchors to triangulate
|
def _bearing_publish_cb(self):
|
||||||
if 0 not in valid or 1 not in valid:
|
"""10 Hz: Kalman predict+update, publish fused bearing."""
|
||||||
|
now = time.monotonic()
|
||||||
|
timeout = self._p["range_timeout_s"]
|
||||||
|
sep = self._p["anchor_separation"]
|
||||||
|
with self._lock:
|
||||||
|
valid = {
|
||||||
|
aid: entry
|
||||||
|
for aid, entry in self._ranges.items()
|
||||||
|
if (now - entry[3]) <= timeout
|
||||||
|
}
|
||||||
|
if not valid:
|
||||||
return
|
return
|
||||||
|
both_fresh = 0 in valid and 1 in valid
|
||||||
|
confidence = 1.0 if both_fresh else 0.5
|
||||||
|
active_tag = valid[next(iter(valid))][2]
|
||||||
|
dt = 1.0 / self._p["bearing_rate"]
|
||||||
|
self._kf.predict(dt=dt)
|
||||||
|
if both_fresh:
|
||||||
r0 = valid[0][0]
|
r0 = valid[0][0]
|
||||||
r1 = valid[1][0]
|
r1 = valid[1][0]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
x_t, y_t = triangulate_2anchor(
|
x_t, y_t = triangulate_2anchor(
|
||||||
r0=r0,
|
r0=r0, r1=r1, sep=sep,
|
||||||
r1=r1,
|
|
||||||
sep=sep,
|
|
||||||
anchor_z=self._p["anchor_height"],
|
anchor_z=self._p["anchor_height"],
|
||||||
tag_z=self._p["tag_height"],
|
tag_z=self._p["tag_height"],
|
||||||
)
|
)
|
||||||
except (ValueError, ZeroDivisionError) as exc:
|
except (ValueError, ZeroDivisionError) as exc:
|
||||||
self.get_logger().warn(f"Triangulation error: {exc}")
|
self.get_logger().warn(f"Triangulation error: {exc}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Kalman filter update
|
|
||||||
dt = 1.0 / self._p["publish_rate"]
|
|
||||||
self._kf.predict(dt=dt)
|
|
||||||
self._kf.update(x_t, y_t)
|
self._kf.update(x_t, y_t)
|
||||||
kx, ky = self._kf.position()
|
kx, ky = self._kf.position()
|
||||||
|
bearing = bearing_from_pos(kx, ky)
|
||||||
# Publish PoseStamped in base_link
|
range_m = math.sqrt(kx * kx + ky * ky)
|
||||||
|
hdr = Header()
|
||||||
|
hdr.stamp = self.get_clock().now().to_msg()
|
||||||
|
hdr.frame_id = "base_link"
|
||||||
|
brg_msg = UwbBearing()
|
||||||
|
brg_msg.header = hdr
|
||||||
|
brg_msg.bearing_rad = float(bearing)
|
||||||
|
brg_msg.range_m = float(range_m)
|
||||||
|
brg_msg.confidence = float(confidence)
|
||||||
|
brg_msg.tag_id = active_tag
|
||||||
|
self._bearing_pub.publish(brg_msg)
|
||||||
pose = PoseStamped()
|
pose = PoseStamped()
|
||||||
pose.header = hdr
|
pose.header = hdr
|
||||||
pose.pose.position.x = kx
|
pose.pose.position.x = kx
|
||||||
pose.pose.position.y = ky
|
pose.pose.position.y = ky
|
||||||
pose.pose.position.z = 0.0
|
pose.pose.position.z = 0.0
|
||||||
# Orientation: face the person (yaw = atan2(y, x))
|
yaw = bearing
|
||||||
yaw = math.atan2(ky, kx)
|
|
||||||
pose.pose.orientation.z = math.sin(yaw / 2.0)
|
pose.pose.orientation.z = math.sin(yaw / 2.0)
|
||||||
pose.pose.orientation.w = math.cos(yaw / 2.0)
|
pose.pose.orientation.w = math.cos(yaw / 2.0)
|
||||||
|
|
||||||
self._target_pub.publish(pose)
|
self._target_pub.publish(pose)
|
||||||
|
|
||||||
|
|
||||||
# ── Entry point ───────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def main(args=None):
|
def main(args=None):
|
||||||
rclpy.init(args=args)
|
rclpy.init(args=args)
|
||||||
node = UwbDriverNode()
|
node = UwbDriverNode()
|
||||||
|
|||||||
@ -7,7 +7,7 @@ No ROS2 / serial / GPU dependencies — runs with plain pytest.
|
|||||||
import math
|
import math
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from saltybot_uwb.ranging_math import triangulate_2anchor, KalmanFilter2D
|
from saltybot_uwb.ranging_math import triangulate_2anchor, KalmanFilter2D, bearing_from_pos
|
||||||
|
|
||||||
|
|
||||||
# ── triangulate_2anchor ───────────────────────────────────────────────────────
|
# ── triangulate_2anchor ───────────────────────────────────────────────────────
|
||||||
@ -172,3 +172,47 @@ class TestKalmanFilter2D:
|
|||||||
x, y = kf.position()
|
x, y = kf.position()
|
||||||
assert not math.isnan(x)
|
assert not math.isnan(x)
|
||||||
assert not math.isnan(y)
|
assert not math.isnan(y)
|
||||||
|
|
||||||
|
|
||||||
|
# ── bearing_from_pos ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestBearingFromPos:
|
||||||
|
|
||||||
|
def test_directly_ahead_zero_bearing(self):
|
||||||
|
"""Person directly ahead: x=2, y=0 → bearing ≈ 0."""
|
||||||
|
b = bearing_from_pos(2.0, 0.0)
|
||||||
|
assert abs(b) < 0.001
|
||||||
|
|
||||||
|
def test_left_gives_positive_bearing(self):
|
||||||
|
"""Person to the left (y>0): bearing should be positive."""
|
||||||
|
b = bearing_from_pos(1.0, 1.0)
|
||||||
|
assert b > 0.0
|
||||||
|
assert abs(b - math.pi / 4.0) < 0.001
|
||||||
|
|
||||||
|
def test_right_gives_negative_bearing(self):
|
||||||
|
"""Person to the right (y<0): bearing should be negative."""
|
||||||
|
b = bearing_from_pos(1.0, -1.0)
|
||||||
|
assert b < 0.0
|
||||||
|
assert abs(b + math.pi / 4.0) < 0.001
|
||||||
|
|
||||||
|
def test_directly_left_ninety_degrees(self):
|
||||||
|
"""Person directly to the left: x=0, y=1 → bearing = π/2."""
|
||||||
|
b = bearing_from_pos(0.0, 1.0)
|
||||||
|
assert abs(b - math.pi / 2.0) < 0.001
|
||||||
|
|
||||||
|
def test_directly_right_minus_ninety_degrees(self):
|
||||||
|
"""Person directly to the right: x=0, y=-1 → bearing = -π/2."""
|
||||||
|
b = bearing_from_pos(0.0, -1.0)
|
||||||
|
assert abs(b + math.pi / 2.0) < 0.001
|
||||||
|
|
||||||
|
def test_range_pi_to_minus_pi(self):
|
||||||
|
"""Bearing is always in -π..+π."""
|
||||||
|
for x in [-2.0, -0.1, 0.1, 2.0]:
|
||||||
|
for y in [-2.0, -0.1, 0.1, 2.0]:
|
||||||
|
b = bearing_from_pos(x, y)
|
||||||
|
assert -math.pi <= b <= math.pi
|
||||||
|
|
||||||
|
def test_no_nan_for_tiny_distance(self):
|
||||||
|
"""Very close target should not produce NaN."""
|
||||||
|
b = bearing_from_pos(0.001, 0.001)
|
||||||
|
assert not math.isnan(b)
|
||||||
|
|||||||
@ -8,6 +8,7 @@ find_package(rosidl_default_generators REQUIRED)
|
|||||||
rosidl_generate_interfaces(${PROJECT_NAME}
|
rosidl_generate_interfaces(${PROJECT_NAME}
|
||||||
"msg/UwbRange.msg"
|
"msg/UwbRange.msg"
|
||||||
"msg/UwbRangeArray.msg"
|
"msg/UwbRangeArray.msg"
|
||||||
|
"msg/UwbBearing.msg"
|
||||||
DEPENDENCIES std_msgs
|
DEPENDENCIES std_msgs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
18
jetson/ros2_ws/src/saltybot_uwb_msgs/msg/UwbBearing.msg
Normal file
18
jetson/ros2_ws/src/saltybot_uwb_msgs/msg/UwbBearing.msg
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# UwbBearing.msg — 10 Hz Kalman-fused bearing estimate from dual-anchor UWB.
|
||||||
|
#
|
||||||
|
# bearing_rad : horizontal bearing to tag in base_link frame (radians)
|
||||||
|
# positive = tag to the left (CCW), negative = right (CW)
|
||||||
|
# range: -π .. +π
|
||||||
|
# range_m : Kalman-filtered horizontal range to tag (metres)
|
||||||
|
# confidence : quality indicator
|
||||||
|
# 1.0 = both anchors fresh
|
||||||
|
# 0.5 = single anchor only (bearing less reliable)
|
||||||
|
# 0.0 = stale / no data (message not published in this state)
|
||||||
|
# tag_id : enrolled tag identifier that produced this estimate
|
||||||
|
|
||||||
|
std_msgs/Header header
|
||||||
|
|
||||||
|
float32 bearing_rad
|
||||||
|
float32 range_m
|
||||||
|
float32 confidence
|
||||||
|
string tag_id
|
||||||
@ -4,6 +4,7 @@
|
|||||||
# range_m : measured horizontal range in metres (after height correction)
|
# range_m : measured horizontal range in metres (after height correction)
|
||||||
# raw_mm : raw TWR range from AT+RANGE? response, millimetres
|
# raw_mm : raw TWR range from AT+RANGE? response, millimetres
|
||||||
# rssi : received signal strength (dBm), 0 if not reported by module
|
# rssi : received signal strength (dBm), 0 if not reported by module
|
||||||
|
# tag_id : enrolled tag identifier; empty string if tag pairing is disabled
|
||||||
|
|
||||||
std_msgs/Header header
|
std_msgs/Header header
|
||||||
|
|
||||||
@ -11,3 +12,4 @@ uint8 anchor_id
|
|||||||
float32 range_m
|
float32 range_m
|
||||||
uint32 raw_mm
|
uint32 raw_mm
|
||||||
float32 rssi
|
float32 rssi
|
||||||
|
string tag_id
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user