Compare commits

...

7 Commits

Author SHA1 Message Date
0821845210 feat(social): multi-modal tracking fusion — UWB+camera Kalman filter (Issue #92)
New packages:
  saltybot_social_msgs   — FusedTarget.msg custom message
  saltybot_social_tracking — 4-state Kalman fusion node

saltybot_social_tracking/tracking_fusion_node.py
  Subscribes to /uwb/target (PoseStamped, ~10 Hz) and /person/target
  (PoseStamped, ~30 Hz) and publishes /social/tracking/fused_target
  (FusedTarget) at 20 Hz.

  Source arbitration:
    • "fused"     — both UWB and camera are fresh; confidence-weighted blend
    • "uwb"       — UWB fresh, camera stale
    • "camera"    — camera fresh, UWB stale
    • "predicted" — all sources stale; KF coasts for up to predict_timeout (3 s)

  Kalman filter (kalman_tracker.py):
    State [x, y, vx, vy] with discrete Wiener acceleration noise model
    (process_noise=3.0 m/s²) sized for EUC speeds (20-30 km/h, ≈5.5-8.3 m/s).
    Separate UWB (0.20 m) and camera (0.12 m) measurement noise.
    Velocity estimate converges after ~3 s of 10 Hz UWB measurements.

  Confidence model (source_arbiter.py):
    Per-source confidence = quality × max(0, 1 - age/timeout).
    Composite confidence accounts for KF positional uncertainty and
    is capped at 0.4 during dead-reckoning ("predicted") mode.

Tests: 58/58 pass (no ROS2 runtime required).

Note: saltybot_social_msgs here adds FusedTarget.msg; PR #98
(Issue #84) adds PersonalityState.msg + QueryMood.srv to the same
package. The maintainer should squash-merge #98 first and rebase
this branch on top of it before merging to avoid the package.xml
conflict.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-01 23:33:35 -05:00
d9c983f666 Merge pull request 'feat(social): navigation & path planning #91' (#97) from sl-perception/social-nav into main 2026-03-01 23:30:40 -05:00
54e9274405 Merge pull request 'feat(uwb): MaUWB ESP32-S3 DW3000 dual-anchor bearing driver (Issue #90)' (#99) from sl-firmware/uwb-integration into main 2026-03-01 23:30:12 -05:00
b432492785 Merge pull request 'feat(social): multi-modal person state tracker #82' (#93) from sl-perception/social-person-state into main 2026-03-01 23:30:04 -05:00
9a68dfdb2e feat(uwb): MaUWB ESP32-S3 DW3000 dual-anchor bearing driver (Issue #90)
## Summary
- saltybot_uwb_msgs: add UwbBearing.msg, add tag_id to UwbRange.msg,
  register UwbBearing in CMakeLists.txt
- ranging_math.py: add bearing_from_pos(x, y) helper (atan2-based)
- uwb_driver_node.py: dual-rate architecture
    • 100 Hz /uwb/ranges  — raw TWR ranges with tag_id attribution
    • 10 Hz  /uwb/bearing — Kalman-fused bearing + range estimate
    • enrolled_tag_ids parameter for tag pairing filter
    • AT+RANGE_ADDR=<tag> pairing command on connect
- uwb_config.yaml: range_rate / bearing_rate / enrolled_tag_ids params
- uwb.launch.py: expose new params as launch arguments
- test_ranging_math.py: 7 new bearing_from_pos unit tests

Closes #90

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-01 23:25:08 -05:00
d872ea5e34 feat(social): navigation + follow modes + MiDaS depth + waypoints (Issue #91)
- saltybot_social_msgs: full message/service definitions (standalone compilation)
- saltybot_social_nav: social navigation orchestrator
  - Follow modes: shadow/lead/side/orbit/loose/tight
  - Voice steering: mode switching + route commands via /social/speech/*
  - A* obstacle avoidance on Nav2/SLAM occupancy grid (8-directional, inflation)
  - MiDaS monocular depth for CSI cameras (TRT FP16 + ONNX fallback)
  - Waypoint teaching + replay with WaypointRoute persistence
  - High-speed EUC tracking (5.5 m/s = ~20 km/h)
  - Predictive position extrapolation (0.3s ahead at high speed)
- Launch: social_nav.launch.py (social_nav + midas_depth + waypoint_teacher)
- Config: social_nav_params.yaml
- Script: build_midas_trt_engine.py (ONNX -> TRT FP16)
2026-03-01 23:15:00 -05:00
84790412d6 feat(social): multi-modal person state tracker (Issue #82) 2026-03-01 23:08:22 -05:00
61 changed files with 3657 additions and 155 deletions

View File

@ -0,0 +1,8 @@
person_state_tracker:
ros__parameters:
engagement_distance: 2.0
absent_timeout: 5.0
prune_timeout: 30.0
update_rate: 10.0
n_cameras: 4
uwb_enabled: false

View File

@ -0,0 +1,35 @@
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
return LaunchDescription([
DeclareLaunchArgument(
'engagement_distance',
default_value='2.0',
description='Distance in metres below which a person is considered engaged'
),
DeclareLaunchArgument(
'absent_timeout',
default_value='5.0',
description='Seconds without detection before marking person as ABSENT'
),
DeclareLaunchArgument(
'uwb_enabled',
default_value='false',
description='Whether UWB anchor data is available'
),
Node(
package='saltybot_social',
executable='person_state_tracker',
name='person_state_tracker',
output='screen',
parameters=[{
'engagement_distance': LaunchConfiguration('engagement_distance'),
'absent_timeout': LaunchConfiguration('absent_timeout'),
'uwb_enabled': LaunchConfiguration('uwb_enabled'),
}],
),
])

View File

@ -0,0 +1,25 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social</name>
<version>0.1.0</version>
<description>Multi-modal person identity fusion and state tracking for saltybot</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>std_msgs</depend>
<depend>geometry_msgs</depend>
<depend>sensor_msgs</depend>
<depend>vision_msgs</depend>
<depend>saltybot_social_msgs</depend>
<depend>tf2_ros</depend>
<depend>tf2_geometry_msgs</depend>
<depend>cv_bridge</depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,201 @@
import math
import time
from typing import List, Optional, Tuple
from saltybot_social.person_state import PersonTrack, State
class IdentityFuser:
"""Multi-modal identity fusion: face_id + speaker_id + UWB anchor -> unified person_id."""
def __init__(self, position_window: float = 3.0):
self._tracks: dict[int, PersonTrack] = {}
self._next_id: int = 1
self._position_window = position_window
self._max_history = 20
def _distance_3d(self, a: Optional[tuple], b: Optional[tuple]) -> float:
if a is None or b is None:
return float('inf')
return math.sqrt(sum((ai - bi) ** 2 for ai, bi in zip(a, b)))
def _compute_distance_and_bearing(self, position: tuple) -> Tuple[float, float]:
x, y, z = position
distance = math.sqrt(x * x + y * y + z * z)
bearing = math.degrees(math.atan2(y, x))
return distance, bearing
def _allocate_id(self) -> int:
pid = self._next_id
self._next_id += 1
return pid
def update_face(self, face_id: int, name: str,
position_3d: Optional[tuple], camera_id: int,
now: float) -> int:
"""Match by face_id first. If face_id >= 0, find existing track or create new."""
if face_id >= 0:
for track in self._tracks.values():
if track.face_id == face_id:
track.name = name if name and name != 'unknown' else track.name
if position_3d is not None:
track.position = position_3d
dist, bearing = self._compute_distance_and_bearing(position_3d)
track.distance = dist
track.bearing_deg = bearing
track.history_distances.append(dist)
if len(track.history_distances) > self._max_history:
track.history_distances = track.history_distances[-self._max_history:]
track.camera_id = camera_id
track.last_seen = now
track.last_face_seen = now
return track.person_id
# No existing track with this face_id; try proximity match for face_id < 0
if face_id < 0 and position_3d is not None:
best_track = None
best_dist = 0.5 # 0.5m proximity threshold
for track in self._tracks.values():
d = self._distance_3d(track.position, position_3d)
if d < best_dist and (now - track.last_seen) < self._position_window:
best_dist = d
best_track = track
if best_track is not None:
best_track.position = position_3d
dist, bearing = self._compute_distance_and_bearing(position_3d)
best_track.distance = dist
best_track.bearing_deg = bearing
best_track.history_distances.append(dist)
if len(best_track.history_distances) > self._max_history:
best_track.history_distances = best_track.history_distances[-self._max_history:]
best_track.camera_id = camera_id
best_track.last_seen = now
return best_track.person_id
# Create new track
pid = self._allocate_id()
track = PersonTrack(person_id=pid, face_id=face_id, name=name,
camera_id=camera_id, last_seen=now, last_face_seen=now)
if position_3d is not None:
track.position = position_3d
dist, bearing = self._compute_distance_and_bearing(position_3d)
track.distance = dist
track.bearing_deg = bearing
track.history_distances.append(dist)
self._tracks[pid] = track
return pid
def update_speaker(self, speaker_id: str, now: float) -> int:
"""Find track with nearest recently-seen position, assign speaker_id."""
# First check if any track already has this speaker_id
for track in self._tracks.values():
if track.speaker_id == speaker_id:
track.last_seen = now
return track.person_id
# Assign to nearest recently-seen track
best_track = None
best_dist = float('inf')
for track in self._tracks.values():
if (now - track.last_seen) < self._position_window and track.distance > 0:
if track.distance < best_dist:
best_dist = track.distance
best_track = track
if best_track is not None:
best_track.speaker_id = speaker_id
best_track.last_seen = now
return best_track.person_id
# No suitable track found; create a new one
pid = self._allocate_id()
track = PersonTrack(person_id=pid, speaker_id=speaker_id,
last_seen=now)
self._tracks[pid] = track
return pid
def update_uwb(self, uwb_anchor_id: str, position_3d: tuple,
now: float) -> int:
"""Match by proximity (nearest track within 0.5m)."""
# Check existing UWB assignment
for track in self._tracks.values():
if track.uwb_anchor_id == uwb_anchor_id:
track.position = position_3d
dist, bearing = self._compute_distance_and_bearing(position_3d)
track.distance = dist
track.bearing_deg = bearing
track.history_distances.append(dist)
if len(track.history_distances) > self._max_history:
track.history_distances = track.history_distances[-self._max_history:]
track.last_seen = now
return track.person_id
# Proximity match
best_track = None
best_dist = 0.5
for track in self._tracks.values():
d = self._distance_3d(track.position, position_3d)
if d < best_dist and (now - track.last_seen) < self._position_window:
best_dist = d
best_track = track
if best_track is not None:
best_track.uwb_anchor_id = uwb_anchor_id
best_track.position = position_3d
dist, bearing = self._compute_distance_and_bearing(position_3d)
best_track.distance = dist
best_track.bearing_deg = bearing
best_track.history_distances.append(dist)
if len(best_track.history_distances) > self._max_history:
best_track.history_distances = best_track.history_distances[-self._max_history:]
best_track.last_seen = now
return best_track.person_id
# Create new track
pid = self._allocate_id()
track = PersonTrack(person_id=pid, uwb_anchor_id=uwb_anchor_id,
position=position_3d, last_seen=now)
dist, bearing = self._compute_distance_and_bearing(position_3d)
track.distance = dist
track.bearing_deg = bearing
track.history_distances.append(dist)
self._tracks[pid] = track
return pid
def get_all_tracks(self) -> List[PersonTrack]:
return list(self._tracks.values())
@staticmethod
def compute_attention(tracks: List[PersonTrack]) -> int:
"""Focus on nearest engaged/talking person, or -1 if none."""
candidates = [t for t in tracks
if t.state in (State.ENGAGED, State.TALKING) and t.distance > 0]
if not candidates:
return -1
# Prefer talking, then nearest engaged
talking = [t for t in candidates if t.state == State.TALKING]
if talking:
return min(talking, key=lambda t: t.distance).person_id
return min(candidates, key=lambda t: t.distance).person_id
def prune_absent(self, absent_timeout: float = 30.0):
"""Remove tracks absent longer than timeout."""
now = time.monotonic()
to_remove = [pid for pid, t in self._tracks.items()
if t.state == State.ABSENT and (now - t.last_seen) > absent_timeout]
for pid in to_remove:
del self._tracks[pid]
@staticmethod
def detect_group(tracks: List[PersonTrack]) -> bool:
"""Returns True if >= 2 persons within 1.5m of each other."""
active = [t for t in tracks
if t.position is not None and t.state != State.ABSENT]
for i, a in enumerate(active):
for b in active[i + 1:]:
dx = a.position[0] - b.position[0]
dy = a.position[1] - b.position[1]
dz = a.position[2] - b.position[2]
if math.sqrt(dx * dx + dy * dy + dz * dz) < 1.5:
return True
return False

View File

@ -0,0 +1,54 @@
import time
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Optional
class State(IntEnum):
UNKNOWN = 0
APPROACHING = 1
ENGAGED = 2
TALKING = 3
LEAVING = 4
ABSENT = 5
@dataclass
class PersonTrack:
person_id: int
face_id: int = -1
speaker_id: str = ''
uwb_anchor_id: str = ''
name: str = 'unknown'
state: State = State.UNKNOWN
position: Optional[tuple] = None # (x, y, z) in base_link
distance: float = 0.0
bearing_deg: float = 0.0
engagement_score: float = 0.0
camera_id: int = -1
last_seen: float = field(default_factory=time.monotonic)
last_face_seen: float = 0.0
history_distances: list = field(default_factory=list) # last N distances for trend
def update_state(self, now: float, engagement_distance: float = 2.0,
absent_timeout: float = 5.0):
"""Transition state machine based on distance trend and time."""
age = now - self.last_seen
if age > absent_timeout:
self.state = State.ABSENT
return
if self.speaker_id:
self.state = State.TALKING
return
if len(self.history_distances) >= 3:
trend = self.history_distances[-1] - self.history_distances[-3]
if self.distance < engagement_distance:
self.state = State.ENGAGED
elif trend < -0.2: # moving closer
self.state = State.APPROACHING
elif trend > 0.3: # moving away
self.state = State.LEAVING
else:
self.state = State.ENGAGED if self.distance < engagement_distance else State.UNKNOWN
elif self.distance > 0:
self.state = State.ENGAGED if self.distance < engagement_distance else State.APPROACHING

View File

@ -0,0 +1,273 @@
import time
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from std_msgs.msg import Int32, String
from geometry_msgs.msg import PoseArray, PoseStamped, Point
from sensor_msgs.msg import Image
from builtin_interfaces.msg import Time as TimeMsg
from saltybot_social_msgs.msg import (
FaceDetection,
FaceDetectionArray,
PersonState,
PersonStateArray,
)
from saltybot_social.identity_fuser import IdentityFuser
from saltybot_social.person_state import State
class PersonStateTrackerNode(Node):
"""Main ROS2 node for multi-modal person tracking."""
def __init__(self):
super().__init__('person_state_tracker')
# Parameters
self.declare_parameter('engagement_distance', 2.0)
self.declare_parameter('absent_timeout', 5.0)
self.declare_parameter('prune_timeout', 30.0)
self.declare_parameter('update_rate', 10.0)
self.declare_parameter('n_cameras', 4)
self.declare_parameter('uwb_enabled', False)
self._engagement_distance = self.get_parameter('engagement_distance').value
self._absent_timeout = self.get_parameter('absent_timeout').value
self._prune_timeout = self.get_parameter('prune_timeout').value
update_rate = self.get_parameter('update_rate').value
n_cameras = self.get_parameter('n_cameras').value
uwb_enabled = self.get_parameter('uwb_enabled').value
self._fuser = IdentityFuser()
# QoS profiles
best_effort_qos = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=5
)
reliable_qos = QoSProfile(
reliability=ReliabilityPolicy.RELIABLE,
history=HistoryPolicy.KEEP_LAST,
depth=1
)
# Subscriptions
self.create_subscription(
FaceDetectionArray,
'/social/faces/detections',
self._on_face_detections,
best_effort_qos
)
self.create_subscription(
String,
'/social/speech/speaker_id',
self._on_speaker_id,
10
)
if uwb_enabled:
self.create_subscription(
PoseArray,
'/uwb/positions',
self._on_uwb,
10
)
# Camera subscriptions (monitor topic existence)
self.create_subscription(
Image, '/camera/color/image_raw',
self._on_camera_image, best_effort_qos)
for i in range(n_cameras):
self.create_subscription(
Image, f'/surround/cam{i}/image_raw',
self._on_camera_image, best_effort_qos)
# Existing person tracker position
self.create_subscription(
PoseStamped,
'/person/target',
self._on_person_target,
10
)
# Publishers
self._persons_pub = self.create_publisher(
PersonStateArray, '/social/persons', best_effort_qos)
self._attention_pub = self.create_publisher(
Int32, '/social/attention/target_id', reliable_qos)
# Timer
timer_period = 1.0 / update_rate
self.create_timer(timer_period, self._on_timer)
self.get_logger().info(
f'PersonStateTracker started: engagement={self._engagement_distance}m, '
f'absent_timeout={self._absent_timeout}s, rate={update_rate}Hz, '
f'cameras={n_cameras}, uwb={uwb_enabled}')
def _on_face_detections(self, msg: FaceDetectionArray):
now = time.monotonic()
for face in msg.faces:
position_3d = None
# Use bbox center as rough position estimate if depth is available
# Real 3D position would come from depth camera projection
if face.bbox_x > 0 or face.bbox_y > 0:
# Approximate: use bbox center as bearing proxy, distance from bbox size
# This is a placeholder; real impl would use depth image lookup
position_3d = (
max(0.5, 1.0 / max(face.bbox_w, 0.01)), # rough depth from bbox width
face.bbox_x + face.bbox_w / 2.0 - 0.5, # x offset from center
face.bbox_y + face.bbox_h / 2.0 - 0.5 # y offset from center
)
camera_id = -1
if hasattr(face.header, 'frame_id') and face.header.frame_id:
frame = face.header.frame_id
if 'cam' in frame:
try:
camera_id = int(frame.split('cam')[-1].split('_')[0])
except ValueError:
pass
self._fuser.update_face(
face_id=face.face_id,
name=face.person_name,
position_3d=position_3d,
camera_id=camera_id,
now=now
)
def _on_speaker_id(self, msg: String):
now = time.monotonic()
speaker_id = msg.data
# Parse "speaker_<id>" format or raw id
if speaker_id.startswith('speaker_'):
speaker_id = speaker_id[len('speaker_'):]
self._fuser.update_speaker(speaker_id, now)
def _on_uwb(self, msg: PoseArray):
now = time.monotonic()
for i, pose in enumerate(msg.poses):
position = (pose.position.x, pose.position.y, pose.position.z)
anchor_id = f'uwb_{i}'
if hasattr(msg, 'header') and msg.header.frame_id:
anchor_id = f'{msg.header.frame_id}_{i}'
self._fuser.update_uwb(anchor_id, position, now)
def _on_camera_image(self, msg: Image):
# Monitor only -- no processing here
pass
def _on_person_target(self, msg: PoseStamped):
now = time.monotonic()
position_3d = (
msg.pose.position.x,
msg.pose.position.y,
msg.pose.position.z
)
# Update position for unidentified person (from YOLOv8 tracker)
self._fuser.update_face(
face_id=-1,
name='unknown',
position_3d=position_3d,
camera_id=-1,
now=now
)
def _on_timer(self):
now = time.monotonic()
tracks = self._fuser.get_all_tracks()
# Update state machine for all tracks
for track in tracks:
track.update_state(
now,
engagement_distance=self._engagement_distance,
absent_timeout=self._absent_timeout
)
# Compute engagement scores
for track in tracks:
if track.state == State.ABSENT:
track.engagement_score = 0.0
elif track.state == State.TALKING:
track.engagement_score = 1.0
elif track.state == State.ENGAGED:
track.engagement_score = max(0.0, 1.0 - track.distance / self._engagement_distance)
elif track.state == State.APPROACHING:
track.engagement_score = 0.3
elif track.state == State.LEAVING:
track.engagement_score = 0.1
else:
track.engagement_score = 0.0
# Prune long-absent tracks
self._fuser.prune_absent(self._prune_timeout)
# Compute attention target
attention_id = IdentityFuser.compute_attention(tracks)
# Build and publish PersonStateArray
msg = PersonStateArray()
msg.header.stamp = self.get_clock().now().to_msg()
msg.header.frame_id = 'base_link'
msg.primary_attention_id = attention_id
for track in tracks:
ps = PersonState()
ps.header.stamp = msg.header.stamp
ps.header.frame_id = 'base_link'
ps.person_id = track.person_id
ps.person_name = track.name
ps.face_id = track.face_id
ps.speaker_id = track.speaker_id
ps.uwb_anchor_id = track.uwb_anchor_id
if track.position is not None:
ps.position = Point(
x=float(track.position[0]),
y=float(track.position[1]),
z=float(track.position[2]))
ps.distance = float(track.distance)
ps.bearing_deg = float(track.bearing_deg)
ps.state = int(track.state)
ps.engagement_score = float(track.engagement_score)
ps.last_seen = self._mono_to_ros_time(track.last_seen)
ps.camera_id = track.camera_id
msg.persons.append(ps)
self._persons_pub.publish(msg)
# Publish attention target
att_msg = Int32()
att_msg.data = attention_id
self._attention_pub.publish(att_msg)
def _mono_to_ros_time(self, mono: float) -> TimeMsg:
"""Convert monotonic timestamp to approximate ROS time."""
# Offset from monotonic to wall clock
offset = time.time() - time.monotonic()
wall = mono + offset
sec = int(wall)
nsec = int((wall - sec) * 1e9)
t = TimeMsg()
t.sec = sec
t.nanosec = nsec
return t
def main(args=None):
rclpy.init(args=args)
node = PersonStateTrackerNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/saltybot_social
[install]
install_scripts=$base/lib/saltybot_social

View File

@ -0,0 +1,32 @@
from setuptools import find_packages, setup
import os
from glob import glob
package_name = 'saltybot_social'
setup(
name=package_name,
version='0.1.0',
packages=find_packages(exclude=['test']),
data_files=[
('share/ament_index/resource_index/packages',
['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
(os.path.join('share', package_name, 'launch'),
glob(os.path.join('launch', '*launch.[pxy][yma]*'))),
(os.path.join('share', package_name, 'config'),
glob(os.path.join('config', '*.yaml'))),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='seb',
maintainer_email='seb@vayrette.com',
description='Multi-modal person identity fusion and state tracking for saltybot',
license='MIT',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'person_state_tracker = saltybot_social.person_state_tracker_node:main',
],
},
)

View File

@ -0,0 +1,28 @@
cmake_minimum_required(VERSION 3.8)
project(saltybot_social_msgs)
find_package(ament_cmake REQUIRED)
find_package(rosidl_default_generators REQUIRED)
find_package(std_msgs REQUIRED)
find_package(geometry_msgs REQUIRED)
find_package(builtin_interfaces REQUIRED)
rosidl_generate_interfaces(${PROJECT_NAME}
# Social perception (from sl-perception)
"msg/FaceDetection.msg"
"msg/FaceDetectionArray.msg"
"msg/FaceEmbedding.msg"
"msg/FaceEmbeddingArray.msg"
"msg/PersonState.msg"
"msg/PersonStateArray.msg"
"srv/EnrollPerson.srv"
"srv/ListPersons.srv"
"srv/DeletePerson.srv"
"srv/UpdatePerson.srv"
# Multi-modal tracking fusion (Issue #92)
"msg/FusedTarget.msg"
DEPENDENCIES std_msgs geometry_msgs builtin_interfaces
)
ament_export_dependencies(rosidl_default_runtime)
ament_package()

View File

@ -0,0 +1,10 @@
std_msgs/Header header
int32 face_id
string person_name
float32 confidence
float32 recognition_score
float32 bbox_x
float32 bbox_y
float32 bbox_w
float32 bbox_h
float32[10] landmarks

View File

@ -0,0 +1,2 @@
std_msgs/Header header
saltybot_social_msgs/FaceDetection[] faces

View File

@ -0,0 +1,5 @@
int32 person_id
string person_name
float32[] embedding
builtin_interfaces/Time enrolled_at
int32 sample_count

View File

@ -0,0 +1,2 @@
std_msgs/Header header
saltybot_social_msgs/FaceEmbedding[] embeddings

View File

@ -0,0 +1,19 @@
# FusedTarget.msg — output of the multi-modal tracking fusion node.
#
# Position and velocity are in the base_link frame (robot-centred,
# +X forward, +Y left). z components are always 0.0 for ground-plane tracking.
#
# Confidence: 0.0 = no data / fully predicted; 1.0 = strong fused measurement.
# active_source: "fused" | "uwb" | "camera" | "predicted"
std_msgs/Header header
geometry_msgs/Point position # filtered 2-D position (m), z=0
geometry_msgs/Vector3 velocity # filtered 2-D velocity (m/s), z=0
float32 range_m # Euclidean distance from robot to fused position
float32 bearing_rad # bearing in base_link (+ve = person to the left)
float32 confidence # composite confidence [0.0, 1.0]
string active_source # "fused" | "uwb" | "camera" | "predicted"
string tag_id # UWB tag address (empty when UWB not contributing)

View File

@ -0,0 +1,19 @@
std_msgs/Header header
int32 person_id
string person_name
int32 face_id
string speaker_id
string uwb_anchor_id
geometry_msgs/Point position
float32 distance
float32 bearing_deg
uint8 state
uint8 STATE_UNKNOWN=0
uint8 STATE_APPROACHING=1
uint8 STATE_ENGAGED=2
uint8 STATE_TALKING=3
uint8 STATE_LEAVING=4
uint8 STATE_ABSENT=5
float32 engagement_score
builtin_interfaces/Time last_seen
int32 camera_id

View File

@ -0,0 +1,3 @@
std_msgs/Header header
saltybot_social_msgs/PersonState[] persons
int32 primary_attention_id

View File

@ -0,0 +1,28 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social_msgs</name>
<version>0.1.0</version>
<description>
Custom ROS2 message and service definitions for saltybot social capabilities.
Includes social perception types (face detection, person state, enrollment)
and multi-modal tracking fusion types (FusedTarget) from Issue #92.
</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<buildtool_depend>ament_cmake</buildtool_depend>
<build_depend>rosidl_default_generators</build_depend>
<depend>std_msgs</depend>
<depend>geometry_msgs</depend>
<depend>builtin_interfaces</depend>
<exec_depend>rosidl_default_runtime</exec_depend>
<member_of_group>rosidl_interface_packages</member_of_group>
<export>
<build_type>ament_cmake</build_type>
</export>
</package>

View File

@ -0,0 +1,4 @@
int32 person_id
---
bool success
string message

View File

@ -0,0 +1,7 @@
string name
string mode
int32 n_samples
---
bool success
string message
int32 person_id

View File

@ -0,0 +1,2 @@
---
saltybot_social_msgs/FaceEmbedding[] persons

View File

@ -0,0 +1,5 @@
int32 person_id
string new_name
---
bool success
string message

View File

@ -0,0 +1,22 @@
social_nav:
ros__parameters:
follow_mode: 'shadow'
follow_distance: 1.2
lead_distance: 2.0
orbit_radius: 1.5
max_linear_speed: 1.0
max_linear_speed_fast: 5.5
max_angular_speed: 1.0
goal_tolerance: 0.3
routes_dir: '/mnt/nvme/saltybot/routes'
home_x: 0.0
home_y: 0.0
map_resolution: 0.05
obstacle_inflation_cells: 3
midas_depth:
ros__parameters:
onnx_path: '/mnt/nvme/saltybot/models/midas_small.onnx'
engine_path: '/mnt/nvme/saltybot/models/midas_small.engine'
process_rate: 5.0
output_scale: 1.0

View File

@ -0,0 +1,57 @@
"""Launch file for saltybot social navigation."""
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
return LaunchDescription([
DeclareLaunchArgument('follow_mode', default_value='shadow',
description='Follow mode: shadow/lead/side/orbit/loose/tight'),
DeclareLaunchArgument('follow_distance', default_value='1.2',
description='Follow distance in meters'),
DeclareLaunchArgument('max_linear_speed', default_value='1.0',
description='Max linear speed (m/s)'),
DeclareLaunchArgument('routes_dir',
default_value='/mnt/nvme/saltybot/routes',
description='Directory for saved routes'),
Node(
package='saltybot_social_nav',
executable='social_nav',
name='social_nav',
output='screen',
parameters=[{
'follow_mode': LaunchConfiguration('follow_mode'),
'follow_distance': LaunchConfiguration('follow_distance'),
'max_linear_speed': LaunchConfiguration('max_linear_speed'),
'routes_dir': LaunchConfiguration('routes_dir'),
}],
),
Node(
package='saltybot_social_nav',
executable='midas_depth',
name='midas_depth',
output='screen',
parameters=[{
'onnx_path': '/mnt/nvme/saltybot/models/midas_small.onnx',
'engine_path': '/mnt/nvme/saltybot/models/midas_small.engine',
'process_rate': 5.0,
'output_scale': 1.0,
}],
),
Node(
package='saltybot_social_nav',
executable='waypoint_teacher',
name='waypoint_teacher',
output='screen',
parameters=[{
'routes_dir': LaunchConfiguration('routes_dir'),
'recording_interval': 0.5,
}],
),
])

View File

@ -0,0 +1,28 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social_nav</name>
<version>0.1.0</version>
<description>Social navigation for saltybot: follow modes, waypoint teaching, A* avoidance, MiDaS depth</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>std_msgs</depend>
<depend>geometry_msgs</depend>
<depend>nav_msgs</depend>
<depend>sensor_msgs</depend>
<depend>cv_bridge</depend>
<depend>tf2_ros</depend>
<depend>tf2_geometry_msgs</depend>
<depend>saltybot_social_msgs</depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,82 @@
"""astar.py -- A* path planner for saltybot social navigation."""
import heapq
import numpy as np
def astar(grid: np.ndarray, start: tuple, goal: tuple,
obstacle_val: int = 100) -> list | None:
"""
A* on a 2D occupancy grid (row, col indexing).
Args:
grid: 2D numpy array, values 0=free, >=obstacle_val=obstacle
start: (row, col) start cell
goal: (row, col) goal cell
obstacle_val: cells with value >= obstacle_val are blocked
Returns:
List of (row, col) tuples from start to goal, or None if no path.
"""
rows, cols = grid.shape
def h(a, b):
return abs(a[0] - b[0]) + abs(a[1] - b[1]) # Manhattan heuristic
open_set = []
heapq.heappush(open_set, (h(start, goal), 0, start))
came_from = {}
g_score = {start: 0}
# 8-directional movement
neighbors_delta = [
(-1, -1), (-1, 0), (-1, 1),
(0, -1), (0, 1),
(1, -1), (1, 0), (1, 1),
]
while open_set:
_, cost, current = heapq.heappop(open_set)
if current == goal:
path = []
while current in came_from:
path.append(current)
current = came_from[current]
path.append(start)
return list(reversed(path))
if cost > g_score.get(current, float('inf')):
continue
for dr, dc in neighbors_delta:
nr, nc = current[0] + dr, current[1] + dc
if not (0 <= nr < rows and 0 <= nc < cols):
continue
if grid[nr, nc] >= obstacle_val:
continue
move_cost = 1.414 if (dr != 0 and dc != 0) else 1.0
new_g = g_score[current] + move_cost
neighbor = (nr, nc)
if new_g < g_score.get(neighbor, float('inf')):
g_score[neighbor] = new_g
f = new_g + h(neighbor, goal)
came_from[neighbor] = current
heapq.heappush(open_set, (f, new_g, neighbor))
return None # No path found
def inflate_obstacles(grid: np.ndarray, inflation_radius_cells: int) -> np.ndarray:
"""Inflate obstacles for robot footprint safety."""
from scipy.ndimage import binary_dilation
obstacle_mask = grid >= 50
kernel = np.ones(
(2 * inflation_radius_cells + 1, 2 * inflation_radius_cells + 1),
dtype=bool,
)
inflated = binary_dilation(obstacle_mask, structure=kernel)
result = grid.copy()
result[inflated] = 100
return result

View File

@ -0,0 +1,82 @@
"""follow_modes.py -- Follow mode geometry for saltybot social navigation."""
import math
from enum import Enum
import numpy as np
class FollowMode(Enum):
SHADOW = 'shadow' # stay directly behind at follow_distance
LEAD = 'lead' # move ahead of person by lead_distance
SIDE = 'side' # stay to the right (or left) at side_offset
ORBIT = 'orbit' # circle around person at orbit_radius
LOOSE = 'loose' # general follow, larger tolerance
TIGHT = 'tight' # close follow, small tolerance
def compute_shadow_target(person_pos, person_bearing_deg, follow_dist=1.2):
"""Target position: behind person along their movement direction."""
bearing_rad = math.radians(person_bearing_deg + 180.0)
tx = person_pos[0] + follow_dist * math.sin(bearing_rad)
ty = person_pos[1] + follow_dist * math.cos(bearing_rad)
return (tx, ty, person_pos[2])
def compute_lead_target(person_pos, person_bearing_deg, lead_dist=2.0):
"""Target position: ahead of person."""
bearing_rad = math.radians(person_bearing_deg)
tx = person_pos[0] + lead_dist * math.sin(bearing_rad)
ty = person_pos[1] + lead_dist * math.cos(bearing_rad)
return (tx, ty, person_pos[2])
def compute_side_target(person_pos, person_bearing_deg, side_dist=1.0, right=True):
"""Target position: to the right (or left) of person."""
sign = 1.0 if right else -1.0
bearing_rad = math.radians(person_bearing_deg + sign * 90.0)
tx = person_pos[0] + side_dist * math.sin(bearing_rad)
ty = person_pos[1] + side_dist * math.cos(bearing_rad)
return (tx, ty, person_pos[2])
def compute_orbit_target(person_pos, orbit_angle_deg, orbit_radius=1.5):
"""Target on circle of radius orbit_radius around person."""
angle_rad = math.radians(orbit_angle_deg)
tx = person_pos[0] + orbit_radius * math.sin(angle_rad)
ty = person_pos[1] + orbit_radius * math.cos(angle_rad)
return (tx, ty, person_pos[2])
def compute_loose_target(person_pos, robot_pos, follow_dist=2.0, tolerance=0.8):
"""Only move if farther than follow_dist + tolerance."""
dx = person_pos[0] - robot_pos[0]
dy = person_pos[1] - robot_pos[1]
dist = math.hypot(dx, dy)
if dist <= follow_dist + tolerance:
return robot_pos
# Target at follow_dist behind person (toward robot)
scale = (dist - follow_dist) / dist
return (robot_pos[0] + dx * scale, robot_pos[1] + dy * scale, person_pos[2])
def compute_tight_target(person_pos, follow_dist=0.6):
"""Close follow: stay very near person."""
return (person_pos[0], person_pos[1] - follow_dist, person_pos[2])
MODE_VOICE_COMMANDS = {
'shadow': FollowMode.SHADOW,
'follow me': FollowMode.SHADOW,
'behind me': FollowMode.SHADOW,
'lead': FollowMode.LEAD,
'go ahead': FollowMode.LEAD,
'lead me': FollowMode.LEAD,
'side': FollowMode.SIDE,
'stay beside': FollowMode.SIDE,
'orbit': FollowMode.ORBIT,
'circle me': FollowMode.ORBIT,
'loose': FollowMode.LOOSE,
'give me space': FollowMode.LOOSE,
'tight': FollowMode.TIGHT,
'stay close': FollowMode.TIGHT,
}

View File

@ -0,0 +1,231 @@
"""
midas_depth_node.py -- MiDaS monocular depth estimation for saltybot.
Uses MiDaS_small via ONNX Runtime or TensorRT FP16.
Provides relative depth estimates for cameras without active depth (CSI cameras).
Publishes /social/depth/cam{i}/image (sensor_msgs/Image, float32, relative depth)
"""
import os
import numpy as np
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from sensor_msgs.msg import Image
from cv_bridge import CvBridge
# MiDaS_small input size
_MIDAS_H = 256
_MIDAS_W = 256
# ImageNet normalization
_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
class _TRTBackend:
"""TensorRT inference backend for MiDaS."""
def __init__(self, engine_path: str, logger):
self._logger = logger
try:
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit # noqa: F401
self._cuda = cuda
rt_logger = trt.Logger(trt.Logger.WARNING)
with open(engine_path, 'rb') as f:
engine = trt.Runtime(rt_logger).deserialize_cuda_engine(f.read())
self._context = engine.create_execution_context()
# Allocate buffers
self._d_input = cuda.mem_alloc(1 * 3 * _MIDAS_H * _MIDAS_W * 4)
self._d_output = cuda.mem_alloc(1 * _MIDAS_H * _MIDAS_W * 4)
self._h_output = np.empty((_MIDAS_H, _MIDAS_W), dtype=np.float32)
self._stream = cuda.Stream()
self._logger.info(f'TRT engine loaded: {engine_path}')
except Exception as e:
raise RuntimeError(f'TRT init failed: {e}')
def infer(self, input_tensor: np.ndarray) -> np.ndarray:
self._cuda.memcpy_htod_async(
self._d_input, input_tensor.ravel(), self._stream)
self._context.execute_async_v2(
bindings=[int(self._d_input), int(self._d_output)],
stream_handle=self._stream.handle)
self._cuda.memcpy_dtoh_async(
self._h_output, self._d_output, self._stream)
self._stream.synchronize()
return self._h_output.copy()
class _ONNXBackend:
"""ONNX Runtime inference backend for MiDaS."""
def __init__(self, onnx_path: str, logger):
self._logger = logger
try:
import onnxruntime as ort
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
self._session = ort.InferenceSession(onnx_path, providers=providers)
self._input_name = self._session.get_inputs()[0].name
self._logger.info(f'ONNX model loaded: {onnx_path}')
except Exception as e:
raise RuntimeError(f'ONNX init failed: {e}')
def infer(self, input_tensor: np.ndarray) -> np.ndarray:
result = self._session.run(None, {self._input_name: input_tensor})
return result[0].squeeze()
class MiDaSDepthNode(Node):
"""MiDaS monocular depth estimation node."""
def __init__(self):
super().__init__('midas_depth')
# Parameters
self.declare_parameter('onnx_path',
'/mnt/nvme/saltybot/models/midas_small.onnx')
self.declare_parameter('engine_path',
'/mnt/nvme/saltybot/models/midas_small.engine')
self.declare_parameter('camera_topics', [
'/surround/cam0/image_raw',
'/surround/cam1/image_raw',
'/surround/cam2/image_raw',
'/surround/cam3/image_raw',
])
self.declare_parameter('output_scale', 1.0)
self.declare_parameter('process_rate', 5.0)
onnx_path = self.get_parameter('onnx_path').value
engine_path = self.get_parameter('engine_path').value
self._camera_topics = self.get_parameter('camera_topics').value
self._output_scale = self.get_parameter('output_scale').value
process_rate = self.get_parameter('process_rate').value
# Initialize inference backend (TRT preferred, ONNX fallback)
self._backend = None
if os.path.exists(engine_path):
try:
self._backend = _TRTBackend(engine_path, self.get_logger())
except RuntimeError:
self.get_logger().warn('TRT failed, trying ONNX fallback')
if self._backend is None and os.path.exists(onnx_path):
try:
self._backend = _ONNXBackend(onnx_path, self.get_logger())
except RuntimeError:
self.get_logger().error('Both TRT and ONNX backends failed')
if self._backend is None:
self.get_logger().error(
'No MiDaS model found. Depth estimation disabled.')
self._bridge = CvBridge()
# Latest frames per camera (round-robin processing)
self._latest_frames = [None] * len(self._camera_topics)
self._current_cam_idx = 0
# QoS for camera subscriptions
cam_qos = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=1,
)
# Subscribe to each camera topic
self._cam_subs = []
for i, topic in enumerate(self._camera_topics):
sub = self.create_subscription(
Image, topic,
lambda msg, idx=i: self._on_image(msg, idx),
cam_qos)
self._cam_subs.append(sub)
# Publishers: one per camera
self._depth_pubs = []
for i in range(len(self._camera_topics)):
pub = self.create_publisher(
Image, f'/social/depth/cam{i}/image', 10)
self._depth_pubs.append(pub)
# Timer: round-robin across cameras
timer_period = 1.0 / process_rate
self._timer = self.create_timer(timer_period, self._timer_callback)
self.get_logger().info(
f'MiDaS depth node started: {len(self._camera_topics)} cameras '
f'@ {process_rate} Hz')
def _on_image(self, msg: Image, cam_idx: int):
"""Cache latest frame for each camera."""
self._latest_frames[cam_idx] = msg
def _preprocess(self, bgr: np.ndarray) -> np.ndarray:
"""Preprocess BGR image to MiDaS input tensor [1,3,256,256]."""
import cv2
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
resized = cv2.resize(rgb, (_MIDAS_W, _MIDAS_H),
interpolation=cv2.INTER_LINEAR)
normalized = (resized.astype(np.float32) / 255.0 - _MEAN) / _STD
# HWC -> CHW -> NCHW
tensor = normalized.transpose(2, 0, 1)[np.newaxis, ...]
return tensor.astype(np.float32)
def _infer(self, tensor: np.ndarray) -> np.ndarray:
"""Run inference, returns [256,256] float32 relative inverse depth."""
if self._backend is None:
return np.zeros((_MIDAS_H, _MIDAS_W), dtype=np.float32)
return self._backend.infer(tensor)
def _postprocess(self, raw: np.ndarray, orig_shape: tuple) -> np.ndarray:
"""Resize depth back to original image shape, apply output_scale."""
import cv2
h, w = orig_shape[:2]
depth = cv2.resize(raw, (w, h), interpolation=cv2.INTER_LINEAR)
depth = depth * self._output_scale
return depth
def _timer_callback(self):
"""Process one camera per tick (round-robin)."""
if not self._camera_topics:
return
idx = self._current_cam_idx
self._current_cam_idx = (idx + 1) % len(self._camera_topics)
msg = self._latest_frames[idx]
if msg is None:
return
try:
bgr = self._bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')
except Exception as e:
self.get_logger().warn(f'cv_bridge error cam{idx}: {e}')
return
tensor = self._preprocess(bgr)
raw_depth = self._infer(tensor)
depth_map = self._postprocess(raw_depth, bgr.shape)
# Publish as float32 Image
depth_msg = self._bridge.cv2_to_imgmsg(depth_map, encoding='32FC1')
depth_msg.header = msg.header
self._depth_pubs[idx].publish(depth_msg)
def main(args=None):
rclpy.init(args=args)
node = MiDaSDepthNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,584 @@
"""
social_nav_node.py -- Social navigation node for saltybot.
Orchestrates person following with multiple modes, voice commands,
waypoint teaching/replay, and A* obstacle avoidance.
Follow modes:
shadow -- stay directly behind at follow_distance
lead -- move ahead of person
side -- stay beside (default right)
orbit -- circle around person
loose -- relaxed follow with deadband
tight -- close follow
Waypoint teaching:
Voice command "teach route <name>" -> record mode ON
Voice command "stop teaching" -> save route
Voice command "replay route <name>" -> playback
Voice commands:
"follow me" / "shadow" -> SHADOW mode
"lead me" / "go ahead" -> LEAD mode
"stay beside" -> SIDE mode
"orbit" -> ORBIT mode
"give me space" -> LOOSE mode
"stay close" -> TIGHT mode
"stop" / "halt" -> STOP
"go home" -> navigate to home waypoint
"teach route <name>" -> start recording
"stop teaching" -> finish recording
"replay route <name>" -> playback recorded route
"""
import math
import time
import re
from collections import deque
import numpy as np
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from geometry_msgs.msg import Twist, PoseStamped
from nav_msgs.msg import OccupancyGrid, Odometry
from std_msgs.msg import String, Int32
from .follow_modes import (
FollowMode, MODE_VOICE_COMMANDS,
compute_shadow_target, compute_lead_target, compute_side_target,
compute_orbit_target, compute_loose_target, compute_tight_target,
)
from .astar import astar, inflate_obstacles
from .waypoint_teacher import WaypointRoute, WaypointReplayer
# Try importing social msgs; fallback gracefully
try:
from saltybot_social_msgs.msg import PersonStateArray
_HAS_SOCIAL_MSGS = True
except ImportError:
_HAS_SOCIAL_MSGS = False
# Proportional controller gains
_K_ANG = 2.0 # angular gain
_K_LIN = 0.8 # linear gain
_HIGH_SPEED_THRESHOLD = 3.0 # m/s person velocity triggers fast mode
_PREDICT_AHEAD_S = 0.3 # seconds to extrapolate position
_TEACH_MIN_DIST = 0.5 # meters between recorded waypoints
class SocialNavNode(Node):
"""Main social navigation orchestrator."""
def __init__(self):
super().__init__('social_nav')
# -- Parameters --
self.declare_parameter('follow_mode', 'shadow')
self.declare_parameter('follow_distance', 1.2)
self.declare_parameter('lead_distance', 2.0)
self.declare_parameter('orbit_radius', 1.5)
self.declare_parameter('max_linear_speed', 1.0)
self.declare_parameter('max_linear_speed_fast', 5.5)
self.declare_parameter('max_angular_speed', 1.0)
self.declare_parameter('goal_tolerance', 0.3)
self.declare_parameter('routes_dir', '/mnt/nvme/saltybot/routes')
self.declare_parameter('home_x', 0.0)
self.declare_parameter('home_y', 0.0)
self.declare_parameter('map_resolution', 0.05)
self.declare_parameter('obstacle_inflation_cells', 3)
self._follow_mode = FollowMode(
self.get_parameter('follow_mode').value)
self._follow_distance = self.get_parameter('follow_distance').value
self._lead_distance = self.get_parameter('lead_distance').value
self._orbit_radius = self.get_parameter('orbit_radius').value
self._max_lin = self.get_parameter('max_linear_speed').value
self._max_lin_fast = self.get_parameter('max_linear_speed_fast').value
self._max_ang = self.get_parameter('max_angular_speed').value
self._goal_tol = self.get_parameter('goal_tolerance').value
self._routes_dir = self.get_parameter('routes_dir').value
self._home_x = self.get_parameter('home_x').value
self._home_y = self.get_parameter('home_y').value
self._map_resolution = self.get_parameter('map_resolution').value
self._inflation_cells = self.get_parameter(
'obstacle_inflation_cells').value
# -- State --
self._robot_x = 0.0
self._robot_y = 0.0
self._robot_yaw = 0.0
self._target_person_pos = None # (x, y, z)
self._target_person_bearing = 0.0
self._target_person_id = -1
self._person_history = deque(maxlen=5) # for velocity estimation
self._stopped = False
self._go_home = False
# Occupancy grid for A*
self._occ_grid = None
self._occ_origin = (0.0, 0.0)
self._occ_resolution = 0.05
self._astar_path = None
# Orbit state
self._orbit_angle = 0.0
# Waypoint teaching / replay
self._teaching = False
self._current_route = None
self._last_teach_x = None
self._last_teach_y = None
self._replayer = None
# -- QoS profiles --
best_effort_qos = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST, depth=1)
reliable_qos = QoSProfile(
reliability=ReliabilityPolicy.RELIABLE,
history=HistoryPolicy.KEEP_LAST, depth=1)
# -- Subscriptions --
if _HAS_SOCIAL_MSGS:
self.create_subscription(
PersonStateArray, '/social/persons',
self._on_persons, best_effort_qos)
else:
self.get_logger().warn(
'saltybot_social_msgs not found; '
'using /person/target fallback')
self.create_subscription(
PoseStamped, '/person/target',
self._on_person_target, best_effort_qos)
self.create_subscription(
String, '/social/speech/command',
self._on_voice_command, 10)
self.create_subscription(
String, '/social/speech/transcript',
self._on_transcript, 10)
self.create_subscription(
OccupancyGrid, '/map',
self._on_map, reliable_qos)
self.create_subscription(
Odometry, '/odom',
self._on_odom, best_effort_qos)
self.create_subscription(
Int32, '/social/attention/target_id',
self._on_target_id, 10)
# -- Publishers --
self._cmd_vel_pub = self.create_publisher(
Twist, '/cmd_vel', best_effort_qos)
self._mode_pub = self.create_publisher(
String, '/social/nav/mode', reliable_qos)
self._target_pub = self.create_publisher(
PoseStamped, '/social/nav/target_pos', 10)
self._status_pub = self.create_publisher(
String, '/social/nav/status', best_effort_qos)
# -- Main loop timer (20 Hz) --
self._timer = self.create_timer(0.05, self._control_loop)
self.get_logger().info(
f'Social nav started: mode={self._follow_mode.value}, '
f'dist={self._follow_distance}m')
# ----------------------------------------------------------------
# Subscriptions
# ----------------------------------------------------------------
def _on_persons(self, msg):
"""Handle PersonStateArray from social perception."""
target_id = msg.primary_attention_id
if self._target_person_id >= 0:
target_id = self._target_person_id
for p in msg.persons:
if p.person_id == target_id or (
target_id < 0 and len(msg.persons) > 0):
pos = (p.position.x, p.position.y, p.position.z)
self._update_person_position(pos, p.bearing_deg)
break
def _on_person_target(self, msg: PoseStamped):
"""Fallback: single person target pose."""
pos = (msg.pose.position.x, msg.pose.position.y,
msg.pose.position.z)
# Estimate bearing from quaternion yaw
q = msg.pose.orientation
yaw = math.atan2(2.0 * (q.w * q.z + q.x * q.y),
1.0 - 2.0 * (q.y * q.y + q.z * q.z))
self._update_person_position(pos, math.degrees(yaw))
def _update_person_position(self, pos, bearing_deg):
"""Update person tracking state and record history."""
now = time.time()
self._target_person_pos = pos
self._target_person_bearing = bearing_deg
self._person_history.append((now, pos[0], pos[1]))
def _on_odom(self, msg: Odometry):
"""Update robot pose from odometry."""
self._robot_x = msg.pose.pose.position.x
self._robot_y = msg.pose.pose.position.y
q = msg.pose.pose.orientation
self._robot_yaw = math.atan2(
2.0 * (q.w * q.z + q.x * q.y),
1.0 - 2.0 * (q.y * q.y + q.z * q.z))
def _on_map(self, msg: OccupancyGrid):
"""Cache occupancy grid for A* planning."""
w, h = msg.info.width, msg.info.height
data = np.array(msg.data, dtype=np.int8).reshape((h, w))
# Convert -1 (unknown) to free (0) for planning
data[data < 0] = 0
self._occ_grid = data.astype(np.int32)
self._occ_origin = (msg.info.origin.position.x,
msg.info.origin.position.y)
self._occ_resolution = msg.info.resolution
def _on_target_id(self, msg: Int32):
"""Switch target person."""
self._target_person_id = msg.data
self.get_logger().info(f'Target person ID set to {msg.data}')
def _on_voice_command(self, msg: String):
"""Handle discrete voice commands for mode switching."""
cmd = msg.data.strip().lower()
if cmd in ('stop', 'halt'):
self._stopped = True
self._replayer = None
self._publish_status('STOPPED')
return
if cmd in ('resume', 'go', 'start'):
self._stopped = False
self._publish_status('RESUMED')
return
matched = MODE_VOICE_COMMANDS.get(cmd)
if matched:
self._follow_mode = matched
self._stopped = False
mode_msg = String()
mode_msg.data = self._follow_mode.value
self._mode_pub.publish(mode_msg)
self._publish_status(f'MODE: {self._follow_mode.value}')
def _on_transcript(self, msg: String):
"""Handle free-form voice transcripts for route teaching."""
text = msg.data.strip().lower()
# "teach route <name>"
m = re.match(r'teach\s+route\s+(\w+)', text)
if m:
name = m.group(1)
self._teaching = True
self._current_route = WaypointRoute(name)
self._last_teach_x = self._robot_x
self._last_teach_y = self._robot_y
self._publish_status(f'TEACHING: {name}')
self.get_logger().info(f'Recording route: {name}')
return
# "stop teaching"
if 'stop teaching' in text:
if self._teaching and self._current_route:
self._current_route.save(self._routes_dir)
self._publish_status(
f'SAVED: {self._current_route.name} '
f'({len(self._current_route.waypoints)} pts)')
self.get_logger().info(
f'Route saved: {self._current_route.name}')
self._teaching = False
self._current_route = None
return
# "replay route <name>"
m = re.match(r'replay\s+route\s+(\w+)', text)
if m:
name = m.group(1)
try:
route = WaypointRoute.load(self._routes_dir, name)
self._replayer = WaypointReplayer(route)
self._stopped = False
self._publish_status(f'REPLAY: {name}')
self.get_logger().info(f'Replaying route: {name}')
except FileNotFoundError:
self._publish_status(f'ROUTE NOT FOUND: {name}')
return
# "go home"
if 'go home' in text:
self._go_home = True
self._stopped = False
self._publish_status('GO HOME')
return
# Also try mode commands from transcript
for phrase, mode in MODE_VOICE_COMMANDS.items():
if phrase in text:
self._follow_mode = mode
self._stopped = False
mode_msg = String()
mode_msg.data = self._follow_mode.value
self._mode_pub.publish(mode_msg)
self._publish_status(f'MODE: {self._follow_mode.value}')
return
# ----------------------------------------------------------------
# Control loop
# ----------------------------------------------------------------
def _control_loop(self):
"""Main 20Hz control loop."""
# Record waypoint if teaching
if self._teaching and self._current_route:
self._maybe_record_waypoint()
# Publish zero velocity if stopped
if self._stopped:
self._publish_cmd_vel(0.0, 0.0)
return
# Determine navigation target
target = self._get_nav_target()
if target is None:
self._publish_cmd_vel(0.0, 0.0)
return
tx, ty, tz = target
# Publish debug target
self._publish_target_pose(tx, ty, tz)
# Check if arrived
dist_to_target = math.hypot(tx - self._robot_x, ty - self._robot_y)
if dist_to_target < self._goal_tol:
self._publish_cmd_vel(0.0, 0.0)
if self._go_home:
self._go_home = False
self._publish_status('HOME REACHED')
return
# Try A* path if map available
if self._occ_grid is not None:
path_target = self._plan_astar(tx, ty)
if path_target:
tx, ty = path_target
# Determine speed limit
max_lin = self._max_lin
person_vel = self._estimate_person_velocity()
if person_vel > _HIGH_SPEED_THRESHOLD:
max_lin = self._max_lin_fast
# Compute and publish cmd_vel
lin, ang = self._compute_cmd_vel(
self._robot_x, self._robot_y, self._robot_yaw,
tx, ty, max_lin)
self._publish_cmd_vel(lin, ang)
def _get_nav_target(self):
"""Determine current navigation target based on mode/state."""
# Route replay takes priority
if self._replayer and not self._replayer.is_done:
self._replayer.check_arrived(self._robot_x, self._robot_y)
wp = self._replayer.current_waypoint()
if wp:
return (wp.x, wp.y, wp.z)
else:
self._replayer = None
self._publish_status('REPLAY DONE')
return None
# Go home
if self._go_home:
return (self._home_x, self._home_y, 0.0)
# Person following
if self._target_person_pos is None:
return None
# Predict person position ahead for high-speed tracking
px, py, pz = self._predict_person_position()
bearing = self._target_person_bearing
robot_pos = (self._robot_x, self._robot_y, 0.0)
if self._follow_mode == FollowMode.SHADOW:
return compute_shadow_target(
(px, py, pz), bearing, self._follow_distance)
elif self._follow_mode == FollowMode.LEAD:
return compute_lead_target(
(px, py, pz), bearing, self._lead_distance)
elif self._follow_mode == FollowMode.SIDE:
return compute_side_target(
(px, py, pz), bearing, self._follow_distance)
elif self._follow_mode == FollowMode.ORBIT:
self._orbit_angle = (self._orbit_angle + 1.0) % 360.0
return compute_orbit_target(
(px, py, pz), self._orbit_angle, self._orbit_radius)
elif self._follow_mode == FollowMode.LOOSE:
return compute_loose_target(
(px, py, pz), robot_pos, self._follow_distance)
elif self._follow_mode == FollowMode.TIGHT:
return compute_tight_target(
(px, py, pz), self._follow_distance)
return (px, py, pz)
def _predict_person_position(self):
"""Extrapolate person position using velocity from recent samples."""
if self._target_person_pos is None:
return (0.0, 0.0, 0.0)
px, py, pz = self._target_person_pos
if len(self._person_history) >= 3:
# Use last 3 samples for velocity estimation
t0, x0, y0 = self._person_history[-3]
t1, x1, y1 = self._person_history[-1]
dt = t1 - t0
if dt > 0.01:
vx = (x1 - x0) / dt
vy = (y1 - y0) / dt
speed = math.hypot(vx, vy)
if speed > _HIGH_SPEED_THRESHOLD:
px += vx * _PREDICT_AHEAD_S
py += vy * _PREDICT_AHEAD_S
return (px, py, pz)
def _estimate_person_velocity(self) -> float:
"""Estimate person speed from recent position history."""
if len(self._person_history) < 2:
return 0.0
t0, x0, y0 = self._person_history[-2]
t1, x1, y1 = self._person_history[-1]
dt = t1 - t0
if dt < 0.01:
return 0.0
return math.hypot(x1 - x0, y1 - y0) / dt
def _plan_astar(self, target_x, target_y):
"""Run A* on occupancy grid, return next waypoint in world coords."""
grid = self._occ_grid
res = self._occ_resolution
ox, oy = self._occ_origin
# World to grid
def w2g(wx, wy):
return (int((wy - oy) / res), int((wx - ox) / res))
start = w2g(self._robot_x, self._robot_y)
goal = w2g(target_x, target_y)
rows, cols = grid.shape
if not (0 <= start[0] < rows and 0 <= start[1] < cols):
return None
if not (0 <= goal[0] < rows and 0 <= goal[1] < cols):
return None
inflated = inflate_obstacles(grid, self._inflation_cells)
path = astar(inflated, start, goal)
if path and len(path) > 1:
# Follow a lookahead point (3 steps ahead or end)
lookahead_idx = min(3, len(path) - 1)
r, c = path[lookahead_idx]
wx = ox + c * res + res / 2.0
wy = oy + r * res + res / 2.0
self._astar_path = path
return (wx, wy)
return None
def _compute_cmd_vel(self, rx, ry, ryaw, tx, ty, max_lin):
"""Proportional controller: compute linear and angular velocity."""
dx = tx - rx
dy = ty - ry
dist = math.hypot(dx, dy)
angle_to_target = math.atan2(dy, dx)
angle_error = angle_to_target - ryaw
# Normalize angle error to [-pi, pi]
while angle_error > math.pi:
angle_error -= 2.0 * math.pi
while angle_error < -math.pi:
angle_error += 2.0 * math.pi
angular_vel = _K_ANG * angle_error
angular_vel = max(-self._max_ang,
min(self._max_ang, angular_vel))
# Reduce linear speed when turning hard
angle_factor = max(0.0, 1.0 - abs(angle_error) / (math.pi / 2.0))
linear_vel = _K_LIN * dist * angle_factor
linear_vel = max(0.0, min(max_lin, linear_vel))
return (linear_vel, angular_vel)
# ----------------------------------------------------------------
# Waypoint teaching
# ----------------------------------------------------------------
def _maybe_record_waypoint(self):
"""Record waypoint if robot moved > _TEACH_MIN_DIST."""
if self._last_teach_x is None:
self._last_teach_x = self._robot_x
self._last_teach_y = self._robot_y
dist = math.hypot(
self._robot_x - self._last_teach_x,
self._robot_y - self._last_teach_y)
if dist >= _TEACH_MIN_DIST:
yaw_deg = math.degrees(self._robot_yaw)
self._current_route.add(
self._robot_x, self._robot_y, 0.0, yaw_deg)
self._last_teach_x = self._robot_x
self._last_teach_y = self._robot_y
# ----------------------------------------------------------------
# Publishers
# ----------------------------------------------------------------
def _publish_cmd_vel(self, linear: float, angular: float):
twist = Twist()
twist.linear.x = linear
twist.angular.z = angular
self._cmd_vel_pub.publish(twist)
def _publish_target_pose(self, x, y, z):
msg = PoseStamped()
msg.header.stamp = self.get_clock().now().to_msg()
msg.header.frame_id = 'map'
msg.pose.position.x = x
msg.pose.position.y = y
msg.pose.position.z = z
self._target_pub.publish(msg)
def _publish_status(self, status: str):
msg = String()
msg.data = status
self._status_pub.publish(msg)
self.get_logger().info(f'Nav status: {status}')
def main(args=None):
rclpy.init(args=args)
node = SocialNavNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,91 @@
"""waypoint_teacher.py -- Record and replay waypoint routes."""
import json
import time
import math
from pathlib import Path
from dataclasses import dataclass, asdict
@dataclass
class Waypoint:
x: float
y: float
z: float
yaw_deg: float
timestamp: float
label: str = ''
class WaypointRoute:
"""A named sequence of waypoints."""
def __init__(self, name: str):
self.name = name
self.waypoints: list[Waypoint] = []
self.created_at = time.time()
def add(self, x, y, z, yaw_deg, label=''):
self.waypoints.append(Waypoint(x, y, z, yaw_deg, time.time(), label))
def to_dict(self):
return {
'name': self.name,
'created_at': self.created_at,
'waypoints': [asdict(w) for w in self.waypoints],
}
@classmethod
def from_dict(cls, d):
route = cls(d['name'])
route.created_at = d.get('created_at', 0)
route.waypoints = [Waypoint(**w) for w in d['waypoints']]
return route
def save(self, routes_dir: str):
Path(routes_dir).mkdir(parents=True, exist_ok=True)
path = Path(routes_dir) / f'{self.name}.json'
path.write_text(json.dumps(self.to_dict(), indent=2))
@classmethod
def load(cls, routes_dir: str, name: str):
path = Path(routes_dir) / f'{name}.json'
return cls.from_dict(json.loads(path.read_text()))
@staticmethod
def list_routes(routes_dir: str) -> list[str]:
d = Path(routes_dir)
if not d.exists():
return []
return [p.stem for p in d.glob('*.json')]
class WaypointReplayer:
"""Iterates through waypoints, returning next target."""
def __init__(self, route: WaypointRoute, arrival_radius: float = 0.3):
self._route = route
self._idx = 0
self._arrival_radius = arrival_radius
def current_waypoint(self) -> Waypoint | None:
if self._idx < len(self._route.waypoints):
return self._route.waypoints[self._idx]
return None
def check_arrived(self, robot_x, robot_y) -> bool:
wp = self.current_waypoint()
if wp is None:
return False
dist = math.hypot(robot_x - wp.x, robot_y - wp.y)
if dist < self._arrival_radius:
self._idx += 1
return True
return False
@property
def is_done(self) -> bool:
return self._idx >= len(self._route.waypoints)
def reset(self):
self._idx = 0

View File

@ -0,0 +1,135 @@
"""
waypoint_teacher_node.py -- Standalone waypoint teacher ROS2 node.
Listens to /social/speech/transcript for "teach route <name>" and "stop teaching".
Records robot pose at configurable intervals. Saves/loads routes via WaypointRoute.
"""
import math
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from nav_msgs.msg import Odometry
from std_msgs.msg import String
from .waypoint_teacher import WaypointRoute
class WaypointTeacherNode(Node):
"""Standalone waypoint teaching node."""
def __init__(self):
super().__init__('waypoint_teacher')
self.declare_parameter('routes_dir', '/mnt/nvme/saltybot/routes')
self.declare_parameter('recording_interval', 0.5) # meters
self._routes_dir = self.get_parameter('routes_dir').value
self._interval = self.get_parameter('recording_interval').value
self._teaching = False
self._route = None
self._last_x = None
self._last_y = None
self._robot_x = 0.0
self._robot_y = 0.0
self._robot_yaw = 0.0
best_effort_qos = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST, depth=1)
self.create_subscription(
Odometry, '/odom', self._on_odom, best_effort_qos)
self.create_subscription(
String, '/social/speech/transcript',
self._on_transcript, 10)
self._status_pub = self.create_publisher(
String, '/social/waypoint/status', 10)
# Record timer at 10Hz (check distance)
self._timer = self.create_timer(0.1, self._record_tick)
self.get_logger().info(
f'Waypoint teacher ready (interval={self._interval}m, '
f'dir={self._routes_dir})')
def _on_odom(self, msg: Odometry):
self._robot_x = msg.pose.pose.position.x
self._robot_y = msg.pose.pose.position.y
q = msg.pose.pose.orientation
self._robot_yaw = math.atan2(
2.0 * (q.w * q.z + q.x * q.y),
1.0 - 2.0 * (q.y * q.y + q.z * q.z))
def _on_transcript(self, msg: String):
text = msg.data.strip().lower()
import re
m = re.match(r'teach\s+route\s+(\w+)', text)
if m:
name = m.group(1)
self._route = WaypointRoute(name)
self._teaching = True
self._last_x = self._robot_x
self._last_y = self._robot_y
self._pub_status(f'RECORDING: {name}')
self.get_logger().info(f'Recording route: {name}')
return
if 'stop teaching' in text:
if self._teaching and self._route:
self._route.save(self._routes_dir)
n = len(self._route.waypoints)
self._pub_status(
f'SAVED: {self._route.name} ({n} waypoints)')
self.get_logger().info(
f'Route saved: {self._route.name} ({n} pts)')
self._teaching = False
self._route = None
return
if 'list routes' in text:
routes = WaypointRoute.list_routes(self._routes_dir)
self._pub_status(f'ROUTES: {", ".join(routes) or "(none)"}')
def _record_tick(self):
if not self._teaching or self._route is None:
return
if self._last_x is None:
self._last_x = self._robot_x
self._last_y = self._robot_y
dist = math.hypot(
self._robot_x - self._last_x,
self._robot_y - self._last_y)
if dist >= self._interval:
yaw_deg = math.degrees(self._robot_yaw)
self._route.add(self._robot_x, self._robot_y, 0.0, yaw_deg)
self._last_x = self._robot_x
self._last_y = self._robot_y
def _pub_status(self, text: str):
msg = String()
msg.data = text
self._status_pub.publish(msg)
def main(args=None):
rclpy.init(args=args)
node = WaypointTeacherNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,80 @@
#!/usr/bin/env python3
"""
build_midas_trt_engine.py -- Build TensorRT FP16 engine for MiDaS_small from ONNX.
Usage:
python3 build_midas_trt_engine.py \
--onnx /mnt/nvme/saltybot/models/midas_small.onnx \
--engine /mnt/nvme/saltybot/models/midas_small.engine \
--fp16
Requires: tensorrt, pycuda
"""
import argparse
import os
import sys
def build_engine(onnx_path: str, engine_path: str, fp16: bool = True):
try:
import tensorrt as trt
except ImportError:
print('ERROR: tensorrt not found. Install TensorRT first.')
sys.exit(1)
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
print(f'Parsing ONNX model: {onnx_path}')
with open(onnx_path, 'rb') as f:
if not parser.parse(f.read()):
for i in range(parser.num_errors):
print(f' ONNX parse error: {parser.get_error(i)}')
sys.exit(1)
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB
if fp16 and builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
print('FP16 mode enabled')
elif fp16:
print('WARNING: FP16 not supported on this platform, using FP32')
print('Building TensorRT engine (this may take several minutes)...')
engine_bytes = builder.build_serialized_network(network, config)
if engine_bytes is None:
print('ERROR: Failed to build engine')
sys.exit(1)
os.makedirs(os.path.dirname(engine_path) or '.', exist_ok=True)
with open(engine_path, 'wb') as f:
f.write(engine_bytes)
size_mb = len(engine_bytes) / (1024 * 1024)
print(f'Engine saved: {engine_path} ({size_mb:.1f} MB)')
def main():
parser = argparse.ArgumentParser(
description='Build TensorRT FP16 engine for MiDaS_small')
parser.add_argument('--onnx', required=True,
help='Path to MiDaS ONNX model')
parser.add_argument('--engine', required=True,
help='Output TRT engine path')
parser.add_argument('--fp16', action='store_true', default=True,
help='Enable FP16 (default: True)')
parser.add_argument('--fp32', action='store_true',
help='Force FP32 (disable FP16)')
args = parser.parse_args()
fp16 = not args.fp32
build_engine(args.onnx, args.engine, fp16=fp16)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/saltybot_social_nav
[install]
install_scripts=$base/lib/saltybot_social_nav

View File

@ -0,0 +1,31 @@
from setuptools import setup
import os
from glob import glob
package_name = 'saltybot_social_nav'
setup(
name=package_name,
version='0.1.0',
packages=[package_name],
data_files=[
('share/ament_index/resource_index/packages', ['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
(os.path.join('share', package_name, 'launch'), glob('launch/*.py')),
(os.path.join('share', package_name, 'config'), glob('config/*.yaml')),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='seb',
maintainer_email='seb@vayrette.com',
description='Social navigation for saltybot: follow modes, waypoint teaching, A* avoidance, MiDaS depth',
license='MIT',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'social_nav = saltybot_social_nav.social_nav_node:main',
'midas_depth = saltybot_social_nav.midas_depth_node:main',
'waypoint_teacher = saltybot_social_nav.waypoint_teacher_node:main',
],
},
)

View File

@ -0,0 +1,12 @@
# Copyright 2026 SaltyLab
# Licensed under MIT
from ament_copyright.main import main
import pytest
@pytest.mark.copyright
@pytest.mark.linter
def test_copyright():
rc = main(argv=['.', 'test'])
assert rc == 0, 'Found errors'

View File

@ -0,0 +1,14 @@
# Copyright 2026 SaltyLab
# Licensed under MIT
from ament_flake8.main import main_with_errors
import pytest
@pytest.mark.flake8
@pytest.mark.linter
def test_flake8():
rc, errors = main_with_errors(argv=[])
assert rc == 0, \
'Found %d code style errors / warnings:\n' % len(errors) + \
'\n'.join(errors)

View File

@ -0,0 +1,12 @@
# Copyright 2026 SaltyLab
# Licensed under MIT
from ament_pep257.main import main
import pytest
@pytest.mark.pep257
@pytest.mark.linter
def test_pep257():
rc = main(argv=['.', 'test'])
assert rc == 0, 'Found code style errors / warnings'

View File

@ -0,0 +1,5 @@
__pycache__/
*.pyc
*.pyo
*.egg-info/
.pytest_cache/

View File

@ -0,0 +1,48 @@
# tracking_params.yaml — saltybot_social_tracking / TrackingFusionNode
#
# Run with:
# ros2 launch saltybot_social_tracking tracking.launch.py
#
# Topics consumed:
# /uwb/target (geometry_msgs/PoseStamped) — UWB triangulated position
# /person/target (geometry_msgs/PoseStamped) — camera-detected position
#
# Topic produced:
# /social/tracking/fused_target (saltybot_social_msgs/FusedTarget)
# ── Source staleness timeouts ──────────────────────────────────────────────────
# UWB driver publishes at ~10 Hz; 1.5 s = 15 missed cycles before declared stale.
uwb_timeout: 1.5 # seconds
# Camera detector publishes at ~30 Hz; 1.0 s = 30 missed frames before stale.
cam_timeout: 1.0 # seconds
# How long the Kalman filter may coast (dead-reckoning) with no live source
# before the node stops publishing.
# At 10 m/s (EUC top-speed) the robot drifts ≈30 m over 3 s — beyond the UWB
# follow-range, so 3 s is a reasonable hard stop.
predict_timeout: 3.0 # seconds
# ── Kalman filter tuning ───────────────────────────────────────────────────────
# process_noise: acceleration noise std-dev (m/s²).
# EUC riders can brake or accelerate at ~35 m/s²; 3.0 is a good starting point.
# Increase if the filtered track lags behind fast direction changes.
# Decrease if the track is noisy.
process_noise: 3.0 # m/s²
# UWB position measurement noise (std-dev, metres).
# DW3000 TWR accuracy ≈ ±1020 cm; 0.20 accounts for system-level error.
meas_noise_uwb: 0.20 # m
# Camera position noise (std-dev, metres).
# Depth reprojection error with RealSense D435i at 13 m ≈ ±515 cm.
meas_noise_cam: 0.12 # m
# ── Control loop ──────────────────────────────────────────────────────────────
control_rate: 20.0 # Hz — KF predict + publish rate
# ── Source arbiter ────────────────────────────────────────────────────────────
# Minimum normalised confidence for a source to be considered live.
# Range [0, 1]; lower = more permissive; default 0.15 keeps slightly stale
# sources active rather than dropping to "predicted" prematurely.
confidence_threshold: 0.15

View File

@ -0,0 +1,44 @@
"""tracking.launch.py — launch the TrackingFusionNode with default params."""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
pkg_share = get_package_share_directory("saltybot_social_tracking")
default_params = os.path.join(pkg_share, "config", "tracking_params.yaml")
return LaunchDescription([
DeclareLaunchArgument(
"params_file",
default_value=default_params,
description="Path to tracking fusion parameter YAML file",
),
DeclareLaunchArgument(
"control_rate",
default_value="20.0",
description="KF predict + publish rate (Hz)",
),
DeclareLaunchArgument(
"predict_timeout",
default_value="3.0",
description="Max KF coast time before stopping publish (s)",
),
Node(
package="saltybot_social_tracking",
executable="tracking_fusion_node",
name="tracking_fusion",
output="screen",
parameters=[
LaunchConfiguration("params_file"),
{
"control_rate": LaunchConfiguration("control_rate"),
"predict_timeout": LaunchConfiguration("predict_timeout"),
},
],
),
])

View File

@ -0,0 +1,31 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social_tracking</name>
<version>0.1.0</version>
<description>
Multi-modal tracking fusion for saltybot.
Fuses UWB triangulated position (/uwb/target) and camera-detected position
(/person/target) using a 4-state Kalman filter to produce a smooth, low-latency
fused estimate at /social/tracking/fused_target.
Handles EUC rider speeds (20-30 km/h), signal handoff, and predictive coasting.
</description>
<maintainer email="sl-controls@saltylab.local">sl-controls</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>geometry_msgs</depend>
<depend>std_msgs</depend>
<depend>saltybot_social_msgs</depend>
<buildtool_depend>ament_python</buildtool_depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,134 @@
"""
kalman_tracker.py 4-state linear Kalman filter for 2-D position+velocity tracking.
State vector: [x, y, vx, vy]
Observation: [x_meas, y_meas]
Process model: constant velocity with Wiener process acceleration noise.
Tuned to handle EUC rider speeds (2030 km/h 5.58.3 m/s) with fast
acceleration transients.
Pure module no ROS2 dependency; fully unit-testable.
"""
import math
import numpy as np
class KalmanTracker:
"""
4-state Kalman filter: state = [x, y, vx, vy].
Parameters
----------
process_noise : acceleration noise standard deviation (m/).
Higher values allow the filter to track rapid velocity
changes (EUC acceleration events). Default 3.0 m/.
meas_noise_uwb : UWB position measurement noise std-dev (m). Default 0.20 m.
meas_noise_cam : Camera position measurement noise std-dev (m). Default 0.12 m.
"""
def __init__(
self,
process_noise: float = 3.0,
meas_noise_uwb: float = 0.20,
meas_noise_cam: float = 0.12,
):
self._q = float(process_noise)
self._r_uwb = float(meas_noise_uwb)
self._r_cam = float(meas_noise_cam)
# State [x, y, vx, vy]
self._x = np.zeros(4)
# Covariance — large initial uncertainty (10 m², 10 (m/s)²)
self._P = np.diag([10.0, 10.0, 10.0, 10.0])
# Observation matrix: H * x = [x, y]
self._H = np.array([[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0]])
self._initialized = False
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
@property
def initialized(self) -> bool:
return self._initialized
def initialize(self, x: float, y: float) -> None:
"""Seed the filter at position (x, y) with zero velocity."""
self._x = np.array([x, y, 0.0, 0.0])
self._P = np.diag([1.0, 1.0, 5.0, 5.0])
self._initialized = True
def predict(self, dt: float) -> None:
"""
Advance the filter state by dt seconds.
Uses a discrete Wiener process acceleration model so that positional
uncertainty grows as O(dt^4/4) and velocity uncertainty as O(dt^2).
This lets the filter coast accurately through short signal outages
while still being responsive to EUC velocity changes.
"""
if dt <= 0.0:
return
F = np.array([[1.0, 0.0, dt, 0.0],
[0.0, 1.0, 0.0, dt],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0]])
q = self._q
dt2 = dt * dt
dt3 = dt2 * dt
dt4 = dt3 * dt
Q = (q * q) * np.array([
[dt4 / 4.0, 0.0, dt3 / 2.0, 0.0 ],
[0.0, dt4 / 4.0, 0.0, dt3 / 2.0],
[dt3 / 2.0, 0.0, dt2, 0.0 ],
[0.0, dt3 / 2.0, 0.0, dt2 ],
])
self._x = F @ self._x
self._P = F @ self._P @ F.T + Q
def update(self, x_meas: float, y_meas: float, source: str = "uwb") -> None:
"""
Apply a position measurement (x_meas, y_meas).
source : "uwb" or "camera" selects the appropriate noise covariance.
"""
r = self._r_uwb if source == "uwb" else self._r_cam
R = np.diag([r * r, r * r])
z = np.array([x_meas, y_meas])
innov = z - self._H @ self._x # innovation
S = self._H @ self._P @ self._H.T + R # innovation covariance
K = self._P @ self._H.T @ np.linalg.inv(S) # Kalman gain
self._x = self._x + K @ innov
self._P = (np.eye(4) - K @ self._H) @ self._P
# ------------------------------------------------------------------
# State accessors
# ------------------------------------------------------------------
@property
def position(self) -> tuple:
"""Current filtered position (x, y) in metres."""
return float(self._x[0]), float(self._x[1])
@property
def velocity(self) -> tuple:
"""Current filtered velocity (vx, vy) in m/s."""
return float(self._x[2]), float(self._x[3])
def position_uncertainty_m(self) -> float:
"""RMS positional uncertainty (m) from diagonal of covariance."""
return float(math.sqrt((self._P[0, 0] + self._P[1, 1]) / 2.0))
def covariance_copy(self) -> np.ndarray:
"""Return a copy of the current 4×4 covariance matrix."""
return self._P.copy()

View File

@ -0,0 +1,155 @@
"""
source_arbiter.py Source confidence scoring and selection for tracking fusion.
Two sensor sources are supported:
UWB geometry_msgs/PoseStamped from /uwb/target (triangulated, ~10 Hz)
Camera geometry_msgs/PoseStamped from /person/target (depth+YOLO, ~30 Hz)
Confidence model
----------------
Each source's confidence is its raw measurement quality multiplied by a
linear staleness factor that drops to zero at its respective timeout:
conf = quality * max(0, 1 - age / timeout)
UWB quality is always 1.0 (the ranging hardware confidence is not exposed
by the driver in origin/main; the UWB node already applies Kalman filtering).
Camera quality defaults to 1.0; callers may pass a lower value when the
detection confidence is available.
Source selection
----------------
Both fresh "fused" (confidence-weighted position blend)
UWB only "uwb"
Camera only "camera"
Neither fresh "predicted" (Kalman coasts)
Pure module no ROS2 dependency; fully unit-testable.
"""
import math
def _staleness_factor(age_s: float, timeout_s: float) -> float:
"""Linear decay: 1.0 at age=0, 0.0 at age=timeout, clamped."""
if timeout_s <= 0.0:
return 0.0
return max(0.0, 1.0 - age_s / timeout_s)
def uwb_confidence(age_s: float, timeout_s: float, quality: float = 1.0) -> float:
"""
UWB source confidence.
age_s : seconds since last UWB measurement (0; use large value if never)
timeout_s: staleness threshold (s); confidence reaches 0 at this age
quality : inherent measurement quality [0, 1] (default 1.0)
"""
return quality * _staleness_factor(age_s, timeout_s)
def camera_confidence(
age_s: float, timeout_s: float, quality: float = 1.0
) -> float:
"""
Camera source confidence.
age_s : seconds since last camera detection (0; use large value if never)
timeout_s: staleness threshold (s)
quality : YOLO detection confidence or other quality score [0, 1]
"""
return quality * _staleness_factor(age_s, timeout_s)
def select_source(
uwb_conf: float,
cam_conf: float,
threshold: float = 0.15,
) -> str:
"""
Choose the active tracking source label.
Returns one of: "fused", "uwb", "camera", "predicted".
threshold: minimum confidence for a source to be considered live.
Sources below threshold are ignored.
"""
uwb_ok = uwb_conf >= threshold
cam_ok = cam_conf >= threshold
if uwb_ok and cam_ok:
return "fused"
if uwb_ok:
return "uwb"
if cam_ok:
return "camera"
return "predicted"
def fuse_positions(
uwb_x: float,
uwb_y: float,
uwb_conf: float,
cam_x: float,
cam_y: float,
cam_conf: float,
) -> tuple:
"""
Confidence-weighted position fusion.
Returns (fused_x, fused_y).
When total confidence is zero (shouldn't happen in "fused" state, but
guarded), returns the UWB position as fallback.
"""
total = uwb_conf + cam_conf
if total <= 0.0:
return uwb_x, uwb_y
w = uwb_conf / total
return (
w * uwb_x + (1.0 - w) * cam_x,
w * uwb_y + (1.0 - w) * cam_y,
)
def composite_confidence(
uwb_conf: float,
cam_conf: float,
source: str,
kf_uncertainty_m: float,
max_kf_uncertainty_m: float = 3.0,
) -> float:
"""
Compute a single composite confidence value [0, 1] for the fused output.
source : current source label (from select_source)
kf_uncertainty_m : current KF positional RMS uncertainty
max_kf_uncertainty_m: uncertainty at which confidence collapses to 0
"""
if source == "predicted":
# Decay with growing KF uncertainty; no sensor feeds are live
raw = max(0.0, 1.0 - kf_uncertainty_m / max_kf_uncertainty_m)
return min(0.4, raw) # cap at 0.4 — caller should know this is dead-reckoning
if source == "fused":
raw = max(uwb_conf, cam_conf)
elif source == "uwb":
raw = uwb_conf
else: # "camera"
raw = cam_conf
# Scale by KF health (full confidence only if KF is tight)
kf_health = max(0.0, 1.0 - kf_uncertainty_m / max_kf_uncertainty_m)
return raw * (0.5 + 0.5 * kf_health)
def bearing_and_range(x: float, y: float) -> tuple:
"""
Compute bearing (rad, +ve = left) and range (m) to position (x, y).
Consistent with person_follower_node conventions:
bearing = atan2(y, x) (base_link frame: +X forward, +Y left)
range = sqrt( + )
"""
return math.atan2(y, x), math.sqrt(x * x + y * y)

View File

@ -0,0 +1,257 @@
"""
tracking_fusion_node.py Multi-modal tracking fusion for saltybot.
Subscribes
----------
/uwb/target (geometry_msgs/PoseStamped) UWB-triangulated position (~10 Hz)
/person/target (geometry_msgs/PoseStamped) camera-detected position (~30 Hz)
Publishes
---------
/social/tracking/fused_target (saltybot_social_msgs/FusedTarget) at control_rate Hz
Algorithm
---------
1. Each incoming measurement updates a 4-state Kalman filter [x, y, vx, vy].
2. A 20 Hz timer runs predict+select+publish:
a. KF predict(dt)
b. Compute per-source confidence from measurement age + staleness model
c. If either source is live:
- "fused" confidence-weighted position blend KF update
- "uwb" UWB position KF update
- "camera" camera position KF update
d. Build FusedTarget from KF state + composite confidence
3. If all sources are lost but within predict_timeout, keep publishing with
active_source="predicted" and degrading confidence.
4. Beyond predict_timeout, no message is published (node stays alive).
Kalman tuning for EUC speeds (2030 km/h 5.58.3 m/s)
---------------------------------------------------------
process_noise=3.0 m/ allows rapid acceleration events
predict_timeout=3.0 s coasts 30 m at 10 m/s; acceptable dead-reckoning
Parameters
----------
uwb_timeout : UWB staleness threshold (s) default 1.5
cam_timeout : Camera staleness threshold (s) default 1.0
predict_timeout : Max KF coast before no publish (s) default 3.0
process_noise : KF acceleration noise std-dev (m/) default 3.0
meas_noise_uwb : UWB position noise std-dev (m) default 0.20
meas_noise_cam : Camera position noise std-dev (m) default 0.12
control_rate : Publish / KF predict rate (Hz) default 20.0
confidence_threshold: Min source confidence to use (01) default 0.15
Usage
-----
ros2 launch saltybot_social_tracking tracking.launch.py
"""
import math
import time
import rclpy
from rclpy.node import Node
from geometry_msgs.msg import PoseStamped
from std_msgs.msg import Header
from saltybot_social_msgs.msg import FusedTarget
from saltybot_social_tracking.kalman_tracker import KalmanTracker
from saltybot_social_tracking.source_arbiter import (
uwb_confidence,
camera_confidence,
select_source,
fuse_positions,
composite_confidence,
bearing_and_range,
)
_BIG_AGE = 1e9 # sentinel: source never received
class TrackingFusionNode(Node):
def __init__(self):
super().__init__("tracking_fusion")
# ── Parameters ────────────────────────────────────────────────────────
self.declare_parameter("uwb_timeout", 1.5)
self.declare_parameter("cam_timeout", 1.0)
self.declare_parameter("predict_timeout", 3.0)
self.declare_parameter("process_noise", 3.0)
self.declare_parameter("meas_noise_uwb", 0.20)
self.declare_parameter("meas_noise_cam", 0.12)
self.declare_parameter("control_rate", 20.0)
self.declare_parameter("confidence_threshold", 0.15)
self._p = self._load_params()
# ── State ─────────────────────────────────────────────────────────────
self._kf = KalmanTracker(
process_noise=self._p["process_noise"],
meas_noise_uwb=self._p["meas_noise_uwb"],
meas_noise_cam=self._p["meas_noise_cam"],
)
self._last_uwb_t: float = 0.0 # monotonic; 0 = never received
self._last_cam_t: float = 0.0
self._uwb_x: float = 0.0
self._uwb_y: float = 0.0
self._cam_x: float = 0.0
self._cam_y: float = 0.0
self._uwb_tag_id: str = ""
self._last_predict_t: float = 0.0 # monotonic time of last predict call
self._last_any_t: float = 0.0 # monotonic time of last live measurement
# ── Subscriptions ─────────────────────────────────────────────────────
self.create_subscription(
PoseStamped, "/uwb/target", self._uwb_cb, 10)
self.create_subscription(
PoseStamped, "/person/target", self._cam_cb, 10)
# ── Publisher ─────────────────────────────────────────────────────────
self._pub = self.create_publisher(FusedTarget, "/social/tracking/fused_target", 10)
# ── Timer ─────────────────────────────────────────────────────────────
self._timer = self.create_timer(
1.0 / self._p["control_rate"], self._control_cb)
self.get_logger().info(
f"TrackingFusion ready "
f"rate={self._p['control_rate']}Hz "
f"uwb_timeout={self._p['uwb_timeout']}s "
f"cam_timeout={self._p['cam_timeout']}s "
f"predict_timeout={self._p['predict_timeout']}s "
f"process_noise={self._p['process_noise']}m/s²"
)
# ── Parameter helpers ──────────────────────────────────────────────────────
def _load_params(self) -> dict:
return {
"uwb_timeout": self.get_parameter("uwb_timeout").value,
"cam_timeout": self.get_parameter("cam_timeout").value,
"predict_timeout": self.get_parameter("predict_timeout").value,
"process_noise": self.get_parameter("process_noise").value,
"meas_noise_uwb": self.get_parameter("meas_noise_uwb").value,
"meas_noise_cam": self.get_parameter("meas_noise_cam").value,
"control_rate": self.get_parameter("control_rate").value,
"confidence_threshold": self.get_parameter("confidence_threshold").value,
}
# ── Measurement callbacks ──────────────────────────────────────────────────
def _uwb_cb(self, msg: PoseStamped) -> None:
self._uwb_x = msg.pose.position.x
self._uwb_y = msg.pose.position.y
self._uwb_tag_id = "" # PoseStamped has no tag field; tag reported via /uwb/bearing
t = time.monotonic()
self._last_uwb_t = t
self._last_any_t = t
# Seed KF on first measurement
if not self._kf.initialized:
self._kf.initialize(self._uwb_x, self._uwb_y)
self._last_predict_t = t
def _cam_cb(self, msg: PoseStamped) -> None:
self._cam_x = msg.pose.position.x
self._cam_y = msg.pose.position.y
t = time.monotonic()
self._last_cam_t = t
self._last_any_t = t
if not self._kf.initialized:
self._kf.initialize(self._cam_x, self._cam_y)
self._last_predict_t = t
# ── Control loop ───────────────────────────────────────────────────────────
def _control_cb(self) -> None:
self._p = self._load_params()
if not self._kf.initialized:
return # no data yet — nothing to publish
now = time.monotonic()
dt = now - self._last_predict_t if self._last_predict_t > 0.0 else (
1.0 / self._p["control_rate"]
)
self._last_predict_t = now
# KF predict
self._kf.predict(dt)
# Source confidence
uwb_age = (now - self._last_uwb_t) if self._last_uwb_t > 0.0 else _BIG_AGE
cam_age = (now - self._last_cam_t) if self._last_cam_t > 0.0 else _BIG_AGE
u_conf = uwb_confidence(uwb_age, self._p["uwb_timeout"])
c_conf = camera_confidence(cam_age, self._p["cam_timeout"])
threshold = self._p["confidence_threshold"]
source = select_source(u_conf, c_conf, threshold)
if source == "predicted":
# Check predict_timeout — stop publishing if too stale
last_live_age = (
(now - self._last_any_t) if self._last_any_t > 0.0 else _BIG_AGE
)
if last_live_age > self._p["predict_timeout"]:
return # silently stop publishing
# Apply measurement update if a live source exists
if source == "fused":
fx, fy = fuse_positions(
self._uwb_x, self._uwb_y, u_conf,
self._cam_x, self._cam_y, c_conf,
)
self._kf.update(fx, fy, source="uwb") # use tighter noise for blended
elif source == "uwb":
self._kf.update(self._uwb_x, self._uwb_y, source="uwb")
elif source == "camera":
self._kf.update(self._cam_x, self._cam_y, source="camera")
# "predicted" → no update; KF coasts
# Build and publish message
kx, ky = self._kf.position
vx, vy = self._kf.velocity
kf_unc = self._kf.position_uncertainty_m()
conf = composite_confidence(u_conf, c_conf, source, kf_unc)
bearing, range_m = bearing_and_range(kx, ky)
hdr = Header()
hdr.stamp = self.get_clock().now().to_msg()
hdr.frame_id = "base_link"
msg = FusedTarget()
msg.header = hdr
msg.position.x = kx
msg.position.y = ky
msg.position.z = 0.0
msg.velocity.x = vx
msg.velocity.y = vy
msg.velocity.z = 0.0
msg.range_m = float(range_m)
msg.bearing_rad = float(bearing)
msg.confidence = float(conf)
msg.active_source = source
msg.tag_id = self._uwb_tag_id if "uwb" in source else ""
self._pub.publish(msg)
# ── Entry point ────────────────────────────────────────────────────────────────
def main(args=None):
rclpy.init(args=args)
node = TrackingFusionNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.try_shutdown()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,5 @@
[develop]
script_dir=$base/lib/saltybot_social_tracking
[install]
install_scripts=$base/lib/saltybot_social_tracking

View File

@ -0,0 +1,32 @@
from setuptools import setup, find_packages
import os
from glob import glob
package_name = "saltybot_social_tracking"
setup(
name=package_name,
version="0.1.0",
packages=find_packages(exclude=["test"]),
data_files=[
("share/ament_index/resource_index/packages",
[f"resource/{package_name}"]),
(f"share/{package_name}", ["package.xml"]),
(os.path.join("share", package_name, "config"),
glob("config/*.yaml")),
(os.path.join("share", package_name, "launch"),
glob("launch/*.py")),
],
install_requires=["setuptools"],
zip_safe=True,
maintainer="sl-controls",
maintainer_email="sl-controls@saltylab.local",
description="Multi-modal tracking fusion (UWB + camera Kalman filter)",
license="MIT",
tests_require=["pytest"],
entry_points={
"console_scripts": [
f"tracking_fusion_node = {package_name}.tracking_fusion_node:main",
],
},
)

View File

@ -0,0 +1,438 @@
"""
test_tracking_fusion.py Unit tests for saltybot_social_tracking pure modules.
Tests cover:
- KalmanTracker: initialization, predict, update, state accessors
- source_arbiter: confidence functions, source selection, fusion, bearing
No ROS2 runtime required.
"""
import math
import sys
import os
# Allow running: python -m pytest test/test_tracking_fusion.py
# from the package root without installing the package.
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import pytest
import numpy as np
from saltybot_social_tracking.kalman_tracker import KalmanTracker
from saltybot_social_tracking.source_arbiter import (
uwb_confidence,
camera_confidence,
select_source,
fuse_positions,
composite_confidence,
bearing_and_range,
)
# ─────────────────────────────────────────────────────────────────────────────
# KalmanTracker tests
# ─────────────────────────────────────────────────────────────────────────────
class TestKalmanTrackerInit:
def test_not_initialized_by_default(self):
kf = KalmanTracker()
assert not kf.initialized
def test_initialize_sets_position(self):
kf = KalmanTracker()
kf.initialize(3.0, 1.5)
assert kf.initialized
x, y = kf.position
assert abs(x - 3.0) < 1e-9
assert abs(y - 1.5) < 1e-9
def test_initialize_sets_zero_velocity(self):
kf = KalmanTracker()
kf.initialize(1.0, -2.0)
vx, vy = kf.velocity
assert abs(vx) < 1e-9
assert abs(vy) < 1e-9
def test_initialize_origin(self):
kf = KalmanTracker()
kf.initialize(0.0, 0.0)
assert kf.initialized
x, y = kf.position
assert x == 0.0 and y == 0.0
class TestKalmanTrackerPredict:
def test_predict_zero_dt_no_change(self):
kf = KalmanTracker()
kf.initialize(2.0, 1.0)
kf.predict(0.0)
x, y = kf.position
assert abs(x - 2.0) < 1e-9
assert abs(y - 1.0) < 1e-9
def test_predict_negative_dt_no_change(self):
kf = KalmanTracker()
kf.initialize(2.0, 1.0)
kf.predict(-0.1)
x, y = kf.position
assert abs(x - 2.0) < 1e-9
def test_predict_constant_velocity(self):
"""After a position update gives the filter a velocity, predict should extrapolate."""
kf = KalmanTracker(process_noise=0.001, meas_noise_uwb=0.001)
kf.initialize(0.0, 0.0)
# Force filter to track a moving target to build up velocity estimate
dt = 0.05
for i in range(40):
t = i * dt
kf.predict(dt)
kf.update(2.0 * t, 0.0, "uwb") # 2 m/s in x
# After many updates the velocity estimate should be close to 2 m/s
vx, vy = kf.velocity
assert abs(vx - 2.0) < 0.3, f"vx={vx:.3f}"
assert abs(vy) < 0.2
def test_predict_grows_uncertainty(self):
kf = KalmanTracker()
kf.initialize(0.0, 0.0)
unc_before = kf.position_uncertainty_m()
kf.predict(1.0)
unc_after = kf.position_uncertainty_m()
assert unc_after > unc_before
def test_predict_multiple_steps(self):
kf = KalmanTracker()
kf.initialize(0.0, 0.0)
kf.predict(0.1)
kf.predict(0.1)
kf.predict(0.1)
# No assertion on exact value; just verify no exception and state is finite
x, y = kf.position
assert math.isfinite(x) and math.isfinite(y)
class TestKalmanTrackerUpdate:
def test_update_pulls_position_toward_measurement(self):
kf = KalmanTracker()
kf.initialize(0.0, 0.0)
kf.update(5.0, 5.0, "uwb")
x, y = kf.position
assert x > 0.0 and y > 0.0
assert x < 5.0 and y < 5.0 # blended, not jumped
def test_update_reduces_uncertainty(self):
kf = KalmanTracker()
kf.initialize(0.0, 0.0)
kf.predict(1.0) # uncertainty grows
unc_mid = kf.position_uncertainty_m()
kf.update(0.1, 0.1, "uwb") # measurement corrects
unc_after = kf.position_uncertainty_m()
assert unc_after < unc_mid
def test_update_converges_to_true_position(self):
"""Many updates from same point should converge to that point."""
kf = KalmanTracker(meas_noise_uwb=0.01)
kf.initialize(0.0, 0.0)
for _ in range(50):
kf.update(3.0, -1.0, "uwb")
x, y = kf.position
assert abs(x - 3.0) < 0.05, f"x={x:.4f}"
assert abs(y - (-1.0)) < 0.05, f"y={y:.4f}"
def test_update_camera_source_different_noise(self):
"""Camera and UWB updates should both move state (noise model differs)."""
kf1 = KalmanTracker(meas_noise_uwb=0.20, meas_noise_cam=0.10)
kf1.initialize(0.0, 0.0)
kf1.update(5.0, 0.0, "uwb")
x_uwb, _ = kf1.position
kf2 = KalmanTracker(meas_noise_uwb=0.20, meas_noise_cam=0.10)
kf2.initialize(0.0, 0.0)
kf2.update(5.0, 0.0, "camera")
x_cam, _ = kf2.position
# Camera has lower noise → stronger pull toward measurement
assert x_cam > x_uwb
def test_update_unknown_source_defaults_to_camera_noise(self):
kf = KalmanTracker()
kf.initialize(0.0, 0.0)
kf.update(2.0, 0.0, "other") # unknown source — should not raise
x, _ = kf.position
assert x > 0.0
def test_position_uncertainty_finite(self):
kf = KalmanTracker()
kf.initialize(1.0, 1.0)
kf.predict(0.05)
kf.update(1.1, 0.9, "uwb")
assert math.isfinite(kf.position_uncertainty_m())
assert kf.position_uncertainty_m() >= 0.0
def test_covariance_copy_is_independent(self):
kf = KalmanTracker()
kf.initialize(0.0, 0.0)
cov = kf.covariance_copy()
cov[0, 0] = 9999.0 # mutate copy
assert kf.covariance_copy()[0, 0] != 9999.0
# ─────────────────────────────────────────────────────────────────────────────
# source_arbiter tests
# ─────────────────────────────────────────────────────────────────────────────
class TestUwbConfidence:
def test_zero_age_returns_quality(self):
assert abs(uwb_confidence(0.0, 1.5) - 1.0) < 1e-9
def test_at_timeout_returns_zero(self):
assert uwb_confidence(1.5, 1.5) == pytest.approx(0.0)
def test_beyond_timeout_returns_zero(self):
assert uwb_confidence(2.0, 1.5) == 0.0
def test_half_timeout_returns_half(self):
assert uwb_confidence(0.75, 1.5) == pytest.approx(0.5)
def test_quality_scales_result(self):
assert uwb_confidence(0.0, 1.5, quality=0.7) == pytest.approx(0.7)
def test_large_age_returns_zero(self):
assert uwb_confidence(1e9, 1.5) == 0.0
def test_zero_timeout_returns_zero(self):
assert uwb_confidence(0.0, 0.0) == 0.0
class TestCameraConfidence:
def test_zero_age_full_quality(self):
assert camera_confidence(0.0, 1.0, quality=1.0) == pytest.approx(1.0)
def test_at_timeout_zero(self):
assert camera_confidence(1.0, 1.0) == pytest.approx(0.0)
def test_beyond_timeout_zero(self):
assert camera_confidence(2.0, 1.0) == 0.0
def test_quality_scales(self):
# age=0, quality=0.8
assert camera_confidence(0.0, 1.0, quality=0.8) == pytest.approx(0.8)
def test_halfway(self):
assert camera_confidence(0.5, 1.0) == pytest.approx(0.5)
class TestSelectSource:
def test_both_above_threshold_fused(self):
assert select_source(0.8, 0.6) == "fused"
def test_only_uwb_above_threshold(self):
assert select_source(0.8, 0.0) == "uwb"
def test_only_cam_above_threshold(self):
assert select_source(0.0, 0.7) == "camera"
def test_both_below_threshold(self):
assert select_source(0.0, 0.0) == "predicted"
def test_threshold_boundary_uwb(self):
# Exactly at threshold — should be treated as live
assert select_source(0.15, 0.0, threshold=0.15) == "uwb"
def test_threshold_boundary_below(self):
assert select_source(0.14, 0.0, threshold=0.15) == "predicted"
def test_custom_threshold(self):
assert select_source(0.5, 0.0, threshold=0.6) == "predicted"
assert select_source(0.5, 0.0, threshold=0.4) == "uwb"
class TestFusePositions:
def test_equal_confidence_returns_midpoint(self):
x, y = fuse_positions(0.0, 0.0, 1.0, 4.0, 4.0, 1.0)
assert abs(x - 2.0) < 1e-9
assert abs(y - 2.0) < 1e-9
def test_full_uwb_weight_returns_uwb(self):
x, y = fuse_positions(3.0, 1.0, 1.0, 0.0, 0.0, 0.0)
assert abs(x - 3.0) < 1e-9
def test_full_cam_weight_returns_cam(self):
x, y = fuse_positions(0.0, 0.0, 0.0, -2.0, 5.0, 1.0)
assert abs(x - (-2.0)) < 1e-9
assert abs(y - 5.0) < 1e-9
def test_weighted_blend(self):
# UWB at (0,0) conf=3, camera at (4,0) conf=1 → x = 3/4*0 + 1/4*4 = 1
x, y = fuse_positions(0.0, 0.0, 3.0, 4.0, 0.0, 1.0)
assert abs(x - 1.0) < 1e-9
def test_zero_total_returns_uwb_fallback(self):
x, y = fuse_positions(7.0, 2.0, 0.0, 3.0, 1.0, 0.0)
assert abs(x - 7.0) < 1e-9
class TestCompositeConfidence:
def test_fused_source_high_confidence(self):
conf = composite_confidence(0.9, 0.8, "fused", 0.05)
assert conf > 0.7
def test_predicted_source_capped(self):
conf = composite_confidence(0.0, 0.0, "predicted", 0.1)
assert conf <= 0.4
def test_predicted_high_uncertainty_low_confidence(self):
conf = composite_confidence(0.0, 0.0, "predicted", 3.0, max_kf_uncertainty_m=3.0)
assert conf == pytest.approx(0.0)
def test_uwb_only(self):
conf = composite_confidence(0.8, 0.0, "uwb", 0.05)
assert conf > 0.3
def test_camera_only(self):
conf = composite_confidence(0.0, 0.7, "camera", 0.05)
assert conf > 0.2
def test_high_kf_uncertainty_reduces_confidence(self):
low_unc = composite_confidence(0.9, 0.0, "uwb", 0.1)
high_unc = composite_confidence(0.9, 0.0, "uwb", 2.9)
assert low_unc > high_unc
class TestBearingAndRange:
def test_straight_ahead(self):
bearing, rng = bearing_and_range(2.0, 0.0)
assert abs(bearing) < 1e-9
assert abs(rng - 2.0) < 1e-9
def test_left_of_robot(self):
# +Y = left in base_link frame; bearing should be positive
bearing, rng = bearing_and_range(0.0, 1.0)
assert abs(bearing - math.pi / 2.0) < 1e-9
assert abs(rng - 1.0) < 1e-9
def test_right_of_robot(self):
bearing, rng = bearing_and_range(0.0, -1.0)
assert abs(bearing - (-math.pi / 2.0)) < 1e-9
def test_diagonal(self):
bearing, rng = bearing_and_range(1.0, 1.0)
assert abs(bearing - math.pi / 4.0) < 1e-9
assert abs(rng - math.sqrt(2.0)) < 1e-9
def test_at_origin(self):
bearing, rng = bearing_and_range(0.0, 0.0)
assert rng == pytest.approx(0.0)
assert math.isfinite(bearing) # atan2(0,0) = 0 in most implementations
def test_range_always_non_negative(self):
for x, y in [(-1, 0), (0, -1), (-2, -3), (5, -5)]:
_, rng = bearing_and_range(x, y)
assert rng >= 0.0
# ─────────────────────────────────────────────────────────────────────────────
# Integration scenario tests
# ─────────────────────────────────────────────────────────────────────────────
class TestIntegrationScenarios:
def test_euc_speed_velocity_tracking(self):
"""Verify KF can track EUC speed (8 m/s) within 0.5 m/s after warm-up."""
kf = KalmanTracker(process_noise=3.0, meas_noise_uwb=0.20)
kf.initialize(0.0, 0.0)
dt = 1.0 / 10.0 # 10 Hz UWB rate
speed = 8.0 # m/s (≈29 km/h)
for i in range(60):
t = i * dt
kf.predict(dt)
kf.update(speed * t, 0.0, "uwb")
vx, vy = kf.velocity
assert abs(vx - speed) < 0.6, f"vx={vx:.2f} expected≈{speed}"
assert abs(vy) < 0.3
def test_signal_loss_recovery(self):
"""
After 1 s of signal loss the filter should still have a reasonable
position estimate (not diverged to infinity).
"""
kf = KalmanTracker(process_noise=3.0)
kf.initialize(2.0, 0.5)
# Warm up with 2 m/s x motion
dt = 0.05
for i in range(20):
kf.predict(dt)
kf.update(2.0 * (i + 1) * dt, 0.0, "uwb")
# Coast for 1 second (20 × 50 ms) without measurements
for _ in range(20):
kf.predict(dt)
x, y = kf.position
assert math.isfinite(x) and math.isfinite(y)
assert abs(x) < 20.0 # shouldn't have drifted more than 20 m
def test_uwb_to_camera_handoff(self):
"""
Simulate UWB going stale and camera taking over Kalman should
smoothly continue tracking without a jump.
"""
kf = KalmanTracker(meas_noise_uwb=0.20, meas_noise_cam=0.12)
kf.initialize(0.0, 0.0)
dt = 0.05
# Phase 1: UWB active
for i in range(20):
kf.predict(dt)
kf.update(float(i) * 0.1, 0.0, "uwb")
x_at_handoff, _ = kf.position
# Phase 2: Camera takes over from same trajectory
for i in range(20, 40):
kf.predict(dt)
kf.update(float(i) * 0.1, 0.0, "camera")
x_after, _ = kf.position
# Position should have continued progressing (not stuck or reset)
assert x_after > x_at_handoff
def test_confidence_degradation_during_coast(self):
"""Composite confidence should drop as KF uncertainty grows during coast."""
kf = KalmanTracker(process_noise=3.0)
kf.initialize(2.0, 0.0)
# Fresh: tight uncertainty → high confidence
unc_fresh = kf.position_uncertainty_m()
conf_fresh = composite_confidence(0.0, 0.0, "predicted", unc_fresh)
# After 2 s coast
for _ in range(40):
kf.predict(0.05)
unc_stale = kf.position_uncertainty_m()
conf_stale = composite_confidence(0.0, 0.0, "predicted", unc_stale)
assert conf_fresh >= conf_stale
def test_fused_source_confidence_weighted_position(self):
"""Fused position should sit between UWB and camera, closer to higher-conf source."""
# UWB at x=0 with high conf, camera at x=10 with low conf
uwb_c = 0.9
cam_c = 0.1
fx, fy = fuse_positions(0.0, 0.0, uwb_c, 10.0, 0.0, cam_c)
# Should be much closer to UWB (0) than camera (10)
assert fx < 3.0, f"fused_x={fx:.2f}"
def test_select_source_transitions(self):
"""Verify correct source transitions as confidences change."""
assert select_source(0.9, 0.8) == "fused"
assert select_source(0.9, 0.0) == "uwb"
assert select_source(0.0, 0.8) == "camera"
assert select_source(0.0, 0.0) == "predicted"

View File

@ -1,57 +1,24 @@
# uwb_config.yaml — MaUWB ESP32-S3 DW3000 UWB follow-me system
#
# Hardware layout:
# Anchor-0 (port side) → USB port_a, y = +anchor_separation/2
# Anchor-1 (starboard side) → USB port_b, y = -anchor_separation/2
# Tag on person → belt clip, ~0.9m above ground
# uwb_config.yaml — MaUWB ESP32-S3 DW3000 UWB integration (Issue #90)
#
# Run with:
# ros2 launch saltybot_uwb uwb.launch.py
# Override at launch:
# ros2 launch saltybot_uwb uwb.launch.py port_a:=/dev/ttyUSB2
# ── Serial ports ──────────────────────────────────────────────────────────────
# Set udev rules to get stable symlinks:
# /dev/uwb-anchor0 → port_a
# /dev/uwb-anchor1 → port_b
# (See jetson/docs/pinout.md for udev setup)
port_a: /dev/uwb-anchor0 # Anchor-0 (port)
port_b: /dev/uwb-anchor1 # Anchor-1 (starboard)
baudrate: 115200 # MaUWB default — do not change
port_a: /dev/uwb-anchor0
port_b: /dev/uwb-anchor1
baudrate: 115200
# ── Anchor geometry ────────────────────────────────────────────────────────────
# anchor_separation: centre-to-centre distance between anchors (metres)
# Must match physical mounting. Larger = more accurate lateral resolution.
anchor_separation: 0.25 # metres (25cm)
anchor_separation: 0.25
anchor_height: 0.80
tag_height: 0.90
# anchor_height: height of anchors above ground (metres)
# Orin stem mount ≈ 0.80m on the saltybot platform
anchor_height: 0.80 # metres
range_timeout_s: 1.0
max_range_m: 8.0
min_range_m: 0.05
# tag_height: height of person's belt-clip tag above ground (metres)
tag_height: 0.90 # metres (adjust per user)
# ── Range validity ─────────────────────────────────────────────────────────────
# range_timeout_s: stale anchor — excluded from triangulation after this gap
range_timeout_s: 1.0 # seconds
# max_range_m: discard ranges beyond this (DW3000 indoor practical limit ≈8m)
max_range_m: 8.0 # metres
# min_range_m: discard ranges below this (likely multipath artefacts)
min_range_m: 0.05 # metres
# ── Kalman filter ──────────────────────────────────────────────────────────────
# kf_process_noise: Q scalar — how dynamic the person's motion is
# Higher → faster response, more jitter
kf_process_noise: 0.1
kf_meas_noise: 0.3
# kf_meas_noise: R scalar — how noisy the UWB ranging is
# DW3000 indoor accuracy ≈ 10cm RMS → 0.1m → R ≈ 0.01
# Use 0.3 to be conservative on first deployment
kf_meas_noise: 0.3
range_rate: 100.0
bearing_rate: 10.0
# ── Publish rate ──────────────────────────────────────────────────────────────
# Should match or exceed the AT+RANGE? poll rate from both anchors.
# 20Hz means 50ms per cycle; each anchor query takes ~10ms → headroom ok.
publish_rate: 20.0 # Hz
enrolled_tag_ids: [""]

View File

@ -1,10 +1,14 @@
"""
uwb.launch.py Launch UWB driver node for MaUWB ESP32-S3 follow-me.
uwb.launch.py Launch UWB driver node for MaUWB ESP32-S3 DW3000 (Issue #90).
Topics:
/uwb/ranges 100 Hz raw TWR ranges
/uwb/bearing 10 Hz Kalman-fused bearing
/uwb/target 10 Hz triangulated PoseStamped (backwards compat)
Usage:
ros2 launch saltybot_uwb uwb.launch.py
ros2 launch saltybot_uwb uwb.launch.py port_a:=/dev/ttyUSB2 port_b:=/dev/ttyUSB3
ros2 launch saltybot_uwb uwb.launch.py anchor_separation:=0.30 publish_rate:=10.0
ros2 launch saltybot_uwb uwb.launch.py enrolled_tag_ids:="['0xDEADBEEF']"
"""
import os
@ -31,7 +35,9 @@ def generate_launch_description():
DeclareLaunchArgument("min_range_m", default_value="0.05"),
DeclareLaunchArgument("kf_process_noise", default_value="0.1"),
DeclareLaunchArgument("kf_meas_noise", default_value="0.3"),
DeclareLaunchArgument("publish_rate", default_value="20.0"),
DeclareLaunchArgument("range_rate", default_value="100.0"),
DeclareLaunchArgument("bearing_rate", default_value="10.0"),
DeclareLaunchArgument("enrolled_tag_ids", default_value="['']"),
Node(
package="saltybot_uwb",
@ -52,7 +58,9 @@ def generate_launch_description():
"min_range_m": LaunchConfiguration("min_range_m"),
"kf_process_noise": LaunchConfiguration("kf_process_noise"),
"kf_meas_noise": LaunchConfiguration("kf_meas_noise"),
"publish_rate": LaunchConfiguration("publish_rate"),
"range_rate": LaunchConfiguration("range_rate"),
"bearing_rate": LaunchConfiguration("bearing_rate"),
"enrolled_tag_ids": LaunchConfiguration("enrolled_tag_ids"),
},
],
),

View File

@ -29,6 +29,26 @@ Returns (x_t, y_t); caller should treat negative x_t as 0.
import math
# ── Bearing helper ────────────────────────────────────────────────────────────
def bearing_from_pos(x: float, y: float) -> float:
"""
Compute horizontal bearing to a point (x, y) in base_link.
Parameters
----------
x : forward distance (metres)
y : lateral offset (metres, positive = left of robot / CCW)
Returns
-------
bearing_rad : bearing in radians, range -π .. +π
positive = target to the left (CCW)
0 = directly ahead
"""
return math.atan2(y, x)
# ── Triangulation ─────────────────────────────────────────────────────────────
def triangulate_2anchor(

View File

@ -1,32 +1,52 @@
"""
uwb_driver_node.py ROS2 node for MaUWB ESP32-S3 DW3000 follow-me system.
uwb_driver_node.py ROS2 node for MaUWB ESP32-S3 DW3000 UWB integration.
Hardware
2× MaUWB ESP32-S3 DW3000 anchors on robot stem (USB Orin Nano)
2× MaUWB ESP32-S3 DW3000 anchors on robot stem (USB Orin)
- Anchor-0: port side (y = +sep/2)
- Anchor-1: starboard (y = -sep/2)
1× MaUWB tag on person (belt clip)
1× MaUWB tag per enrolled person (belt clip)
AT command interface (115200 8N1)
Query: AT+RANGE?\r\n
Response (from anchors):
+RANGE:<anchor_id>,<range_mm>[,<rssi>]\r\n
Query:
AT+RANGE?\r\n
Config:
AT+anchor_tag=ANCHOR\r\n set module as anchor
AT+anchor_tag=TAG\r\n set module as tag
Response (from anchors, TWR protocol):
+RANGE:<anchor_id>,<range_mm>[,<rssi>[,<tag_addr>]]\r\n
Tag pairing (optional targets a specific enrolled tag):
AT+RANGE_ADDR=<tag_addr>\r\n anchor only ranges with that tag
Publishes
/uwb/target (geometry_msgs/PoseStamped) triangulated person position in base_link
/uwb/ranges (saltybot_uwb_msgs/UwbRangeArray) raw ranges from both anchors
/uwb/ranges (saltybot_uwb_msgs/UwbRangeArray) 100 Hz raw anchor ranges
/uwb/bearing (saltybot_uwb_msgs/UwbBearing) 10 Hz Kalman-fused bearing
/uwb/target (geometry_msgs/PoseStamped) 10 Hz triangulated position
(kept for backwards compat)
Safety
If a range is stale (> range_timeout_s), that anchor is excluded from
triangulation. If both anchors are stale, /uwb/target is not published.
Tag pairing
Set enrolled_tag_ids to a list of tag address strings (e.g. ["0x1234ABCD"]).
When non-empty, ranges from unrecognised tags are silently discarded.
The matched tag address is stamped in UwbRange.tag_id and UwbBearing.tag_id.
When enrolled_tag_ids is empty, all ranges are accepted (tag_id = "").
Parameters
port_a, port_b serial ports for anchor-0 / anchor-1
baudrate 115200 (default)
anchor_separation centre-to-centre anchor spacing (m)
anchor_height anchor mounting height (m)
tag_height person tag height (m)
range_timeout_s stale-anchor threshold (s)
max_range_m / min_range_m validity window (m)
kf_process_noise Kalman Q scalar
kf_meas_noise Kalman R scalar
range_rate Hz /uwb/ranges publish rate (default 100)
bearing_rate Hz /uwb/bearing publish rate (default 10)
enrolled_tag_ids list[str] accepted tag addresses; [] = accept all
Usage
@ -44,8 +64,8 @@ from rclpy.node import Node
from geometry_msgs.msg import PoseStamped
from std_msgs.msg import Header
from saltybot_uwb_msgs.msg import UwbRange, UwbRangeArray
from saltybot_uwb.ranging_math import triangulate_2anchor, KalmanFilter2D
from saltybot_uwb_msgs.msg import UwbRange, UwbRangeArray, UwbBearing
from saltybot_uwb.ranging_math import triangulate_2anchor, KalmanFilter2D, bearing_from_pos
try:
import serial
@ -54,26 +74,31 @@ except ImportError:
_SERIAL_AVAILABLE = False
# Regex: +RANGE:<id>,<mm> or +RANGE:<id>,<mm>,<rssi>
# +RANGE:<id>,<mm> or +RANGE:<id>,<mm>,<rssi> or +RANGE:<id>,<mm>,<rssi>,<tag>
_RANGE_RE = re.compile(
r"\+RANGE\s*:\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(-?[\d.]+))?",
r"\+RANGE\s*:\s*(\d+)\s*,\s*(\d+)"
r"(?:\s*,\s*(-?[\d.]+)"
r"(?:\s*,\s*([\w:x]+))?"
r")?",
re.IGNORECASE,
)
class SerialReader(threading.Thread):
"""
Background thread: polls an anchor's UART, fires callback on every
valid +RANGE response.
Background thread: polls one anchor's UART at maximum TWR rate,
fires callback on every valid +RANGE response.
Supports optional tag pairing via AT+RANGE_ADDR command.
"""
def __init__(self, port, baudrate, anchor_id, callback, logger):
def __init__(self, port, baudrate, anchor_id, callback, logger, tag_addr=None):
super().__init__(daemon=True)
self._port = port
self._baudrate = baudrate
self._anchor_id = anchor_id
self._callback = callback
self._logger = logger
self._tag_addr = tag_addr
self._running = False
self._ser = None
@ -86,7 +111,10 @@ class SerialReader(threading.Thread):
)
self._logger.info(
f"Anchor-{self._anchor_id}: opened {self._port}"
+ (f" paired with tag {self._tag_addr}" if self._tag_addr else "")
)
if self._tag_addr:
self._send_pairing_cmd()
self._read_loop()
except Exception as exc:
self._logger.warn(
@ -96,12 +124,24 @@ class SerialReader(threading.Thread):
self._ser.close()
time.sleep(2.0)
def _send_pairing_cmd(self):
"""Configure the anchor to range only with the paired tag."""
try:
cmd = f"AT+RANGE_ADDR={self._tag_addr}\r\n".encode("ascii")
self._ser.write(cmd)
time.sleep(0.1)
self._logger.info(
f"Anchor-{self._anchor_id}: sent tag pairing {self._tag_addr}"
)
except Exception as exc:
self._logger.warn(
f"Anchor-{self._anchor_id}: pairing cmd failed: {exc}"
)
def _read_loop(self):
while self._running:
try:
# Query the anchor
self._ser.write(b"AT+RANGE?\r\n")
# Read up to 10 lines waiting for a +RANGE response
for _ in range(10):
raw = self._ser.readline()
if not raw:
@ -111,13 +151,14 @@ class SerialReader(threading.Thread):
if m:
range_mm = int(m.group(2))
rssi = float(m.group(3)) if m.group(3) else 0.0
self._callback(self._anchor_id, range_mm, rssi)
tag_addr = m.group(4) if m.group(4) else ""
self._callback(self._anchor_id, range_mm, rssi, tag_addr)
break
except Exception as exc:
self._logger.warn(
f"Anchor-{self._anchor_id} read error: {exc}"
)
break # trigger reconnect
break
def stop(self):
self._running = False
@ -130,48 +171,51 @@ class UwbDriverNode(Node):
def __init__(self):
super().__init__("uwb_driver")
# ── Parameters ────────────────────────────────────────────────────────
self.declare_parameter("port_a", "/dev/ttyUSB0")
self.declare_parameter("port_b", "/dev/ttyUSB1")
self.declare_parameter("baudrate", 115200)
self.declare_parameter("anchor_separation", 0.25)
self.declare_parameter("anchor_height", 0.80)
self.declare_parameter("tag_height", 0.90)
self.declare_parameter("range_timeout_s", 1.0)
self.declare_parameter("max_range_m", 8.0)
self.declare_parameter("min_range_m", 0.05)
self.declare_parameter("kf_process_noise", 0.1)
self.declare_parameter("kf_meas_noise", 0.3)
self.declare_parameter("publish_rate", 20.0)
self.declare_parameter("port_a", "/dev/uwb-anchor0")
self.declare_parameter("port_b", "/dev/uwb-anchor1")
self.declare_parameter("baudrate", 115200)
self.declare_parameter("anchor_separation", 0.25)
self.declare_parameter("anchor_height", 0.80)
self.declare_parameter("tag_height", 0.90)
self.declare_parameter("range_timeout_s", 1.0)
self.declare_parameter("max_range_m", 8.0)
self.declare_parameter("min_range_m", 0.05)
self.declare_parameter("kf_process_noise", 0.1)
self.declare_parameter("kf_meas_noise", 0.3)
self.declare_parameter("range_rate", 100.0)
self.declare_parameter("bearing_rate", 10.0)
self.declare_parameter("enrolled_tag_ids", [""])
self._p = self._load_params()
# ── State (protected by lock) ──────────────────────────────────────
self._lock = threading.Lock()
self._ranges = {} # anchor_id → (range_m, rssi, timestamp)
self._kf = KalmanFilter2D(
raw_ids = self.get_parameter("enrolled_tag_ids").value
self._enrolled_tags = [t.strip() for t in raw_ids if t.strip()]
paired_tag = self._enrolled_tags[0] if self._enrolled_tags else None
self._lock = threading.Lock()
self._ranges: dict = {}
self._kf = KalmanFilter2D(
process_noise=self._p["kf_process_noise"],
measurement_noise=self._p["kf_meas_noise"],
dt=1.0 / self._p["publish_rate"],
dt=1.0 / self._p["bearing_rate"],
)
# ── Publishers ────────────────────────────────────────────────────
self._target_pub = self.create_publisher(
PoseStamped, "/uwb/target", 10)
self._ranges_pub = self.create_publisher(
UwbRangeArray, "/uwb/ranges", 10)
self._ranges_pub = self.create_publisher(UwbRangeArray, "/uwb/ranges", 10)
self._bearing_pub = self.create_publisher(UwbBearing, "/uwb/bearing", 10)
self._target_pub = self.create_publisher(PoseStamped, "/uwb/target", 10)
# ── Serial readers ────────────────────────────────────────────────
if _SERIAL_AVAILABLE:
self._reader_a = SerialReader(
self._p["port_a"], self._p["baudrate"],
anchor_id=0, callback=self._range_cb,
logger=self.get_logger(),
tag_addr=paired_tag,
)
self._reader_b = SerialReader(
self._p["port_b"], self._p["baudrate"],
anchor_id=1, callback=self._range_cb,
logger=self.get_logger(),
tag_addr=paired_tag,
)
self._reader_a.start()
self._reader_b.start()
@ -180,19 +224,21 @@ class UwbDriverNode(Node):
"pyserial not installed — running in simulation mode (no serial I/O)"
)
# ── Publish timer ─────────────────────────────────────────────────
self._timer = self.create_timer(
1.0 / self._p["publish_rate"], self._publish_cb
self._range_timer = self.create_timer(
1.0 / self._p["range_rate"], self._range_publish_cb
)
self._bearing_timer = self.create_timer(
1.0 / self._p["bearing_rate"], self._bearing_publish_cb
)
self.get_logger().info(
f"UWB driver ready sep={self._p['anchor_separation']}m "
f"ports={self._p['port_a']},{self._p['port_b']} "
f"rate={self._p['publish_rate']}Hz"
f"range={self._p['range_rate']}Hz "
f"bearing={self._p['bearing_rate']}Hz "
f"enrolled_tags={self._enrolled_tags or ['<any>']}"
)
# ── Helpers ───────────────────────────────────────────────────────────────
def _load_params(self):
return {
"port_a": self.get_parameter("port_a").value,
@ -206,90 +252,106 @@ class UwbDriverNode(Node):
"min_range_m": self.get_parameter("min_range_m").value,
"kf_process_noise": self.get_parameter("kf_process_noise").value,
"kf_meas_noise": self.get_parameter("kf_meas_noise").value,
"publish_rate": self.get_parameter("publish_rate").value,
"range_rate": self.get_parameter("range_rate").value,
"bearing_rate": self.get_parameter("bearing_rate").value,
}
# ── Callbacks ─────────────────────────────────────────────────────────────
def _is_enrolled(self, tag_addr: str) -> bool:
if not self._enrolled_tags:
return True
return tag_addr in self._enrolled_tags
def _range_cb(self, anchor_id: int, range_mm: int, rssi: float):
"""Called from serial reader threads — thread-safe update."""
def _range_cb(self, anchor_id: int, range_mm: int, rssi: float, tag_addr: str):
if not self._is_enrolled(tag_addr):
return
range_m = range_mm / 1000.0
p = self._p
if range_m < p["min_range_m"] or range_m > p["max_range_m"]:
return
with self._lock:
self._ranges[anchor_id] = (range_m, rssi, time.monotonic())
self._ranges[anchor_id] = (range_m, rssi, tag_addr, time.monotonic())
def _publish_cb(self):
now = time.monotonic()
def _range_publish_cb(self):
"""100 Hz: publish current raw ranges as UwbRangeArray."""
now = time.monotonic()
timeout = self._p["range_timeout_s"]
sep = self._p["anchor_separation"]
with self._lock:
# Collect valid (non-stale) ranges
valid = {}
for aid, (r, rssi, t) in self._ranges.items():
if now - t <= timeout:
valid[aid] = (r, rssi, t)
# Build and publish UwbRangeArray regardless (even if partial)
valid = {
aid: entry
for aid, entry in self._ranges.items()
if (now - entry[3]) <= timeout
}
hdr = Header()
hdr.stamp = self.get_clock().now().to_msg()
hdr.frame_id = "base_link"
arr = UwbRangeArray()
arr.header = hdr
for aid, (r, rssi, _) in valid.items():
for aid, (r, rssi, tag_id, _) in valid.items():
entry = UwbRange()
entry.header = hdr
entry.anchor_id = aid
entry.range_m = float(r)
entry.raw_mm = int(round(r * 1000.0))
entry.rssi = float(rssi)
entry.tag_id = tag_id
arr.ranges.append(entry)
self._ranges_pub.publish(arr)
# Need both anchors to triangulate
if 0 not in valid or 1 not in valid:
def _bearing_publish_cb(self):
"""10 Hz: Kalman predict+update, publish fused bearing."""
now = time.monotonic()
timeout = self._p["range_timeout_s"]
sep = self._p["anchor_separation"]
with self._lock:
valid = {
aid: entry
for aid, entry in self._ranges.items()
if (now - entry[3]) <= timeout
}
if not valid:
return
r0 = valid[0][0]
r1 = valid[1][0]
try:
x_t, y_t = triangulate_2anchor(
r0=r0,
r1=r1,
sep=sep,
anchor_z=self._p["anchor_height"],
tag_z=self._p["tag_height"],
)
except (ValueError, ZeroDivisionError) as exc:
self.get_logger().warn(f"Triangulation error: {exc}")
return
# Kalman filter update
dt = 1.0 / self._p["publish_rate"]
both_fresh = 0 in valid and 1 in valid
confidence = 1.0 if both_fresh else 0.5
active_tag = valid[next(iter(valid))][2]
dt = 1.0 / self._p["bearing_rate"]
self._kf.predict(dt=dt)
self._kf.update(x_t, y_t)
if both_fresh:
r0 = valid[0][0]
r1 = valid[1][0]
try:
x_t, y_t = triangulate_2anchor(
r0=r0, r1=r1, sep=sep,
anchor_z=self._p["anchor_height"],
tag_z=self._p["tag_height"],
)
except (ValueError, ZeroDivisionError) as exc:
self.get_logger().warn(f"Triangulation error: {exc}")
return
self._kf.update(x_t, y_t)
kx, ky = self._kf.position()
# Publish PoseStamped in base_link
bearing = bearing_from_pos(kx, ky)
range_m = math.sqrt(kx * kx + ky * ky)
hdr = Header()
hdr.stamp = self.get_clock().now().to_msg()
hdr.frame_id = "base_link"
brg_msg = UwbBearing()
brg_msg.header = hdr
brg_msg.bearing_rad = float(bearing)
brg_msg.range_m = float(range_m)
brg_msg.confidence = float(confidence)
brg_msg.tag_id = active_tag
self._bearing_pub.publish(brg_msg)
pose = PoseStamped()
pose.header = hdr
pose.pose.position.x = kx
pose.pose.position.y = ky
pose.pose.position.z = 0.0
# Orientation: face the person (yaw = atan2(y, x))
yaw = math.atan2(ky, kx)
yaw = bearing
pose.pose.orientation.z = math.sin(yaw / 2.0)
pose.pose.orientation.w = math.cos(yaw / 2.0)
self._target_pub.publish(pose)
# ── Entry point ───────────────────────────────────────────────────────────────
def main(args=None):
rclpy.init(args=args)
node = UwbDriverNode()

View File

@ -7,7 +7,7 @@ No ROS2 / serial / GPU dependencies — runs with plain pytest.
import math
import pytest
from saltybot_uwb.ranging_math import triangulate_2anchor, KalmanFilter2D
from saltybot_uwb.ranging_math import triangulate_2anchor, KalmanFilter2D, bearing_from_pos
# ── triangulate_2anchor ───────────────────────────────────────────────────────
@ -172,3 +172,47 @@ class TestKalmanFilter2D:
x, y = kf.position()
assert not math.isnan(x)
assert not math.isnan(y)
# ── bearing_from_pos ──────────────────────────────────────────────────────────
class TestBearingFromPos:
def test_directly_ahead_zero_bearing(self):
"""Person directly ahead: x=2, y=0 → bearing ≈ 0."""
b = bearing_from_pos(2.0, 0.0)
assert abs(b) < 0.001
def test_left_gives_positive_bearing(self):
"""Person to the left (y>0): bearing should be positive."""
b = bearing_from_pos(1.0, 1.0)
assert b > 0.0
assert abs(b - math.pi / 4.0) < 0.001
def test_right_gives_negative_bearing(self):
"""Person to the right (y<0): bearing should be negative."""
b = bearing_from_pos(1.0, -1.0)
assert b < 0.0
assert abs(b + math.pi / 4.0) < 0.001
def test_directly_left_ninety_degrees(self):
"""Person directly to the left: x=0, y=1 → bearing = π/2."""
b = bearing_from_pos(0.0, 1.0)
assert abs(b - math.pi / 2.0) < 0.001
def test_directly_right_minus_ninety_degrees(self):
"""Person directly to the right: x=0, y=-1 → bearing = -π/2."""
b = bearing_from_pos(0.0, -1.0)
assert abs(b + math.pi / 2.0) < 0.001
def test_range_pi_to_minus_pi(self):
"""Bearing is always in -π..+π."""
for x in [-2.0, -0.1, 0.1, 2.0]:
for y in [-2.0, -0.1, 0.1, 2.0]:
b = bearing_from_pos(x, y)
assert -math.pi <= b <= math.pi
def test_no_nan_for_tiny_distance(self):
"""Very close target should not produce NaN."""
b = bearing_from_pos(0.001, 0.001)
assert not math.isnan(b)

View File

@ -8,6 +8,7 @@ find_package(rosidl_default_generators REQUIRED)
rosidl_generate_interfaces(${PROJECT_NAME}
"msg/UwbRange.msg"
"msg/UwbRangeArray.msg"
"msg/UwbBearing.msg"
DEPENDENCIES std_msgs
)

View File

@ -0,0 +1,18 @@
# UwbBearing.msg — 10 Hz Kalman-fused bearing estimate from dual-anchor UWB.
#
# bearing_rad : horizontal bearing to tag in base_link frame (radians)
# positive = tag to the left (CCW), negative = right (CW)
# range: -π .. +π
# range_m : Kalman-filtered horizontal range to tag (metres)
# confidence : quality indicator
# 1.0 = both anchors fresh
# 0.5 = single anchor only (bearing less reliable)
# 0.0 = stale / no data (message not published in this state)
# tag_id : enrolled tag identifier that produced this estimate
std_msgs/Header header
float32 bearing_rad
float32 range_m
float32 confidence
string tag_id

View File

@ -4,6 +4,7 @@
# range_m : measured horizontal range in metres (after height correction)
# raw_mm : raw TWR range from AT+RANGE? response, millimetres
# rssi : received signal strength (dBm), 0 if not reported by module
# tag_id : enrolled tag identifier; empty string if tag pairing is disabled
std_msgs/Header header
@ -11,3 +12,4 @@ uint8 anchor_id
float32 range_m
uint32 raw_mm
float32 rssi
string tag_id