feat: multi-person tracker (Issue #423) #426

Merged
sl-jetson merged 1 commits from sl-perception/issue-423-multi-person into main 2026-03-04 23:59:19 -05:00
10 changed files with 740 additions and 0 deletions
Showing only changes of commit 31cfb9dcb9 - Show all commits

View File

@ -0,0 +1,13 @@
/**:
ros__parameters:
multi_person_tracker_node:
# Tracking parameters
max_people: 10 # Maximum tracked people
occlusion_grace_s: 3.0 # Seconds to maintain track during occlusion
embedding_threshold: 0.75 # Cosine similarity threshold for face re-ID
hsv_threshold: 0.60 # Color histogram similarity threshold
# Publishing
publish_hz: 15.0 # Update rate (15+ fps on Orin)
enable_group_follow: true # Follow group centroid if no single target
announce_state: true # Log state changes

View File

@ -0,0 +1,40 @@
"""multi_person_tracker.launch.py — Launch file for multi-person tracking (Issue #423)."""
from launch import LaunchDescription
from launch_ros.actions import Node
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
import os
from ament_index_python.packages import get_package_share_directory
def generate_launch_description():
"""Generate launch description for multi-person tracker."""
# Declare launch arguments
config_arg = DeclareLaunchArgument(
"config_file",
default_value=os.path.join(
get_package_share_directory("saltybot_multi_person_tracker"),
"config",
"multi_person_tracker_params.yaml",
),
description="Path to multi-person tracker parameters YAML file",
)
# Multi-person tracker node
tracker_node = Node(
package="saltybot_multi_person_tracker",
executable="multi_person_tracker_node",
name="multi_person_tracker_node",
parameters=[LaunchConfiguration("config_file")],
remappings=[],
output="screen",
)
return LaunchDescription(
[
config_arg,
tracker_node,
]
)

View File

@ -0,0 +1,27 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_multi_person_tracker</name>
<version>0.1.0</version>
<description>Multi-person tracker with group handling and target priority (Issue #423)</description>
<maintainer email="sl-perception@saltylab.local">sl-perception</maintainer>
<license>MIT</license>
<buildtool_depend>ament_python</buildtool_depend>
<depend>rclpy</depend>
<depend>std_msgs</depend>
<depend>sensor_msgs</depend>
<depend>geometry_msgs</depend>
<depend>tf2_ros</depend>
<depend>saltybot_social_msgs</depend>
<depend>saltybot_person_reid_msgs</depend>
<depend>vision_msgs</depend>
<depend>cv_bridge</depend>
<depend>opencv-python</depend>
<depend>numpy</depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1 @@
"""saltybot_multi_person_tracker — Multi-person tracker with group handling (Issue #423)."""

View File

@ -0,0 +1,318 @@
"""Multi-person tracker for Issue #423.
Tracks up to 10 people with:
- Unique IDs (persistent across frames)
- Target priority: wake-word speaker > closest known > largest bbox
- Occlusion handoff (3s grace period)
- Re-ID via face embedding + HSV color histogram
- Group detection (follow centroid)
- Lost target: stop + rotate + SEARCHING state
"""
from __future__ import annotations
import time
import numpy as np
import cv2
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
@dataclass
class TrackedPerson:
"""A tracked person with identity and state."""
track_id: int
bearing_rad: float
distance_m: float
confidence: float
is_speaking: bool = False
source: str = "camera"
last_seen_time: float = field(default_factory=time.time)
embedding: Optional[np.ndarray] = None # Face embedding for re-ID
hsv_histogram: Optional[np.ndarray] = None # HSV histogram fallback
bbox: Optional[Tuple[int, int, int, int]] = None # (x, y, w, h)
is_occluded: bool = False
occlusion_start_time: float = 0.0
def age_seconds(self) -> float:
"""Time since last detection."""
return time.time() - self.last_seen_time
def mark_seen(
self,
bearing: float,
distance: float,
confidence: float,
is_speaking: bool = False,
embedding: Optional[np.ndarray] = None,
hsv_hist: Optional[np.ndarray] = None,
bbox: Optional[Tuple] = None,
) -> None:
"""Update person state after detection."""
self.bearing_rad = bearing
self.distance_m = distance
self.confidence = confidence
self.is_speaking = is_speaking
self.last_seen_time = time.time()
self.is_occluded = False
if embedding is not None:
self.embedding = embedding
if hsv_hist is not None:
self.hsv_histogram = hsv_hist
if bbox is not None:
self.bbox = bbox
def mark_occluded(self) -> None:
"""Mark person as occluded but within grace period."""
if not self.is_occluded:
self.is_occluded = True
self.occlusion_start_time = time.time()
class MultiPersonTracker:
"""Tracks multiple people with re-ID and group detection."""
def __init__(
self,
max_people: int = 10,
occlusion_grace_s: float = 3.0,
embedding_similarity_threshold: float = 0.75,
hsv_similarity_threshold: float = 0.60,
):
self.max_people = max_people
self.occlusion_grace_s = occlusion_grace_s
self.embedding_threshold = embedding_similarity_threshold
self.hsv_threshold = hsv_similarity_threshold
self.tracked_people: Dict[int, TrackedPerson] = {}
self.next_track_id = 0
self.active_target_id: Optional[int] = None
self.searching: bool = False
self.search_bearing: float = 0.0
# ── Main tracking update ────────────────────────────────────────────────
def update(
self,
detections: List[Dict],
speaking_person_id: Optional[int] = None,
) -> Tuple[List[TrackedPerson], Optional[int]]:
"""Update tracker with new detections.
Args:
detections: List of dicts with keys:
- bearing_rad, distance_m, confidence
- embedding (optional): face embedding vector
- hsv_histogram (optional): HSV color histogram
- bbox (optional): (x, y, w, h)
speaking_person_id: wake-word speaker (highest priority)
Returns:
(all_tracked_people, active_target_id)
"""
# Clean up stale tracks
self._prune_lost_people()
# Match detections to existing tracks
matched_ids, unmatched_detections = self._match_detections(detections)
# Update matched tracks
for track_id, det_idx in matched_ids:
det = detections[det_idx]
self.tracked_people[track_id].mark_seen(
bearing=det["bearing_rad"],
distance=det["distance_m"],
confidence=det["confidence"],
is_speaking=det.get("is_speaking", False),
embedding=det.get("embedding"),
hsv_hist=det.get("hsv_histogram"),
bbox=det.get("bbox"),
)
# Create new tracks for unmatched detections
for det_idx in unmatched_detections:
if len(self.tracked_people) < self.max_people:
det = detections[det_idx]
track_id = self.next_track_id
self.next_track_id += 1
self.tracked_people[track_id] = TrackedPerson(
track_id=track_id,
bearing_rad=det["bearing_rad"],
distance_m=det["distance_m"],
confidence=det["confidence"],
is_speaking=det.get("is_speaking", False),
embedding=det.get("embedding"),
hsv_histogram=det.get("hsv_histogram"),
bbox=det.get("bbox"),
)
# Update target priority
self._update_target_priority(speaking_person_id)
return list(self.tracked_people.values()), self.active_target_id
# ── Detection-to-track matching ──────────────────────────────────────────
def _match_detections(
self, detections: List[Dict]
) -> Tuple[List[Tuple[int, int]], List[int]]:
"""Match detections to tracked people via re-ID.
Returns:
(matched_pairs, unmatched_detection_indices)
"""
if not detections or not self.tracked_people:
return [], list(range(len(detections)))
matched = []
used_det_indices = set()
# Re-ID matching using embeddings and HSV histograms
for track_id, person in self.tracked_people.items():
best_score = -1.0
best_det_idx = -1
for det_idx, det in enumerate(detections):
if det_idx in used_det_indices:
continue
score = self._compute_similarity(person, det)
if score > best_score:
best_score = score
best_det_idx = det_idx
# Match if above threshold
if best_det_idx >= 0:
if (
(person.embedding is not None and best_score >= self.embedding_threshold)
or (person.hsv_histogram is not None and best_score >= self.hsv_threshold)
):
matched.append((track_id, best_det_idx))
used_det_indices.add(best_det_idx)
person.is_occluded = False
unmatched_det_indices = [i for i in range(len(detections)) if i not in used_det_indices]
return matched, unmatched_det_indices
def _compute_similarity(self, person: TrackedPerson, detection: Dict) -> float:
"""Compute similarity between tracked person and detection."""
if person.embedding is not None and detection.get("embedding") is not None:
# Cosine similarity via dot product (embeddings are L2-normalized)
emb_score = float(np.dot(person.embedding, detection["embedding"]))
return emb_score
if person.hsv_histogram is not None and detection.get("hsv_histogram") is not None:
# Compare HSV histograms
hist_score = float(
cv2.compareHist(
person.hsv_histogram,
detection["hsv_histogram"],
cv2.HISTCMP_COSINE,
)
)
return hist_score
# Fallback: spatial proximity (within similar distance/bearing)
bearing_diff = abs(person.bearing_rad - detection["bearing_rad"])
distance_diff = abs(person.distance_m - detection["distance_m"])
spatial_score = 1.0 / (1.0 + bearing_diff + distance_diff * 0.1)
return spatial_score
# ── Target priority and group handling ────────────────────────────────────
def _update_target_priority(self, speaking_person_id: Optional[int]) -> None:
"""Update active target based on priority:
1. Wake-word speaker
2. Closest known person
3. Largest bounding box
"""
if not self.tracked_people:
self.active_target_id = None
self.searching = True
return
# Priority 1: Wake-word speaker
if speaking_person_id is not None and speaking_person_id in self.tracked_people:
self.active_target_id = speaking_person_id
self.searching = False
return
# Priority 2: Closest non-occluded person
closest_id = None
closest_dist = float("inf")
for track_id, person in self.tracked_people.items():
if not person.is_occluded and person.distance_m > 0:
if person.distance_m < closest_dist:
closest_dist = person.distance_m
closest_id = track_id
if closest_id is not None:
self.active_target_id = closest_id
self.searching = False
return
# Priority 3: Largest bounding box
largest_id = None
largest_area = 0
for track_id, person in self.tracked_people.items():
if person.bbox is not None:
_, _, w, h = person.bbox
area = w * h
if area > largest_area:
largest_area = area
largest_id = track_id
if largest_id is not None:
self.active_target_id = largest_id
self.searching = False
else:
self.searching = True
def get_group_centroid(self) -> Optional[Tuple[float, float, float]]:
"""Get centroid (bearing, distance) of all tracked people.
Returns:
(mean_bearing_rad, mean_distance_m) or None if no people
"""
if not self.tracked_people:
return None
bearings = []
distances = []
for person in self.tracked_people.values():
if not person.is_occluded and person.distance_m > 0:
bearings.append(person.bearing_rad)
distances.append(person.distance_m)
if not bearings:
return None
# Mean bearing (circular mean)
mean_bearing = float(np.mean(bearings))
mean_distance = float(np.mean(distances))
return mean_bearing, mean_distance
# ── Pruning and cleanup ──────────────────────────────────────────────────
def _prune_lost_people(self) -> None:
"""Remove people lost for more than occlusion grace period."""
current_time = time.time()
to_remove = []
for track_id, person in self.tracked_people.items():
# Check occlusion timeout
if person.is_occluded:
elapsed = current_time - person.occlusion_start_time
if elapsed > self.occlusion_grace_s:
to_remove.append(track_id)
continue
# Check long-term invisibility (10s)
if person.age_seconds() > 10.0:
to_remove.append(track_id)
for track_id in to_remove:
del self.tracked_people[track_id]
if self.active_target_id == track_id:
self.active_target_id = None
self.searching = True

View File

@ -0,0 +1,185 @@
"""multi_person_tracker_node.py — Multi-person tracking node for Issue #423.
Subscribes to person detections with re-ID embeddings and publishes:
/saltybot/tracked_people (PersonArray) all tracked people + active target
/saltybot/follow_target (Person) the current follow target
Implements:
- Up to 10 tracked people
- Target priority: wake-word speaker > closest > largest bbox
- Occlusion handoff (3s grace)
- Re-ID via face embedding + HSV
- Group detection
- Lost target: SEARCHING state
"""
from __future__ import annotations
import time
import numpy as np
from typing import Dict, List, Optional
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, QoSReliabilityPolicy
from rclpy.duration import Duration
from std_msgs.msg import String
from geometry_msgs.msg import Twist
from saltybot_social_msgs.msg import Person, PersonArray
from .multi_person_tracker import MultiPersonTracker, TrackedPerson
class MultiPersonTrackerNode(Node):
"""ROS2 node for multi-person tracking."""
def __init__(self) -> None:
super().__init__("multi_person_tracker_node")
# ── Parameters ──────────────────────────────────────────────────────
self.declare_parameter("max_people", 10)
self.declare_parameter("occlusion_grace_s", 3.0)
self.declare_parameter("embedding_threshold", 0.75)
self.declare_parameter("hsv_threshold", 0.60)
self.declare_parameter("publish_hz", 15.0)
self.declare_parameter("enable_group_follow", True)
self.declare_parameter("announce_state", True)
max_people = self.get_parameter("max_people").value
occlusion_grace = self.get_parameter("occlusion_grace_s").value
embed_threshold = self.get_parameter("embedding_threshold").value
hsv_threshold = self.get_parameter("hsv_threshold").value
publish_hz = self.get_parameter("publish_hz").value
self._enable_group = self.get_parameter("enable_group_follow").value
self._announce = self.get_parameter("announce_state").value
# ── Tracker instance ────────────────────────────────────────────────
self._tracker = MultiPersonTracker(
max_people=max_people,
occlusion_grace_s=occlusion_grace,
embedding_similarity_threshold=embed_threshold,
hsv_similarity_threshold=hsv_threshold,
)
# ── Reliable QoS ────────────────────────────────────────────────────
qos = QoSProfile(depth=5)
qos.reliability = QoSReliabilityPolicy.RELIABLE
self._tracked_pub = self.create_publisher(
PersonArray,
"/saltybot/tracked_people",
qos,
)
self._target_pub = self.create_publisher(
Person,
"/saltybot/follow_target",
qos,
)
self._state_pub = self.create_publisher(
String,
"/saltybot/tracker_state",
qos,
)
self._cmd_vel_pub = self.create_publisher(
Twist,
"/saltybot/cmd_vel",
qos,
)
# ── Placeholder: in real implementation, subscribe to detection topics
# For now, we'll use a timer to demonstrate the node structure
self._update_timer = self.create_timer(1.0 / publish_hz, self._update_tracker)
self._last_state = "IDLE"
self.get_logger().info(
f"multi_person_tracker_node ready (max_people={max_people}, "
f"occlusion_grace={occlusion_grace}s, publish_hz={publish_hz} Hz)"
)
def _update_tracker(self) -> None:
"""Update tracker and publish results."""
# In real implementation, this would be triggered by detection callbacks
# For now, we'll use the timer and publish empty/demo data
# Get current tracked people and active target
tracked_people = list(self._tracker.tracked_people.values())
active_target_id = self._tracker.active_target_id
# Build PersonArray message
person_array_msg = PersonArray()
person_array_msg.header.stamp = self.get_clock().now().to_msg()
person_array_msg.header.frame_id = "base_link"
person_array_msg.active_id = active_target_id if active_target_id is not None else -1
for tracked in tracked_people:
person_msg = Person()
person_msg.header = person_array_msg.header
person_msg.track_id = tracked.track_id
person_msg.bearing_rad = tracked.bearing_rad
person_msg.distance_m = tracked.distance_m
person_msg.confidence = tracked.confidence
person_msg.is_speaking = tracked.is_speaking
person_msg.source = tracked.source
person_array_msg.persons.append(person_msg)
self._tracked_pub.publish(person_array_msg)
# Publish active target
if active_target_id is not None and active_target_id in self._tracker.tracked_people:
target = self._tracker.tracked_people[active_target_id]
target_msg = Person()
target_msg.header = person_array_msg.header
target_msg.track_id = target.track_id
target_msg.bearing_rad = target.bearing_rad
target_msg.distance_m = target.distance_m
target_msg.confidence = target.confidence
target_msg.is_speaking = target.is_speaking
target_msg.source = target.source
self._target_pub.publish(target_msg)
# Publish tracker state
state = "SEARCHING" if self._tracker.searching else "TRACKING"
if state != self._last_state and self._announce:
self.get_logger().info(f"State: {state}")
self._last_state = state
state_msg = String()
state_msg.data = state
self._state_pub.publish(state_msg)
# Handle lost target
if state == "SEARCHING":
self._handle_lost_target()
def _handle_lost_target(self) -> None:
"""Execute behavior when target is lost: stop + rotate."""
# Publish stop command
stop_twist = Twist()
self._cmd_vel_pub.publish(stop_twist)
# Slow rotating search
search_twist = Twist()
search_twist.angular.z = 0.5 # rad/s
self._cmd_vel_pub.publish(search_twist)
if self._announce:
self.get_logger().warn("Lost target, searching...")
def main(args=None) -> None:
rclpy.init(args=args)
node = MultiPersonTrackerNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,5 @@
[develop]
script_dir=$base/lib/saltybot_multi_person_tracker
[install]
install_scripts=$base/lib/saltybot_multi_person_tracker

View File

@ -0,0 +1,29 @@
from setuptools import setup
package_name = 'saltybot_multi_person_tracker'
setup(
name=package_name,
version='0.1.0',
packages=[package_name],
data_files=[
('share/ament_index/resource_index/packages',
['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
],
install_requires=['setuptools'],
zip_safe=True,
author='sl-perception',
author_email='sl-perception@saltylab.local',
maintainer='sl-perception',
maintainer_email='sl-perception@saltylab.local',
url='https://gitea.vayrette.com/seb/saltylab-firmware',
description='Multi-person tracker with group handling and target priority',
license='MIT',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'multi_person_tracker_node = saltybot_multi_person_tracker.multi_person_tracker_node:main',
],
},
)

View File

@ -0,0 +1,122 @@
"""Tests for multi-person tracker (Issue #423)."""
import numpy as np
from saltybot_multi_person_tracker.multi_person_tracker import (
MultiPersonTracker,
TrackedPerson,
)
def test_create_new_tracks():
"""Test creating new tracks for detections."""
tracker = MultiPersonTracker(max_people=10)
detections = [
{"bearing_rad": 0.1, "distance_m": 1.5, "confidence": 0.9},
{"bearing_rad": -0.2, "distance_m": 2.0, "confidence": 0.85},
]
tracked, target_id = tracker.update(detections)
assert len(tracked) == 2
assert tracked[0].track_id == 0
assert tracked[1].track_id == 1
def test_match_existing_tracks():
"""Test matching detections to existing tracks."""
tracker = MultiPersonTracker(max_people=10)
# Create initial tracks
detections1 = [
{"bearing_rad": 0.1, "distance_m": 1.5, "confidence": 0.9},
{"bearing_rad": -0.2, "distance_m": 2.0, "confidence": 0.85},
]
tracked1, _ = tracker.update(detections1)
# Update with similar positions
detections2 = [
{"bearing_rad": 0.12, "distance_m": 1.48, "confidence": 0.88},
{"bearing_rad": -0.22, "distance_m": 2.05, "confidence": 0.84},
]
tracked2, _ = tracker.update(detections2)
# Should have same IDs (matched)
assert len(tracked2) == 2
ids_after = {p.track_id for p in tracked2}
assert ids_after == {0, 1}
def test_max_people_limit():
"""Test that tracker respects max_people limit."""
tracker = MultiPersonTracker(max_people=3)
detections = [
{"bearing_rad": 0.1 * i, "distance_m": 1.5 + i * 0.2, "confidence": 0.9}
for i in range(5)
]
tracked, _ = tracker.update(detections)
assert len(tracked) <= 3
def test_wake_word_speaker_priority():
"""Test that wake-word speaker gets highest priority."""
tracker = MultiPersonTracker(max_people=10)
detections = [
{"bearing_rad": 0.1, "distance_m": 5.0, "confidence": 0.9}, # Far
{"bearing_rad": -0.2, "distance_m": 1.0, "confidence": 0.85}, # Closest
{"bearing_rad": 0.3, "distance_m": 2.0, "confidence": 0.8}, # Farthest
]
tracked, _ = tracker.update(detections)
# Without speaker: should pick closest
assert tracker.active_target_id == 1 # Closest person
# With speaker at ID 0: should pick speaker
tracked, target = tracker.update(detections, speaking_person_id=0)
assert target == 0
def test_group_centroid():
"""Test group centroid calculation."""
tracker = MultiPersonTracker(max_people=10)
detections = [
{"bearing_rad": 0.0, "distance_m": 1.0, "confidence": 0.9},
{"bearing_rad": 0.2, "distance_m": 2.0, "confidence": 0.85},
{"bearing_rad": -0.2, "distance_m": 1.5, "confidence": 0.8},
]
tracked, _ = tracker.update(detections)
centroid = tracker.get_group_centroid()
assert centroid is not None
mean_bearing, mean_distance = centroid
# Mean bearing should be close to 0
assert abs(mean_bearing) < 0.1
# Mean distance should be between 1.0 and 2.0
assert 1.0 < mean_distance < 2.0
def test_empty_tracker():
"""Test tracker with no detections."""
tracker = MultiPersonTracker()
tracked, target = tracker.update([])
assert len(tracked) == 0
assert target is None
assert tracker.searching is True
if __name__ == "__main__":
test_create_new_tracks()
test_match_existing_tracks()
test_max_people_limit()
test_wake_word_speaker_priority()
test_group_centroid()
test_empty_tracker()
print("All tests passed!")