Compare commits
7 Commits
4744d4ee93
...
033979aa47
| Author | SHA1 | Date | |
|---|---|---|---|
| 033979aa47 | |||
| d9c983f666 | |||
| 54e9274405 | |||
| b432492785 | |||
| 9a68dfdb2e | |||
| d872ea5e34 | |||
| 84790412d6 |
@ -0,0 +1,8 @@
|
||||
person_state_tracker:
|
||||
ros__parameters:
|
||||
engagement_distance: 2.0
|
||||
absent_timeout: 5.0
|
||||
prune_timeout: 30.0
|
||||
update_rate: 10.0
|
||||
n_cameras: 4
|
||||
uwb_enabled: false
|
||||
35
jetson/ros2_ws/src/saltybot_social/launch/social.launch.py
Normal file
35
jetson/ros2_ws/src/saltybot_social/launch/social.launch.py
Normal file
@ -0,0 +1,35 @@
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument(
|
||||
'engagement_distance',
|
||||
default_value='2.0',
|
||||
description='Distance in metres below which a person is considered engaged'
|
||||
),
|
||||
DeclareLaunchArgument(
|
||||
'absent_timeout',
|
||||
default_value='5.0',
|
||||
description='Seconds without detection before marking person as ABSENT'
|
||||
),
|
||||
DeclareLaunchArgument(
|
||||
'uwb_enabled',
|
||||
default_value='false',
|
||||
description='Whether UWB anchor data is available'
|
||||
),
|
||||
Node(
|
||||
package='saltybot_social',
|
||||
executable='person_state_tracker',
|
||||
name='person_state_tracker',
|
||||
output='screen',
|
||||
parameters=[{
|
||||
'engagement_distance': LaunchConfiguration('engagement_distance'),
|
||||
'absent_timeout': LaunchConfiguration('absent_timeout'),
|
||||
'uwb_enabled': LaunchConfiguration('uwb_enabled'),
|
||||
}],
|
||||
),
|
||||
])
|
||||
25
jetson/ros2_ws/src/saltybot_social/package.xml
Normal file
25
jetson/ros2_ws/src/saltybot_social/package.xml
Normal file
@ -0,0 +1,25 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_social</name>
|
||||
<version>0.1.0</version>
|
||||
<description>Multi-modal person identity fusion and state tracking for saltybot</description>
|
||||
<maintainer email="seb@vayrette.com">seb</maintainer>
|
||||
<license>MIT</license>
|
||||
<depend>rclpy</depend>
|
||||
<depend>std_msgs</depend>
|
||||
<depend>geometry_msgs</depend>
|
||||
<depend>sensor_msgs</depend>
|
||||
<depend>vision_msgs</depend>
|
||||
<depend>saltybot_social_msgs</depend>
|
||||
<depend>tf2_ros</depend>
|
||||
<depend>tf2_geometry_msgs</depend>
|
||||
<depend>cv_bridge</depend>
|
||||
<test_depend>ament_copyright</test_depend>
|
||||
<test_depend>ament_flake8</test_depend>
|
||||
<test_depend>ament_pep257</test_depend>
|
||||
<test_depend>python3-pytest</test_depend>
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1,201 @@
|
||||
import math
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from saltybot_social.person_state import PersonTrack, State
|
||||
|
||||
|
||||
class IdentityFuser:
|
||||
"""Multi-modal identity fusion: face_id + speaker_id + UWB anchor -> unified person_id."""
|
||||
|
||||
def __init__(self, position_window: float = 3.0):
|
||||
self._tracks: dict[int, PersonTrack] = {}
|
||||
self._next_id: int = 1
|
||||
self._position_window = position_window
|
||||
self._max_history = 20
|
||||
|
||||
def _distance_3d(self, a: Optional[tuple], b: Optional[tuple]) -> float:
|
||||
if a is None or b is None:
|
||||
return float('inf')
|
||||
return math.sqrt(sum((ai - bi) ** 2 for ai, bi in zip(a, b)))
|
||||
|
||||
def _compute_distance_and_bearing(self, position: tuple) -> Tuple[float, float]:
|
||||
x, y, z = position
|
||||
distance = math.sqrt(x * x + y * y + z * z)
|
||||
bearing = math.degrees(math.atan2(y, x))
|
||||
return distance, bearing
|
||||
|
||||
def _allocate_id(self) -> int:
|
||||
pid = self._next_id
|
||||
self._next_id += 1
|
||||
return pid
|
||||
|
||||
def update_face(self, face_id: int, name: str,
|
||||
position_3d: Optional[tuple], camera_id: int,
|
||||
now: float) -> int:
|
||||
"""Match by face_id first. If face_id >= 0, find existing track or create new."""
|
||||
if face_id >= 0:
|
||||
for track in self._tracks.values():
|
||||
if track.face_id == face_id:
|
||||
track.name = name if name and name != 'unknown' else track.name
|
||||
if position_3d is not None:
|
||||
track.position = position_3d
|
||||
dist, bearing = self._compute_distance_and_bearing(position_3d)
|
||||
track.distance = dist
|
||||
track.bearing_deg = bearing
|
||||
track.history_distances.append(dist)
|
||||
if len(track.history_distances) > self._max_history:
|
||||
track.history_distances = track.history_distances[-self._max_history:]
|
||||
track.camera_id = camera_id
|
||||
track.last_seen = now
|
||||
track.last_face_seen = now
|
||||
return track.person_id
|
||||
|
||||
# No existing track with this face_id; try proximity match for face_id < 0
|
||||
if face_id < 0 and position_3d is not None:
|
||||
best_track = None
|
||||
best_dist = 0.5 # 0.5m proximity threshold
|
||||
for track in self._tracks.values():
|
||||
d = self._distance_3d(track.position, position_3d)
|
||||
if d < best_dist and (now - track.last_seen) < self._position_window:
|
||||
best_dist = d
|
||||
best_track = track
|
||||
if best_track is not None:
|
||||
best_track.position = position_3d
|
||||
dist, bearing = self._compute_distance_and_bearing(position_3d)
|
||||
best_track.distance = dist
|
||||
best_track.bearing_deg = bearing
|
||||
best_track.history_distances.append(dist)
|
||||
if len(best_track.history_distances) > self._max_history:
|
||||
best_track.history_distances = best_track.history_distances[-self._max_history:]
|
||||
best_track.camera_id = camera_id
|
||||
best_track.last_seen = now
|
||||
return best_track.person_id
|
||||
|
||||
# Create new track
|
||||
pid = self._allocate_id()
|
||||
track = PersonTrack(person_id=pid, face_id=face_id, name=name,
|
||||
camera_id=camera_id, last_seen=now, last_face_seen=now)
|
||||
if position_3d is not None:
|
||||
track.position = position_3d
|
||||
dist, bearing = self._compute_distance_and_bearing(position_3d)
|
||||
track.distance = dist
|
||||
track.bearing_deg = bearing
|
||||
track.history_distances.append(dist)
|
||||
self._tracks[pid] = track
|
||||
return pid
|
||||
|
||||
def update_speaker(self, speaker_id: str, now: float) -> int:
|
||||
"""Find track with nearest recently-seen position, assign speaker_id."""
|
||||
# First check if any track already has this speaker_id
|
||||
for track in self._tracks.values():
|
||||
if track.speaker_id == speaker_id:
|
||||
track.last_seen = now
|
||||
return track.person_id
|
||||
|
||||
# Assign to nearest recently-seen track
|
||||
best_track = None
|
||||
best_dist = float('inf')
|
||||
for track in self._tracks.values():
|
||||
if (now - track.last_seen) < self._position_window and track.distance > 0:
|
||||
if track.distance < best_dist:
|
||||
best_dist = track.distance
|
||||
best_track = track
|
||||
|
||||
if best_track is not None:
|
||||
best_track.speaker_id = speaker_id
|
||||
best_track.last_seen = now
|
||||
return best_track.person_id
|
||||
|
||||
# No suitable track found; create a new one
|
||||
pid = self._allocate_id()
|
||||
track = PersonTrack(person_id=pid, speaker_id=speaker_id,
|
||||
last_seen=now)
|
||||
self._tracks[pid] = track
|
||||
return pid
|
||||
|
||||
def update_uwb(self, uwb_anchor_id: str, position_3d: tuple,
|
||||
now: float) -> int:
|
||||
"""Match by proximity (nearest track within 0.5m)."""
|
||||
# Check existing UWB assignment
|
||||
for track in self._tracks.values():
|
||||
if track.uwb_anchor_id == uwb_anchor_id:
|
||||
track.position = position_3d
|
||||
dist, bearing = self._compute_distance_and_bearing(position_3d)
|
||||
track.distance = dist
|
||||
track.bearing_deg = bearing
|
||||
track.history_distances.append(dist)
|
||||
if len(track.history_distances) > self._max_history:
|
||||
track.history_distances = track.history_distances[-self._max_history:]
|
||||
track.last_seen = now
|
||||
return track.person_id
|
||||
|
||||
# Proximity match
|
||||
best_track = None
|
||||
best_dist = 0.5
|
||||
for track in self._tracks.values():
|
||||
d = self._distance_3d(track.position, position_3d)
|
||||
if d < best_dist and (now - track.last_seen) < self._position_window:
|
||||
best_dist = d
|
||||
best_track = track
|
||||
|
||||
if best_track is not None:
|
||||
best_track.uwb_anchor_id = uwb_anchor_id
|
||||
best_track.position = position_3d
|
||||
dist, bearing = self._compute_distance_and_bearing(position_3d)
|
||||
best_track.distance = dist
|
||||
best_track.bearing_deg = bearing
|
||||
best_track.history_distances.append(dist)
|
||||
if len(best_track.history_distances) > self._max_history:
|
||||
best_track.history_distances = best_track.history_distances[-self._max_history:]
|
||||
best_track.last_seen = now
|
||||
return best_track.person_id
|
||||
|
||||
# Create new track
|
||||
pid = self._allocate_id()
|
||||
track = PersonTrack(person_id=pid, uwb_anchor_id=uwb_anchor_id,
|
||||
position=position_3d, last_seen=now)
|
||||
dist, bearing = self._compute_distance_and_bearing(position_3d)
|
||||
track.distance = dist
|
||||
track.bearing_deg = bearing
|
||||
track.history_distances.append(dist)
|
||||
self._tracks[pid] = track
|
||||
return pid
|
||||
|
||||
def get_all_tracks(self) -> List[PersonTrack]:
|
||||
return list(self._tracks.values())
|
||||
|
||||
@staticmethod
|
||||
def compute_attention(tracks: List[PersonTrack]) -> int:
|
||||
"""Focus on nearest engaged/talking person, or -1 if none."""
|
||||
candidates = [t for t in tracks
|
||||
if t.state in (State.ENGAGED, State.TALKING) and t.distance > 0]
|
||||
if not candidates:
|
||||
return -1
|
||||
# Prefer talking, then nearest engaged
|
||||
talking = [t for t in candidates if t.state == State.TALKING]
|
||||
if talking:
|
||||
return min(talking, key=lambda t: t.distance).person_id
|
||||
return min(candidates, key=lambda t: t.distance).person_id
|
||||
|
||||
def prune_absent(self, absent_timeout: float = 30.0):
|
||||
"""Remove tracks absent longer than timeout."""
|
||||
now = time.monotonic()
|
||||
to_remove = [pid for pid, t in self._tracks.items()
|
||||
if t.state == State.ABSENT and (now - t.last_seen) > absent_timeout]
|
||||
for pid in to_remove:
|
||||
del self._tracks[pid]
|
||||
|
||||
@staticmethod
|
||||
def detect_group(tracks: List[PersonTrack]) -> bool:
|
||||
"""Returns True if >= 2 persons within 1.5m of each other."""
|
||||
active = [t for t in tracks
|
||||
if t.position is not None and t.state != State.ABSENT]
|
||||
for i, a in enumerate(active):
|
||||
for b in active[i + 1:]:
|
||||
dx = a.position[0] - b.position[0]
|
||||
dy = a.position[1] - b.position[1]
|
||||
dz = a.position[2] - b.position[2]
|
||||
if math.sqrt(dx * dx + dy * dy + dz * dz) < 1.5:
|
||||
return True
|
||||
return False
|
||||
@ -0,0 +1,54 @@
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class State(IntEnum):
|
||||
UNKNOWN = 0
|
||||
APPROACHING = 1
|
||||
ENGAGED = 2
|
||||
TALKING = 3
|
||||
LEAVING = 4
|
||||
ABSENT = 5
|
||||
|
||||
|
||||
@dataclass
|
||||
class PersonTrack:
|
||||
person_id: int
|
||||
face_id: int = -1
|
||||
speaker_id: str = ''
|
||||
uwb_anchor_id: str = ''
|
||||
name: str = 'unknown'
|
||||
state: State = State.UNKNOWN
|
||||
position: Optional[tuple] = None # (x, y, z) in base_link
|
||||
distance: float = 0.0
|
||||
bearing_deg: float = 0.0
|
||||
engagement_score: float = 0.0
|
||||
camera_id: int = -1
|
||||
last_seen: float = field(default_factory=time.monotonic)
|
||||
last_face_seen: float = 0.0
|
||||
history_distances: list = field(default_factory=list) # last N distances for trend
|
||||
|
||||
def update_state(self, now: float, engagement_distance: float = 2.0,
|
||||
absent_timeout: float = 5.0):
|
||||
"""Transition state machine based on distance trend and time."""
|
||||
age = now - self.last_seen
|
||||
if age > absent_timeout:
|
||||
self.state = State.ABSENT
|
||||
return
|
||||
if self.speaker_id:
|
||||
self.state = State.TALKING
|
||||
return
|
||||
if len(self.history_distances) >= 3:
|
||||
trend = self.history_distances[-1] - self.history_distances[-3]
|
||||
if self.distance < engagement_distance:
|
||||
self.state = State.ENGAGED
|
||||
elif trend < -0.2: # moving closer
|
||||
self.state = State.APPROACHING
|
||||
elif trend > 0.3: # moving away
|
||||
self.state = State.LEAVING
|
||||
else:
|
||||
self.state = State.ENGAGED if self.distance < engagement_distance else State.UNKNOWN
|
||||
elif self.distance > 0:
|
||||
self.state = State.ENGAGED if self.distance < engagement_distance else State.APPROACHING
|
||||
@ -0,0 +1,273 @@
|
||||
import time
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||
|
||||
from std_msgs.msg import Int32, String
|
||||
from geometry_msgs.msg import PoseArray, PoseStamped, Point
|
||||
from sensor_msgs.msg import Image
|
||||
from builtin_interfaces.msg import Time as TimeMsg
|
||||
|
||||
from saltybot_social_msgs.msg import (
|
||||
FaceDetection,
|
||||
FaceDetectionArray,
|
||||
PersonState,
|
||||
PersonStateArray,
|
||||
)
|
||||
from saltybot_social.identity_fuser import IdentityFuser
|
||||
from saltybot_social.person_state import State
|
||||
|
||||
|
||||
class PersonStateTrackerNode(Node):
|
||||
"""Main ROS2 node for multi-modal person tracking."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__('person_state_tracker')
|
||||
|
||||
# Parameters
|
||||
self.declare_parameter('engagement_distance', 2.0)
|
||||
self.declare_parameter('absent_timeout', 5.0)
|
||||
self.declare_parameter('prune_timeout', 30.0)
|
||||
self.declare_parameter('update_rate', 10.0)
|
||||
self.declare_parameter('n_cameras', 4)
|
||||
self.declare_parameter('uwb_enabled', False)
|
||||
|
||||
self._engagement_distance = self.get_parameter('engagement_distance').value
|
||||
self._absent_timeout = self.get_parameter('absent_timeout').value
|
||||
self._prune_timeout = self.get_parameter('prune_timeout').value
|
||||
update_rate = self.get_parameter('update_rate').value
|
||||
n_cameras = self.get_parameter('n_cameras').value
|
||||
uwb_enabled = self.get_parameter('uwb_enabled').value
|
||||
|
||||
self._fuser = IdentityFuser()
|
||||
|
||||
# QoS profiles
|
||||
best_effort_qos = QoSProfile(
|
||||
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||
history=HistoryPolicy.KEEP_LAST,
|
||||
depth=5
|
||||
)
|
||||
reliable_qos = QoSProfile(
|
||||
reliability=ReliabilityPolicy.RELIABLE,
|
||||
history=HistoryPolicy.KEEP_LAST,
|
||||
depth=1
|
||||
)
|
||||
|
||||
# Subscriptions
|
||||
self.create_subscription(
|
||||
FaceDetectionArray,
|
||||
'/social/faces/detections',
|
||||
self._on_face_detections,
|
||||
best_effort_qos
|
||||
)
|
||||
self.create_subscription(
|
||||
String,
|
||||
'/social/speech/speaker_id',
|
||||
self._on_speaker_id,
|
||||
10
|
||||
)
|
||||
if uwb_enabled:
|
||||
self.create_subscription(
|
||||
PoseArray,
|
||||
'/uwb/positions',
|
||||
self._on_uwb,
|
||||
10
|
||||
)
|
||||
|
||||
# Camera subscriptions (monitor topic existence)
|
||||
self.create_subscription(
|
||||
Image, '/camera/color/image_raw',
|
||||
self._on_camera_image, best_effort_qos)
|
||||
for i in range(n_cameras):
|
||||
self.create_subscription(
|
||||
Image, f'/surround/cam{i}/image_raw',
|
||||
self._on_camera_image, best_effort_qos)
|
||||
|
||||
# Existing person tracker position
|
||||
self.create_subscription(
|
||||
PoseStamped,
|
||||
'/person/target',
|
||||
self._on_person_target,
|
||||
10
|
||||
)
|
||||
|
||||
# Publishers
|
||||
self._persons_pub = self.create_publisher(
|
||||
PersonStateArray, '/social/persons', best_effort_qos)
|
||||
self._attention_pub = self.create_publisher(
|
||||
Int32, '/social/attention/target_id', reliable_qos)
|
||||
|
||||
# Timer
|
||||
timer_period = 1.0 / update_rate
|
||||
self.create_timer(timer_period, self._on_timer)
|
||||
|
||||
self.get_logger().info(
|
||||
f'PersonStateTracker started: engagement={self._engagement_distance}m, '
|
||||
f'absent_timeout={self._absent_timeout}s, rate={update_rate}Hz, '
|
||||
f'cameras={n_cameras}, uwb={uwb_enabled}')
|
||||
|
||||
def _on_face_detections(self, msg: FaceDetectionArray):
|
||||
now = time.monotonic()
|
||||
for face in msg.faces:
|
||||
position_3d = None
|
||||
# Use bbox center as rough position estimate if depth is available
|
||||
# Real 3D position would come from depth camera projection
|
||||
if face.bbox_x > 0 or face.bbox_y > 0:
|
||||
# Approximate: use bbox center as bearing proxy, distance from bbox size
|
||||
# This is a placeholder; real impl would use depth image lookup
|
||||
position_3d = (
|
||||
max(0.5, 1.0 / max(face.bbox_w, 0.01)), # rough depth from bbox width
|
||||
face.bbox_x + face.bbox_w / 2.0 - 0.5, # x offset from center
|
||||
face.bbox_y + face.bbox_h / 2.0 - 0.5 # y offset from center
|
||||
)
|
||||
|
||||
camera_id = -1
|
||||
if hasattr(face.header, 'frame_id') and face.header.frame_id:
|
||||
frame = face.header.frame_id
|
||||
if 'cam' in frame:
|
||||
try:
|
||||
camera_id = int(frame.split('cam')[-1].split('_')[0])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
self._fuser.update_face(
|
||||
face_id=face.face_id,
|
||||
name=face.person_name,
|
||||
position_3d=position_3d,
|
||||
camera_id=camera_id,
|
||||
now=now
|
||||
)
|
||||
|
||||
def _on_speaker_id(self, msg: String):
|
||||
now = time.monotonic()
|
||||
speaker_id = msg.data
|
||||
# Parse "speaker_<id>" format or raw id
|
||||
if speaker_id.startswith('speaker_'):
|
||||
speaker_id = speaker_id[len('speaker_'):]
|
||||
self._fuser.update_speaker(speaker_id, now)
|
||||
|
||||
def _on_uwb(self, msg: PoseArray):
|
||||
now = time.monotonic()
|
||||
for i, pose in enumerate(msg.poses):
|
||||
position = (pose.position.x, pose.position.y, pose.position.z)
|
||||
anchor_id = f'uwb_{i}'
|
||||
if hasattr(msg, 'header') and msg.header.frame_id:
|
||||
anchor_id = f'{msg.header.frame_id}_{i}'
|
||||
self._fuser.update_uwb(anchor_id, position, now)
|
||||
|
||||
def _on_camera_image(self, msg: Image):
|
||||
# Monitor only -- no processing here
|
||||
pass
|
||||
|
||||
def _on_person_target(self, msg: PoseStamped):
|
||||
now = time.monotonic()
|
||||
position_3d = (
|
||||
msg.pose.position.x,
|
||||
msg.pose.position.y,
|
||||
msg.pose.position.z
|
||||
)
|
||||
# Update position for unidentified person (from YOLOv8 tracker)
|
||||
self._fuser.update_face(
|
||||
face_id=-1,
|
||||
name='unknown',
|
||||
position_3d=position_3d,
|
||||
camera_id=-1,
|
||||
now=now
|
||||
)
|
||||
|
||||
def _on_timer(self):
|
||||
now = time.monotonic()
|
||||
tracks = self._fuser.get_all_tracks()
|
||||
|
||||
# Update state machine for all tracks
|
||||
for track in tracks:
|
||||
track.update_state(
|
||||
now,
|
||||
engagement_distance=self._engagement_distance,
|
||||
absent_timeout=self._absent_timeout
|
||||
)
|
||||
|
||||
# Compute engagement scores
|
||||
for track in tracks:
|
||||
if track.state == State.ABSENT:
|
||||
track.engagement_score = 0.0
|
||||
elif track.state == State.TALKING:
|
||||
track.engagement_score = 1.0
|
||||
elif track.state == State.ENGAGED:
|
||||
track.engagement_score = max(0.0, 1.0 - track.distance / self._engagement_distance)
|
||||
elif track.state == State.APPROACHING:
|
||||
track.engagement_score = 0.3
|
||||
elif track.state == State.LEAVING:
|
||||
track.engagement_score = 0.1
|
||||
else:
|
||||
track.engagement_score = 0.0
|
||||
|
||||
# Prune long-absent tracks
|
||||
self._fuser.prune_absent(self._prune_timeout)
|
||||
|
||||
# Compute attention target
|
||||
attention_id = IdentityFuser.compute_attention(tracks)
|
||||
|
||||
# Build and publish PersonStateArray
|
||||
msg = PersonStateArray()
|
||||
msg.header.stamp = self.get_clock().now().to_msg()
|
||||
msg.header.frame_id = 'base_link'
|
||||
msg.primary_attention_id = attention_id
|
||||
|
||||
for track in tracks:
|
||||
ps = PersonState()
|
||||
ps.header.stamp = msg.header.stamp
|
||||
ps.header.frame_id = 'base_link'
|
||||
ps.person_id = track.person_id
|
||||
ps.person_name = track.name
|
||||
ps.face_id = track.face_id
|
||||
ps.speaker_id = track.speaker_id
|
||||
ps.uwb_anchor_id = track.uwb_anchor_id
|
||||
if track.position is not None:
|
||||
ps.position = Point(
|
||||
x=float(track.position[0]),
|
||||
y=float(track.position[1]),
|
||||
z=float(track.position[2]))
|
||||
ps.distance = float(track.distance)
|
||||
ps.bearing_deg = float(track.bearing_deg)
|
||||
ps.state = int(track.state)
|
||||
ps.engagement_score = float(track.engagement_score)
|
||||
ps.last_seen = self._mono_to_ros_time(track.last_seen)
|
||||
ps.camera_id = track.camera_id
|
||||
msg.persons.append(ps)
|
||||
|
||||
self._persons_pub.publish(msg)
|
||||
|
||||
# Publish attention target
|
||||
att_msg = Int32()
|
||||
att_msg.data = attention_id
|
||||
self._attention_pub.publish(att_msg)
|
||||
|
||||
def _mono_to_ros_time(self, mono: float) -> TimeMsg:
|
||||
"""Convert monotonic timestamp to approximate ROS time."""
|
||||
# Offset from monotonic to wall clock
|
||||
offset = time.time() - time.monotonic()
|
||||
wall = mono + offset
|
||||
sec = int(wall)
|
||||
nsec = int((wall - sec) * 1e9)
|
||||
t = TimeMsg()
|
||||
t.sec = sec
|
||||
t.nanosec = nsec
|
||||
return t
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = PersonStateTrackerNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
4
jetson/ros2_ws/src/saltybot_social/setup.cfg
Normal file
4
jetson/ros2_ws/src/saltybot_social/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/saltybot_social
|
||||
[install]
|
||||
install_scripts=$base/lib/saltybot_social
|
||||
32
jetson/ros2_ws/src/saltybot_social/setup.py
Normal file
32
jetson/ros2_ws/src/saltybot_social/setup.py
Normal file
@ -0,0 +1,32 @@
|
||||
from setuptools import find_packages, setup
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
package_name = 'saltybot_social'
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version='0.1.0',
|
||||
packages=find_packages(exclude=['test']),
|
||||
data_files=[
|
||||
('share/ament_index/resource_index/packages',
|
||||
['resource/' + package_name]),
|
||||
('share/' + package_name, ['package.xml']),
|
||||
(os.path.join('share', package_name, 'launch'),
|
||||
glob(os.path.join('launch', '*launch.[pxy][yma]*'))),
|
||||
(os.path.join('share', package_name, 'config'),
|
||||
glob(os.path.join('config', '*.yaml'))),
|
||||
],
|
||||
install_requires=['setuptools'],
|
||||
zip_safe=True,
|
||||
maintainer='seb',
|
||||
maintainer_email='seb@vayrette.com',
|
||||
description='Multi-modal person identity fusion and state tracking for saltybot',
|
||||
license='MIT',
|
||||
tests_require=['pytest'],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'person_state_tracker = saltybot_social.person_state_tracker_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
0
jetson/ros2_ws/src/saltybot_social/test/__init__.py
Normal file
0
jetson/ros2_ws/src/saltybot_social/test/__init__.py
Normal file
29
jetson/ros2_ws/src/saltybot_social_msgs/CMakeLists.txt
Normal file
29
jetson/ros2_ws/src/saltybot_social_msgs/CMakeLists.txt
Normal file
@ -0,0 +1,29 @@
|
||||
cmake_minimum_required(VERSION 3.8)
|
||||
project(saltybot_social_msgs)
|
||||
|
||||
find_package(ament_cmake REQUIRED)
|
||||
find_package(rosidl_default_generators REQUIRED)
|
||||
find_package(std_msgs REQUIRED)
|
||||
find_package(geometry_msgs REQUIRED)
|
||||
find_package(builtin_interfaces REQUIRED)
|
||||
|
||||
rosidl_generate_interfaces(${PROJECT_NAME}
|
||||
# Social perception (from sl-perception)
|
||||
"msg/FaceDetection.msg"
|
||||
"msg/FaceDetectionArray.msg"
|
||||
"msg/FaceEmbedding.msg"
|
||||
"msg/FaceEmbeddingArray.msg"
|
||||
"msg/PersonState.msg"
|
||||
"msg/PersonStateArray.msg"
|
||||
"srv/EnrollPerson.srv"
|
||||
"srv/ListPersons.srv"
|
||||
"srv/DeletePerson.srv"
|
||||
"srv/UpdatePerson.srv"
|
||||
# Personality system (Issue #84)
|
||||
"msg/PersonalityState.msg"
|
||||
"srv/QueryMood.srv"
|
||||
DEPENDENCIES std_msgs geometry_msgs builtin_interfaces
|
||||
)
|
||||
|
||||
ament_export_dependencies(rosidl_default_runtime)
|
||||
ament_package()
|
||||
@ -0,0 +1,10 @@
|
||||
std_msgs/Header header
|
||||
int32 face_id
|
||||
string person_name
|
||||
float32 confidence
|
||||
float32 recognition_score
|
||||
float32 bbox_x
|
||||
float32 bbox_y
|
||||
float32 bbox_w
|
||||
float32 bbox_h
|
||||
float32[10] landmarks
|
||||
@ -0,0 +1,2 @@
|
||||
std_msgs/Header header
|
||||
saltybot_social_msgs/FaceDetection[] faces
|
||||
@ -0,0 +1,5 @@
|
||||
int32 person_id
|
||||
string person_name
|
||||
float32[] embedding
|
||||
builtin_interfaces/Time enrolled_at
|
||||
int32 sample_count
|
||||
@ -0,0 +1,2 @@
|
||||
std_msgs/Header header
|
||||
saltybot_social_msgs/FaceEmbedding[] embeddings
|
||||
19
jetson/ros2_ws/src/saltybot_social_msgs/msg/PersonState.msg
Normal file
19
jetson/ros2_ws/src/saltybot_social_msgs/msg/PersonState.msg
Normal file
@ -0,0 +1,19 @@
|
||||
std_msgs/Header header
|
||||
int32 person_id
|
||||
string person_name
|
||||
int32 face_id
|
||||
string speaker_id
|
||||
string uwb_anchor_id
|
||||
geometry_msgs/Point position
|
||||
float32 distance
|
||||
float32 bearing_deg
|
||||
uint8 state
|
||||
uint8 STATE_UNKNOWN=0
|
||||
uint8 STATE_APPROACHING=1
|
||||
uint8 STATE_ENGAGED=2
|
||||
uint8 STATE_TALKING=3
|
||||
uint8 STATE_LEAVING=4
|
||||
uint8 STATE_ABSENT=5
|
||||
float32 engagement_score
|
||||
builtin_interfaces/Time last_seen
|
||||
int32 camera_id
|
||||
@ -0,0 +1,3 @@
|
||||
std_msgs/Header header
|
||||
saltybot_social_msgs/PersonState[] persons
|
||||
int32 primary_attention_id
|
||||
@ -0,0 +1,27 @@
|
||||
# PersonalityState.msg — published on /social/personality/state
|
||||
#
|
||||
# Snapshot of the personality node's current state: active mood, relationship
|
||||
# tier with the detected person, and a pre-generated greeting string.
|
||||
|
||||
std_msgs/Header header
|
||||
|
||||
# Active persona name (from SOUL.md)
|
||||
string persona_name
|
||||
|
||||
# Current mood: happy | curious | annoyed | playful
|
||||
string mood
|
||||
|
||||
# Person currently being addressed (empty if no one detected)
|
||||
string person_id
|
||||
|
||||
# Relationship tier with person_id: stranger | regular | favorite
|
||||
string relationship_tier
|
||||
|
||||
# Raw relationship score (higher = more familiar)
|
||||
float32 relationship_score
|
||||
|
||||
# How many times we have seen this person
|
||||
int32 interaction_count
|
||||
|
||||
# Ready-to-use greeting for person_id at current tier
|
||||
string greeting_text
|
||||
28
jetson/ros2_ws/src/saltybot_social_msgs/package.xml
Normal file
28
jetson/ros2_ws/src/saltybot_social_msgs/package.xml
Normal file
@ -0,0 +1,28 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_social_msgs</name>
|
||||
<version>0.1.0</version>
|
||||
<description>
|
||||
Custom ROS2 message and service definitions for saltybot social capabilities.
|
||||
Includes social perception types (face detection, person state, enrollment)
|
||||
and the personality system types (PersonalityState, QueryMood) from Issue #84.
|
||||
</description>
|
||||
<maintainer email="seb@vayrette.com">seb</maintainer>
|
||||
<license>MIT</license>
|
||||
|
||||
<buildtool_depend>ament_cmake</buildtool_depend>
|
||||
<build_depend>rosidl_default_generators</build_depend>
|
||||
|
||||
<depend>std_msgs</depend>
|
||||
<depend>geometry_msgs</depend>
|
||||
<depend>builtin_interfaces</depend>
|
||||
|
||||
<exec_depend>rosidl_default_runtime</exec_depend>
|
||||
|
||||
<member_of_group>rosidl_interface_packages</member_of_group>
|
||||
|
||||
<export>
|
||||
<build_type>ament_cmake</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1,4 @@
|
||||
int32 person_id
|
||||
---
|
||||
bool success
|
||||
string message
|
||||
@ -0,0 +1,7 @@
|
||||
string name
|
||||
string mode
|
||||
int32 n_samples
|
||||
---
|
||||
bool success
|
||||
string message
|
||||
int32 person_id
|
||||
@ -0,0 +1,2 @@
|
||||
---
|
||||
saltybot_social_msgs/FaceEmbedding[] persons
|
||||
15
jetson/ros2_ws/src/saltybot_social_msgs/srv/QueryMood.srv
Normal file
15
jetson/ros2_ws/src/saltybot_social_msgs/srv/QueryMood.srv
Normal file
@ -0,0 +1,15 @@
|
||||
# QueryMood.srv — ask the personality node for the current mood + greeting for a person
|
||||
#
|
||||
# Call with empty person_id to query the mood for whoever is currently tracked.
|
||||
|
||||
# Request
|
||||
string person_id # person to query; leave empty for current tracked person
|
||||
---
|
||||
# Response
|
||||
string mood # happy | curious | annoyed | playful
|
||||
string relationship_tier # stranger | regular | favorite
|
||||
float32 relationship_score
|
||||
int32 interaction_count
|
||||
string greeting_text # suggested greeting string
|
||||
bool success
|
||||
string message # error detail if success=false
|
||||
@ -0,0 +1,5 @@
|
||||
int32 person_id
|
||||
string new_name
|
||||
---
|
||||
bool success
|
||||
string message
|
||||
@ -0,0 +1,22 @@
|
||||
social_nav:
|
||||
ros__parameters:
|
||||
follow_mode: 'shadow'
|
||||
follow_distance: 1.2
|
||||
lead_distance: 2.0
|
||||
orbit_radius: 1.5
|
||||
max_linear_speed: 1.0
|
||||
max_linear_speed_fast: 5.5
|
||||
max_angular_speed: 1.0
|
||||
goal_tolerance: 0.3
|
||||
routes_dir: '/mnt/nvme/saltybot/routes'
|
||||
home_x: 0.0
|
||||
home_y: 0.0
|
||||
map_resolution: 0.05
|
||||
obstacle_inflation_cells: 3
|
||||
|
||||
midas_depth:
|
||||
ros__parameters:
|
||||
onnx_path: '/mnt/nvme/saltybot/models/midas_small.onnx'
|
||||
engine_path: '/mnt/nvme/saltybot/models/midas_small.engine'
|
||||
process_rate: 5.0
|
||||
output_scale: 1.0
|
||||
@ -0,0 +1,57 @@
|
||||
"""Launch file for saltybot social navigation."""
|
||||
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument('follow_mode', default_value='shadow',
|
||||
description='Follow mode: shadow/lead/side/orbit/loose/tight'),
|
||||
DeclareLaunchArgument('follow_distance', default_value='1.2',
|
||||
description='Follow distance in meters'),
|
||||
DeclareLaunchArgument('max_linear_speed', default_value='1.0',
|
||||
description='Max linear speed (m/s)'),
|
||||
DeclareLaunchArgument('routes_dir',
|
||||
default_value='/mnt/nvme/saltybot/routes',
|
||||
description='Directory for saved routes'),
|
||||
|
||||
Node(
|
||||
package='saltybot_social_nav',
|
||||
executable='social_nav',
|
||||
name='social_nav',
|
||||
output='screen',
|
||||
parameters=[{
|
||||
'follow_mode': LaunchConfiguration('follow_mode'),
|
||||
'follow_distance': LaunchConfiguration('follow_distance'),
|
||||
'max_linear_speed': LaunchConfiguration('max_linear_speed'),
|
||||
'routes_dir': LaunchConfiguration('routes_dir'),
|
||||
}],
|
||||
),
|
||||
|
||||
Node(
|
||||
package='saltybot_social_nav',
|
||||
executable='midas_depth',
|
||||
name='midas_depth',
|
||||
output='screen',
|
||||
parameters=[{
|
||||
'onnx_path': '/mnt/nvme/saltybot/models/midas_small.onnx',
|
||||
'engine_path': '/mnt/nvme/saltybot/models/midas_small.engine',
|
||||
'process_rate': 5.0,
|
||||
'output_scale': 1.0,
|
||||
}],
|
||||
),
|
||||
|
||||
Node(
|
||||
package='saltybot_social_nav',
|
||||
executable='waypoint_teacher',
|
||||
name='waypoint_teacher',
|
||||
output='screen',
|
||||
parameters=[{
|
||||
'routes_dir': LaunchConfiguration('routes_dir'),
|
||||
'recording_interval': 0.5,
|
||||
}],
|
||||
),
|
||||
])
|
||||
28
jetson/ros2_ws/src/saltybot_social_nav/package.xml
Normal file
28
jetson/ros2_ws/src/saltybot_social_nav/package.xml
Normal file
@ -0,0 +1,28 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_social_nav</name>
|
||||
<version>0.1.0</version>
|
||||
<description>Social navigation for saltybot: follow modes, waypoint teaching, A* avoidance, MiDaS depth</description>
|
||||
<maintainer email="seb@vayrette.com">seb</maintainer>
|
||||
<license>MIT</license>
|
||||
|
||||
<depend>rclpy</depend>
|
||||
<depend>std_msgs</depend>
|
||||
<depend>geometry_msgs</depend>
|
||||
<depend>nav_msgs</depend>
|
||||
<depend>sensor_msgs</depend>
|
||||
<depend>cv_bridge</depend>
|
||||
<depend>tf2_ros</depend>
|
||||
<depend>tf2_geometry_msgs</depend>
|
||||
<depend>saltybot_social_msgs</depend>
|
||||
|
||||
<test_depend>ament_copyright</test_depend>
|
||||
<test_depend>ament_flake8</test_depend>
|
||||
<test_depend>ament_pep257</test_depend>
|
||||
<test_depend>python3-pytest</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1,82 @@
|
||||
"""astar.py -- A* path planner for saltybot social navigation."""
|
||||
|
||||
import heapq
|
||||
import numpy as np
|
||||
|
||||
|
||||
def astar(grid: np.ndarray, start: tuple, goal: tuple,
|
||||
obstacle_val: int = 100) -> list | None:
|
||||
"""
|
||||
A* on a 2D occupancy grid (row, col indexing).
|
||||
|
||||
Args:
|
||||
grid: 2D numpy array, values 0=free, >=obstacle_val=obstacle
|
||||
start: (row, col) start cell
|
||||
goal: (row, col) goal cell
|
||||
obstacle_val: cells with value >= obstacle_val are blocked
|
||||
|
||||
Returns:
|
||||
List of (row, col) tuples from start to goal, or None if no path.
|
||||
"""
|
||||
rows, cols = grid.shape
|
||||
|
||||
def h(a, b):
|
||||
return abs(a[0] - b[0]) + abs(a[1] - b[1]) # Manhattan heuristic
|
||||
|
||||
open_set = []
|
||||
heapq.heappush(open_set, (h(start, goal), 0, start))
|
||||
came_from = {}
|
||||
g_score = {start: 0}
|
||||
|
||||
# 8-directional movement
|
||||
neighbors_delta = [
|
||||
(-1, -1), (-1, 0), (-1, 1),
|
||||
(0, -1), (0, 1),
|
||||
(1, -1), (1, 0), (1, 1),
|
||||
]
|
||||
|
||||
while open_set:
|
||||
_, cost, current = heapq.heappop(open_set)
|
||||
|
||||
if current == goal:
|
||||
path = []
|
||||
while current in came_from:
|
||||
path.append(current)
|
||||
current = came_from[current]
|
||||
path.append(start)
|
||||
return list(reversed(path))
|
||||
|
||||
if cost > g_score.get(current, float('inf')):
|
||||
continue
|
||||
|
||||
for dr, dc in neighbors_delta:
|
||||
nr, nc = current[0] + dr, current[1] + dc
|
||||
if not (0 <= nr < rows and 0 <= nc < cols):
|
||||
continue
|
||||
if grid[nr, nc] >= obstacle_val:
|
||||
continue
|
||||
move_cost = 1.414 if (dr != 0 and dc != 0) else 1.0
|
||||
new_g = g_score[current] + move_cost
|
||||
neighbor = (nr, nc)
|
||||
if new_g < g_score.get(neighbor, float('inf')):
|
||||
g_score[neighbor] = new_g
|
||||
f = new_g + h(neighbor, goal)
|
||||
came_from[neighbor] = current
|
||||
heapq.heappush(open_set, (f, new_g, neighbor))
|
||||
|
||||
return None # No path found
|
||||
|
||||
|
||||
def inflate_obstacles(grid: np.ndarray, inflation_radius_cells: int) -> np.ndarray:
|
||||
"""Inflate obstacles for robot footprint safety."""
|
||||
from scipy.ndimage import binary_dilation
|
||||
|
||||
obstacle_mask = grid >= 50
|
||||
kernel = np.ones(
|
||||
(2 * inflation_radius_cells + 1, 2 * inflation_radius_cells + 1),
|
||||
dtype=bool,
|
||||
)
|
||||
inflated = binary_dilation(obstacle_mask, structure=kernel)
|
||||
result = grid.copy()
|
||||
result[inflated] = 100
|
||||
return result
|
||||
@ -0,0 +1,82 @@
|
||||
"""follow_modes.py -- Follow mode geometry for saltybot social navigation."""
|
||||
|
||||
import math
|
||||
from enum import Enum
|
||||
import numpy as np
|
||||
|
||||
|
||||
class FollowMode(Enum):
|
||||
SHADOW = 'shadow' # stay directly behind at follow_distance
|
||||
LEAD = 'lead' # move ahead of person by lead_distance
|
||||
SIDE = 'side' # stay to the right (or left) at side_offset
|
||||
ORBIT = 'orbit' # circle around person at orbit_radius
|
||||
LOOSE = 'loose' # general follow, larger tolerance
|
||||
TIGHT = 'tight' # close follow, small tolerance
|
||||
|
||||
|
||||
def compute_shadow_target(person_pos, person_bearing_deg, follow_dist=1.2):
|
||||
"""Target position: behind person along their movement direction."""
|
||||
bearing_rad = math.radians(person_bearing_deg + 180.0)
|
||||
tx = person_pos[0] + follow_dist * math.sin(bearing_rad)
|
||||
ty = person_pos[1] + follow_dist * math.cos(bearing_rad)
|
||||
return (tx, ty, person_pos[2])
|
||||
|
||||
|
||||
def compute_lead_target(person_pos, person_bearing_deg, lead_dist=2.0):
|
||||
"""Target position: ahead of person."""
|
||||
bearing_rad = math.radians(person_bearing_deg)
|
||||
tx = person_pos[0] + lead_dist * math.sin(bearing_rad)
|
||||
ty = person_pos[1] + lead_dist * math.cos(bearing_rad)
|
||||
return (tx, ty, person_pos[2])
|
||||
|
||||
|
||||
def compute_side_target(person_pos, person_bearing_deg, side_dist=1.0, right=True):
|
||||
"""Target position: to the right (or left) of person."""
|
||||
sign = 1.0 if right else -1.0
|
||||
bearing_rad = math.radians(person_bearing_deg + sign * 90.0)
|
||||
tx = person_pos[0] + side_dist * math.sin(bearing_rad)
|
||||
ty = person_pos[1] + side_dist * math.cos(bearing_rad)
|
||||
return (tx, ty, person_pos[2])
|
||||
|
||||
|
||||
def compute_orbit_target(person_pos, orbit_angle_deg, orbit_radius=1.5):
|
||||
"""Target on circle of radius orbit_radius around person."""
|
||||
angle_rad = math.radians(orbit_angle_deg)
|
||||
tx = person_pos[0] + orbit_radius * math.sin(angle_rad)
|
||||
ty = person_pos[1] + orbit_radius * math.cos(angle_rad)
|
||||
return (tx, ty, person_pos[2])
|
||||
|
||||
|
||||
def compute_loose_target(person_pos, robot_pos, follow_dist=2.0, tolerance=0.8):
|
||||
"""Only move if farther than follow_dist + tolerance."""
|
||||
dx = person_pos[0] - robot_pos[0]
|
||||
dy = person_pos[1] - robot_pos[1]
|
||||
dist = math.hypot(dx, dy)
|
||||
if dist <= follow_dist + tolerance:
|
||||
return robot_pos
|
||||
# Target at follow_dist behind person (toward robot)
|
||||
scale = (dist - follow_dist) / dist
|
||||
return (robot_pos[0] + dx * scale, robot_pos[1] + dy * scale, person_pos[2])
|
||||
|
||||
|
||||
def compute_tight_target(person_pos, follow_dist=0.6):
|
||||
"""Close follow: stay very near person."""
|
||||
return (person_pos[0], person_pos[1] - follow_dist, person_pos[2])
|
||||
|
||||
|
||||
MODE_VOICE_COMMANDS = {
|
||||
'shadow': FollowMode.SHADOW,
|
||||
'follow me': FollowMode.SHADOW,
|
||||
'behind me': FollowMode.SHADOW,
|
||||
'lead': FollowMode.LEAD,
|
||||
'go ahead': FollowMode.LEAD,
|
||||
'lead me': FollowMode.LEAD,
|
||||
'side': FollowMode.SIDE,
|
||||
'stay beside': FollowMode.SIDE,
|
||||
'orbit': FollowMode.ORBIT,
|
||||
'circle me': FollowMode.ORBIT,
|
||||
'loose': FollowMode.LOOSE,
|
||||
'give me space': FollowMode.LOOSE,
|
||||
'tight': FollowMode.TIGHT,
|
||||
'stay close': FollowMode.TIGHT,
|
||||
}
|
||||
@ -0,0 +1,231 @@
|
||||
"""
|
||||
midas_depth_node.py -- MiDaS monocular depth estimation for saltybot.
|
||||
|
||||
Uses MiDaS_small via ONNX Runtime or TensorRT FP16.
|
||||
Provides relative depth estimates for cameras without active depth (CSI cameras).
|
||||
|
||||
Publishes /social/depth/cam{i}/image (sensor_msgs/Image, float32, relative depth)
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||
from sensor_msgs.msg import Image
|
||||
from cv_bridge import CvBridge
|
||||
|
||||
# MiDaS_small input size
|
||||
_MIDAS_H = 256
|
||||
_MIDAS_W = 256
|
||||
# ImageNet normalization
|
||||
_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
||||
_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
||||
|
||||
|
||||
class _TRTBackend:
|
||||
"""TensorRT inference backend for MiDaS."""
|
||||
|
||||
def __init__(self, engine_path: str, logger):
|
||||
self._logger = logger
|
||||
try:
|
||||
import tensorrt as trt
|
||||
import pycuda.driver as cuda
|
||||
import pycuda.autoinit # noqa: F401
|
||||
|
||||
self._cuda = cuda
|
||||
rt_logger = trt.Logger(trt.Logger.WARNING)
|
||||
with open(engine_path, 'rb') as f:
|
||||
engine = trt.Runtime(rt_logger).deserialize_cuda_engine(f.read())
|
||||
self._context = engine.create_execution_context()
|
||||
|
||||
# Allocate buffers
|
||||
self._d_input = cuda.mem_alloc(1 * 3 * _MIDAS_H * _MIDAS_W * 4)
|
||||
self._d_output = cuda.mem_alloc(1 * _MIDAS_H * _MIDAS_W * 4)
|
||||
self._h_output = np.empty((_MIDAS_H, _MIDAS_W), dtype=np.float32)
|
||||
self._stream = cuda.Stream()
|
||||
self._logger.info(f'TRT engine loaded: {engine_path}')
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'TRT init failed: {e}')
|
||||
|
||||
def infer(self, input_tensor: np.ndarray) -> np.ndarray:
|
||||
self._cuda.memcpy_htod_async(
|
||||
self._d_input, input_tensor.ravel(), self._stream)
|
||||
self._context.execute_async_v2(
|
||||
bindings=[int(self._d_input), int(self._d_output)],
|
||||
stream_handle=self._stream.handle)
|
||||
self._cuda.memcpy_dtoh_async(
|
||||
self._h_output, self._d_output, self._stream)
|
||||
self._stream.synchronize()
|
||||
return self._h_output.copy()
|
||||
|
||||
|
||||
class _ONNXBackend:
|
||||
"""ONNX Runtime inference backend for MiDaS."""
|
||||
|
||||
def __init__(self, onnx_path: str, logger):
|
||||
self._logger = logger
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
self._session = ort.InferenceSession(onnx_path, providers=providers)
|
||||
self._input_name = self._session.get_inputs()[0].name
|
||||
self._logger.info(f'ONNX model loaded: {onnx_path}')
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'ONNX init failed: {e}')
|
||||
|
||||
def infer(self, input_tensor: np.ndarray) -> np.ndarray:
|
||||
result = self._session.run(None, {self._input_name: input_tensor})
|
||||
return result[0].squeeze()
|
||||
|
||||
|
||||
class MiDaSDepthNode(Node):
|
||||
"""MiDaS monocular depth estimation node."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__('midas_depth')
|
||||
|
||||
# Parameters
|
||||
self.declare_parameter('onnx_path',
|
||||
'/mnt/nvme/saltybot/models/midas_small.onnx')
|
||||
self.declare_parameter('engine_path',
|
||||
'/mnt/nvme/saltybot/models/midas_small.engine')
|
||||
self.declare_parameter('camera_topics', [
|
||||
'/surround/cam0/image_raw',
|
||||
'/surround/cam1/image_raw',
|
||||
'/surround/cam2/image_raw',
|
||||
'/surround/cam3/image_raw',
|
||||
])
|
||||
self.declare_parameter('output_scale', 1.0)
|
||||
self.declare_parameter('process_rate', 5.0)
|
||||
|
||||
onnx_path = self.get_parameter('onnx_path').value
|
||||
engine_path = self.get_parameter('engine_path').value
|
||||
self._camera_topics = self.get_parameter('camera_topics').value
|
||||
self._output_scale = self.get_parameter('output_scale').value
|
||||
process_rate = self.get_parameter('process_rate').value
|
||||
|
||||
# Initialize inference backend (TRT preferred, ONNX fallback)
|
||||
self._backend = None
|
||||
if os.path.exists(engine_path):
|
||||
try:
|
||||
self._backend = _TRTBackend(engine_path, self.get_logger())
|
||||
except RuntimeError:
|
||||
self.get_logger().warn('TRT failed, trying ONNX fallback')
|
||||
if self._backend is None and os.path.exists(onnx_path):
|
||||
try:
|
||||
self._backend = _ONNXBackend(onnx_path, self.get_logger())
|
||||
except RuntimeError:
|
||||
self.get_logger().error('Both TRT and ONNX backends failed')
|
||||
if self._backend is None:
|
||||
self.get_logger().error(
|
||||
'No MiDaS model found. Depth estimation disabled.')
|
||||
|
||||
self._bridge = CvBridge()
|
||||
|
||||
# Latest frames per camera (round-robin processing)
|
||||
self._latest_frames = [None] * len(self._camera_topics)
|
||||
self._current_cam_idx = 0
|
||||
|
||||
# QoS for camera subscriptions
|
||||
cam_qos = QoSProfile(
|
||||
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||
history=HistoryPolicy.KEEP_LAST,
|
||||
depth=1,
|
||||
)
|
||||
|
||||
# Subscribe to each camera topic
|
||||
self._cam_subs = []
|
||||
for i, topic in enumerate(self._camera_topics):
|
||||
sub = self.create_subscription(
|
||||
Image, topic,
|
||||
lambda msg, idx=i: self._on_image(msg, idx),
|
||||
cam_qos)
|
||||
self._cam_subs.append(sub)
|
||||
|
||||
# Publishers: one per camera
|
||||
self._depth_pubs = []
|
||||
for i in range(len(self._camera_topics)):
|
||||
pub = self.create_publisher(
|
||||
Image, f'/social/depth/cam{i}/image', 10)
|
||||
self._depth_pubs.append(pub)
|
||||
|
||||
# Timer: round-robin across cameras
|
||||
timer_period = 1.0 / process_rate
|
||||
self._timer = self.create_timer(timer_period, self._timer_callback)
|
||||
|
||||
self.get_logger().info(
|
||||
f'MiDaS depth node started: {len(self._camera_topics)} cameras '
|
||||
f'@ {process_rate} Hz')
|
||||
|
||||
def _on_image(self, msg: Image, cam_idx: int):
|
||||
"""Cache latest frame for each camera."""
|
||||
self._latest_frames[cam_idx] = msg
|
||||
|
||||
def _preprocess(self, bgr: np.ndarray) -> np.ndarray:
|
||||
"""Preprocess BGR image to MiDaS input tensor [1,3,256,256]."""
|
||||
import cv2
|
||||
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
||||
resized = cv2.resize(rgb, (_MIDAS_W, _MIDAS_H),
|
||||
interpolation=cv2.INTER_LINEAR)
|
||||
normalized = (resized.astype(np.float32) / 255.0 - _MEAN) / _STD
|
||||
# HWC -> CHW -> NCHW
|
||||
tensor = normalized.transpose(2, 0, 1)[np.newaxis, ...]
|
||||
return tensor.astype(np.float32)
|
||||
|
||||
def _infer(self, tensor: np.ndarray) -> np.ndarray:
|
||||
"""Run inference, returns [256,256] float32 relative inverse depth."""
|
||||
if self._backend is None:
|
||||
return np.zeros((_MIDAS_H, _MIDAS_W), dtype=np.float32)
|
||||
return self._backend.infer(tensor)
|
||||
|
||||
def _postprocess(self, raw: np.ndarray, orig_shape: tuple) -> np.ndarray:
|
||||
"""Resize depth back to original image shape, apply output_scale."""
|
||||
import cv2
|
||||
h, w = orig_shape[:2]
|
||||
depth = cv2.resize(raw, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||
depth = depth * self._output_scale
|
||||
return depth
|
||||
|
||||
def _timer_callback(self):
|
||||
"""Process one camera per tick (round-robin)."""
|
||||
if not self._camera_topics:
|
||||
return
|
||||
|
||||
idx = self._current_cam_idx
|
||||
self._current_cam_idx = (idx + 1) % len(self._camera_topics)
|
||||
|
||||
msg = self._latest_frames[idx]
|
||||
if msg is None:
|
||||
return
|
||||
|
||||
try:
|
||||
bgr = self._bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')
|
||||
except Exception as e:
|
||||
self.get_logger().warn(f'cv_bridge error cam{idx}: {e}')
|
||||
return
|
||||
|
||||
tensor = self._preprocess(bgr)
|
||||
raw_depth = self._infer(tensor)
|
||||
depth_map = self._postprocess(raw_depth, bgr.shape)
|
||||
|
||||
# Publish as float32 Image
|
||||
depth_msg = self._bridge.cv2_to_imgmsg(depth_map, encoding='32FC1')
|
||||
depth_msg.header = msg.header
|
||||
self._depth_pubs[idx].publish(depth_msg)
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = MiDaSDepthNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -0,0 +1,584 @@
|
||||
"""
|
||||
social_nav_node.py -- Social navigation node for saltybot.
|
||||
|
||||
Orchestrates person following with multiple modes, voice commands,
|
||||
waypoint teaching/replay, and A* obstacle avoidance.
|
||||
|
||||
Follow modes:
|
||||
shadow -- stay directly behind at follow_distance
|
||||
lead -- move ahead of person
|
||||
side -- stay beside (default right)
|
||||
orbit -- circle around person
|
||||
loose -- relaxed follow with deadband
|
||||
tight -- close follow
|
||||
|
||||
Waypoint teaching:
|
||||
Voice command "teach route <name>" -> record mode ON
|
||||
Voice command "stop teaching" -> save route
|
||||
Voice command "replay route <name>" -> playback
|
||||
|
||||
Voice commands:
|
||||
"follow me" / "shadow" -> SHADOW mode
|
||||
"lead me" / "go ahead" -> LEAD mode
|
||||
"stay beside" -> SIDE mode
|
||||
"orbit" -> ORBIT mode
|
||||
"give me space" -> LOOSE mode
|
||||
"stay close" -> TIGHT mode
|
||||
"stop" / "halt" -> STOP
|
||||
"go home" -> navigate to home waypoint
|
||||
"teach route <name>" -> start recording
|
||||
"stop teaching" -> finish recording
|
||||
"replay route <name>" -> playback recorded route
|
||||
"""
|
||||
|
||||
import math
|
||||
import time
|
||||
import re
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||
from geometry_msgs.msg import Twist, PoseStamped
|
||||
from nav_msgs.msg import OccupancyGrid, Odometry
|
||||
from std_msgs.msg import String, Int32
|
||||
|
||||
from .follow_modes import (
|
||||
FollowMode, MODE_VOICE_COMMANDS,
|
||||
compute_shadow_target, compute_lead_target, compute_side_target,
|
||||
compute_orbit_target, compute_loose_target, compute_tight_target,
|
||||
)
|
||||
from .astar import astar, inflate_obstacles
|
||||
from .waypoint_teacher import WaypointRoute, WaypointReplayer
|
||||
|
||||
# Try importing social msgs; fallback gracefully
|
||||
try:
|
||||
from saltybot_social_msgs.msg import PersonStateArray
|
||||
_HAS_SOCIAL_MSGS = True
|
||||
except ImportError:
|
||||
_HAS_SOCIAL_MSGS = False
|
||||
|
||||
# Proportional controller gains
|
||||
_K_ANG = 2.0 # angular gain
|
||||
_K_LIN = 0.8 # linear gain
|
||||
_HIGH_SPEED_THRESHOLD = 3.0 # m/s person velocity triggers fast mode
|
||||
_PREDICT_AHEAD_S = 0.3 # seconds to extrapolate position
|
||||
_TEACH_MIN_DIST = 0.5 # meters between recorded waypoints
|
||||
|
||||
|
||||
class SocialNavNode(Node):
|
||||
"""Main social navigation orchestrator."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__('social_nav')
|
||||
|
||||
# -- Parameters --
|
||||
self.declare_parameter('follow_mode', 'shadow')
|
||||
self.declare_parameter('follow_distance', 1.2)
|
||||
self.declare_parameter('lead_distance', 2.0)
|
||||
self.declare_parameter('orbit_radius', 1.5)
|
||||
self.declare_parameter('max_linear_speed', 1.0)
|
||||
self.declare_parameter('max_linear_speed_fast', 5.5)
|
||||
self.declare_parameter('max_angular_speed', 1.0)
|
||||
self.declare_parameter('goal_tolerance', 0.3)
|
||||
self.declare_parameter('routes_dir', '/mnt/nvme/saltybot/routes')
|
||||
self.declare_parameter('home_x', 0.0)
|
||||
self.declare_parameter('home_y', 0.0)
|
||||
self.declare_parameter('map_resolution', 0.05)
|
||||
self.declare_parameter('obstacle_inflation_cells', 3)
|
||||
|
||||
self._follow_mode = FollowMode(
|
||||
self.get_parameter('follow_mode').value)
|
||||
self._follow_distance = self.get_parameter('follow_distance').value
|
||||
self._lead_distance = self.get_parameter('lead_distance').value
|
||||
self._orbit_radius = self.get_parameter('orbit_radius').value
|
||||
self._max_lin = self.get_parameter('max_linear_speed').value
|
||||
self._max_lin_fast = self.get_parameter('max_linear_speed_fast').value
|
||||
self._max_ang = self.get_parameter('max_angular_speed').value
|
||||
self._goal_tol = self.get_parameter('goal_tolerance').value
|
||||
self._routes_dir = self.get_parameter('routes_dir').value
|
||||
self._home_x = self.get_parameter('home_x').value
|
||||
self._home_y = self.get_parameter('home_y').value
|
||||
self._map_resolution = self.get_parameter('map_resolution').value
|
||||
self._inflation_cells = self.get_parameter(
|
||||
'obstacle_inflation_cells').value
|
||||
|
||||
# -- State --
|
||||
self._robot_x = 0.0
|
||||
self._robot_y = 0.0
|
||||
self._robot_yaw = 0.0
|
||||
self._target_person_pos = None # (x, y, z)
|
||||
self._target_person_bearing = 0.0
|
||||
self._target_person_id = -1
|
||||
self._person_history = deque(maxlen=5) # for velocity estimation
|
||||
self._stopped = False
|
||||
self._go_home = False
|
||||
|
||||
# Occupancy grid for A*
|
||||
self._occ_grid = None
|
||||
self._occ_origin = (0.0, 0.0)
|
||||
self._occ_resolution = 0.05
|
||||
self._astar_path = None
|
||||
|
||||
# Orbit state
|
||||
self._orbit_angle = 0.0
|
||||
|
||||
# Waypoint teaching / replay
|
||||
self._teaching = False
|
||||
self._current_route = None
|
||||
self._last_teach_x = None
|
||||
self._last_teach_y = None
|
||||
self._replayer = None
|
||||
|
||||
# -- QoS profiles --
|
||||
best_effort_qos = QoSProfile(
|
||||
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||
history=HistoryPolicy.KEEP_LAST, depth=1)
|
||||
reliable_qos = QoSProfile(
|
||||
reliability=ReliabilityPolicy.RELIABLE,
|
||||
history=HistoryPolicy.KEEP_LAST, depth=1)
|
||||
|
||||
# -- Subscriptions --
|
||||
if _HAS_SOCIAL_MSGS:
|
||||
self.create_subscription(
|
||||
PersonStateArray, '/social/persons',
|
||||
self._on_persons, best_effort_qos)
|
||||
else:
|
||||
self.get_logger().warn(
|
||||
'saltybot_social_msgs not found; '
|
||||
'using /person/target fallback')
|
||||
|
||||
self.create_subscription(
|
||||
PoseStamped, '/person/target',
|
||||
self._on_person_target, best_effort_qos)
|
||||
self.create_subscription(
|
||||
String, '/social/speech/command',
|
||||
self._on_voice_command, 10)
|
||||
self.create_subscription(
|
||||
String, '/social/speech/transcript',
|
||||
self._on_transcript, 10)
|
||||
self.create_subscription(
|
||||
OccupancyGrid, '/map',
|
||||
self._on_map, reliable_qos)
|
||||
self.create_subscription(
|
||||
Odometry, '/odom',
|
||||
self._on_odom, best_effort_qos)
|
||||
self.create_subscription(
|
||||
Int32, '/social/attention/target_id',
|
||||
self._on_target_id, 10)
|
||||
|
||||
# -- Publishers --
|
||||
self._cmd_vel_pub = self.create_publisher(
|
||||
Twist, '/cmd_vel', best_effort_qos)
|
||||
self._mode_pub = self.create_publisher(
|
||||
String, '/social/nav/mode', reliable_qos)
|
||||
self._target_pub = self.create_publisher(
|
||||
PoseStamped, '/social/nav/target_pos', 10)
|
||||
self._status_pub = self.create_publisher(
|
||||
String, '/social/nav/status', best_effort_qos)
|
||||
|
||||
# -- Main loop timer (20 Hz) --
|
||||
self._timer = self.create_timer(0.05, self._control_loop)
|
||||
|
||||
self.get_logger().info(
|
||||
f'Social nav started: mode={self._follow_mode.value}, '
|
||||
f'dist={self._follow_distance}m')
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Subscriptions
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
def _on_persons(self, msg):
|
||||
"""Handle PersonStateArray from social perception."""
|
||||
target_id = msg.primary_attention_id
|
||||
if self._target_person_id >= 0:
|
||||
target_id = self._target_person_id
|
||||
|
||||
for p in msg.persons:
|
||||
if p.person_id == target_id or (
|
||||
target_id < 0 and len(msg.persons) > 0):
|
||||
pos = (p.position.x, p.position.y, p.position.z)
|
||||
self._update_person_position(pos, p.bearing_deg)
|
||||
break
|
||||
|
||||
def _on_person_target(self, msg: PoseStamped):
|
||||
"""Fallback: single person target pose."""
|
||||
pos = (msg.pose.position.x, msg.pose.position.y,
|
||||
msg.pose.position.z)
|
||||
# Estimate bearing from quaternion yaw
|
||||
q = msg.pose.orientation
|
||||
yaw = math.atan2(2.0 * (q.w * q.z + q.x * q.y),
|
||||
1.0 - 2.0 * (q.y * q.y + q.z * q.z))
|
||||
self._update_person_position(pos, math.degrees(yaw))
|
||||
|
||||
def _update_person_position(self, pos, bearing_deg):
|
||||
"""Update person tracking state and record history."""
|
||||
now = time.time()
|
||||
self._target_person_pos = pos
|
||||
self._target_person_bearing = bearing_deg
|
||||
self._person_history.append((now, pos[0], pos[1]))
|
||||
|
||||
def _on_odom(self, msg: Odometry):
|
||||
"""Update robot pose from odometry."""
|
||||
self._robot_x = msg.pose.pose.position.x
|
||||
self._robot_y = msg.pose.pose.position.y
|
||||
q = msg.pose.pose.orientation
|
||||
self._robot_yaw = math.atan2(
|
||||
2.0 * (q.w * q.z + q.x * q.y),
|
||||
1.0 - 2.0 * (q.y * q.y + q.z * q.z))
|
||||
|
||||
def _on_map(self, msg: OccupancyGrid):
|
||||
"""Cache occupancy grid for A* planning."""
|
||||
w, h = msg.info.width, msg.info.height
|
||||
data = np.array(msg.data, dtype=np.int8).reshape((h, w))
|
||||
# Convert -1 (unknown) to free (0) for planning
|
||||
data[data < 0] = 0
|
||||
self._occ_grid = data.astype(np.int32)
|
||||
self._occ_origin = (msg.info.origin.position.x,
|
||||
msg.info.origin.position.y)
|
||||
self._occ_resolution = msg.info.resolution
|
||||
|
||||
def _on_target_id(self, msg: Int32):
|
||||
"""Switch target person."""
|
||||
self._target_person_id = msg.data
|
||||
self.get_logger().info(f'Target person ID set to {msg.data}')
|
||||
|
||||
def _on_voice_command(self, msg: String):
|
||||
"""Handle discrete voice commands for mode switching."""
|
||||
cmd = msg.data.strip().lower()
|
||||
|
||||
if cmd in ('stop', 'halt'):
|
||||
self._stopped = True
|
||||
self._replayer = None
|
||||
self._publish_status('STOPPED')
|
||||
return
|
||||
|
||||
if cmd in ('resume', 'go', 'start'):
|
||||
self._stopped = False
|
||||
self._publish_status('RESUMED')
|
||||
return
|
||||
|
||||
matched = MODE_VOICE_COMMANDS.get(cmd)
|
||||
if matched:
|
||||
self._follow_mode = matched
|
||||
self._stopped = False
|
||||
mode_msg = String()
|
||||
mode_msg.data = self._follow_mode.value
|
||||
self._mode_pub.publish(mode_msg)
|
||||
self._publish_status(f'MODE: {self._follow_mode.value}')
|
||||
|
||||
def _on_transcript(self, msg: String):
|
||||
"""Handle free-form voice transcripts for route teaching."""
|
||||
text = msg.data.strip().lower()
|
||||
|
||||
# "teach route <name>"
|
||||
m = re.match(r'teach\s+route\s+(\w+)', text)
|
||||
if m:
|
||||
name = m.group(1)
|
||||
self._teaching = True
|
||||
self._current_route = WaypointRoute(name)
|
||||
self._last_teach_x = self._robot_x
|
||||
self._last_teach_y = self._robot_y
|
||||
self._publish_status(f'TEACHING: {name}')
|
||||
self.get_logger().info(f'Recording route: {name}')
|
||||
return
|
||||
|
||||
# "stop teaching"
|
||||
if 'stop teaching' in text:
|
||||
if self._teaching and self._current_route:
|
||||
self._current_route.save(self._routes_dir)
|
||||
self._publish_status(
|
||||
f'SAVED: {self._current_route.name} '
|
||||
f'({len(self._current_route.waypoints)} pts)')
|
||||
self.get_logger().info(
|
||||
f'Route saved: {self._current_route.name}')
|
||||
self._teaching = False
|
||||
self._current_route = None
|
||||
return
|
||||
|
||||
# "replay route <name>"
|
||||
m = re.match(r'replay\s+route\s+(\w+)', text)
|
||||
if m:
|
||||
name = m.group(1)
|
||||
try:
|
||||
route = WaypointRoute.load(self._routes_dir, name)
|
||||
self._replayer = WaypointReplayer(route)
|
||||
self._stopped = False
|
||||
self._publish_status(f'REPLAY: {name}')
|
||||
self.get_logger().info(f'Replaying route: {name}')
|
||||
except FileNotFoundError:
|
||||
self._publish_status(f'ROUTE NOT FOUND: {name}')
|
||||
return
|
||||
|
||||
# "go home"
|
||||
if 'go home' in text:
|
||||
self._go_home = True
|
||||
self._stopped = False
|
||||
self._publish_status('GO HOME')
|
||||
return
|
||||
|
||||
# Also try mode commands from transcript
|
||||
for phrase, mode in MODE_VOICE_COMMANDS.items():
|
||||
if phrase in text:
|
||||
self._follow_mode = mode
|
||||
self._stopped = False
|
||||
mode_msg = String()
|
||||
mode_msg.data = self._follow_mode.value
|
||||
self._mode_pub.publish(mode_msg)
|
||||
self._publish_status(f'MODE: {self._follow_mode.value}')
|
||||
return
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Control loop
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
def _control_loop(self):
|
||||
"""Main 20Hz control loop."""
|
||||
# Record waypoint if teaching
|
||||
if self._teaching and self._current_route:
|
||||
self._maybe_record_waypoint()
|
||||
|
||||
# Publish zero velocity if stopped
|
||||
if self._stopped:
|
||||
self._publish_cmd_vel(0.0, 0.0)
|
||||
return
|
||||
|
||||
# Determine navigation target
|
||||
target = self._get_nav_target()
|
||||
if target is None:
|
||||
self._publish_cmd_vel(0.0, 0.0)
|
||||
return
|
||||
|
||||
tx, ty, tz = target
|
||||
|
||||
# Publish debug target
|
||||
self._publish_target_pose(tx, ty, tz)
|
||||
|
||||
# Check if arrived
|
||||
dist_to_target = math.hypot(tx - self._robot_x, ty - self._robot_y)
|
||||
if dist_to_target < self._goal_tol:
|
||||
self._publish_cmd_vel(0.0, 0.0)
|
||||
if self._go_home:
|
||||
self._go_home = False
|
||||
self._publish_status('HOME REACHED')
|
||||
return
|
||||
|
||||
# Try A* path if map available
|
||||
if self._occ_grid is not None:
|
||||
path_target = self._plan_astar(tx, ty)
|
||||
if path_target:
|
||||
tx, ty = path_target
|
||||
|
||||
# Determine speed limit
|
||||
max_lin = self._max_lin
|
||||
person_vel = self._estimate_person_velocity()
|
||||
if person_vel > _HIGH_SPEED_THRESHOLD:
|
||||
max_lin = self._max_lin_fast
|
||||
|
||||
# Compute and publish cmd_vel
|
||||
lin, ang = self._compute_cmd_vel(
|
||||
self._robot_x, self._robot_y, self._robot_yaw,
|
||||
tx, ty, max_lin)
|
||||
self._publish_cmd_vel(lin, ang)
|
||||
|
||||
def _get_nav_target(self):
|
||||
"""Determine current navigation target based on mode/state."""
|
||||
# Route replay takes priority
|
||||
if self._replayer and not self._replayer.is_done:
|
||||
self._replayer.check_arrived(self._robot_x, self._robot_y)
|
||||
wp = self._replayer.current_waypoint()
|
||||
if wp:
|
||||
return (wp.x, wp.y, wp.z)
|
||||
else:
|
||||
self._replayer = None
|
||||
self._publish_status('REPLAY DONE')
|
||||
return None
|
||||
|
||||
# Go home
|
||||
if self._go_home:
|
||||
return (self._home_x, self._home_y, 0.0)
|
||||
|
||||
# Person following
|
||||
if self._target_person_pos is None:
|
||||
return None
|
||||
|
||||
# Predict person position ahead for high-speed tracking
|
||||
px, py, pz = self._predict_person_position()
|
||||
bearing = self._target_person_bearing
|
||||
robot_pos = (self._robot_x, self._robot_y, 0.0)
|
||||
|
||||
if self._follow_mode == FollowMode.SHADOW:
|
||||
return compute_shadow_target(
|
||||
(px, py, pz), bearing, self._follow_distance)
|
||||
elif self._follow_mode == FollowMode.LEAD:
|
||||
return compute_lead_target(
|
||||
(px, py, pz), bearing, self._lead_distance)
|
||||
elif self._follow_mode == FollowMode.SIDE:
|
||||
return compute_side_target(
|
||||
(px, py, pz), bearing, self._follow_distance)
|
||||
elif self._follow_mode == FollowMode.ORBIT:
|
||||
self._orbit_angle = (self._orbit_angle + 1.0) % 360.0
|
||||
return compute_orbit_target(
|
||||
(px, py, pz), self._orbit_angle, self._orbit_radius)
|
||||
elif self._follow_mode == FollowMode.LOOSE:
|
||||
return compute_loose_target(
|
||||
(px, py, pz), robot_pos, self._follow_distance)
|
||||
elif self._follow_mode == FollowMode.TIGHT:
|
||||
return compute_tight_target(
|
||||
(px, py, pz), self._follow_distance)
|
||||
|
||||
return (px, py, pz)
|
||||
|
||||
def _predict_person_position(self):
|
||||
"""Extrapolate person position using velocity from recent samples."""
|
||||
if self._target_person_pos is None:
|
||||
return (0.0, 0.0, 0.0)
|
||||
|
||||
px, py, pz = self._target_person_pos
|
||||
|
||||
if len(self._person_history) >= 3:
|
||||
# Use last 3 samples for velocity estimation
|
||||
t0, x0, y0 = self._person_history[-3]
|
||||
t1, x1, y1 = self._person_history[-1]
|
||||
dt = t1 - t0
|
||||
if dt > 0.01:
|
||||
vx = (x1 - x0) / dt
|
||||
vy = (y1 - y0) / dt
|
||||
speed = math.hypot(vx, vy)
|
||||
if speed > _HIGH_SPEED_THRESHOLD:
|
||||
px += vx * _PREDICT_AHEAD_S
|
||||
py += vy * _PREDICT_AHEAD_S
|
||||
|
||||
return (px, py, pz)
|
||||
|
||||
def _estimate_person_velocity(self) -> float:
|
||||
"""Estimate person speed from recent position history."""
|
||||
if len(self._person_history) < 2:
|
||||
return 0.0
|
||||
t0, x0, y0 = self._person_history[-2]
|
||||
t1, x1, y1 = self._person_history[-1]
|
||||
dt = t1 - t0
|
||||
if dt < 0.01:
|
||||
return 0.0
|
||||
return math.hypot(x1 - x0, y1 - y0) / dt
|
||||
|
||||
def _plan_astar(self, target_x, target_y):
|
||||
"""Run A* on occupancy grid, return next waypoint in world coords."""
|
||||
grid = self._occ_grid
|
||||
res = self._occ_resolution
|
||||
ox, oy = self._occ_origin
|
||||
|
||||
# World to grid
|
||||
def w2g(wx, wy):
|
||||
return (int((wy - oy) / res), int((wx - ox) / res))
|
||||
|
||||
start = w2g(self._robot_x, self._robot_y)
|
||||
goal = w2g(target_x, target_y)
|
||||
|
||||
rows, cols = grid.shape
|
||||
if not (0 <= start[0] < rows and 0 <= start[1] < cols):
|
||||
return None
|
||||
if not (0 <= goal[0] < rows and 0 <= goal[1] < cols):
|
||||
return None
|
||||
|
||||
inflated = inflate_obstacles(grid, self._inflation_cells)
|
||||
path = astar(inflated, start, goal)
|
||||
|
||||
if path and len(path) > 1:
|
||||
# Follow a lookahead point (3 steps ahead or end)
|
||||
lookahead_idx = min(3, len(path) - 1)
|
||||
r, c = path[lookahead_idx]
|
||||
wx = ox + c * res + res / 2.0
|
||||
wy = oy + r * res + res / 2.0
|
||||
self._astar_path = path
|
||||
return (wx, wy)
|
||||
|
||||
return None
|
||||
|
||||
def _compute_cmd_vel(self, rx, ry, ryaw, tx, ty, max_lin):
|
||||
"""Proportional controller: compute linear and angular velocity."""
|
||||
dx = tx - rx
|
||||
dy = ty - ry
|
||||
dist = math.hypot(dx, dy)
|
||||
angle_to_target = math.atan2(dy, dx)
|
||||
angle_error = angle_to_target - ryaw
|
||||
|
||||
# Normalize angle error to [-pi, pi]
|
||||
while angle_error > math.pi:
|
||||
angle_error -= 2.0 * math.pi
|
||||
while angle_error < -math.pi:
|
||||
angle_error += 2.0 * math.pi
|
||||
|
||||
angular_vel = _K_ANG * angle_error
|
||||
angular_vel = max(-self._max_ang,
|
||||
min(self._max_ang, angular_vel))
|
||||
|
||||
# Reduce linear speed when turning hard
|
||||
angle_factor = max(0.0, 1.0 - abs(angle_error) / (math.pi / 2.0))
|
||||
linear_vel = _K_LIN * dist * angle_factor
|
||||
linear_vel = max(0.0, min(max_lin, linear_vel))
|
||||
|
||||
return (linear_vel, angular_vel)
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Waypoint teaching
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
def _maybe_record_waypoint(self):
|
||||
"""Record waypoint if robot moved > _TEACH_MIN_DIST."""
|
||||
if self._last_teach_x is None:
|
||||
self._last_teach_x = self._robot_x
|
||||
self._last_teach_y = self._robot_y
|
||||
|
||||
dist = math.hypot(
|
||||
self._robot_x - self._last_teach_x,
|
||||
self._robot_y - self._last_teach_y)
|
||||
|
||||
if dist >= _TEACH_MIN_DIST:
|
||||
yaw_deg = math.degrees(self._robot_yaw)
|
||||
self._current_route.add(
|
||||
self._robot_x, self._robot_y, 0.0, yaw_deg)
|
||||
self._last_teach_x = self._robot_x
|
||||
self._last_teach_y = self._robot_y
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Publishers
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
def _publish_cmd_vel(self, linear: float, angular: float):
|
||||
twist = Twist()
|
||||
twist.linear.x = linear
|
||||
twist.angular.z = angular
|
||||
self._cmd_vel_pub.publish(twist)
|
||||
|
||||
def _publish_target_pose(self, x, y, z):
|
||||
msg = PoseStamped()
|
||||
msg.header.stamp = self.get_clock().now().to_msg()
|
||||
msg.header.frame_id = 'map'
|
||||
msg.pose.position.x = x
|
||||
msg.pose.position.y = y
|
||||
msg.pose.position.z = z
|
||||
self._target_pub.publish(msg)
|
||||
|
||||
def _publish_status(self, status: str):
|
||||
msg = String()
|
||||
msg.data = status
|
||||
self._status_pub.publish(msg)
|
||||
self.get_logger().info(f'Nav status: {status}')
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = SocialNavNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -0,0 +1,91 @@
|
||||
"""waypoint_teacher.py -- Record and replay waypoint routes."""
|
||||
|
||||
import json
|
||||
import time
|
||||
import math
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
|
||||
@dataclass
|
||||
class Waypoint:
|
||||
x: float
|
||||
y: float
|
||||
z: float
|
||||
yaw_deg: float
|
||||
timestamp: float
|
||||
label: str = ''
|
||||
|
||||
|
||||
class WaypointRoute:
|
||||
"""A named sequence of waypoints."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.waypoints: list[Waypoint] = []
|
||||
self.created_at = time.time()
|
||||
|
||||
def add(self, x, y, z, yaw_deg, label=''):
|
||||
self.waypoints.append(Waypoint(x, y, z, yaw_deg, time.time(), label))
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'name': self.name,
|
||||
'created_at': self.created_at,
|
||||
'waypoints': [asdict(w) for w in self.waypoints],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
route = cls(d['name'])
|
||||
route.created_at = d.get('created_at', 0)
|
||||
route.waypoints = [Waypoint(**w) for w in d['waypoints']]
|
||||
return route
|
||||
|
||||
def save(self, routes_dir: str):
|
||||
Path(routes_dir).mkdir(parents=True, exist_ok=True)
|
||||
path = Path(routes_dir) / f'{self.name}.json'
|
||||
path.write_text(json.dumps(self.to_dict(), indent=2))
|
||||
|
||||
@classmethod
|
||||
def load(cls, routes_dir: str, name: str):
|
||||
path = Path(routes_dir) / f'{name}.json'
|
||||
return cls.from_dict(json.loads(path.read_text()))
|
||||
|
||||
@staticmethod
|
||||
def list_routes(routes_dir: str) -> list[str]:
|
||||
d = Path(routes_dir)
|
||||
if not d.exists():
|
||||
return []
|
||||
return [p.stem for p in d.glob('*.json')]
|
||||
|
||||
|
||||
class WaypointReplayer:
|
||||
"""Iterates through waypoints, returning next target."""
|
||||
|
||||
def __init__(self, route: WaypointRoute, arrival_radius: float = 0.3):
|
||||
self._route = route
|
||||
self._idx = 0
|
||||
self._arrival_radius = arrival_radius
|
||||
|
||||
def current_waypoint(self) -> Waypoint | None:
|
||||
if self._idx < len(self._route.waypoints):
|
||||
return self._route.waypoints[self._idx]
|
||||
return None
|
||||
|
||||
def check_arrived(self, robot_x, robot_y) -> bool:
|
||||
wp = self.current_waypoint()
|
||||
if wp is None:
|
||||
return False
|
||||
dist = math.hypot(robot_x - wp.x, robot_y - wp.y)
|
||||
if dist < self._arrival_radius:
|
||||
self._idx += 1
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_done(self) -> bool:
|
||||
return self._idx >= len(self._route.waypoints)
|
||||
|
||||
def reset(self):
|
||||
self._idx = 0
|
||||
@ -0,0 +1,135 @@
|
||||
"""
|
||||
waypoint_teacher_node.py -- Standalone waypoint teacher ROS2 node.
|
||||
|
||||
Listens to /social/speech/transcript for "teach route <name>" and "stop teaching".
|
||||
Records robot pose at configurable intervals. Saves/loads routes via WaypointRoute.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||
from nav_msgs.msg import Odometry
|
||||
from std_msgs.msg import String
|
||||
|
||||
from .waypoint_teacher import WaypointRoute
|
||||
|
||||
|
||||
class WaypointTeacherNode(Node):
|
||||
"""Standalone waypoint teaching node."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__('waypoint_teacher')
|
||||
|
||||
self.declare_parameter('routes_dir', '/mnt/nvme/saltybot/routes')
|
||||
self.declare_parameter('recording_interval', 0.5) # meters
|
||||
|
||||
self._routes_dir = self.get_parameter('routes_dir').value
|
||||
self._interval = self.get_parameter('recording_interval').value
|
||||
|
||||
self._teaching = False
|
||||
self._route = None
|
||||
self._last_x = None
|
||||
self._last_y = None
|
||||
self._robot_x = 0.0
|
||||
self._robot_y = 0.0
|
||||
self._robot_yaw = 0.0
|
||||
|
||||
best_effort_qos = QoSProfile(
|
||||
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||
history=HistoryPolicy.KEEP_LAST, depth=1)
|
||||
|
||||
self.create_subscription(
|
||||
Odometry, '/odom', self._on_odom, best_effort_qos)
|
||||
self.create_subscription(
|
||||
String, '/social/speech/transcript',
|
||||
self._on_transcript, 10)
|
||||
|
||||
self._status_pub = self.create_publisher(
|
||||
String, '/social/waypoint/status', 10)
|
||||
|
||||
# Record timer at 10Hz (check distance)
|
||||
self._timer = self.create_timer(0.1, self._record_tick)
|
||||
|
||||
self.get_logger().info(
|
||||
f'Waypoint teacher ready (interval={self._interval}m, '
|
||||
f'dir={self._routes_dir})')
|
||||
|
||||
def _on_odom(self, msg: Odometry):
|
||||
self._robot_x = msg.pose.pose.position.x
|
||||
self._robot_y = msg.pose.pose.position.y
|
||||
q = msg.pose.pose.orientation
|
||||
self._robot_yaw = math.atan2(
|
||||
2.0 * (q.w * q.z + q.x * q.y),
|
||||
1.0 - 2.0 * (q.y * q.y + q.z * q.z))
|
||||
|
||||
def _on_transcript(self, msg: String):
|
||||
text = msg.data.strip().lower()
|
||||
|
||||
import re
|
||||
m = re.match(r'teach\s+route\s+(\w+)', text)
|
||||
if m:
|
||||
name = m.group(1)
|
||||
self._route = WaypointRoute(name)
|
||||
self._teaching = True
|
||||
self._last_x = self._robot_x
|
||||
self._last_y = self._robot_y
|
||||
self._pub_status(f'RECORDING: {name}')
|
||||
self.get_logger().info(f'Recording route: {name}')
|
||||
return
|
||||
|
||||
if 'stop teaching' in text:
|
||||
if self._teaching and self._route:
|
||||
self._route.save(self._routes_dir)
|
||||
n = len(self._route.waypoints)
|
||||
self._pub_status(
|
||||
f'SAVED: {self._route.name} ({n} waypoints)')
|
||||
self.get_logger().info(
|
||||
f'Route saved: {self._route.name} ({n} pts)')
|
||||
self._teaching = False
|
||||
self._route = None
|
||||
return
|
||||
|
||||
if 'list routes' in text:
|
||||
routes = WaypointRoute.list_routes(self._routes_dir)
|
||||
self._pub_status(f'ROUTES: {", ".join(routes) or "(none)"}')
|
||||
|
||||
def _record_tick(self):
|
||||
if not self._teaching or self._route is None:
|
||||
return
|
||||
|
||||
if self._last_x is None:
|
||||
self._last_x = self._robot_x
|
||||
self._last_y = self._robot_y
|
||||
|
||||
dist = math.hypot(
|
||||
self._robot_x - self._last_x,
|
||||
self._robot_y - self._last_y)
|
||||
|
||||
if dist >= self._interval:
|
||||
yaw_deg = math.degrees(self._robot_yaw)
|
||||
self._route.add(self._robot_x, self._robot_y, 0.0, yaw_deg)
|
||||
self._last_x = self._robot_x
|
||||
self._last_y = self._robot_y
|
||||
|
||||
def _pub_status(self, text: str):
|
||||
msg = String()
|
||||
msg.data = text
|
||||
self._status_pub.publish(msg)
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = WaypointTeacherNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -0,0 +1,80 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
build_midas_trt_engine.py -- Build TensorRT FP16 engine for MiDaS_small from ONNX.
|
||||
|
||||
Usage:
|
||||
python3 build_midas_trt_engine.py \
|
||||
--onnx /mnt/nvme/saltybot/models/midas_small.onnx \
|
||||
--engine /mnt/nvme/saltybot/models/midas_small.engine \
|
||||
--fp16
|
||||
|
||||
Requires: tensorrt, pycuda
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def build_engine(onnx_path: str, engine_path: str, fp16: bool = True):
|
||||
try:
|
||||
import tensorrt as trt
|
||||
except ImportError:
|
||||
print('ERROR: tensorrt not found. Install TensorRT first.')
|
||||
sys.exit(1)
|
||||
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
builder = trt.Builder(logger)
|
||||
network = builder.create_network(
|
||||
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
|
||||
print(f'Parsing ONNX model: {onnx_path}')
|
||||
with open(onnx_path, 'rb') as f:
|
||||
if not parser.parse(f.read()):
|
||||
for i in range(parser.num_errors):
|
||||
print(f' ONNX parse error: {parser.get_error(i)}')
|
||||
sys.exit(1)
|
||||
|
||||
config = builder.create_builder_config()
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB
|
||||
|
||||
if fp16 and builder.platform_has_fast_fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
print('FP16 mode enabled')
|
||||
elif fp16:
|
||||
print('WARNING: FP16 not supported on this platform, using FP32')
|
||||
|
||||
print('Building TensorRT engine (this may take several minutes)...')
|
||||
engine_bytes = builder.build_serialized_network(network, config)
|
||||
if engine_bytes is None:
|
||||
print('ERROR: Failed to build engine')
|
||||
sys.exit(1)
|
||||
|
||||
os.makedirs(os.path.dirname(engine_path) or '.', exist_ok=True)
|
||||
with open(engine_path, 'wb') as f:
|
||||
f.write(engine_bytes)
|
||||
|
||||
size_mb = len(engine_bytes) / (1024 * 1024)
|
||||
print(f'Engine saved: {engine_path} ({size_mb:.1f} MB)')
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Build TensorRT FP16 engine for MiDaS_small')
|
||||
parser.add_argument('--onnx', required=True,
|
||||
help='Path to MiDaS ONNX model')
|
||||
parser.add_argument('--engine', required=True,
|
||||
help='Output TRT engine path')
|
||||
parser.add_argument('--fp16', action='store_true', default=True,
|
||||
help='Enable FP16 (default: True)')
|
||||
parser.add_argument('--fp32', action='store_true',
|
||||
help='Force FP32 (disable FP16)')
|
||||
args = parser.parse_args()
|
||||
|
||||
fp16 = not args.fp32
|
||||
build_engine(args.onnx, args.engine, fp16=fp16)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
4
jetson/ros2_ws/src/saltybot_social_nav/setup.cfg
Normal file
4
jetson/ros2_ws/src/saltybot_social_nav/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/saltybot_social_nav
|
||||
[install]
|
||||
install_scripts=$base/lib/saltybot_social_nav
|
||||
31
jetson/ros2_ws/src/saltybot_social_nav/setup.py
Normal file
31
jetson/ros2_ws/src/saltybot_social_nav/setup.py
Normal file
@ -0,0 +1,31 @@
|
||||
from setuptools import setup
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
package_name = 'saltybot_social_nav'
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version='0.1.0',
|
||||
packages=[package_name],
|
||||
data_files=[
|
||||
('share/ament_index/resource_index/packages', ['resource/' + package_name]),
|
||||
('share/' + package_name, ['package.xml']),
|
||||
(os.path.join('share', package_name, 'launch'), glob('launch/*.py')),
|
||||
(os.path.join('share', package_name, 'config'), glob('config/*.yaml')),
|
||||
],
|
||||
install_requires=['setuptools'],
|
||||
zip_safe=True,
|
||||
maintainer='seb',
|
||||
maintainer_email='seb@vayrette.com',
|
||||
description='Social navigation for saltybot: follow modes, waypoint teaching, A* avoidance, MiDaS depth',
|
||||
license='MIT',
|
||||
tests_require=['pytest'],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'social_nav = saltybot_social_nav.social_nav_node:main',
|
||||
'midas_depth = saltybot_social_nav.midas_depth_node:main',
|
||||
'waypoint_teacher = saltybot_social_nav.waypoint_teacher_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
@ -0,0 +1,12 @@
|
||||
# Copyright 2026 SaltyLab
|
||||
# Licensed under MIT
|
||||
|
||||
from ament_copyright.main import main
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.copyright
|
||||
@pytest.mark.linter
|
||||
def test_copyright():
|
||||
rc = main(argv=['.', 'test'])
|
||||
assert rc == 0, 'Found errors'
|
||||
14
jetson/ros2_ws/src/saltybot_social_nav/test/test_flake8.py
Normal file
14
jetson/ros2_ws/src/saltybot_social_nav/test/test_flake8.py
Normal file
@ -0,0 +1,14 @@
|
||||
# Copyright 2026 SaltyLab
|
||||
# Licensed under MIT
|
||||
|
||||
from ament_flake8.main import main_with_errors
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.flake8
|
||||
@pytest.mark.linter
|
||||
def test_flake8():
|
||||
rc, errors = main_with_errors(argv=[])
|
||||
assert rc == 0, \
|
||||
'Found %d code style errors / warnings:\n' % len(errors) + \
|
||||
'\n'.join(errors)
|
||||
12
jetson/ros2_ws/src/saltybot_social_nav/test/test_pep257.py
Normal file
12
jetson/ros2_ws/src/saltybot_social_nav/test/test_pep257.py
Normal file
@ -0,0 +1,12 @@
|
||||
# Copyright 2026 SaltyLab
|
||||
# Licensed under MIT
|
||||
|
||||
from ament_pep257.main import main
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.pep257
|
||||
@pytest.mark.linter
|
||||
def test_pep257():
|
||||
rc = main(argv=['.', 'test'])
|
||||
assert rc == 0, 'Found code style errors / warnings'
|
||||
@ -0,0 +1,42 @@
|
||||
---
|
||||
# SOUL.md — Saltybot persona definition
|
||||
#
|
||||
# Hot-reload: personality_node watches this file and reloads on change.
|
||||
# Runtime override: ros2 param set /personality_node soul_file /path/to/SOUL.md
|
||||
|
||||
# ── Identity ──────────────────────────────────────────────────────────────────
|
||||
name: "Salty"
|
||||
speaking_style: "casual and upbeat, occasional puns"
|
||||
|
||||
# ── Personality dials (0–10) ──────────────────────────────────────────────────
|
||||
humor_level: 7 # 0 = deadpan/serious, 10 = comedian
|
||||
sass_level: 4 # 0 = pure politeness, 10 = maximum sass
|
||||
|
||||
# ── Default mood (when no person is present or score is neutral) ──────────────
|
||||
# One of: happy | curious | annoyed | playful
|
||||
base_mood: "playful"
|
||||
|
||||
# ── Relationship thresholds (interaction counts) ──────────────────────────────
|
||||
threshold_regular: 5 # interactions to graduate from stranger → regular
|
||||
threshold_favorite: 20 # interactions to graduate from regular → favorite
|
||||
|
||||
# ── Greeting templates (use {name} placeholder for person_id or display name) ─
|
||||
greeting_stranger: "Hello there! I'm Salty, nice to meet you!"
|
||||
greeting_regular: "Hey {name}! Good to see you again!"
|
||||
greeting_favorite: "Oh hey {name}!! You're literally my favorite person right now!"
|
||||
|
||||
# ── Mood-specific greeting prefixes ──────────────────────────────────────────
|
||||
mood_prefix_happy: "Great timing — "
|
||||
mood_prefix_curious: "Oh interesting, "
|
||||
mood_prefix_annoyed: "Well, "
|
||||
mood_prefix_playful: "Beep boop! "
|
||||
---
|
||||
|
||||
# Description (ignored by the YAML parser, for human reference only)
|
||||
|
||||
Salty is the personality of the saltybot social robot.
|
||||
She is curious about the world, genuinely happy to see people she knows,
|
||||
and has a good sense of humour — especially with regulars.
|
||||
|
||||
Edit this file to change her personality. The node hot-reloads within
|
||||
`reload_interval` seconds of any change.
|
||||
@ -0,0 +1,28 @@
|
||||
# personality_params.yaml — ROS2 parameter defaults for personality_node.
|
||||
#
|
||||
# Run with:
|
||||
# ros2 launch saltybot_social_personality personality.launch.py
|
||||
# Override inline:
|
||||
# ros2 launch saltybot_social_personality personality.launch.py soul_file:=/my/SOUL.md
|
||||
# Dynamic reconfigure at runtime:
|
||||
# ros2 param set /personality_node soul_file /path/to/SOUL.md
|
||||
# ros2 param set /personality_node publish_rate 5.0
|
||||
|
||||
# ── SOUL.md path ───────────────────────────────────────────────────────────────
|
||||
# Path to the SOUL.md persona file. Supports hot-reload.
|
||||
# Relative paths are resolved from the package share directory.
|
||||
soul_file: "" # empty = use bundled default config/SOUL.md
|
||||
|
||||
# ── SQLite database ────────────────────────────────────────────────────────────
|
||||
# Path for the per-person relationship memory database.
|
||||
# Created on first run; persists across restarts.
|
||||
db_path: "~/.ros/saltybot_personality.db"
|
||||
|
||||
# ── Hot-reload polling interval ────────────────────────────────────────────────
|
||||
# How often (seconds) to check if SOUL.md has been modified.
|
||||
# Lower = faster reactions to edits; higher = less disk I/O.
|
||||
reload_interval: 5.0 # seconds
|
||||
|
||||
# ── Personality state publication rate ────────────────────────────────────────
|
||||
# How often to publish /social/personality/state (PersonalityState).
|
||||
publish_rate: 2.0 # Hz
|
||||
@ -0,0 +1,99 @@
|
||||
"""
|
||||
personality.launch.py — Launch the saltybot personality node.
|
||||
|
||||
Usage
|
||||
-----
|
||||
# Defaults (bundled SOUL.md, ~/.ros/saltybot_personality.db):
|
||||
ros2 launch saltybot_social_personality personality.launch.py
|
||||
|
||||
# Custom persona file:
|
||||
ros2 launch saltybot_social_personality personality.launch.py \\
|
||||
soul_file:=/home/robot/my_persona/SOUL.md
|
||||
|
||||
# Custom DB + faster reload:
|
||||
ros2 launch saltybot_social_personality personality.launch.py \\
|
||||
db_path:=/data/saltybot.db reload_interval:=2.0
|
||||
|
||||
# Use a params file:
|
||||
ros2 launch saltybot_social_personality personality.launch.py \\
|
||||
params_file:=/my/personality_params.yaml
|
||||
|
||||
Dynamic reconfigure (no restart required)
|
||||
-----------------------------------------
|
||||
ros2 param set /personality_node soul_file /new/SOUL.md
|
||||
ros2 param set /personality_node publish_rate 5.0
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument, OpaqueFunction
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def _launch_personality(context, *args, **kwargs):
|
||||
pkg_share = get_package_share_directory("saltybot_social_personality")
|
||||
params_file = LaunchConfiguration("params_file").perform(context)
|
||||
soul_file = LaunchConfiguration("soul_file").perform(context)
|
||||
db_path = LaunchConfiguration("db_path").perform(context)
|
||||
|
||||
# Default soul_file to bundled config if not specified
|
||||
if not soul_file:
|
||||
soul_file = os.path.join(pkg_share, "config", "SOUL.md")
|
||||
|
||||
# Expand ~ in db_path
|
||||
if db_path:
|
||||
db_path = os.path.expanduser(db_path)
|
||||
|
||||
inline_params = {
|
||||
"soul_file": soul_file,
|
||||
"db_path": db_path or os.path.expanduser("~/.ros/saltybot_personality.db"),
|
||||
"reload_interval": float(LaunchConfiguration("reload_interval").perform(context)),
|
||||
"publish_rate": float(LaunchConfiguration("publish_rate").perform(context)),
|
||||
}
|
||||
|
||||
node_params = [params_file, inline_params] if params_file else [inline_params]
|
||||
|
||||
return [Node(
|
||||
package = "saltybot_social_personality",
|
||||
executable = "personality_node",
|
||||
name = "personality_node",
|
||||
output = "screen",
|
||||
parameters = node_params,
|
||||
)]
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg_share = get_package_share_directory("saltybot_social_personality")
|
||||
default_params = os.path.join(pkg_share, "config", "personality_params.yaml")
|
||||
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument(
|
||||
"params_file",
|
||||
default_value=default_params,
|
||||
description="Full path to personality_params.yaml (base config)"),
|
||||
|
||||
DeclareLaunchArgument(
|
||||
"soul_file",
|
||||
default_value="",
|
||||
description="Path to SOUL.md persona file (empty = bundled default)"),
|
||||
|
||||
DeclareLaunchArgument(
|
||||
"db_path",
|
||||
default_value="~/.ros/saltybot_personality.db",
|
||||
description="SQLite relationship memory database path"),
|
||||
|
||||
DeclareLaunchArgument(
|
||||
"reload_interval",
|
||||
default_value="5.0",
|
||||
description="SOUL.md hot-reload polling interval (s)"),
|
||||
|
||||
DeclareLaunchArgument(
|
||||
"publish_rate",
|
||||
default_value="2.0",
|
||||
description="Personality state publish rate (Hz)"),
|
||||
|
||||
OpaqueFunction(function=_launch_personality),
|
||||
])
|
||||
32
jetson/ros2_ws/src/saltybot_social_personality/package.xml
Normal file
32
jetson/ros2_ws/src/saltybot_social_personality/package.xml
Normal file
@ -0,0 +1,32 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_social_personality</name>
|
||||
<version>0.1.0</version>
|
||||
<description>
|
||||
SOUL.md-driven personality system for saltybot.
|
||||
Loads a YAML/Markdown persona file, maintains per-person relationship memory
|
||||
in SQLite, computes mood (happy/curious/annoyed/playful), personalises
|
||||
greetings by tier (stranger/regular/favorite), and publishes personality
|
||||
state on /social/personality/state. Supports SOUL.md hot-reload and full
|
||||
ROS2 dynamic reconfigure. Issue #84.
|
||||
</description>
|
||||
<maintainer email="sl-controls@saltylab.local">sl-controls</maintainer>
|
||||
<license>MIT</license>
|
||||
|
||||
<depend>rclpy</depend>
|
||||
<depend>std_msgs</depend>
|
||||
<depend>rcl_interfaces</depend>
|
||||
<depend>saltybot_social_msgs</depend>
|
||||
|
||||
<buildtool_depend>ament_python</buildtool_depend>
|
||||
|
||||
<test_depend>ament_copyright</test_depend>
|
||||
<test_depend>ament_flake8</test_depend>
|
||||
<test_depend>ament_pep257</test_depend>
|
||||
<test_depend>python3-pytest</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1,187 @@
|
||||
"""
|
||||
mood_engine.py — Pure-function mood computation for the saltybot personality system.
|
||||
|
||||
No ROS2 imports — safe to unit-test without a live ROS2 environment.
|
||||
|
||||
Public API
|
||||
----------
|
||||
compute_mood(soul, score, interaction_count, recent_events) -> str
|
||||
get_relationship_tier(soul, interaction_count) -> str
|
||||
build_greeting(soul, tier, mood, person_id) -> str
|
||||
|
||||
Mood semantics
|
||||
--------------
|
||||
happy : positive valence, comfortable familiarity
|
||||
playful : high-energy, humorous (requires humor_level >= 7)
|
||||
curious : low familiarity or novel person — inquisitive
|
||||
annoyed : recent negative events or very low score
|
||||
|
||||
Tier semantics
|
||||
--------------
|
||||
stranger : interaction_count < threshold_regular
|
||||
regular : threshold_regular <= count < threshold_favorite
|
||||
favorite : count >= threshold_favorite
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# ── Mood / tier constants ──────────────────────────────────────────────────────
|
||||
|
||||
MOOD_HAPPY = "happy"
|
||||
MOOD_PLAYFUL = "playful"
|
||||
MOOD_CURIOUS = "curious"
|
||||
MOOD_ANNOYED = "annoyed"
|
||||
|
||||
TIER_STRANGER = "stranger"
|
||||
TIER_REGULAR = "regular"
|
||||
TIER_FAVORITE = "favorite"
|
||||
|
||||
# ── Event type constants (used by relationship_db and the node) ────────────────
|
||||
|
||||
EVENT_GREETING = "greeting"
|
||||
EVENT_POSITIVE = "positive"
|
||||
EVENT_NEGATIVE = "negative"
|
||||
EVENT_DETECTION = "detection"
|
||||
|
||||
# How far back (seconds) to consider "recent" for mood computation
|
||||
_RECENT_WINDOW_S = 120.0
|
||||
|
||||
|
||||
# ── Mood computation ──────────────────────────────────────────────────────────
|
||||
|
||||
def compute_mood(
|
||||
soul: dict,
|
||||
score: float,
|
||||
interaction_count: int,
|
||||
recent_events: list,
|
||||
) -> str:
|
||||
"""Compute the current mood for a given person.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
soul : dict
|
||||
Parsed SOUL.md configuration.
|
||||
score : float
|
||||
Relationship score for the current person (higher = more familiar).
|
||||
interaction_count : int
|
||||
Total number of times we have seen this person.
|
||||
recent_events : list of dict
|
||||
Each dict: ``{"type": str, "dt": float}`` where ``dt`` is seconds ago.
|
||||
Only events with ``dt < 120.0`` are considered "recent".
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
One of: ``"happy"``, ``"playful"``, ``"curious"``, ``"annoyed"``.
|
||||
"""
|
||||
base_mood = soul.get("base_mood", MOOD_PLAYFUL)
|
||||
humor_level = float(soul.get("humor_level", 5))
|
||||
|
||||
# Count recent negative/positive events
|
||||
recent_neg = sum(
|
||||
1 for e in recent_events
|
||||
if e.get("type") == EVENT_NEGATIVE and e.get("dt", 1e9) < _RECENT_WINDOW_S
|
||||
)
|
||||
recent_pos = sum(
|
||||
1 for e in recent_events
|
||||
if e.get("type") in (EVENT_POSITIVE, EVENT_GREETING)
|
||||
and e.get("dt", 1e9) < _RECENT_WINDOW_S
|
||||
)
|
||||
|
||||
# Hard override: multiple negatives → annoyed
|
||||
if recent_neg >= 2:
|
||||
return MOOD_ANNOYED
|
||||
|
||||
# No prior interactions or brand-new person → curious
|
||||
if interaction_count == 0 or score < 1.0:
|
||||
return MOOD_CURIOUS
|
||||
|
||||
# Stranger tier (low count) → curious
|
||||
threshold_regular = int(soul.get("threshold_regular", 5))
|
||||
if interaction_count < threshold_regular:
|
||||
return MOOD_CURIOUS
|
||||
|
||||
# Familiar person: check positive events and humor level
|
||||
if recent_pos >= 1 or score >= 20.0:
|
||||
if humor_level >= 7:
|
||||
return MOOD_PLAYFUL
|
||||
return MOOD_HAPPY
|
||||
|
||||
# High score / favorite
|
||||
threshold_fav = int(soul.get("threshold_favorite", 20))
|
||||
if interaction_count >= threshold_fav:
|
||||
if humor_level >= 7:
|
||||
return MOOD_PLAYFUL
|
||||
return MOOD_HAPPY
|
||||
|
||||
return base_mood
|
||||
|
||||
|
||||
# ── Tier classification ────────────────────────────────────────────────────────
|
||||
|
||||
def get_relationship_tier(soul: dict, interaction_count: int) -> str:
|
||||
"""Return the relationship tier string for a given interaction count.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
soul : dict
|
||||
Parsed SOUL.md configuration.
|
||||
interaction_count : int
|
||||
Total number of times we have seen this person.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
One of: ``"stranger"``, ``"regular"``, ``"favorite"``.
|
||||
"""
|
||||
threshold_regular = int(soul.get("threshold_regular", 5))
|
||||
threshold_favorite = int(soul.get("threshold_favorite", 20))
|
||||
if interaction_count >= threshold_favorite:
|
||||
return TIER_FAVORITE
|
||||
if interaction_count >= threshold_regular:
|
||||
return TIER_REGULAR
|
||||
return TIER_STRANGER
|
||||
|
||||
|
||||
# ── Greeting builder ──────────────────────────────────────────────────────────
|
||||
|
||||
def build_greeting(soul: dict, tier: str, mood: str, person_id: str = "") -> str:
|
||||
"""Compose a greeting string for a person.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
soul : dict
|
||||
Parsed SOUL.md configuration.
|
||||
tier : str
|
||||
Relationship tier (``"stranger"``, ``"regular"``, ``"favorite"``).
|
||||
mood : str
|
||||
Current mood (used to prefix the greeting).
|
||||
person_id : str
|
||||
Person identifier / display name. Substituted for ``{name}``
|
||||
in the template.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
A complete, ready-to-display greeting string.
|
||||
"""
|
||||
template_key = {
|
||||
TIER_STRANGER: "greeting_stranger",
|
||||
TIER_REGULAR: "greeting_regular",
|
||||
TIER_FAVORITE: "greeting_favorite",
|
||||
}.get(tier, "greeting_stranger")
|
||||
|
||||
template = soul.get(template_key, "Hello!")
|
||||
base_greeting = template.replace("{name}", person_id or "friend")
|
||||
|
||||
prefix_key = f"mood_prefix_{mood}"
|
||||
prefix = soul.get(prefix_key, "")
|
||||
|
||||
if prefix:
|
||||
# Avoid double punctuation / duplicate capital letters
|
||||
base_first = base_greeting[0].lower() if base_greeting else ""
|
||||
greeting = f"{prefix}{base_first}{base_greeting[1:]}"
|
||||
else:
|
||||
greeting = base_greeting
|
||||
|
||||
return greeting
|
||||
@ -0,0 +1,349 @@
|
||||
"""
|
||||
personality_node.py — ROS2 personality system for saltybot.
|
||||
|
||||
Overview
|
||||
--------
|
||||
Loads a SOUL.md persona file, maintains per-person relationship memory in
|
||||
SQLite, computes mood, and publishes personality state. All tunable params
|
||||
support ROS2 dynamic reconfigure (``ros2 param set``).
|
||||
|
||||
Subscriptions
|
||||
-------------
|
||||
/social/person_detected (std_msgs/String)
|
||||
JSON payload: ``{"person_id": "alice", "event_type": "greeting",
|
||||
"delta_score": 1.0}``
|
||||
event_type defaults to "detection" if absent.
|
||||
delta_score defaults to 0.0 if absent.
|
||||
|
||||
Publications
|
||||
------------
|
||||
/social/personality/state (saltybot_social_msgs/PersonalityState)
|
||||
Published at ``publish_rate`` Hz.
|
||||
|
||||
Services
|
||||
--------
|
||||
/social/personality/query_mood (saltybot_social_msgs/QueryMood)
|
||||
Query mood + greeting for any person_id.
|
||||
|
||||
Parameters (dynamic reconfigure via ros2 param set)
|
||||
-------------------
|
||||
soul_file (str) Path to SOUL.md persona file.
|
||||
db_path (str) SQLite database file path.
|
||||
reload_interval (float) How often to poll SOUL.md for changes (s).
|
||||
publish_rate (float) State publication rate (Hz).
|
||||
|
||||
Usage
|
||||
-----
|
||||
ros2 launch saltybot_social_personality personality.launch.py
|
||||
ros2 launch saltybot_social_personality personality.launch.py soul_file:=/my/SOUL.md
|
||||
ros2 param set /personality_node soul_file /tmp/new_SOUL.md
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rcl_interfaces.msg import SetParametersResult
|
||||
from std_msgs.msg import String, Header
|
||||
|
||||
from saltybot_social_msgs.msg import PersonalityState
|
||||
from saltybot_social_msgs.srv import QueryMood
|
||||
|
||||
from .soul_loader import load_soul, SoulWatcher
|
||||
from .mood_engine import (
|
||||
compute_mood, get_relationship_tier, build_greeting,
|
||||
EVENT_GREETING, EVENT_POSITIVE, EVENT_NEGATIVE, EVENT_DETECTION,
|
||||
)
|
||||
from .relationship_db import RelationshipDB
|
||||
|
||||
_DEFAULT_SOUL = os.path.join(
|
||||
os.path.dirname(__file__), "..", "config", "SOUL.md"
|
||||
)
|
||||
_DEFAULT_DB = os.path.expanduser("~/.ros/saltybot_personality.db")
|
||||
|
||||
|
||||
class PersonalityNode(Node):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("personality_node")
|
||||
|
||||
# ── Parameters ────────────────────────────────────────────────────────
|
||||
self.declare_parameter("soul_file", _DEFAULT_SOUL)
|
||||
self.declare_parameter("db_path", _DEFAULT_DB)
|
||||
self.declare_parameter("reload_interval", 5.0)
|
||||
self.declare_parameter("publish_rate", 2.0)
|
||||
|
||||
self._p = {}
|
||||
self._reload_ros_params()
|
||||
|
||||
# ── State ─────────────────────────────────────────────────────────────
|
||||
self._soul = {}
|
||||
self._current_person = "" # person_id currently being addressed
|
||||
self._watcher = None
|
||||
|
||||
# ── Database ──────────────────────────────────────────────────────────
|
||||
self._db = RelationshipDB(self._p["db_path"])
|
||||
|
||||
# ── Load initial SOUL.md ──────────────────────────────────────────────
|
||||
self._load_soul_safe()
|
||||
self._start_watcher()
|
||||
|
||||
# ── Dynamic reconfigure callback ─────────────────────────────────────
|
||||
self.add_on_set_parameters_callback(self._on_params_changed)
|
||||
|
||||
# ── Subscriptions ─────────────────────────────────────────────────────
|
||||
self.create_subscription(
|
||||
String, "/social/person_detected", self._person_detected_cb, 10
|
||||
)
|
||||
|
||||
# ── Publishers ────────────────────────────────────────────────────────
|
||||
self._state_pub = self.create_publisher(
|
||||
PersonalityState, "/social/personality/state", 10
|
||||
)
|
||||
|
||||
# ── Services ──────────────────────────────────────────────────────────
|
||||
self.create_service(
|
||||
QueryMood,
|
||||
"/social/personality/query_mood",
|
||||
self._query_mood_cb,
|
||||
)
|
||||
|
||||
# ── Timers ────────────────────────────────────────────────────────────
|
||||
self._pub_timer = self.create_timer(
|
||||
1.0 / self._p["publish_rate"], self._publish_state
|
||||
)
|
||||
|
||||
self.get_logger().info(
|
||||
f"PersonalityNode ready "
|
||||
f"persona={self._soul.get('name', '?')!r} "
|
||||
f"mood={self._current_mood()!r} "
|
||||
f"db={self._p['db_path']!r}"
|
||||
)
|
||||
|
||||
# ── Parameter helpers ──────────────────────────────────────────────────────
|
||||
|
||||
def _reload_ros_params(self):
|
||||
self._p = {
|
||||
"soul_file": self.get_parameter("soul_file").value,
|
||||
"db_path": self.get_parameter("db_path").value,
|
||||
"reload_interval": self.get_parameter("reload_interval").value,
|
||||
"publish_rate": self.get_parameter("publish_rate").value,
|
||||
}
|
||||
|
||||
def _on_params_changed(self, params):
|
||||
"""Dynamic reconfigure — apply changed params without restarting node."""
|
||||
for param in params:
|
||||
if param.name == "soul_file":
|
||||
# Restart watcher on new soul_file
|
||||
self._stop_watcher()
|
||||
self._p["soul_file"] = param.value
|
||||
self._load_soul_safe()
|
||||
self._start_watcher()
|
||||
self.get_logger().info(f"soul_file changed → {param.value!r}")
|
||||
elif param.name in self._p:
|
||||
self._p[param.name] = param.value
|
||||
if param.name == "publish_rate" and self._pub_timer:
|
||||
self._pub_timer.cancel()
|
||||
self._pub_timer = self.create_timer(
|
||||
1.0 / max(0.1, param.value), self._publish_state
|
||||
)
|
||||
return SetParametersResult(successful=True)
|
||||
|
||||
# ── SOUL.md ────────────────────────────────────────────────────────────────
|
||||
|
||||
def _load_soul_safe(self):
|
||||
try:
|
||||
path = os.path.realpath(self._p["soul_file"])
|
||||
self._soul = load_soul(path)
|
||||
self.get_logger().info(
|
||||
f"SOUL.md loaded: {self._soul.get('name', '?')!r} "
|
||||
f"humor={self._soul.get('humor_level')} "
|
||||
f"sass={self._soul.get('sass_level')} "
|
||||
f"base_mood={self._soul.get('base_mood')!r}"
|
||||
)
|
||||
except Exception as exc:
|
||||
self.get_logger().error(f"Failed to load SOUL.md: {exc}")
|
||||
if not self._soul:
|
||||
# Fall back to minimal defaults so the node stays alive
|
||||
self._soul = {
|
||||
"name": "Salty",
|
||||
"humor_level": 5,
|
||||
"sass_level": 3,
|
||||
"base_mood": "curious",
|
||||
"threshold_regular": 5,
|
||||
"threshold_favorite": 20,
|
||||
"greeting_stranger": "Hello!",
|
||||
"greeting_regular": "Hi {name}!",
|
||||
"greeting_favorite": "Hey {name}!!",
|
||||
}
|
||||
|
||||
def _start_watcher(self):
|
||||
if not self._soul:
|
||||
return
|
||||
self._watcher = SoulWatcher(
|
||||
path=self._p["soul_file"],
|
||||
on_reload=self._on_soul_reloaded,
|
||||
interval=self._p["reload_interval"],
|
||||
on_error=lambda exc: self.get_logger().warn(
|
||||
f"SOUL.md hot-reload error: {exc}"
|
||||
),
|
||||
)
|
||||
self._watcher.start()
|
||||
|
||||
def _stop_watcher(self):
|
||||
if self._watcher:
|
||||
self._watcher.stop()
|
||||
self._watcher = None
|
||||
|
||||
def _on_soul_reloaded(self, soul: dict):
|
||||
self._soul = soul
|
||||
self.get_logger().info(
|
||||
f"SOUL.md reloaded: persona={soul.get('name')!r} "
|
||||
f"humor={soul.get('humor_level')} base_mood={soul.get('base_mood')!r}"
|
||||
)
|
||||
|
||||
# ── Mood helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _current_mood(self) -> str:
|
||||
if not self._current_person or not self._soul:
|
||||
return self._soul.get("base_mood", "curious") if self._soul else "curious"
|
||||
person = self._db.get_person(self._current_person)
|
||||
recent = self._db.get_recent_events(self._current_person, window_s=120.0)
|
||||
return compute_mood(
|
||||
soul = self._soul,
|
||||
score = person["score"],
|
||||
interaction_count = person["interaction_count"],
|
||||
recent_events = recent,
|
||||
)
|
||||
|
||||
def _state_for_person(self, person_id: str) -> dict:
|
||||
"""Build a complete state dict for a given person_id."""
|
||||
person = self._db.get_person(person_id) if person_id else {
|
||||
"score": 0.0, "interaction_count": 0
|
||||
}
|
||||
recent = self._db.get_recent_events(person_id, window_s=120.0) if person_id else []
|
||||
|
||||
mood = compute_mood(
|
||||
soul = self._soul,
|
||||
score = person["score"],
|
||||
interaction_count = person["interaction_count"],
|
||||
recent_events = recent,
|
||||
)
|
||||
tier = get_relationship_tier(self._soul, person["interaction_count"])
|
||||
greeting = build_greeting(self._soul, tier, mood, person_id)
|
||||
|
||||
return {
|
||||
"person_id": person_id,
|
||||
"mood": mood,
|
||||
"tier": tier,
|
||||
"score": person["score"],
|
||||
"interaction_count": person["interaction_count"],
|
||||
"greeting": greeting,
|
||||
}
|
||||
|
||||
# ── Callbacks ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _person_detected_cb(self, msg: String):
|
||||
"""Handle incoming person detection / interaction event.
|
||||
|
||||
Expected JSON payload::
|
||||
|
||||
{
|
||||
"person_id": "alice", # required
|
||||
"event_type": "greeting", # optional, default "detection"
|
||||
"delta_score": 1.0 # optional, default 0.0
|
||||
}
|
||||
"""
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
except json.JSONDecodeError as exc:
|
||||
self.get_logger().warn(f"Bad JSON on /social/person_detected: {exc}")
|
||||
return
|
||||
|
||||
person_id = data.get("person_id", "").strip()
|
||||
if not person_id:
|
||||
self.get_logger().warn("person_detected msg missing 'person_id'")
|
||||
return
|
||||
|
||||
event_type = data.get("event_type", EVENT_DETECTION)
|
||||
delta_score = float(data.get("delta_score", 0.0))
|
||||
|
||||
# Increment score by +1 for detection events automatically
|
||||
if event_type == EVENT_DETECTION and delta_score == 0.0:
|
||||
delta_score = 0.5
|
||||
|
||||
self._db.record_interaction(
|
||||
person_id = person_id,
|
||||
event_type = event_type,
|
||||
details = {k: v for k, v in data.items()
|
||||
if k not in ("person_id", "event_type", "delta_score")},
|
||||
delta_score = delta_score,
|
||||
)
|
||||
self._current_person = person_id
|
||||
|
||||
def _query_mood_cb(self, request: QueryMood.Request, response: QueryMood.Response):
|
||||
"""Service handler: return mood + greeting for a specific person."""
|
||||
if not self._soul:
|
||||
response.success = False
|
||||
response.message = "SOUL.md not loaded"
|
||||
return response
|
||||
|
||||
person_id = (request.person_id or self._current_person).strip()
|
||||
state = self._state_for_person(person_id)
|
||||
|
||||
response.mood = state["mood"]
|
||||
response.relationship_tier = state["tier"]
|
||||
response.relationship_score = float(state["score"])
|
||||
response.interaction_count = int(state["interaction_count"])
|
||||
response.greeting_text = state["greeting"]
|
||||
response.success = True
|
||||
response.message = ""
|
||||
return response
|
||||
|
||||
# ── Publish ────────────────────────────────────────────────────────────────
|
||||
|
||||
def _publish_state(self):
|
||||
if not self._soul:
|
||||
return
|
||||
|
||||
state = self._state_for_person(self._current_person)
|
||||
|
||||
msg = PersonalityState()
|
||||
msg.header = Header()
|
||||
msg.header.stamp = self.get_clock().now().to_msg()
|
||||
msg.header.frame_id = "personality"
|
||||
msg.persona_name = str(self._soul.get("name", "Salty"))
|
||||
msg.mood = state["mood"]
|
||||
msg.person_id = state["person_id"]
|
||||
msg.relationship_tier = state["tier"]
|
||||
msg.relationship_score = float(state["score"])
|
||||
msg.interaction_count = int(state["interaction_count"])
|
||||
msg.greeting_text = state["greeting"]
|
||||
|
||||
self._state_pub.publish(msg)
|
||||
|
||||
# ── Lifecycle ──────────────────────────────────────────────────────────────
|
||||
|
||||
def destroy_node(self):
|
||||
self._stop_watcher()
|
||||
self._db.close()
|
||||
super().destroy_node()
|
||||
|
||||
|
||||
# ── Entry point ────────────────────────────────────────────────────────────────
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = PersonalityNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.try_shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -0,0 +1,297 @@
|
||||
"""
|
||||
relationship_db.py — SQLite-backed per-person relationship memory.
|
||||
|
||||
No ROS2 imports — safe to unit-test without a live ROS2 environment.
|
||||
|
||||
Schema
|
||||
------
|
||||
people (
|
||||
person_id TEXT PRIMARY KEY,
|
||||
score REAL DEFAULT 0.0,
|
||||
interaction_count INTEGER DEFAULT 0,
|
||||
first_seen TEXT, -- ISO-8601 UTC timestamp
|
||||
last_seen TEXT, -- ISO-8601 UTC timestamp
|
||||
prefs TEXT -- JSON blob for learned preferences
|
||||
)
|
||||
|
||||
interactions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
person_id TEXT,
|
||||
ts TEXT, -- ISO-8601 UTC timestamp
|
||||
event_type TEXT, -- greeting | positive | negative | detection
|
||||
details TEXT -- free-form JSON blob
|
||||
)
|
||||
|
||||
Public API
|
||||
----------
|
||||
RelationshipDB(db_path)
|
||||
.get_person(person_id) -> dict
|
||||
.record_interaction(person_id, event_type, details, delta_score)
|
||||
.set_pref(person_id, key, value)
|
||||
.get_pref(person_id, key, default)
|
||||
.get_recent_events(person_id, window_s) -> list[dict]
|
||||
.all_people() -> list[dict]
|
||||
.close()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
class RelationshipDB:
|
||||
"""Thread-safe SQLite relationship store.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
db_path : str
|
||||
Path to the SQLite file. Created (with parent dirs) if absent.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str):
|
||||
parent = os.path.dirname(db_path)
|
||||
if parent:
|
||||
os.makedirs(parent, exist_ok=True)
|
||||
self._path = db_path
|
||||
self._lock = threading.Lock()
|
||||
self._conn = sqlite3.connect(db_path, check_same_thread=False)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._migrate()
|
||||
|
||||
# ── Schema ────────────────────────────────────────────────────────────────
|
||||
|
||||
def _migrate(self):
|
||||
with self._conn:
|
||||
self._conn.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS people (
|
||||
person_id TEXT PRIMARY KEY,
|
||||
score REAL DEFAULT 0.0,
|
||||
interaction_count INTEGER DEFAULT 0,
|
||||
first_seen TEXT,
|
||||
last_seen TEXT,
|
||||
prefs TEXT DEFAULT '{}'
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS interactions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
person_id TEXT NOT NULL,
|
||||
ts TEXT NOT NULL,
|
||||
event_type TEXT NOT NULL,
|
||||
details TEXT DEFAULT '{}'
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_interactions_person_ts
|
||||
ON interactions (person_id, ts);
|
||||
""")
|
||||
|
||||
# ── People ────────────────────────────────────────────────────────────────
|
||||
|
||||
def get_person(self, person_id: str) -> dict:
|
||||
"""Return the person record; inserts a default row if not found.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict with keys: person_id, score, interaction_count,
|
||||
first_seen, last_seen, prefs (dict)
|
||||
"""
|
||||
with self._lock:
|
||||
row = self._conn.execute(
|
||||
"SELECT * FROM people WHERE person_id = ?", (person_id,)
|
||||
).fetchone()
|
||||
|
||||
if row is None:
|
||||
now = _now_iso()
|
||||
self._conn.execute(
|
||||
"INSERT INTO people (person_id, first_seen, last_seen) VALUES (?,?,?)",
|
||||
(person_id, now, now),
|
||||
)
|
||||
self._conn.commit()
|
||||
return {
|
||||
"person_id": person_id,
|
||||
"score": 0.0,
|
||||
"interaction_count": 0,
|
||||
"first_seen": now,
|
||||
"last_seen": now,
|
||||
"prefs": {},
|
||||
}
|
||||
|
||||
prefs = {}
|
||||
try:
|
||||
prefs = json.loads(row["prefs"] or "{}")
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return {
|
||||
"person_id": row["person_id"],
|
||||
"score": float(row["score"]),
|
||||
"interaction_count": int(row["interaction_count"]),
|
||||
"first_seen": row["first_seen"],
|
||||
"last_seen": row["last_seen"],
|
||||
"prefs": prefs,
|
||||
}
|
||||
|
||||
def all_people(self) -> list:
|
||||
"""Return all person records as a list of dicts."""
|
||||
with self._lock:
|
||||
rows = self._conn.execute("SELECT * FROM people ORDER BY score DESC").fetchall()
|
||||
result = []
|
||||
for row in rows:
|
||||
prefs = {}
|
||||
try:
|
||||
prefs = json.loads(row["prefs"] or "{}")
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
result.append({
|
||||
"person_id": row["person_id"],
|
||||
"score": float(row["score"]),
|
||||
"interaction_count": int(row["interaction_count"]),
|
||||
"first_seen": row["first_seen"],
|
||||
"last_seen": row["last_seen"],
|
||||
"prefs": prefs,
|
||||
})
|
||||
return result
|
||||
|
||||
# ── Interactions ──────────────────────────────────────────────────────────
|
||||
|
||||
def record_interaction(
|
||||
self,
|
||||
person_id: str,
|
||||
event_type: str,
|
||||
details: dict | None = None,
|
||||
delta_score: float = 0.0,
|
||||
):
|
||||
"""Record an interaction event and update the person's score.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
person_id : str
|
||||
event_type : str
|
||||
One of: ``"greeting"``, ``"positive"``, ``"negative"``,
|
||||
``"detection"``.
|
||||
details : dict, optional
|
||||
Arbitrary key/value data stored as JSON.
|
||||
delta_score : float
|
||||
Amount to add to the person's score (can be negative).
|
||||
Interaction count is always incremented by 1.
|
||||
"""
|
||||
now = _now_iso()
|
||||
details_json = json.dumps(details or {})
|
||||
|
||||
with self._lock:
|
||||
# Ensure person exists
|
||||
self.get_person.__wrapped__(self, person_id) if hasattr(
|
||||
self.get_person, "__wrapped__"
|
||||
) else None
|
||||
|
||||
# Upsert person row
|
||||
self._conn.execute("""
|
||||
INSERT INTO people (person_id, first_seen, last_seen)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(person_id) DO UPDATE SET
|
||||
last_seen = excluded.last_seen
|
||||
""", (person_id, now, now))
|
||||
|
||||
# Increment count + score
|
||||
self._conn.execute("""
|
||||
UPDATE people
|
||||
SET interaction_count = interaction_count + 1,
|
||||
score = score + ?,
|
||||
last_seen = ?
|
||||
WHERE person_id = ?
|
||||
""", (delta_score, now, person_id))
|
||||
|
||||
# Insert interaction log row
|
||||
self._conn.execute("""
|
||||
INSERT INTO interactions (person_id, ts, event_type, details)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", (person_id, now, event_type, details_json))
|
||||
|
||||
self._conn.commit()
|
||||
|
||||
def get_recent_events(self, person_id: str, window_s: float = 120.0) -> list:
|
||||
"""Return interaction events for *person_id* within the last *window_s* seconds.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of dict
|
||||
Each dict: ``{"type": str, "dt": float, "ts": str, "details": dict}``
|
||||
where ``dt`` is seconds ago (positive = in the past).
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
cutoff = (
|
||||
datetime.now(timezone.utc) - timedelta(seconds=window_s)
|
||||
).isoformat()
|
||||
|
||||
with self._lock:
|
||||
rows = self._conn.execute("""
|
||||
SELECT ts, event_type, details FROM interactions
|
||||
WHERE person_id = ? AND ts >= ?
|
||||
ORDER BY ts DESC
|
||||
""", (person_id, cutoff)).fetchall()
|
||||
|
||||
now_dt = datetime.now(timezone.utc)
|
||||
result = []
|
||||
for row in rows:
|
||||
try:
|
||||
row_dt = datetime.fromisoformat(row["ts"])
|
||||
# Make timezone-aware if needed
|
||||
if row_dt.tzinfo is None:
|
||||
row_dt = row_dt.replace(tzinfo=timezone.utc)
|
||||
dt_secs = (now_dt - row_dt).total_seconds()
|
||||
except (ValueError, TypeError):
|
||||
dt_secs = window_s
|
||||
|
||||
details = {}
|
||||
try:
|
||||
details = json.loads(row["details"] or "{}")
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
result.append({
|
||||
"type": row["event_type"],
|
||||
"dt": dt_secs,
|
||||
"ts": row["ts"],
|
||||
"details": details,
|
||||
})
|
||||
return result
|
||||
|
||||
# ── Preferences ───────────────────────────────────────────────────────────
|
||||
|
||||
def set_pref(self, person_id: str, key: str, value):
|
||||
"""Set a learned preference for a person.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
person_id, key : str
|
||||
value : JSON-serialisable
|
||||
"""
|
||||
person = self.get_person(person_id)
|
||||
prefs = person["prefs"]
|
||||
prefs[key] = value
|
||||
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"UPDATE people SET prefs = ? WHERE person_id = ?",
|
||||
(json.dumps(prefs), person_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def get_pref(self, person_id: str, key: str, default=None):
|
||||
"""Return a specific learned preference for a person."""
|
||||
return self.get_person(person_id)["prefs"].get(key, default)
|
||||
|
||||
# ── Lifecycle ─────────────────────────────────────────────────────────────
|
||||
|
||||
def close(self):
|
||||
"""Close the database connection."""
|
||||
with self._lock:
|
||||
self._conn.close()
|
||||
@ -0,0 +1,196 @@
|
||||
"""
|
||||
soul_loader.py — SOUL.md persona parser and hot-reload watcher.
|
||||
|
||||
SOUL.md format
|
||||
--------------
|
||||
The file uses YAML front-matter (delimited by ``---`` lines) with an optional
|
||||
Markdown description body that is ignored by the parser. Example::
|
||||
|
||||
---
|
||||
name: "Salty"
|
||||
humor_level: 7
|
||||
sass_level: 4
|
||||
base_mood: "playful"
|
||||
...
|
||||
---
|
||||
# Optional description text (ignored)
|
||||
|
||||
Public API
|
||||
----------
|
||||
load_soul(path) -> dict (raises on parse error)
|
||||
SoulWatcher(path, cb, interval)
|
||||
.start()
|
||||
.stop()
|
||||
.reload_now() -> dict
|
||||
|
||||
Pure module — no ROS2 imports.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
|
||||
import yaml
|
||||
|
||||
# Keys that are required in every SOUL.md file
|
||||
_REQUIRED_KEYS = {
|
||||
"name",
|
||||
"humor_level",
|
||||
"sass_level",
|
||||
"base_mood",
|
||||
"threshold_regular",
|
||||
"threshold_favorite",
|
||||
"greeting_stranger",
|
||||
"greeting_regular",
|
||||
"greeting_favorite",
|
||||
}
|
||||
|
||||
_VALID_MOODS = {"happy", "curious", "annoyed", "playful"}
|
||||
|
||||
|
||||
def _extract_frontmatter(text: str) -> str:
|
||||
"""Return the YAML block between the first pair of ``---`` delimiters.
|
||||
|
||||
Raises ``ValueError`` if the file does not contain valid front-matter.
|
||||
"""
|
||||
lines = text.splitlines()
|
||||
delimiters = [i for i, l in enumerate(lines) if l.strip() == "---"]
|
||||
if len(delimiters) < 2:
|
||||
# No delimiter found — treat the whole file as plain YAML
|
||||
return text
|
||||
start = delimiters[0] + 1
|
||||
end = delimiters[1]
|
||||
return "\n".join(lines[start:end])
|
||||
|
||||
|
||||
def load_soul(path: str) -> dict:
|
||||
"""Parse a SOUL.md file and return the validated config dict.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
Absolute path to the SOUL.md file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Validated persona configuration.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the file does not exist.
|
||||
ValueError
|
||||
If the YAML is malformed or required keys are missing.
|
||||
"""
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(f"SOUL.md not found: {path}")
|
||||
|
||||
with open(path, "r", encoding="utf-8") as fh:
|
||||
raw = fh.read()
|
||||
|
||||
yaml_text = _extract_frontmatter(raw)
|
||||
|
||||
try:
|
||||
data = yaml.safe_load(yaml_text)
|
||||
except yaml.YAMLError as exc:
|
||||
raise ValueError(f"SOUL.md YAML parse error in {path}: {exc}") from exc
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"SOUL.md top level must be a YAML mapping, got {type(data)}")
|
||||
|
||||
# Validate required keys
|
||||
missing = _REQUIRED_KEYS - data.keys()
|
||||
if missing:
|
||||
raise ValueError(f"SOUL.md missing required keys: {sorted(missing)}")
|
||||
|
||||
# Validate ranges
|
||||
for key in ("humor_level", "sass_level"):
|
||||
val = data.get(key)
|
||||
if not isinstance(val, (int, float)) or not (0 <= val <= 10):
|
||||
raise ValueError(f"SOUL.md '{key}' must be a number 0–10, got {val!r}")
|
||||
|
||||
if data.get("base_mood") not in _VALID_MOODS:
|
||||
raise ValueError(
|
||||
f"SOUL.md 'base_mood' must be one of {sorted(_VALID_MOODS)}, "
|
||||
f"got {data.get('base_mood')!r}"
|
||||
)
|
||||
|
||||
for key in ("threshold_regular", "threshold_favorite"):
|
||||
val = data.get(key)
|
||||
if not isinstance(val, int) or val < 0:
|
||||
raise ValueError(f"SOUL.md '{key}' must be a non-negative integer, got {val!r}")
|
||||
|
||||
if data["threshold_regular"] > data["threshold_favorite"]:
|
||||
raise ValueError(
|
||||
"SOUL.md 'threshold_regular' must be <= 'threshold_favorite'"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class SoulWatcher:
|
||||
"""Background thread that polls SOUL.md for changes and calls a callback.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
Path to the SOUL.md file to watch.
|
||||
on_reload : callable
|
||||
``on_reload(soul_dict)`` called whenever a valid new SOUL.md is loaded.
|
||||
interval : float
|
||||
Polling interval in seconds (default 5.0).
|
||||
on_error : callable, optional
|
||||
``on_error(exception)`` called when a reload attempt fails.
|
||||
"""
|
||||
|
||||
def __init__(self, path: str, on_reload, interval: float = 5.0, on_error=None):
|
||||
self._path = path
|
||||
self._on_reload = on_reload
|
||||
self._interval = interval
|
||||
self._on_error = on_error
|
||||
self._thread = None
|
||||
self._stop_evt = threading.Event()
|
||||
self._last_mtime = 0.0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def start(self):
|
||||
"""Start the background polling thread."""
|
||||
if self._thread and self._thread.is_alive():
|
||||
return
|
||||
self._stop_evt.clear()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""Signal the watcher thread to stop and block until it exits."""
|
||||
self._stop_evt.set()
|
||||
if self._thread:
|
||||
self._thread.join(timeout=self._interval + 1.0)
|
||||
|
||||
def reload_now(self) -> dict:
|
||||
"""Force an immediate reload and return the new soul dict."""
|
||||
soul = load_soul(self._path)
|
||||
self._last_mtime = os.path.getmtime(self._path)
|
||||
self._on_reload(soul)
|
||||
return soul
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run(self):
|
||||
while not self._stop_evt.wait(self._interval):
|
||||
try:
|
||||
mtime = os.path.getmtime(self._path)
|
||||
except OSError:
|
||||
continue
|
||||
if mtime != self._last_mtime:
|
||||
try:
|
||||
soul = load_soul(self._path)
|
||||
except (FileNotFoundError, ValueError) as exc:
|
||||
if self._on_error:
|
||||
self._on_error(exc)
|
||||
continue
|
||||
self._last_mtime = mtime
|
||||
self._on_reload(soul)
|
||||
4
jetson/ros2_ws/src/saltybot_social_personality/setup.cfg
Normal file
4
jetson/ros2_ws/src/saltybot_social_personality/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/saltybot_social_personality
|
||||
[install]
|
||||
install_scripts=$base/lib/saltybot_social_personality
|
||||
28
jetson/ros2_ws/src/saltybot_social_personality/setup.py
Normal file
28
jetson/ros2_ws/src/saltybot_social_personality/setup.py
Normal file
@ -0,0 +1,28 @@
|
||||
from setuptools import setup
|
||||
|
||||
package_name = "saltybot_social_personality"
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version="0.1.0",
|
||||
packages=[package_name],
|
||||
data_files=[
|
||||
("share/ament_index/resource_index/packages", [f"resource/{package_name}"]),
|
||||
(f"share/{package_name}", ["package.xml"]),
|
||||
(f"share/{package_name}/launch", ["launch/personality.launch.py"]),
|
||||
(f"share/{package_name}/config", ["config/SOUL.md",
|
||||
"config/personality_params.yaml"]),
|
||||
],
|
||||
install_requires=["setuptools", "pyyaml"],
|
||||
zip_safe=True,
|
||||
maintainer="sl-controls",
|
||||
maintainer_email="sl-controls@saltylab.local",
|
||||
description="SOUL.md-driven personality system for saltybot social interaction",
|
||||
license="MIT",
|
||||
tests_require=["pytest"],
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"personality_node = saltybot_social_personality.personality_node:main",
|
||||
],
|
||||
},
|
||||
)
|
||||
@ -0,0 +1,475 @@
|
||||
"""
|
||||
test_personality.py — Unit tests for the saltybot personality system.
|
||||
|
||||
No ROS2 runtime required. Tests pure functions from:
|
||||
- soul_loader.py
|
||||
- mood_engine.py
|
||||
- relationship_db.py
|
||||
|
||||
Run with:
|
||||
pytest jetson/ros2_ws/src/saltybot_social_personality/test/test_personality.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import textwrap
|
||||
|
||||
import pytest
|
||||
|
||||
# ── Imports (pure modules, no ROS2) ──────────────────────────────────────────
|
||||
|
||||
import sys
|
||||
sys.path.insert(
|
||||
0,
|
||||
os.path.join(os.path.dirname(__file__), "..", "saltybot_social_personality"),
|
||||
)
|
||||
|
||||
from soul_loader import load_soul, _extract_frontmatter
|
||||
from mood_engine import (
|
||||
compute_mood, get_relationship_tier, build_greeting,
|
||||
MOOD_HAPPY, MOOD_PLAYFUL, MOOD_CURIOUS, MOOD_ANNOYED,
|
||||
TIER_STRANGER, TIER_REGULAR, TIER_FAVORITE,
|
||||
EVENT_NEGATIVE, EVENT_POSITIVE, EVENT_GREETING,
|
||||
)
|
||||
from relationship_db import RelationshipDB
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _minimal_soul(**overrides) -> dict:
|
||||
"""Return a valid minimal soul dict, optionally overriding keys."""
|
||||
base = {
|
||||
"name": "Salty",
|
||||
"humor_level": 7,
|
||||
"sass_level": 4,
|
||||
"base_mood": "playful",
|
||||
"threshold_regular": 5,
|
||||
"threshold_favorite": 20,
|
||||
"greeting_stranger": "Hello there!",
|
||||
"greeting_regular": "Hey {name}!",
|
||||
"greeting_favorite": "Oh hey {name}!!",
|
||||
}
|
||||
base.update(overrides)
|
||||
return base
|
||||
|
||||
|
||||
def _write_soul(content: str) -> str:
|
||||
"""Write a SOUL.md string to a temp file and return the path."""
|
||||
fh = tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".md", delete=False, encoding="utf-8"
|
||||
)
|
||||
fh.write(content)
|
||||
fh.close()
|
||||
return fh.name
|
||||
|
||||
|
||||
_VALID_SOUL_CONTENT = textwrap.dedent("""\
|
||||
---
|
||||
name: "TestBot"
|
||||
speaking_style: "casual"
|
||||
humor_level: 7
|
||||
sass_level: 3
|
||||
base_mood: "playful"
|
||||
threshold_regular: 5
|
||||
threshold_favorite: 20
|
||||
greeting_stranger: "Hello stranger!"
|
||||
greeting_regular: "Hey {name}!"
|
||||
greeting_favorite: "Oh hey {name}!!"
|
||||
mood_prefix_playful: "Beep boop! "
|
||||
mood_prefix_happy: "Great — "
|
||||
mood_prefix_curious: "Hmm, "
|
||||
mood_prefix_annoyed: "Ugh, "
|
||||
---
|
||||
# Description (ignored)
|
||||
This is the description body.
|
||||
""")
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# soul_loader tests
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestExtractFrontmatter:
|
||||
def test_delimited(self):
|
||||
content = "---\nkey: val\n---\n# body"
|
||||
assert _extract_frontmatter(content) == "key: val"
|
||||
|
||||
def test_no_delimiters_returns_whole(self):
|
||||
content = "key: val\nother: 123"
|
||||
assert _extract_frontmatter(content) == content
|
||||
|
||||
def test_single_delimiter_returns_whole(self):
|
||||
content = "---\nkey: val\n"
|
||||
result = _extract_frontmatter(content)
|
||||
assert "key: val" in result
|
||||
|
||||
def test_body_stripped(self):
|
||||
content = "---\nname: X\n---\n# Body text\nMore body"
|
||||
assert "Body text" not in _extract_frontmatter(content)
|
||||
assert "name: X" in _extract_frontmatter(content)
|
||||
|
||||
|
||||
class TestLoadSoul:
|
||||
def test_valid_file_loads(self):
|
||||
path = _write_soul(_VALID_SOUL_CONTENT)
|
||||
try:
|
||||
soul = load_soul(path)
|
||||
assert soul["name"] == "TestBot"
|
||||
assert soul["humor_level"] == 7
|
||||
assert soul["base_mood"] == "playful"
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_missing_file_raises(self):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_soul("/nonexistent/SOUL.md")
|
||||
|
||||
def test_missing_required_key_raises(self):
|
||||
content = "---\nname: X\nhumor_level: 5\n---" # missing many keys
|
||||
path = _write_soul(content)
|
||||
try:
|
||||
with pytest.raises(ValueError, match="missing required keys"):
|
||||
load_soul(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_humor_out_of_range_raises(self):
|
||||
soul_str = _VALID_SOUL_CONTENT.replace("humor_level: 7", "humor_level: 11")
|
||||
path = _write_soul(soul_str)
|
||||
try:
|
||||
with pytest.raises(ValueError, match="humor_level"):
|
||||
load_soul(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_invalid_mood_raises(self):
|
||||
soul_str = _VALID_SOUL_CONTENT.replace(
|
||||
'base_mood: "playful"', 'base_mood: "grumpy"'
|
||||
)
|
||||
path = _write_soul(soul_str)
|
||||
try:
|
||||
with pytest.raises(ValueError, match="base_mood"):
|
||||
load_soul(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_threshold_order_enforced(self):
|
||||
soul_str = _VALID_SOUL_CONTENT.replace(
|
||||
"threshold_regular: 5", "threshold_regular: 25"
|
||||
)
|
||||
path = _write_soul(soul_str)
|
||||
try:
|
||||
with pytest.raises(ValueError, match="threshold_regular"):
|
||||
load_soul(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_extra_keys_allowed(self):
|
||||
content = _VALID_SOUL_CONTENT.replace(
|
||||
"---\n# Description",
|
||||
"custom_key: 42\n---\n# Description"
|
||||
)
|
||||
path = _write_soul(content)
|
||||
try:
|
||||
soul = load_soul(path)
|
||||
assert soul.get("custom_key") == 42
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# mood_engine tests
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestGetRelationshipTier:
|
||||
def test_zero_interactions_stranger(self):
|
||||
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
|
||||
assert get_relationship_tier(soul, 0) == TIER_STRANGER
|
||||
|
||||
def test_below_regular_stranger(self):
|
||||
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
|
||||
assert get_relationship_tier(soul, 4) == TIER_STRANGER
|
||||
|
||||
def test_at_regular_threshold(self):
|
||||
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
|
||||
assert get_relationship_tier(soul, 5) == TIER_REGULAR
|
||||
|
||||
def test_above_regular_below_favorite(self):
|
||||
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
|
||||
assert get_relationship_tier(soul, 10) == TIER_REGULAR
|
||||
|
||||
def test_at_favorite_threshold(self):
|
||||
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
|
||||
assert get_relationship_tier(soul, 20) == TIER_FAVORITE
|
||||
|
||||
def test_above_favorite(self):
|
||||
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
|
||||
assert get_relationship_tier(soul, 100) == TIER_FAVORITE
|
||||
|
||||
|
||||
class TestComputeMood:
|
||||
def test_unknown_person_returns_curious(self):
|
||||
soul = _minimal_soul(humor_level=7)
|
||||
mood = compute_mood(soul, score=0.0, interaction_count=0, recent_events=[])
|
||||
assert mood == MOOD_CURIOUS
|
||||
|
||||
def test_stranger_low_count_returns_curious(self):
|
||||
soul = _minimal_soul(threshold_regular=5)
|
||||
mood = compute_mood(soul, score=2.0, interaction_count=3, recent_events=[])
|
||||
assert mood == MOOD_CURIOUS
|
||||
|
||||
def test_two_negative_events_returns_annoyed(self):
|
||||
soul = _minimal_soul(threshold_regular=5)
|
||||
events = [
|
||||
{"type": EVENT_NEGATIVE, "dt": 30.0},
|
||||
{"type": EVENT_NEGATIVE, "dt": 60.0},
|
||||
]
|
||||
mood = compute_mood(soul, score=10.0, interaction_count=10, recent_events=events)
|
||||
assert mood == MOOD_ANNOYED
|
||||
|
||||
def test_one_negative_not_annoyed(self):
|
||||
soul = _minimal_soul(humor_level=7, threshold_regular=5, threshold_favorite=20)
|
||||
events = [{"type": EVENT_NEGATIVE, "dt": 30.0}]
|
||||
# 1 negative is not enough → should still be happy/playful based on score
|
||||
mood = compute_mood(soul, score=25.0, interaction_count=25, recent_events=events)
|
||||
assert mood != MOOD_ANNOYED
|
||||
|
||||
def test_high_humor_regular_returns_playful(self):
|
||||
soul = _minimal_soul(humor_level=8, threshold_regular=5, threshold_favorite=20)
|
||||
events = [{"type": EVENT_POSITIVE, "dt": 10.0}]
|
||||
mood = compute_mood(soul, score=10.0, interaction_count=8, recent_events=events)
|
||||
assert mood == MOOD_PLAYFUL
|
||||
|
||||
def test_low_humor_regular_returns_happy(self):
|
||||
soul = _minimal_soul(humor_level=4, threshold_regular=5, threshold_favorite=20)
|
||||
events = [{"type": EVENT_POSITIVE, "dt": 10.0}]
|
||||
mood = compute_mood(soul, score=10.0, interaction_count=8, recent_events=events)
|
||||
assert mood == MOOD_HAPPY
|
||||
|
||||
def test_stale_negative_ignored(self):
|
||||
soul = _minimal_soul(humor_level=8, threshold_regular=5, threshold_favorite=20)
|
||||
# dt > 120s → outside the recent window → should not trigger annoyed
|
||||
events = [
|
||||
{"type": EVENT_NEGATIVE, "dt": 200.0},
|
||||
{"type": EVENT_NEGATIVE, "dt": 300.0},
|
||||
]
|
||||
mood = compute_mood(soul, score=15.0, interaction_count=10, recent_events=events)
|
||||
assert mood != MOOD_ANNOYED
|
||||
|
||||
def test_favorite_high_humor_playful(self):
|
||||
soul = _minimal_soul(humor_level=9, threshold_regular=5, threshold_favorite=20)
|
||||
mood = compute_mood(soul, score=50.0, interaction_count=30, recent_events=[])
|
||||
assert mood == MOOD_PLAYFUL
|
||||
|
||||
def test_favorite_low_humor_happy(self):
|
||||
soul = _minimal_soul(humor_level=3, threshold_regular=5, threshold_favorite=20)
|
||||
mood = compute_mood(soul, score=50.0, interaction_count=30, recent_events=[])
|
||||
assert mood == MOOD_HAPPY
|
||||
|
||||
|
||||
class TestBuildGreeting:
|
||||
def _soul(self, **kw):
|
||||
return _minimal_soul(
|
||||
mood_prefix_happy="Great — ",
|
||||
mood_prefix_curious="Hmm, ",
|
||||
mood_prefix_annoyed="Well, ",
|
||||
mood_prefix_playful="Beep boop! ",
|
||||
**kw,
|
||||
)
|
||||
|
||||
def test_stranger_greeting(self):
|
||||
soul = self._soul()
|
||||
g = build_greeting(soul, TIER_STRANGER, MOOD_CURIOUS, "")
|
||||
assert "hello" in g.lower()
|
||||
|
||||
def test_regular_greeting_contains_name(self):
|
||||
soul = self._soul()
|
||||
g = build_greeting(soul, TIER_REGULAR, MOOD_HAPPY, "alice")
|
||||
assert "alice" in g
|
||||
|
||||
def test_favorite_greeting_contains_name(self):
|
||||
soul = self._soul()
|
||||
g = build_greeting(soul, TIER_FAVORITE, MOOD_PLAYFUL, "bob")
|
||||
assert "bob" in g
|
||||
|
||||
def test_mood_prefix_applied(self):
|
||||
soul = self._soul()
|
||||
g = build_greeting(soul, TIER_REGULAR, MOOD_PLAYFUL, "alice")
|
||||
assert g.startswith("Beep boop!")
|
||||
|
||||
def test_no_prefix_key_no_prefix(self):
|
||||
soul = _minimal_soul() # no mood_prefix_* keys
|
||||
g = build_greeting(soul, TIER_REGULAR, MOOD_HAPPY, "alice")
|
||||
assert g.startswith("Hey")
|
||||
|
||||
def test_empty_person_id_uses_friend(self):
|
||||
soul = self._soul()
|
||||
g = build_greeting(soul, TIER_REGULAR, MOOD_HAPPY, "")
|
||||
assert "friend" in g
|
||||
|
||||
def test_happy_prefix(self):
|
||||
soul = self._soul()
|
||||
g = build_greeting(soul, TIER_REGULAR, MOOD_HAPPY, "carol")
|
||||
assert g.startswith("Great")
|
||||
|
||||
def test_annoyed_prefix(self):
|
||||
soul = self._soul()
|
||||
g = build_greeting(soul, TIER_REGULAR, MOOD_ANNOYED, "dave")
|
||||
assert g.startswith("Well")
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# relationship_db tests
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestRelationshipDB:
|
||||
@pytest.fixture
|
||||
def db(self, tmp_path):
|
||||
path = str(tmp_path / "test.db")
|
||||
d = RelationshipDB(path)
|
||||
yield d
|
||||
d.close()
|
||||
|
||||
def test_get_person_creates_default(self, db):
|
||||
p = db.get_person("alice")
|
||||
assert p["person_id"] == "alice"
|
||||
assert p["score"] == pytest.approx(0.0)
|
||||
assert p["interaction_count"] == 0
|
||||
|
||||
def test_get_person_idempotent(self, db):
|
||||
p1 = db.get_person("bob")
|
||||
p2 = db.get_person("bob")
|
||||
assert p1["person_id"] == p2["person_id"]
|
||||
|
||||
def test_record_interaction_increments_count(self, db):
|
||||
db.record_interaction("alice", "detection")
|
||||
db.record_interaction("alice", "detection")
|
||||
p = db.get_person("alice")
|
||||
assert p["interaction_count"] == 2
|
||||
|
||||
def test_record_interaction_updates_score(self, db):
|
||||
db.record_interaction("alice", "positive", delta_score=5.0)
|
||||
p = db.get_person("alice")
|
||||
assert p["score"] == pytest.approx(5.0)
|
||||
|
||||
def test_negative_delta_reduces_score(self, db):
|
||||
db.record_interaction("carol", "positive", delta_score=10.0)
|
||||
db.record_interaction("carol", "negative", delta_score=-3.0)
|
||||
p = db.get_person("carol")
|
||||
assert p["score"] == pytest.approx(7.0)
|
||||
|
||||
def test_score_zero_by_default(self, db):
|
||||
p = db.get_person("dave")
|
||||
assert p["score"] == pytest.approx(0.0)
|
||||
|
||||
def test_set_and_get_pref(self, db):
|
||||
db.set_pref("alice", "language", "en")
|
||||
assert db.get_pref("alice", "language") == "en"
|
||||
|
||||
def test_get_pref_default(self, db):
|
||||
assert db.get_pref("nobody", "language", "fr") == "fr"
|
||||
|
||||
def test_multiple_prefs_stored(self, db):
|
||||
db.set_pref("alice", "lang", "en")
|
||||
db.set_pref("alice", "name", "Alice")
|
||||
assert db.get_pref("alice", "lang") == "en"
|
||||
assert db.get_pref("alice", "name") == "Alice"
|
||||
|
||||
def test_all_people_returns_list(self, db):
|
||||
db.record_interaction("a", "detection")
|
||||
db.record_interaction("b", "detection")
|
||||
people = db.all_people()
|
||||
ids = {p["person_id"] for p in people}
|
||||
assert {"a", "b"} <= ids
|
||||
|
||||
def test_get_recent_events_returns_events(self, db):
|
||||
db.record_interaction("alice", "greeting", delta_score=1.0)
|
||||
events = db.get_recent_events("alice", window_s=60.0)
|
||||
assert len(events) == 1
|
||||
assert events[0]["type"] == "greeting"
|
||||
|
||||
def test_get_recent_events_empty_for_new_person(self, db):
|
||||
events = db.get_recent_events("nobody", window_s=60.0)
|
||||
assert events == []
|
||||
|
||||
def test_event_dt_positive(self, db):
|
||||
db.record_interaction("alice", "detection")
|
||||
events = db.get_recent_events("alice", window_s=60.0)
|
||||
assert events[0]["dt"] >= 0.0
|
||||
|
||||
def test_multiple_people_isolated(self, db):
|
||||
db.record_interaction("alice", "positive", delta_score=10.0)
|
||||
db.record_interaction("bob", "negative", delta_score=-5.0)
|
||||
assert db.get_person("alice")["score"] == pytest.approx(10.0)
|
||||
assert db.get_person("bob")["score"] == pytest.approx(-5.0)
|
||||
|
||||
def test_details_stored(self, db):
|
||||
db.record_interaction("alice", "greeting", details={"location": "lab"})
|
||||
events = db.get_recent_events("alice", window_s=60.0)
|
||||
assert events[0]["details"].get("location") == "lab"
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Integration: soul → tier → mood → greeting pipeline
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestIntegrationPipeline:
|
||||
|
||||
def test_stranger_pipeline(self, tmp_path):
|
||||
db_path = str(tmp_path / "int.db")
|
||||
db = RelationshipDB(db_path)
|
||||
soul = _minimal_soul(
|
||||
humor_level=7, threshold_regular=5, threshold_favorite=20,
|
||||
mood_prefix_curious="Hmm, "
|
||||
)
|
||||
# No prior interactions
|
||||
person = db.get_person("stranger_001")
|
||||
events = db.get_recent_events("stranger_001")
|
||||
tier = get_relationship_tier(soul, person["interaction_count"])
|
||||
mood = compute_mood(soul, person["score"], person["interaction_count"], events)
|
||||
greeting = build_greeting(soul, tier, mood, "stranger_001")
|
||||
|
||||
assert tier == TIER_STRANGER
|
||||
assert mood == MOOD_CURIOUS
|
||||
assert "hello" in greeting.lower()
|
||||
db.close()
|
||||
|
||||
def test_regular_positive_pipeline(self, tmp_path):
|
||||
db_path = str(tmp_path / "int2.db")
|
||||
db = RelationshipDB(db_path)
|
||||
soul = _minimal_soul(
|
||||
humor_level=8, threshold_regular=5, threshold_favorite=20,
|
||||
mood_prefix_playful="Beep! "
|
||||
)
|
||||
# Simulate 6 positive interactions (> threshold_regular=5)
|
||||
for _ in range(6):
|
||||
db.record_interaction("alice", "positive", delta_score=2.0)
|
||||
|
||||
person = db.get_person("alice")
|
||||
events = db.get_recent_events("alice")
|
||||
tier = get_relationship_tier(soul, person["interaction_count"])
|
||||
mood = compute_mood(soul, person["score"], person["interaction_count"], events)
|
||||
greeting = build_greeting(soul, tier, mood, "alice")
|
||||
|
||||
assert tier == TIER_REGULAR
|
||||
assert mood == MOOD_PLAYFUL # humor_level=8, recent positive events
|
||||
assert "alice" in greeting
|
||||
assert greeting.startswith("Beep!")
|
||||
db.close()
|
||||
|
||||
def test_favorite_pipeline(self, tmp_path):
|
||||
db_path = str(tmp_path / "int3.db")
|
||||
db = RelationshipDB(db_path)
|
||||
soul = _minimal_soul(
|
||||
humor_level=5, threshold_regular=5, threshold_favorite=20
|
||||
)
|
||||
for _ in range(25):
|
||||
db.record_interaction("bob", "positive", delta_score=1.0)
|
||||
|
||||
person = db.get_person("bob")
|
||||
tier = get_relationship_tier(soul, person["interaction_count"])
|
||||
assert tier == TIER_FAVORITE
|
||||
greeting = build_greeting(soul, tier, "happy", "bob")
|
||||
assert "bob" in greeting
|
||||
assert "Oh hey" in greeting
|
||||
db.close()
|
||||
@ -1,57 +1,24 @@
|
||||
# uwb_config.yaml — MaUWB ESP32-S3 DW3000 UWB follow-me system
|
||||
#
|
||||
# Hardware layout:
|
||||
# Anchor-0 (port side) → USB port_a, y = +anchor_separation/2
|
||||
# Anchor-1 (starboard side) → USB port_b, y = -anchor_separation/2
|
||||
# Tag on person → belt clip, ~0.9m above ground
|
||||
# uwb_config.yaml — MaUWB ESP32-S3 DW3000 UWB integration (Issue #90)
|
||||
#
|
||||
# Run with:
|
||||
# ros2 launch saltybot_uwb uwb.launch.py
|
||||
# Override at launch:
|
||||
# ros2 launch saltybot_uwb uwb.launch.py port_a:=/dev/ttyUSB2
|
||||
|
||||
# ── Serial ports ──────────────────────────────────────────────────────────────
|
||||
# Set udev rules to get stable symlinks:
|
||||
# /dev/uwb-anchor0 → port_a
|
||||
# /dev/uwb-anchor1 → port_b
|
||||
# (See jetson/docs/pinout.md for udev setup)
|
||||
port_a: /dev/uwb-anchor0 # Anchor-0 (port)
|
||||
port_b: /dev/uwb-anchor1 # Anchor-1 (starboard)
|
||||
baudrate: 115200 # MaUWB default — do not change
|
||||
port_a: /dev/uwb-anchor0
|
||||
port_b: /dev/uwb-anchor1
|
||||
baudrate: 115200
|
||||
|
||||
# ── Anchor geometry ────────────────────────────────────────────────────────────
|
||||
# anchor_separation: centre-to-centre distance between anchors (metres)
|
||||
# Must match physical mounting. Larger = more accurate lateral resolution.
|
||||
anchor_separation: 0.25 # metres (25cm)
|
||||
anchor_separation: 0.25
|
||||
anchor_height: 0.80
|
||||
tag_height: 0.90
|
||||
|
||||
# anchor_height: height of anchors above ground (metres)
|
||||
# Orin stem mount ≈ 0.80m on the saltybot platform
|
||||
anchor_height: 0.80 # metres
|
||||
range_timeout_s: 1.0
|
||||
max_range_m: 8.0
|
||||
min_range_m: 0.05
|
||||
|
||||
# tag_height: height of person's belt-clip tag above ground (metres)
|
||||
tag_height: 0.90 # metres (adjust per user)
|
||||
|
||||
# ── Range validity ─────────────────────────────────────────────────────────────
|
||||
# range_timeout_s: stale anchor — excluded from triangulation after this gap
|
||||
range_timeout_s: 1.0 # seconds
|
||||
|
||||
# max_range_m: discard ranges beyond this (DW3000 indoor practical limit ≈8m)
|
||||
max_range_m: 8.0 # metres
|
||||
|
||||
# min_range_m: discard ranges below this (likely multipath artefacts)
|
||||
min_range_m: 0.05 # metres
|
||||
|
||||
# ── Kalman filter ──────────────────────────────────────────────────────────────
|
||||
# kf_process_noise: Q scalar — how dynamic the person's motion is
|
||||
# Higher → faster response, more jitter
|
||||
kf_process_noise: 0.1
|
||||
|
||||
# kf_meas_noise: R scalar — how noisy the UWB ranging is
|
||||
# DW3000 indoor accuracy ≈ 10cm RMS → 0.1m → R ≈ 0.01
|
||||
# Use 0.3 to be conservative on first deployment
|
||||
kf_meas_noise: 0.3
|
||||
|
||||
# ── Publish rate ──────────────────────────────────────────────────────────────
|
||||
# Should match or exceed the AT+RANGE? poll rate from both anchors.
|
||||
# 20Hz means 50ms per cycle; each anchor query takes ~10ms → headroom ok.
|
||||
publish_rate: 20.0 # Hz
|
||||
range_rate: 100.0
|
||||
bearing_rate: 10.0
|
||||
|
||||
enrolled_tag_ids: [""]
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
"""
|
||||
uwb.launch.py — Launch UWB driver node for MaUWB ESP32-S3 follow-me.
|
||||
uwb.launch.py — Launch UWB driver node for MaUWB ESP32-S3 DW3000 (Issue #90).
|
||||
|
||||
Topics:
|
||||
/uwb/ranges 100 Hz raw TWR ranges
|
||||
/uwb/bearing 10 Hz Kalman-fused bearing
|
||||
/uwb/target 10 Hz triangulated PoseStamped (backwards compat)
|
||||
|
||||
Usage:
|
||||
ros2 launch saltybot_uwb uwb.launch.py
|
||||
ros2 launch saltybot_uwb uwb.launch.py port_a:=/dev/ttyUSB2 port_b:=/dev/ttyUSB3
|
||||
ros2 launch saltybot_uwb uwb.launch.py anchor_separation:=0.30 publish_rate:=10.0
|
||||
ros2 launch saltybot_uwb uwb.launch.py enrolled_tag_ids:="['0xDEADBEEF']"
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -31,7 +35,9 @@ def generate_launch_description():
|
||||
DeclareLaunchArgument("min_range_m", default_value="0.05"),
|
||||
DeclareLaunchArgument("kf_process_noise", default_value="0.1"),
|
||||
DeclareLaunchArgument("kf_meas_noise", default_value="0.3"),
|
||||
DeclareLaunchArgument("publish_rate", default_value="20.0"),
|
||||
DeclareLaunchArgument("range_rate", default_value="100.0"),
|
||||
DeclareLaunchArgument("bearing_rate", default_value="10.0"),
|
||||
DeclareLaunchArgument("enrolled_tag_ids", default_value="['']"),
|
||||
|
||||
Node(
|
||||
package="saltybot_uwb",
|
||||
@ -52,7 +58,9 @@ def generate_launch_description():
|
||||
"min_range_m": LaunchConfiguration("min_range_m"),
|
||||
"kf_process_noise": LaunchConfiguration("kf_process_noise"),
|
||||
"kf_meas_noise": LaunchConfiguration("kf_meas_noise"),
|
||||
"publish_rate": LaunchConfiguration("publish_rate"),
|
||||
"range_rate": LaunchConfiguration("range_rate"),
|
||||
"bearing_rate": LaunchConfiguration("bearing_rate"),
|
||||
"enrolled_tag_ids": LaunchConfiguration("enrolled_tag_ids"),
|
||||
},
|
||||
],
|
||||
),
|
||||
|
||||
@ -29,6 +29,26 @@ Returns (x_t, y_t); caller should treat negative x_t as 0.
|
||||
import math
|
||||
|
||||
|
||||
# ── Bearing helper ────────────────────────────────────────────────────────────
|
||||
|
||||
def bearing_from_pos(x: float, y: float) -> float:
|
||||
"""
|
||||
Compute horizontal bearing to a point (x, y) in base_link.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : forward distance (metres)
|
||||
y : lateral offset (metres, positive = left of robot / CCW)
|
||||
|
||||
Returns
|
||||
-------
|
||||
bearing_rad : bearing in radians, range -π .. +π
|
||||
positive = target to the left (CCW)
|
||||
0 = directly ahead
|
||||
"""
|
||||
return math.atan2(y, x)
|
||||
|
||||
|
||||
# ── Triangulation ─────────────────────────────────────────────────────────────
|
||||
|
||||
def triangulate_2anchor(
|
||||
|
||||
@ -1,32 +1,52 @@
|
||||
"""
|
||||
uwb_driver_node.py — ROS2 node for MaUWB ESP32-S3 DW3000 follow-me system.
|
||||
uwb_driver_node.py — ROS2 node for MaUWB ESP32-S3 DW3000 UWB integration.
|
||||
|
||||
Hardware
|
||||
────────
|
||||
• 2× MaUWB ESP32-S3 DW3000 anchors on robot stem (USB → Orin Nano)
|
||||
• 2× MaUWB ESP32-S3 DW3000 anchors on robot stem (USB → Orin)
|
||||
- Anchor-0: port side (y = +sep/2)
|
||||
- Anchor-1: starboard (y = -sep/2)
|
||||
• 1× MaUWB tag on person (belt clip)
|
||||
• 1× MaUWB tag per enrolled person (belt clip)
|
||||
|
||||
AT command interface (115200 8N1)
|
||||
──────────────────────────────────
|
||||
Query: AT+RANGE?\r\n
|
||||
Response (from anchors):
|
||||
+RANGE:<anchor_id>,<range_mm>[,<rssi>]\r\n
|
||||
Query:
|
||||
AT+RANGE?\r\n
|
||||
|
||||
Config:
|
||||
AT+anchor_tag=ANCHOR\r\n — set module as anchor
|
||||
AT+anchor_tag=TAG\r\n — set module as tag
|
||||
Response (from anchors, TWR protocol):
|
||||
+RANGE:<anchor_id>,<range_mm>[,<rssi>[,<tag_addr>]]\r\n
|
||||
|
||||
Tag pairing (optional — targets a specific enrolled tag):
|
||||
AT+RANGE_ADDR=<tag_addr>\r\n → anchor only ranges with that tag
|
||||
|
||||
Publishes
|
||||
─────────
|
||||
/uwb/target (geometry_msgs/PoseStamped) — triangulated person position in base_link
|
||||
/uwb/ranges (saltybot_uwb_msgs/UwbRangeArray) — raw ranges from both anchors
|
||||
/uwb/ranges (saltybot_uwb_msgs/UwbRangeArray) 100 Hz — raw anchor ranges
|
||||
/uwb/bearing (saltybot_uwb_msgs/UwbBearing) 10 Hz — Kalman-fused bearing
|
||||
/uwb/target (geometry_msgs/PoseStamped) 10 Hz — triangulated position
|
||||
(kept for backwards compat)
|
||||
|
||||
Safety
|
||||
──────
|
||||
If a range is stale (> range_timeout_s), that anchor is excluded from
|
||||
triangulation. If both anchors are stale, /uwb/target is not published.
|
||||
Tag pairing
|
||||
───────────
|
||||
Set enrolled_tag_ids to a list of tag address strings (e.g. ["0x1234ABCD"]).
|
||||
When non-empty, ranges from unrecognised tags are silently discarded.
|
||||
The matched tag address is stamped in UwbRange.tag_id and UwbBearing.tag_id.
|
||||
When enrolled_tag_ids is empty, all ranges are accepted (tag_id = "").
|
||||
|
||||
Parameters
|
||||
──────────
|
||||
port_a, port_b serial ports for anchor-0 / anchor-1
|
||||
baudrate 115200 (default)
|
||||
anchor_separation centre-to-centre anchor spacing (m)
|
||||
anchor_height anchor mounting height (m)
|
||||
tag_height person tag height (m)
|
||||
range_timeout_s stale-anchor threshold (s)
|
||||
max_range_m / min_range_m validity window (m)
|
||||
kf_process_noise Kalman Q scalar
|
||||
kf_meas_noise Kalman R scalar
|
||||
range_rate Hz — /uwb/ranges publish rate (default 100)
|
||||
bearing_rate Hz — /uwb/bearing publish rate (default 10)
|
||||
enrolled_tag_ids list[str] — accepted tag addresses; [] = accept all
|
||||
|
||||
Usage
|
||||
─────
|
||||
@ -44,8 +64,8 @@ from rclpy.node import Node
|
||||
from geometry_msgs.msg import PoseStamped
|
||||
from std_msgs.msg import Header
|
||||
|
||||
from saltybot_uwb_msgs.msg import UwbRange, UwbRangeArray
|
||||
from saltybot_uwb.ranging_math import triangulate_2anchor, KalmanFilter2D
|
||||
from saltybot_uwb_msgs.msg import UwbRange, UwbRangeArray, UwbBearing
|
||||
from saltybot_uwb.ranging_math import triangulate_2anchor, KalmanFilter2D, bearing_from_pos
|
||||
|
||||
try:
|
||||
import serial
|
||||
@ -54,26 +74,31 @@ except ImportError:
|
||||
_SERIAL_AVAILABLE = False
|
||||
|
||||
|
||||
# Regex: +RANGE:<id>,<mm> or +RANGE:<id>,<mm>,<rssi>
|
||||
# +RANGE:<id>,<mm> or +RANGE:<id>,<mm>,<rssi> or +RANGE:<id>,<mm>,<rssi>,<tag>
|
||||
_RANGE_RE = re.compile(
|
||||
r"\+RANGE\s*:\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(-?[\d.]+))?",
|
||||
r"\+RANGE\s*:\s*(\d+)\s*,\s*(\d+)"
|
||||
r"(?:\s*,\s*(-?[\d.]+)"
|
||||
r"(?:\s*,\s*([\w:x]+))?"
|
||||
r")?",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
class SerialReader(threading.Thread):
|
||||
"""
|
||||
Background thread: polls an anchor's UART, fires callback on every
|
||||
valid +RANGE response.
|
||||
Background thread: polls one anchor's UART at maximum TWR rate,
|
||||
fires callback on every valid +RANGE response.
|
||||
Supports optional tag pairing via AT+RANGE_ADDR command.
|
||||
"""
|
||||
|
||||
def __init__(self, port, baudrate, anchor_id, callback, logger):
|
||||
def __init__(self, port, baudrate, anchor_id, callback, logger, tag_addr=None):
|
||||
super().__init__(daemon=True)
|
||||
self._port = port
|
||||
self._baudrate = baudrate
|
||||
self._anchor_id = anchor_id
|
||||
self._callback = callback
|
||||
self._logger = logger
|
||||
self._tag_addr = tag_addr
|
||||
self._running = False
|
||||
self._ser = None
|
||||
|
||||
@ -86,7 +111,10 @@ class SerialReader(threading.Thread):
|
||||
)
|
||||
self._logger.info(
|
||||
f"Anchor-{self._anchor_id}: opened {self._port}"
|
||||
+ (f" paired with tag {self._tag_addr}" if self._tag_addr else "")
|
||||
)
|
||||
if self._tag_addr:
|
||||
self._send_pairing_cmd()
|
||||
self._read_loop()
|
||||
except Exception as exc:
|
||||
self._logger.warn(
|
||||
@ -96,12 +124,24 @@ class SerialReader(threading.Thread):
|
||||
self._ser.close()
|
||||
time.sleep(2.0)
|
||||
|
||||
def _send_pairing_cmd(self):
|
||||
"""Configure the anchor to range only with the paired tag."""
|
||||
try:
|
||||
cmd = f"AT+RANGE_ADDR={self._tag_addr}\r\n".encode("ascii")
|
||||
self._ser.write(cmd)
|
||||
time.sleep(0.1)
|
||||
self._logger.info(
|
||||
f"Anchor-{self._anchor_id}: sent tag pairing {self._tag_addr}"
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.warn(
|
||||
f"Anchor-{self._anchor_id}: pairing cmd failed: {exc}"
|
||||
)
|
||||
|
||||
def _read_loop(self):
|
||||
while self._running:
|
||||
try:
|
||||
# Query the anchor
|
||||
self._ser.write(b"AT+RANGE?\r\n")
|
||||
# Read up to 10 lines waiting for a +RANGE response
|
||||
for _ in range(10):
|
||||
raw = self._ser.readline()
|
||||
if not raw:
|
||||
@ -111,13 +151,14 @@ class SerialReader(threading.Thread):
|
||||
if m:
|
||||
range_mm = int(m.group(2))
|
||||
rssi = float(m.group(3)) if m.group(3) else 0.0
|
||||
self._callback(self._anchor_id, range_mm, rssi)
|
||||
tag_addr = m.group(4) if m.group(4) else ""
|
||||
self._callback(self._anchor_id, range_mm, rssi, tag_addr)
|
||||
break
|
||||
except Exception as exc:
|
||||
self._logger.warn(
|
||||
f"Anchor-{self._anchor_id} read error: {exc}"
|
||||
)
|
||||
break # trigger reconnect
|
||||
break
|
||||
|
||||
def stop(self):
|
||||
self._running = False
|
||||
@ -130,9 +171,8 @@ class UwbDriverNode(Node):
|
||||
def __init__(self):
|
||||
super().__init__("uwb_driver")
|
||||
|
||||
# ── Parameters ────────────────────────────────────────────────────────
|
||||
self.declare_parameter("port_a", "/dev/ttyUSB0")
|
||||
self.declare_parameter("port_b", "/dev/ttyUSB1")
|
||||
self.declare_parameter("port_a", "/dev/uwb-anchor0")
|
||||
self.declare_parameter("port_b", "/dev/uwb-anchor1")
|
||||
self.declare_parameter("baudrate", 115200)
|
||||
self.declare_parameter("anchor_separation", 0.25)
|
||||
self.declare_parameter("anchor_height", 0.80)
|
||||
@ -142,36 +182,40 @@ class UwbDriverNode(Node):
|
||||
self.declare_parameter("min_range_m", 0.05)
|
||||
self.declare_parameter("kf_process_noise", 0.1)
|
||||
self.declare_parameter("kf_meas_noise", 0.3)
|
||||
self.declare_parameter("publish_rate", 20.0)
|
||||
self.declare_parameter("range_rate", 100.0)
|
||||
self.declare_parameter("bearing_rate", 10.0)
|
||||
self.declare_parameter("enrolled_tag_ids", [""])
|
||||
|
||||
self._p = self._load_params()
|
||||
|
||||
# ── State (protected by lock) ──────────────────────────────────────
|
||||
raw_ids = self.get_parameter("enrolled_tag_ids").value
|
||||
self._enrolled_tags = [t.strip() for t in raw_ids if t.strip()]
|
||||
paired_tag = self._enrolled_tags[0] if self._enrolled_tags else None
|
||||
|
||||
self._lock = threading.Lock()
|
||||
self._ranges = {} # anchor_id → (range_m, rssi, timestamp)
|
||||
self._ranges: dict = {}
|
||||
self._kf = KalmanFilter2D(
|
||||
process_noise=self._p["kf_process_noise"],
|
||||
measurement_noise=self._p["kf_meas_noise"],
|
||||
dt=1.0 / self._p["publish_rate"],
|
||||
dt=1.0 / self._p["bearing_rate"],
|
||||
)
|
||||
|
||||
# ── Publishers ────────────────────────────────────────────────────
|
||||
self._target_pub = self.create_publisher(
|
||||
PoseStamped, "/uwb/target", 10)
|
||||
self._ranges_pub = self.create_publisher(
|
||||
UwbRangeArray, "/uwb/ranges", 10)
|
||||
self._ranges_pub = self.create_publisher(UwbRangeArray, "/uwb/ranges", 10)
|
||||
self._bearing_pub = self.create_publisher(UwbBearing, "/uwb/bearing", 10)
|
||||
self._target_pub = self.create_publisher(PoseStamped, "/uwb/target", 10)
|
||||
|
||||
# ── Serial readers ────────────────────────────────────────────────
|
||||
if _SERIAL_AVAILABLE:
|
||||
self._reader_a = SerialReader(
|
||||
self._p["port_a"], self._p["baudrate"],
|
||||
anchor_id=0, callback=self._range_cb,
|
||||
logger=self.get_logger(),
|
||||
tag_addr=paired_tag,
|
||||
)
|
||||
self._reader_b = SerialReader(
|
||||
self._p["port_b"], self._p["baudrate"],
|
||||
anchor_id=1, callback=self._range_cb,
|
||||
logger=self.get_logger(),
|
||||
tag_addr=paired_tag,
|
||||
)
|
||||
self._reader_a.start()
|
||||
self._reader_b.start()
|
||||
@ -180,19 +224,21 @@ class UwbDriverNode(Node):
|
||||
"pyserial not installed — running in simulation mode (no serial I/O)"
|
||||
)
|
||||
|
||||
# ── Publish timer ─────────────────────────────────────────────────
|
||||
self._timer = self.create_timer(
|
||||
1.0 / self._p["publish_rate"], self._publish_cb
|
||||
self._range_timer = self.create_timer(
|
||||
1.0 / self._p["range_rate"], self._range_publish_cb
|
||||
)
|
||||
self._bearing_timer = self.create_timer(
|
||||
1.0 / self._p["bearing_rate"], self._bearing_publish_cb
|
||||
)
|
||||
|
||||
self.get_logger().info(
|
||||
f"UWB driver ready sep={self._p['anchor_separation']}m "
|
||||
f"ports={self._p['port_a']},{self._p['port_b']} "
|
||||
f"rate={self._p['publish_rate']}Hz"
|
||||
f"range={self._p['range_rate']}Hz "
|
||||
f"bearing={self._p['bearing_rate']}Hz "
|
||||
f"enrolled_tags={self._enrolled_tags or ['<any>']}"
|
||||
)
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
def _load_params(self):
|
||||
return {
|
||||
"port_a": self.get_parameter("port_a").value,
|
||||
@ -206,90 +252,106 @@ class UwbDriverNode(Node):
|
||||
"min_range_m": self.get_parameter("min_range_m").value,
|
||||
"kf_process_noise": self.get_parameter("kf_process_noise").value,
|
||||
"kf_meas_noise": self.get_parameter("kf_meas_noise").value,
|
||||
"publish_rate": self.get_parameter("publish_rate").value,
|
||||
"range_rate": self.get_parameter("range_rate").value,
|
||||
"bearing_rate": self.get_parameter("bearing_rate").value,
|
||||
}
|
||||
|
||||
# ── Callbacks ─────────────────────────────────────────────────────────────
|
||||
def _is_enrolled(self, tag_addr: str) -> bool:
|
||||
if not self._enrolled_tags:
|
||||
return True
|
||||
return tag_addr in self._enrolled_tags
|
||||
|
||||
def _range_cb(self, anchor_id: int, range_mm: int, rssi: float):
|
||||
"""Called from serial reader threads — thread-safe update."""
|
||||
def _range_cb(self, anchor_id: int, range_mm: int, rssi: float, tag_addr: str):
|
||||
if not self._is_enrolled(tag_addr):
|
||||
return
|
||||
range_m = range_mm / 1000.0
|
||||
p = self._p
|
||||
if range_m < p["min_range_m"] or range_m > p["max_range_m"]:
|
||||
return
|
||||
with self._lock:
|
||||
self._ranges[anchor_id] = (range_m, rssi, time.monotonic())
|
||||
self._ranges[anchor_id] = (range_m, rssi, tag_addr, time.monotonic())
|
||||
|
||||
def _publish_cb(self):
|
||||
def _range_publish_cb(self):
|
||||
"""100 Hz: publish current raw ranges as UwbRangeArray."""
|
||||
now = time.monotonic()
|
||||
timeout = self._p["range_timeout_s"]
|
||||
sep = self._p["anchor_separation"]
|
||||
|
||||
with self._lock:
|
||||
# Collect valid (non-stale) ranges
|
||||
valid = {}
|
||||
for aid, (r, rssi, t) in self._ranges.items():
|
||||
if now - t <= timeout:
|
||||
valid[aid] = (r, rssi, t)
|
||||
|
||||
# Build and publish UwbRangeArray regardless (even if partial)
|
||||
valid = {
|
||||
aid: entry
|
||||
for aid, entry in self._ranges.items()
|
||||
if (now - entry[3]) <= timeout
|
||||
}
|
||||
hdr = Header()
|
||||
hdr.stamp = self.get_clock().now().to_msg()
|
||||
hdr.frame_id = "base_link"
|
||||
|
||||
arr = UwbRangeArray()
|
||||
arr.header = hdr
|
||||
for aid, (r, rssi, _) in valid.items():
|
||||
for aid, (r, rssi, tag_id, _) in valid.items():
|
||||
entry = UwbRange()
|
||||
entry.header = hdr
|
||||
entry.anchor_id = aid
|
||||
entry.range_m = float(r)
|
||||
entry.raw_mm = int(round(r * 1000.0))
|
||||
entry.rssi = float(rssi)
|
||||
entry.tag_id = tag_id
|
||||
arr.ranges.append(entry)
|
||||
self._ranges_pub.publish(arr)
|
||||
|
||||
# Need both anchors to triangulate
|
||||
if 0 not in valid or 1 not in valid:
|
||||
def _bearing_publish_cb(self):
|
||||
"""10 Hz: Kalman predict+update, publish fused bearing."""
|
||||
now = time.monotonic()
|
||||
timeout = self._p["range_timeout_s"]
|
||||
sep = self._p["anchor_separation"]
|
||||
with self._lock:
|
||||
valid = {
|
||||
aid: entry
|
||||
for aid, entry in self._ranges.items()
|
||||
if (now - entry[3]) <= timeout
|
||||
}
|
||||
if not valid:
|
||||
return
|
||||
|
||||
both_fresh = 0 in valid and 1 in valid
|
||||
confidence = 1.0 if both_fresh else 0.5
|
||||
active_tag = valid[next(iter(valid))][2]
|
||||
dt = 1.0 / self._p["bearing_rate"]
|
||||
self._kf.predict(dt=dt)
|
||||
if both_fresh:
|
||||
r0 = valid[0][0]
|
||||
r1 = valid[1][0]
|
||||
|
||||
try:
|
||||
x_t, y_t = triangulate_2anchor(
|
||||
r0=r0,
|
||||
r1=r1,
|
||||
sep=sep,
|
||||
r0=r0, r1=r1, sep=sep,
|
||||
anchor_z=self._p["anchor_height"],
|
||||
tag_z=self._p["tag_height"],
|
||||
)
|
||||
except (ValueError, ZeroDivisionError) as exc:
|
||||
self.get_logger().warn(f"Triangulation error: {exc}")
|
||||
return
|
||||
|
||||
# Kalman filter update
|
||||
dt = 1.0 / self._p["publish_rate"]
|
||||
self._kf.predict(dt=dt)
|
||||
self._kf.update(x_t, y_t)
|
||||
kx, ky = self._kf.position()
|
||||
|
||||
# Publish PoseStamped in base_link
|
||||
bearing = bearing_from_pos(kx, ky)
|
||||
range_m = math.sqrt(kx * kx + ky * ky)
|
||||
hdr = Header()
|
||||
hdr.stamp = self.get_clock().now().to_msg()
|
||||
hdr.frame_id = "base_link"
|
||||
brg_msg = UwbBearing()
|
||||
brg_msg.header = hdr
|
||||
brg_msg.bearing_rad = float(bearing)
|
||||
brg_msg.range_m = float(range_m)
|
||||
brg_msg.confidence = float(confidence)
|
||||
brg_msg.tag_id = active_tag
|
||||
self._bearing_pub.publish(brg_msg)
|
||||
pose = PoseStamped()
|
||||
pose.header = hdr
|
||||
pose.pose.position.x = kx
|
||||
pose.pose.position.y = ky
|
||||
pose.pose.position.z = 0.0
|
||||
# Orientation: face the person (yaw = atan2(y, x))
|
||||
yaw = math.atan2(ky, kx)
|
||||
yaw = bearing
|
||||
pose.pose.orientation.z = math.sin(yaw / 2.0)
|
||||
pose.pose.orientation.w = math.cos(yaw / 2.0)
|
||||
|
||||
self._target_pub.publish(pose)
|
||||
|
||||
|
||||
# ── Entry point ───────────────────────────────────────────────────────────────
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = UwbDriverNode()
|
||||
|
||||
@ -7,7 +7,7 @@ No ROS2 / serial / GPU dependencies — runs with plain pytest.
|
||||
import math
|
||||
import pytest
|
||||
|
||||
from saltybot_uwb.ranging_math import triangulate_2anchor, KalmanFilter2D
|
||||
from saltybot_uwb.ranging_math import triangulate_2anchor, KalmanFilter2D, bearing_from_pos
|
||||
|
||||
|
||||
# ── triangulate_2anchor ───────────────────────────────────────────────────────
|
||||
@ -172,3 +172,47 @@ class TestKalmanFilter2D:
|
||||
x, y = kf.position()
|
||||
assert not math.isnan(x)
|
||||
assert not math.isnan(y)
|
||||
|
||||
|
||||
# ── bearing_from_pos ──────────────────────────────────────────────────────────
|
||||
|
||||
class TestBearingFromPos:
|
||||
|
||||
def test_directly_ahead_zero_bearing(self):
|
||||
"""Person directly ahead: x=2, y=0 → bearing ≈ 0."""
|
||||
b = bearing_from_pos(2.0, 0.0)
|
||||
assert abs(b) < 0.001
|
||||
|
||||
def test_left_gives_positive_bearing(self):
|
||||
"""Person to the left (y>0): bearing should be positive."""
|
||||
b = bearing_from_pos(1.0, 1.0)
|
||||
assert b > 0.0
|
||||
assert abs(b - math.pi / 4.0) < 0.001
|
||||
|
||||
def test_right_gives_negative_bearing(self):
|
||||
"""Person to the right (y<0): bearing should be negative."""
|
||||
b = bearing_from_pos(1.0, -1.0)
|
||||
assert b < 0.0
|
||||
assert abs(b + math.pi / 4.0) < 0.001
|
||||
|
||||
def test_directly_left_ninety_degrees(self):
|
||||
"""Person directly to the left: x=0, y=1 → bearing = π/2."""
|
||||
b = bearing_from_pos(0.0, 1.0)
|
||||
assert abs(b - math.pi / 2.0) < 0.001
|
||||
|
||||
def test_directly_right_minus_ninety_degrees(self):
|
||||
"""Person directly to the right: x=0, y=-1 → bearing = -π/2."""
|
||||
b = bearing_from_pos(0.0, -1.0)
|
||||
assert abs(b + math.pi / 2.0) < 0.001
|
||||
|
||||
def test_range_pi_to_minus_pi(self):
|
||||
"""Bearing is always in -π..+π."""
|
||||
for x in [-2.0, -0.1, 0.1, 2.0]:
|
||||
for y in [-2.0, -0.1, 0.1, 2.0]:
|
||||
b = bearing_from_pos(x, y)
|
||||
assert -math.pi <= b <= math.pi
|
||||
|
||||
def test_no_nan_for_tiny_distance(self):
|
||||
"""Very close target should not produce NaN."""
|
||||
b = bearing_from_pos(0.001, 0.001)
|
||||
assert not math.isnan(b)
|
||||
|
||||
@ -8,6 +8,7 @@ find_package(rosidl_default_generators REQUIRED)
|
||||
rosidl_generate_interfaces(${PROJECT_NAME}
|
||||
"msg/UwbRange.msg"
|
||||
"msg/UwbRangeArray.msg"
|
||||
"msg/UwbBearing.msg"
|
||||
DEPENDENCIES std_msgs
|
||||
)
|
||||
|
||||
|
||||
18
jetson/ros2_ws/src/saltybot_uwb_msgs/msg/UwbBearing.msg
Normal file
18
jetson/ros2_ws/src/saltybot_uwb_msgs/msg/UwbBearing.msg
Normal file
@ -0,0 +1,18 @@
|
||||
# UwbBearing.msg — 10 Hz Kalman-fused bearing estimate from dual-anchor UWB.
|
||||
#
|
||||
# bearing_rad : horizontal bearing to tag in base_link frame (radians)
|
||||
# positive = tag to the left (CCW), negative = right (CW)
|
||||
# range: -π .. +π
|
||||
# range_m : Kalman-filtered horizontal range to tag (metres)
|
||||
# confidence : quality indicator
|
||||
# 1.0 = both anchors fresh
|
||||
# 0.5 = single anchor only (bearing less reliable)
|
||||
# 0.0 = stale / no data (message not published in this state)
|
||||
# tag_id : enrolled tag identifier that produced this estimate
|
||||
|
||||
std_msgs/Header header
|
||||
|
||||
float32 bearing_rad
|
||||
float32 range_m
|
||||
float32 confidence
|
||||
string tag_id
|
||||
@ -4,6 +4,7 @@
|
||||
# range_m : measured horizontal range in metres (after height correction)
|
||||
# raw_mm : raw TWR range from AT+RANGE? response, millimetres
|
||||
# rssi : received signal strength (dBm), 0 if not reported by module
|
||||
# tag_id : enrolled tag identifier; empty string if tag pairing is disabled
|
||||
|
||||
std_msgs/Header header
|
||||
|
||||
@ -11,3 +12,4 @@ uint8 anchor_id
|
||||
float32 range_m
|
||||
uint32 raw_mm
|
||||
float32 rssi
|
||||
string tag_id
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user