feat(social): multi-modal person state tracker (Issue #82)

This commit is contained in:
sl-perception 2026-03-01 23:08:22 -05:00
parent ac6fcb9a42
commit 84790412d6
23 changed files with 734 additions and 0 deletions

View File

@ -0,0 +1,8 @@
person_state_tracker:
ros__parameters:
engagement_distance: 2.0
absent_timeout: 5.0
prune_timeout: 30.0
update_rate: 10.0
n_cameras: 4
uwb_enabled: false

View File

@ -0,0 +1,35 @@
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
return LaunchDescription([
DeclareLaunchArgument(
'engagement_distance',
default_value='2.0',
description='Distance in metres below which a person is considered engaged'
),
DeclareLaunchArgument(
'absent_timeout',
default_value='5.0',
description='Seconds without detection before marking person as ABSENT'
),
DeclareLaunchArgument(
'uwb_enabled',
default_value='false',
description='Whether UWB anchor data is available'
),
Node(
package='saltybot_social',
executable='person_state_tracker',
name='person_state_tracker',
output='screen',
parameters=[{
'engagement_distance': LaunchConfiguration('engagement_distance'),
'absent_timeout': LaunchConfiguration('absent_timeout'),
'uwb_enabled': LaunchConfiguration('uwb_enabled'),
}],
),
])

View File

@ -0,0 +1,25 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social</name>
<version>0.1.0</version>
<description>Multi-modal person identity fusion and state tracking for saltybot</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>std_msgs</depend>
<depend>geometry_msgs</depend>
<depend>sensor_msgs</depend>
<depend>vision_msgs</depend>
<depend>saltybot_social_msgs</depend>
<depend>tf2_ros</depend>
<depend>tf2_geometry_msgs</depend>
<depend>cv_bridge</depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,201 @@
import math
import time
from typing import List, Optional, Tuple
from saltybot_social.person_state import PersonTrack, State
class IdentityFuser:
"""Multi-modal identity fusion: face_id + speaker_id + UWB anchor -> unified person_id."""
def __init__(self, position_window: float = 3.0):
self._tracks: dict[int, PersonTrack] = {}
self._next_id: int = 1
self._position_window = position_window
self._max_history = 20
def _distance_3d(self, a: Optional[tuple], b: Optional[tuple]) -> float:
if a is None or b is None:
return float('inf')
return math.sqrt(sum((ai - bi) ** 2 for ai, bi in zip(a, b)))
def _compute_distance_and_bearing(self, position: tuple) -> Tuple[float, float]:
x, y, z = position
distance = math.sqrt(x * x + y * y + z * z)
bearing = math.degrees(math.atan2(y, x))
return distance, bearing
def _allocate_id(self) -> int:
pid = self._next_id
self._next_id += 1
return pid
def update_face(self, face_id: int, name: str,
position_3d: Optional[tuple], camera_id: int,
now: float) -> int:
"""Match by face_id first. If face_id >= 0, find existing track or create new."""
if face_id >= 0:
for track in self._tracks.values():
if track.face_id == face_id:
track.name = name if name and name != 'unknown' else track.name
if position_3d is not None:
track.position = position_3d
dist, bearing = self._compute_distance_and_bearing(position_3d)
track.distance = dist
track.bearing_deg = bearing
track.history_distances.append(dist)
if len(track.history_distances) > self._max_history:
track.history_distances = track.history_distances[-self._max_history:]
track.camera_id = camera_id
track.last_seen = now
track.last_face_seen = now
return track.person_id
# No existing track with this face_id; try proximity match for face_id < 0
if face_id < 0 and position_3d is not None:
best_track = None
best_dist = 0.5 # 0.5m proximity threshold
for track in self._tracks.values():
d = self._distance_3d(track.position, position_3d)
if d < best_dist and (now - track.last_seen) < self._position_window:
best_dist = d
best_track = track
if best_track is not None:
best_track.position = position_3d
dist, bearing = self._compute_distance_and_bearing(position_3d)
best_track.distance = dist
best_track.bearing_deg = bearing
best_track.history_distances.append(dist)
if len(best_track.history_distances) > self._max_history:
best_track.history_distances = best_track.history_distances[-self._max_history:]
best_track.camera_id = camera_id
best_track.last_seen = now
return best_track.person_id
# Create new track
pid = self._allocate_id()
track = PersonTrack(person_id=pid, face_id=face_id, name=name,
camera_id=camera_id, last_seen=now, last_face_seen=now)
if position_3d is not None:
track.position = position_3d
dist, bearing = self._compute_distance_and_bearing(position_3d)
track.distance = dist
track.bearing_deg = bearing
track.history_distances.append(dist)
self._tracks[pid] = track
return pid
def update_speaker(self, speaker_id: str, now: float) -> int:
"""Find track with nearest recently-seen position, assign speaker_id."""
# First check if any track already has this speaker_id
for track in self._tracks.values():
if track.speaker_id == speaker_id:
track.last_seen = now
return track.person_id
# Assign to nearest recently-seen track
best_track = None
best_dist = float('inf')
for track in self._tracks.values():
if (now - track.last_seen) < self._position_window and track.distance > 0:
if track.distance < best_dist:
best_dist = track.distance
best_track = track
if best_track is not None:
best_track.speaker_id = speaker_id
best_track.last_seen = now
return best_track.person_id
# No suitable track found; create a new one
pid = self._allocate_id()
track = PersonTrack(person_id=pid, speaker_id=speaker_id,
last_seen=now)
self._tracks[pid] = track
return pid
def update_uwb(self, uwb_anchor_id: str, position_3d: tuple,
now: float) -> int:
"""Match by proximity (nearest track within 0.5m)."""
# Check existing UWB assignment
for track in self._tracks.values():
if track.uwb_anchor_id == uwb_anchor_id:
track.position = position_3d
dist, bearing = self._compute_distance_and_bearing(position_3d)
track.distance = dist
track.bearing_deg = bearing
track.history_distances.append(dist)
if len(track.history_distances) > self._max_history:
track.history_distances = track.history_distances[-self._max_history:]
track.last_seen = now
return track.person_id
# Proximity match
best_track = None
best_dist = 0.5
for track in self._tracks.values():
d = self._distance_3d(track.position, position_3d)
if d < best_dist and (now - track.last_seen) < self._position_window:
best_dist = d
best_track = track
if best_track is not None:
best_track.uwb_anchor_id = uwb_anchor_id
best_track.position = position_3d
dist, bearing = self._compute_distance_and_bearing(position_3d)
best_track.distance = dist
best_track.bearing_deg = bearing
best_track.history_distances.append(dist)
if len(best_track.history_distances) > self._max_history:
best_track.history_distances = best_track.history_distances[-self._max_history:]
best_track.last_seen = now
return best_track.person_id
# Create new track
pid = self._allocate_id()
track = PersonTrack(person_id=pid, uwb_anchor_id=uwb_anchor_id,
position=position_3d, last_seen=now)
dist, bearing = self._compute_distance_and_bearing(position_3d)
track.distance = dist
track.bearing_deg = bearing
track.history_distances.append(dist)
self._tracks[pid] = track
return pid
def get_all_tracks(self) -> List[PersonTrack]:
return list(self._tracks.values())
@staticmethod
def compute_attention(tracks: List[PersonTrack]) -> int:
"""Focus on nearest engaged/talking person, or -1 if none."""
candidates = [t for t in tracks
if t.state in (State.ENGAGED, State.TALKING) and t.distance > 0]
if not candidates:
return -1
# Prefer talking, then nearest engaged
talking = [t for t in candidates if t.state == State.TALKING]
if talking:
return min(talking, key=lambda t: t.distance).person_id
return min(candidates, key=lambda t: t.distance).person_id
def prune_absent(self, absent_timeout: float = 30.0):
"""Remove tracks absent longer than timeout."""
now = time.monotonic()
to_remove = [pid for pid, t in self._tracks.items()
if t.state == State.ABSENT and (now - t.last_seen) > absent_timeout]
for pid in to_remove:
del self._tracks[pid]
@staticmethod
def detect_group(tracks: List[PersonTrack]) -> bool:
"""Returns True if >= 2 persons within 1.5m of each other."""
active = [t for t in tracks
if t.position is not None and t.state != State.ABSENT]
for i, a in enumerate(active):
for b in active[i + 1:]:
dx = a.position[0] - b.position[0]
dy = a.position[1] - b.position[1]
dz = a.position[2] - b.position[2]
if math.sqrt(dx * dx + dy * dy + dz * dz) < 1.5:
return True
return False

View File

@ -0,0 +1,54 @@
import time
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Optional
class State(IntEnum):
UNKNOWN = 0
APPROACHING = 1
ENGAGED = 2
TALKING = 3
LEAVING = 4
ABSENT = 5
@dataclass
class PersonTrack:
person_id: int
face_id: int = -1
speaker_id: str = ''
uwb_anchor_id: str = ''
name: str = 'unknown'
state: State = State.UNKNOWN
position: Optional[tuple] = None # (x, y, z) in base_link
distance: float = 0.0
bearing_deg: float = 0.0
engagement_score: float = 0.0
camera_id: int = -1
last_seen: float = field(default_factory=time.monotonic)
last_face_seen: float = 0.0
history_distances: list = field(default_factory=list) # last N distances for trend
def update_state(self, now: float, engagement_distance: float = 2.0,
absent_timeout: float = 5.0):
"""Transition state machine based on distance trend and time."""
age = now - self.last_seen
if age > absent_timeout:
self.state = State.ABSENT
return
if self.speaker_id:
self.state = State.TALKING
return
if len(self.history_distances) >= 3:
trend = self.history_distances[-1] - self.history_distances[-3]
if self.distance < engagement_distance:
self.state = State.ENGAGED
elif trend < -0.2: # moving closer
self.state = State.APPROACHING
elif trend > 0.3: # moving away
self.state = State.LEAVING
else:
self.state = State.ENGAGED if self.distance < engagement_distance else State.UNKNOWN
elif self.distance > 0:
self.state = State.ENGAGED if self.distance < engagement_distance else State.APPROACHING

View File

@ -0,0 +1,273 @@
import time
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from std_msgs.msg import Int32, String
from geometry_msgs.msg import PoseArray, PoseStamped, Point
from sensor_msgs.msg import Image
from builtin_interfaces.msg import Time as TimeMsg
from saltybot_social_msgs.msg import (
FaceDetection,
FaceDetectionArray,
PersonState,
PersonStateArray,
)
from saltybot_social.identity_fuser import IdentityFuser
from saltybot_social.person_state import State
class PersonStateTrackerNode(Node):
"""Main ROS2 node for multi-modal person tracking."""
def __init__(self):
super().__init__('person_state_tracker')
# Parameters
self.declare_parameter('engagement_distance', 2.0)
self.declare_parameter('absent_timeout', 5.0)
self.declare_parameter('prune_timeout', 30.0)
self.declare_parameter('update_rate', 10.0)
self.declare_parameter('n_cameras', 4)
self.declare_parameter('uwb_enabled', False)
self._engagement_distance = self.get_parameter('engagement_distance').value
self._absent_timeout = self.get_parameter('absent_timeout').value
self._prune_timeout = self.get_parameter('prune_timeout').value
update_rate = self.get_parameter('update_rate').value
n_cameras = self.get_parameter('n_cameras').value
uwb_enabled = self.get_parameter('uwb_enabled').value
self._fuser = IdentityFuser()
# QoS profiles
best_effort_qos = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=5
)
reliable_qos = QoSProfile(
reliability=ReliabilityPolicy.RELIABLE,
history=HistoryPolicy.KEEP_LAST,
depth=1
)
# Subscriptions
self.create_subscription(
FaceDetectionArray,
'/social/faces/detections',
self._on_face_detections,
best_effort_qos
)
self.create_subscription(
String,
'/social/speech/speaker_id',
self._on_speaker_id,
10
)
if uwb_enabled:
self.create_subscription(
PoseArray,
'/uwb/positions',
self._on_uwb,
10
)
# Camera subscriptions (monitor topic existence)
self.create_subscription(
Image, '/camera/color/image_raw',
self._on_camera_image, best_effort_qos)
for i in range(n_cameras):
self.create_subscription(
Image, f'/surround/cam{i}/image_raw',
self._on_camera_image, best_effort_qos)
# Existing person tracker position
self.create_subscription(
PoseStamped,
'/person/target',
self._on_person_target,
10
)
# Publishers
self._persons_pub = self.create_publisher(
PersonStateArray, '/social/persons', best_effort_qos)
self._attention_pub = self.create_publisher(
Int32, '/social/attention/target_id', reliable_qos)
# Timer
timer_period = 1.0 / update_rate
self.create_timer(timer_period, self._on_timer)
self.get_logger().info(
f'PersonStateTracker started: engagement={self._engagement_distance}m, '
f'absent_timeout={self._absent_timeout}s, rate={update_rate}Hz, '
f'cameras={n_cameras}, uwb={uwb_enabled}')
def _on_face_detections(self, msg: FaceDetectionArray):
now = time.monotonic()
for face in msg.faces:
position_3d = None
# Use bbox center as rough position estimate if depth is available
# Real 3D position would come from depth camera projection
if face.bbox_x > 0 or face.bbox_y > 0:
# Approximate: use bbox center as bearing proxy, distance from bbox size
# This is a placeholder; real impl would use depth image lookup
position_3d = (
max(0.5, 1.0 / max(face.bbox_w, 0.01)), # rough depth from bbox width
face.bbox_x + face.bbox_w / 2.0 - 0.5, # x offset from center
face.bbox_y + face.bbox_h / 2.0 - 0.5 # y offset from center
)
camera_id = -1
if hasattr(face.header, 'frame_id') and face.header.frame_id:
frame = face.header.frame_id
if 'cam' in frame:
try:
camera_id = int(frame.split('cam')[-1].split('_')[0])
except ValueError:
pass
self._fuser.update_face(
face_id=face.face_id,
name=face.person_name,
position_3d=position_3d,
camera_id=camera_id,
now=now
)
def _on_speaker_id(self, msg: String):
now = time.monotonic()
speaker_id = msg.data
# Parse "speaker_<id>" format or raw id
if speaker_id.startswith('speaker_'):
speaker_id = speaker_id[len('speaker_'):]
self._fuser.update_speaker(speaker_id, now)
def _on_uwb(self, msg: PoseArray):
now = time.monotonic()
for i, pose in enumerate(msg.poses):
position = (pose.position.x, pose.position.y, pose.position.z)
anchor_id = f'uwb_{i}'
if hasattr(msg, 'header') and msg.header.frame_id:
anchor_id = f'{msg.header.frame_id}_{i}'
self._fuser.update_uwb(anchor_id, position, now)
def _on_camera_image(self, msg: Image):
# Monitor only -- no processing here
pass
def _on_person_target(self, msg: PoseStamped):
now = time.monotonic()
position_3d = (
msg.pose.position.x,
msg.pose.position.y,
msg.pose.position.z
)
# Update position for unidentified person (from YOLOv8 tracker)
self._fuser.update_face(
face_id=-1,
name='unknown',
position_3d=position_3d,
camera_id=-1,
now=now
)
def _on_timer(self):
now = time.monotonic()
tracks = self._fuser.get_all_tracks()
# Update state machine for all tracks
for track in tracks:
track.update_state(
now,
engagement_distance=self._engagement_distance,
absent_timeout=self._absent_timeout
)
# Compute engagement scores
for track in tracks:
if track.state == State.ABSENT:
track.engagement_score = 0.0
elif track.state == State.TALKING:
track.engagement_score = 1.0
elif track.state == State.ENGAGED:
track.engagement_score = max(0.0, 1.0 - track.distance / self._engagement_distance)
elif track.state == State.APPROACHING:
track.engagement_score = 0.3
elif track.state == State.LEAVING:
track.engagement_score = 0.1
else:
track.engagement_score = 0.0
# Prune long-absent tracks
self._fuser.prune_absent(self._prune_timeout)
# Compute attention target
attention_id = IdentityFuser.compute_attention(tracks)
# Build and publish PersonStateArray
msg = PersonStateArray()
msg.header.stamp = self.get_clock().now().to_msg()
msg.header.frame_id = 'base_link'
msg.primary_attention_id = attention_id
for track in tracks:
ps = PersonState()
ps.header.stamp = msg.header.stamp
ps.header.frame_id = 'base_link'
ps.person_id = track.person_id
ps.person_name = track.name
ps.face_id = track.face_id
ps.speaker_id = track.speaker_id
ps.uwb_anchor_id = track.uwb_anchor_id
if track.position is not None:
ps.position = Point(
x=float(track.position[0]),
y=float(track.position[1]),
z=float(track.position[2]))
ps.distance = float(track.distance)
ps.bearing_deg = float(track.bearing_deg)
ps.state = int(track.state)
ps.engagement_score = float(track.engagement_score)
ps.last_seen = self._mono_to_ros_time(track.last_seen)
ps.camera_id = track.camera_id
msg.persons.append(ps)
self._persons_pub.publish(msg)
# Publish attention target
att_msg = Int32()
att_msg.data = attention_id
self._attention_pub.publish(att_msg)
def _mono_to_ros_time(self, mono: float) -> TimeMsg:
"""Convert monotonic timestamp to approximate ROS time."""
# Offset from monotonic to wall clock
offset = time.time() - time.monotonic()
wall = mono + offset
sec = int(wall)
nsec = int((wall - sec) * 1e9)
t = TimeMsg()
t.sec = sec
t.nanosec = nsec
return t
def main(args=None):
rclpy.init(args=args)
node = PersonStateTrackerNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/saltybot_social
[install]
install_scripts=$base/lib/saltybot_social

View File

@ -0,0 +1,32 @@
from setuptools import find_packages, setup
import os
from glob import glob
package_name = 'saltybot_social'
setup(
name=package_name,
version='0.1.0',
packages=find_packages(exclude=['test']),
data_files=[
('share/ament_index/resource_index/packages',
['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
(os.path.join('share', package_name, 'launch'),
glob(os.path.join('launch', '*launch.[pxy][yma]*'))),
(os.path.join('share', package_name, 'config'),
glob(os.path.join('config', '*.yaml'))),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='seb',
maintainer_email='seb@vayrette.com',
description='Multi-modal person identity fusion and state tracking for saltybot',
license='MIT',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'person_state_tracker = saltybot_social.person_state_tracker_node:main',
],
},
)

View File

@ -0,0 +1,24 @@
cmake_minimum_required(VERSION 3.8)
project(saltybot_social_msgs)
find_package(ament_cmake REQUIRED)
find_package(rosidl_default_generators REQUIRED)
find_package(std_msgs REQUIRED)
find_package(geometry_msgs REQUIRED)
find_package(builtin_interfaces REQUIRED)
rosidl_generate_interfaces(${PROJECT_NAME}
"msg/FaceDetection.msg"
"msg/FaceDetectionArray.msg"
"msg/FaceEmbedding.msg"
"msg/FaceEmbeddingArray.msg"
"msg/PersonState.msg"
"msg/PersonStateArray.msg"
"srv/EnrollPerson.srv"
"srv/ListPersons.srv"
"srv/DeletePerson.srv"
"srv/UpdatePerson.srv"
DEPENDENCIES std_msgs geometry_msgs builtin_interfaces
)
ament_package()

View File

@ -0,0 +1,10 @@
std_msgs/Header header
int32 face_id
string person_name
float32 confidence
float32 recognition_score
float32 bbox_x
float32 bbox_y
float32 bbox_w
float32 bbox_h
float32[10] landmarks

View File

@ -0,0 +1,2 @@
std_msgs/Header header
saltybot_social_msgs/FaceDetection[] faces

View File

@ -0,0 +1,5 @@
int32 person_id
string person_name
float32[] embedding
builtin_interfaces/Time enrolled_at
int32 sample_count

View File

@ -0,0 +1,2 @@
std_msgs/Header header
saltybot_social_msgs/FaceEmbedding[] embeddings

View File

@ -0,0 +1,19 @@
std_msgs/Header header
int32 person_id
string person_name
int32 face_id
string speaker_id
string uwb_anchor_id
geometry_msgs/Point position
float32 distance
float32 bearing_deg
uint8 state
uint8 STATE_UNKNOWN=0
uint8 STATE_APPROACHING=1
uint8 STATE_ENGAGED=2
uint8 STATE_TALKING=3
uint8 STATE_LEAVING=4
uint8 STATE_ABSENT=5
float32 engagement_score
builtin_interfaces/Time last_seen
int32 camera_id

View File

@ -0,0 +1,3 @@
std_msgs/Header header
saltybot_social_msgs/PersonState[] persons
int32 primary_attention_id

View File

@ -0,0 +1,19 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social_msgs</name>
<version>0.1.0</version>
<description>Custom ROS2 messages and services for saltybot social capabilities</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<buildtool_depend>ament_cmake</buildtool_depend>
<depend>std_msgs</depend>
<depend>geometry_msgs</depend>
<depend>builtin_interfaces</depend>
<build_depend>rosidl_default_generators</build_depend>
<exec_depend>rosidl_default_runtime</exec_depend>
<member_of_group>rosidl_interface_packages</member_of_group>
<export>
<build_type>ament_cmake</build_type>
</export>
</package>

View File

@ -0,0 +1,4 @@
int32 person_id
---
bool success
string message

View File

@ -0,0 +1,7 @@
string name
string mode
int32 n_samples
---
bool success
string message
int32 person_id

View File

@ -0,0 +1,2 @@
---
saltybot_social_msgs/FaceEmbedding[] persons

View File

@ -0,0 +1,5 @@
int32 person_id
string new_name
---
bool success
string message