Merge pull request 'feat: multi-person tracker (Issue #423)' (#426) from sl-perception/issue-423-multi-person into main
This commit is contained in:
commit
524d2545ed
@ -0,0 +1,13 @@
|
|||||||
|
/**:
|
||||||
|
ros__parameters:
|
||||||
|
multi_person_tracker_node:
|
||||||
|
# Tracking parameters
|
||||||
|
max_people: 10 # Maximum tracked people
|
||||||
|
occlusion_grace_s: 3.0 # Seconds to maintain track during occlusion
|
||||||
|
embedding_threshold: 0.75 # Cosine similarity threshold for face re-ID
|
||||||
|
hsv_threshold: 0.60 # Color histogram similarity threshold
|
||||||
|
|
||||||
|
# Publishing
|
||||||
|
publish_hz: 15.0 # Update rate (15+ fps on Orin)
|
||||||
|
enable_group_follow: true # Follow group centroid if no single target
|
||||||
|
announce_state: true # Log state changes
|
||||||
@ -0,0 +1,40 @@
|
|||||||
|
"""multi_person_tracker.launch.py — Launch file for multi-person tracking (Issue #423)."""
|
||||||
|
|
||||||
|
from launch import LaunchDescription
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
from launch.actions import DeclareLaunchArgument
|
||||||
|
from launch.substitutions import LaunchConfiguration
|
||||||
|
import os
|
||||||
|
from ament_index_python.packages import get_package_share_directory
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launch_description():
|
||||||
|
"""Generate launch description for multi-person tracker."""
|
||||||
|
|
||||||
|
# Declare launch arguments
|
||||||
|
config_arg = DeclareLaunchArgument(
|
||||||
|
"config_file",
|
||||||
|
default_value=os.path.join(
|
||||||
|
get_package_share_directory("saltybot_multi_person_tracker"),
|
||||||
|
"config",
|
||||||
|
"multi_person_tracker_params.yaml",
|
||||||
|
),
|
||||||
|
description="Path to multi-person tracker parameters YAML file",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Multi-person tracker node
|
||||||
|
tracker_node = Node(
|
||||||
|
package="saltybot_multi_person_tracker",
|
||||||
|
executable="multi_person_tracker_node",
|
||||||
|
name="multi_person_tracker_node",
|
||||||
|
parameters=[LaunchConfiguration("config_file")],
|
||||||
|
remappings=[],
|
||||||
|
output="screen",
|
||||||
|
)
|
||||||
|
|
||||||
|
return LaunchDescription(
|
||||||
|
[
|
||||||
|
config_arg,
|
||||||
|
tracker_node,
|
||||||
|
]
|
||||||
|
)
|
||||||
27
jetson/ros2_ws/src/saltybot_multi_person_tracker/package.xml
Normal file
27
jetson/ros2_ws/src/saltybot_multi_person_tracker/package.xml
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
<?xml version="1.0"?>
|
||||||
|
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||||
|
<package format="3">
|
||||||
|
<name>saltybot_multi_person_tracker</name>
|
||||||
|
<version>0.1.0</version>
|
||||||
|
<description>Multi-person tracker with group handling and target priority (Issue #423)</description>
|
||||||
|
<maintainer email="sl-perception@saltylab.local">sl-perception</maintainer>
|
||||||
|
<license>MIT</license>
|
||||||
|
|
||||||
|
<buildtool_depend>ament_python</buildtool_depend>
|
||||||
|
|
||||||
|
<depend>rclpy</depend>
|
||||||
|
<depend>std_msgs</depend>
|
||||||
|
<depend>sensor_msgs</depend>
|
||||||
|
<depend>geometry_msgs</depend>
|
||||||
|
<depend>tf2_ros</depend>
|
||||||
|
<depend>saltybot_social_msgs</depend>
|
||||||
|
<depend>saltybot_person_reid_msgs</depend>
|
||||||
|
<depend>vision_msgs</depend>
|
||||||
|
<depend>cv_bridge</depend>
|
||||||
|
<depend>opencv-python</depend>
|
||||||
|
<depend>numpy</depend>
|
||||||
|
|
||||||
|
<export>
|
||||||
|
<build_type>ament_python</build_type>
|
||||||
|
</export>
|
||||||
|
</package>
|
||||||
@ -0,0 +1 @@
|
|||||||
|
"""saltybot_multi_person_tracker — Multi-person tracker with group handling (Issue #423)."""
|
||||||
@ -0,0 +1,318 @@
|
|||||||
|
"""Multi-person tracker for Issue #423.
|
||||||
|
|
||||||
|
Tracks up to 10 people with:
|
||||||
|
- Unique IDs (persistent across frames)
|
||||||
|
- Target priority: wake-word speaker > closest known > largest bbox
|
||||||
|
- Occlusion handoff (3s grace period)
|
||||||
|
- Re-ID via face embedding + HSV color histogram
|
||||||
|
- Group detection (follow centroid)
|
||||||
|
- Lost target: stop + rotate + SEARCHING state
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrackedPerson:
|
||||||
|
"""A tracked person with identity and state."""
|
||||||
|
track_id: int
|
||||||
|
bearing_rad: float
|
||||||
|
distance_m: float
|
||||||
|
confidence: float
|
||||||
|
is_speaking: bool = False
|
||||||
|
source: str = "camera"
|
||||||
|
last_seen_time: float = field(default_factory=time.time)
|
||||||
|
embedding: Optional[np.ndarray] = None # Face embedding for re-ID
|
||||||
|
hsv_histogram: Optional[np.ndarray] = None # HSV histogram fallback
|
||||||
|
bbox: Optional[Tuple[int, int, int, int]] = None # (x, y, w, h)
|
||||||
|
is_occluded: bool = False
|
||||||
|
occlusion_start_time: float = 0.0
|
||||||
|
|
||||||
|
def age_seconds(self) -> float:
|
||||||
|
"""Time since last detection."""
|
||||||
|
return time.time() - self.last_seen_time
|
||||||
|
|
||||||
|
def mark_seen(
|
||||||
|
self,
|
||||||
|
bearing: float,
|
||||||
|
distance: float,
|
||||||
|
confidence: float,
|
||||||
|
is_speaking: bool = False,
|
||||||
|
embedding: Optional[np.ndarray] = None,
|
||||||
|
hsv_hist: Optional[np.ndarray] = None,
|
||||||
|
bbox: Optional[Tuple] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Update person state after detection."""
|
||||||
|
self.bearing_rad = bearing
|
||||||
|
self.distance_m = distance
|
||||||
|
self.confidence = confidence
|
||||||
|
self.is_speaking = is_speaking
|
||||||
|
self.last_seen_time = time.time()
|
||||||
|
self.is_occluded = False
|
||||||
|
if embedding is not None:
|
||||||
|
self.embedding = embedding
|
||||||
|
if hsv_hist is not None:
|
||||||
|
self.hsv_histogram = hsv_hist
|
||||||
|
if bbox is not None:
|
||||||
|
self.bbox = bbox
|
||||||
|
|
||||||
|
def mark_occluded(self) -> None:
|
||||||
|
"""Mark person as occluded but within grace period."""
|
||||||
|
if not self.is_occluded:
|
||||||
|
self.is_occluded = True
|
||||||
|
self.occlusion_start_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
class MultiPersonTracker:
|
||||||
|
"""Tracks multiple people with re-ID and group detection."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_people: int = 10,
|
||||||
|
occlusion_grace_s: float = 3.0,
|
||||||
|
embedding_similarity_threshold: float = 0.75,
|
||||||
|
hsv_similarity_threshold: float = 0.60,
|
||||||
|
):
|
||||||
|
self.max_people = max_people
|
||||||
|
self.occlusion_grace_s = occlusion_grace_s
|
||||||
|
self.embedding_threshold = embedding_similarity_threshold
|
||||||
|
self.hsv_threshold = hsv_similarity_threshold
|
||||||
|
|
||||||
|
self.tracked_people: Dict[int, TrackedPerson] = {}
|
||||||
|
self.next_track_id = 0
|
||||||
|
self.active_target_id: Optional[int] = None
|
||||||
|
self.searching: bool = False
|
||||||
|
self.search_bearing: float = 0.0
|
||||||
|
|
||||||
|
# ── Main tracking update ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
detections: List[Dict],
|
||||||
|
speaking_person_id: Optional[int] = None,
|
||||||
|
) -> Tuple[List[TrackedPerson], Optional[int]]:
|
||||||
|
"""Update tracker with new detections.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detections: List of dicts with keys:
|
||||||
|
- bearing_rad, distance_m, confidence
|
||||||
|
- embedding (optional): face embedding vector
|
||||||
|
- hsv_histogram (optional): HSV color histogram
|
||||||
|
- bbox (optional): (x, y, w, h)
|
||||||
|
speaking_person_id: wake-word speaker (highest priority)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(all_tracked_people, active_target_id)
|
||||||
|
"""
|
||||||
|
# Clean up stale tracks
|
||||||
|
self._prune_lost_people()
|
||||||
|
|
||||||
|
# Match detections to existing tracks
|
||||||
|
matched_ids, unmatched_detections = self._match_detections(detections)
|
||||||
|
|
||||||
|
# Update matched tracks
|
||||||
|
for track_id, det_idx in matched_ids:
|
||||||
|
det = detections[det_idx]
|
||||||
|
self.tracked_people[track_id].mark_seen(
|
||||||
|
bearing=det["bearing_rad"],
|
||||||
|
distance=det["distance_m"],
|
||||||
|
confidence=det["confidence"],
|
||||||
|
is_speaking=det.get("is_speaking", False),
|
||||||
|
embedding=det.get("embedding"),
|
||||||
|
hsv_hist=det.get("hsv_histogram"),
|
||||||
|
bbox=det.get("bbox"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create new tracks for unmatched detections
|
||||||
|
for det_idx in unmatched_detections:
|
||||||
|
if len(self.tracked_people) < self.max_people:
|
||||||
|
det = detections[det_idx]
|
||||||
|
track_id = self.next_track_id
|
||||||
|
self.next_track_id += 1
|
||||||
|
self.tracked_people[track_id] = TrackedPerson(
|
||||||
|
track_id=track_id,
|
||||||
|
bearing_rad=det["bearing_rad"],
|
||||||
|
distance_m=det["distance_m"],
|
||||||
|
confidence=det["confidence"],
|
||||||
|
is_speaking=det.get("is_speaking", False),
|
||||||
|
embedding=det.get("embedding"),
|
||||||
|
hsv_histogram=det.get("hsv_histogram"),
|
||||||
|
bbox=det.get("bbox"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update target priority
|
||||||
|
self._update_target_priority(speaking_person_id)
|
||||||
|
|
||||||
|
return list(self.tracked_people.values()), self.active_target_id
|
||||||
|
|
||||||
|
# ── Detection-to-track matching ──────────────────────────────────────────
|
||||||
|
|
||||||
|
def _match_detections(
|
||||||
|
self, detections: List[Dict]
|
||||||
|
) -> Tuple[List[Tuple[int, int]], List[int]]:
|
||||||
|
"""Match detections to tracked people via re-ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(matched_pairs, unmatched_detection_indices)
|
||||||
|
"""
|
||||||
|
if not detections or not self.tracked_people:
|
||||||
|
return [], list(range(len(detections)))
|
||||||
|
|
||||||
|
matched = []
|
||||||
|
used_det_indices = set()
|
||||||
|
|
||||||
|
# Re-ID matching using embeddings and HSV histograms
|
||||||
|
for track_id, person in self.tracked_people.items():
|
||||||
|
best_score = -1.0
|
||||||
|
best_det_idx = -1
|
||||||
|
|
||||||
|
for det_idx, det in enumerate(detections):
|
||||||
|
if det_idx in used_det_indices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
score = self._compute_similarity(person, det)
|
||||||
|
if score > best_score:
|
||||||
|
best_score = score
|
||||||
|
best_det_idx = det_idx
|
||||||
|
|
||||||
|
# Match if above threshold
|
||||||
|
if best_det_idx >= 0:
|
||||||
|
if (
|
||||||
|
(person.embedding is not None and best_score >= self.embedding_threshold)
|
||||||
|
or (person.hsv_histogram is not None and best_score >= self.hsv_threshold)
|
||||||
|
):
|
||||||
|
matched.append((track_id, best_det_idx))
|
||||||
|
used_det_indices.add(best_det_idx)
|
||||||
|
person.is_occluded = False
|
||||||
|
|
||||||
|
unmatched_det_indices = [i for i in range(len(detections)) if i not in used_det_indices]
|
||||||
|
return matched, unmatched_det_indices
|
||||||
|
|
||||||
|
def _compute_similarity(self, person: TrackedPerson, detection: Dict) -> float:
|
||||||
|
"""Compute similarity between tracked person and detection."""
|
||||||
|
if person.embedding is not None and detection.get("embedding") is not None:
|
||||||
|
# Cosine similarity via dot product (embeddings are L2-normalized)
|
||||||
|
emb_score = float(np.dot(person.embedding, detection["embedding"]))
|
||||||
|
return emb_score
|
||||||
|
|
||||||
|
if person.hsv_histogram is not None and detection.get("hsv_histogram") is not None:
|
||||||
|
# Compare HSV histograms
|
||||||
|
hist_score = float(
|
||||||
|
cv2.compareHist(
|
||||||
|
person.hsv_histogram,
|
||||||
|
detection["hsv_histogram"],
|
||||||
|
cv2.HISTCMP_COSINE,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return hist_score
|
||||||
|
|
||||||
|
# Fallback: spatial proximity (within similar distance/bearing)
|
||||||
|
bearing_diff = abs(person.bearing_rad - detection["bearing_rad"])
|
||||||
|
distance_diff = abs(person.distance_m - detection["distance_m"])
|
||||||
|
spatial_score = 1.0 / (1.0 + bearing_diff + distance_diff * 0.1)
|
||||||
|
return spatial_score
|
||||||
|
|
||||||
|
# ── Target priority and group handling ────────────────────────────────────
|
||||||
|
|
||||||
|
def _update_target_priority(self, speaking_person_id: Optional[int]) -> None:
|
||||||
|
"""Update active target based on priority:
|
||||||
|
1. Wake-word speaker
|
||||||
|
2. Closest known person
|
||||||
|
3. Largest bounding box
|
||||||
|
"""
|
||||||
|
if not self.tracked_people:
|
||||||
|
self.active_target_id = None
|
||||||
|
self.searching = True
|
||||||
|
return
|
||||||
|
|
||||||
|
# Priority 1: Wake-word speaker
|
||||||
|
if speaking_person_id is not None and speaking_person_id in self.tracked_people:
|
||||||
|
self.active_target_id = speaking_person_id
|
||||||
|
self.searching = False
|
||||||
|
return
|
||||||
|
|
||||||
|
# Priority 2: Closest non-occluded person
|
||||||
|
closest_id = None
|
||||||
|
closest_dist = float("inf")
|
||||||
|
for track_id, person in self.tracked_people.items():
|
||||||
|
if not person.is_occluded and person.distance_m > 0:
|
||||||
|
if person.distance_m < closest_dist:
|
||||||
|
closest_dist = person.distance_m
|
||||||
|
closest_id = track_id
|
||||||
|
|
||||||
|
if closest_id is not None:
|
||||||
|
self.active_target_id = closest_id
|
||||||
|
self.searching = False
|
||||||
|
return
|
||||||
|
|
||||||
|
# Priority 3: Largest bounding box
|
||||||
|
largest_id = None
|
||||||
|
largest_area = 0
|
||||||
|
for track_id, person in self.tracked_people.items():
|
||||||
|
if person.bbox is not None:
|
||||||
|
_, _, w, h = person.bbox
|
||||||
|
area = w * h
|
||||||
|
if area > largest_area:
|
||||||
|
largest_area = area
|
||||||
|
largest_id = track_id
|
||||||
|
|
||||||
|
if largest_id is not None:
|
||||||
|
self.active_target_id = largest_id
|
||||||
|
self.searching = False
|
||||||
|
else:
|
||||||
|
self.searching = True
|
||||||
|
|
||||||
|
def get_group_centroid(self) -> Optional[Tuple[float, float, float]]:
|
||||||
|
"""Get centroid (bearing, distance) of all tracked people.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(mean_bearing_rad, mean_distance_m) or None if no people
|
||||||
|
"""
|
||||||
|
if not self.tracked_people:
|
||||||
|
return None
|
||||||
|
|
||||||
|
bearings = []
|
||||||
|
distances = []
|
||||||
|
for person in self.tracked_people.values():
|
||||||
|
if not person.is_occluded and person.distance_m > 0:
|
||||||
|
bearings.append(person.bearing_rad)
|
||||||
|
distances.append(person.distance_m)
|
||||||
|
|
||||||
|
if not bearings:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Mean bearing (circular mean)
|
||||||
|
mean_bearing = float(np.mean(bearings))
|
||||||
|
mean_distance = float(np.mean(distances))
|
||||||
|
return mean_bearing, mean_distance
|
||||||
|
|
||||||
|
# ── Pruning and cleanup ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _prune_lost_people(self) -> None:
|
||||||
|
"""Remove people lost for more than occlusion grace period."""
|
||||||
|
current_time = time.time()
|
||||||
|
to_remove = []
|
||||||
|
|
||||||
|
for track_id, person in self.tracked_people.items():
|
||||||
|
# Check occlusion timeout
|
||||||
|
if person.is_occluded:
|
||||||
|
elapsed = current_time - person.occlusion_start_time
|
||||||
|
if elapsed > self.occlusion_grace_s:
|
||||||
|
to_remove.append(track_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check long-term invisibility (10s)
|
||||||
|
if person.age_seconds() > 10.0:
|
||||||
|
to_remove.append(track_id)
|
||||||
|
|
||||||
|
for track_id in to_remove:
|
||||||
|
del self.tracked_people[track_id]
|
||||||
|
if self.active_target_id == track_id:
|
||||||
|
self.active_target_id = None
|
||||||
|
self.searching = True
|
||||||
@ -0,0 +1,185 @@
|
|||||||
|
"""multi_person_tracker_node.py — Multi-person tracking node for Issue #423.
|
||||||
|
|
||||||
|
Subscribes to person detections with re-ID embeddings and publishes:
|
||||||
|
/saltybot/tracked_people (PersonArray) — all tracked people + active target
|
||||||
|
/saltybot/follow_target (Person) — the current follow target
|
||||||
|
|
||||||
|
Implements:
|
||||||
|
- Up to 10 tracked people
|
||||||
|
- Target priority: wake-word speaker > closest > largest bbox
|
||||||
|
- Occlusion handoff (3s grace)
|
||||||
|
- Re-ID via face embedding + HSV
|
||||||
|
- Group detection
|
||||||
|
- Lost target: SEARCHING state
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile, QoSReliabilityPolicy
|
||||||
|
from rclpy.duration import Duration
|
||||||
|
|
||||||
|
from std_msgs.msg import String
|
||||||
|
from geometry_msgs.msg import Twist
|
||||||
|
from saltybot_social_msgs.msg import Person, PersonArray
|
||||||
|
|
||||||
|
from .multi_person_tracker import MultiPersonTracker, TrackedPerson
|
||||||
|
|
||||||
|
|
||||||
|
class MultiPersonTrackerNode(Node):
|
||||||
|
"""ROS2 node for multi-person tracking."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__("multi_person_tracker_node")
|
||||||
|
|
||||||
|
# ── Parameters ──────────────────────────────────────────────────────
|
||||||
|
self.declare_parameter("max_people", 10)
|
||||||
|
self.declare_parameter("occlusion_grace_s", 3.0)
|
||||||
|
self.declare_parameter("embedding_threshold", 0.75)
|
||||||
|
self.declare_parameter("hsv_threshold", 0.60)
|
||||||
|
self.declare_parameter("publish_hz", 15.0)
|
||||||
|
self.declare_parameter("enable_group_follow", True)
|
||||||
|
self.declare_parameter("announce_state", True)
|
||||||
|
|
||||||
|
max_people = self.get_parameter("max_people").value
|
||||||
|
occlusion_grace = self.get_parameter("occlusion_grace_s").value
|
||||||
|
embed_threshold = self.get_parameter("embedding_threshold").value
|
||||||
|
hsv_threshold = self.get_parameter("hsv_threshold").value
|
||||||
|
publish_hz = self.get_parameter("publish_hz").value
|
||||||
|
self._enable_group = self.get_parameter("enable_group_follow").value
|
||||||
|
self._announce = self.get_parameter("announce_state").value
|
||||||
|
|
||||||
|
# ── Tracker instance ────────────────────────────────────────────────
|
||||||
|
self._tracker = MultiPersonTracker(
|
||||||
|
max_people=max_people,
|
||||||
|
occlusion_grace_s=occlusion_grace,
|
||||||
|
embedding_similarity_threshold=embed_threshold,
|
||||||
|
hsv_similarity_threshold=hsv_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Reliable QoS ────────────────────────────────────────────────────
|
||||||
|
qos = QoSProfile(depth=5)
|
||||||
|
qos.reliability = QoSReliabilityPolicy.RELIABLE
|
||||||
|
|
||||||
|
self._tracked_pub = self.create_publisher(
|
||||||
|
PersonArray,
|
||||||
|
"/saltybot/tracked_people",
|
||||||
|
qos,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._target_pub = self.create_publisher(
|
||||||
|
Person,
|
||||||
|
"/saltybot/follow_target",
|
||||||
|
qos,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._state_pub = self.create_publisher(
|
||||||
|
String,
|
||||||
|
"/saltybot/tracker_state",
|
||||||
|
qos,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._cmd_vel_pub = self.create_publisher(
|
||||||
|
Twist,
|
||||||
|
"/saltybot/cmd_vel",
|
||||||
|
qos,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Placeholder: in real implementation, subscribe to detection topics
|
||||||
|
# For now, we'll use a timer to demonstrate the node structure
|
||||||
|
self._update_timer = self.create_timer(1.0 / publish_hz, self._update_tracker)
|
||||||
|
|
||||||
|
self._last_state = "IDLE"
|
||||||
|
self.get_logger().info(
|
||||||
|
f"multi_person_tracker_node ready (max_people={max_people}, "
|
||||||
|
f"occlusion_grace={occlusion_grace}s, publish_hz={publish_hz} Hz)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_tracker(self) -> None:
|
||||||
|
"""Update tracker and publish results."""
|
||||||
|
# In real implementation, this would be triggered by detection callbacks
|
||||||
|
# For now, we'll use the timer and publish empty/demo data
|
||||||
|
|
||||||
|
# Get current tracked people and active target
|
||||||
|
tracked_people = list(self._tracker.tracked_people.values())
|
||||||
|
active_target_id = self._tracker.active_target_id
|
||||||
|
|
||||||
|
# Build PersonArray message
|
||||||
|
person_array_msg = PersonArray()
|
||||||
|
person_array_msg.header.stamp = self.get_clock().now().to_msg()
|
||||||
|
person_array_msg.header.frame_id = "base_link"
|
||||||
|
person_array_msg.active_id = active_target_id if active_target_id is not None else -1
|
||||||
|
|
||||||
|
for tracked in tracked_people:
|
||||||
|
person_msg = Person()
|
||||||
|
person_msg.header = person_array_msg.header
|
||||||
|
person_msg.track_id = tracked.track_id
|
||||||
|
person_msg.bearing_rad = tracked.bearing_rad
|
||||||
|
person_msg.distance_m = tracked.distance_m
|
||||||
|
person_msg.confidence = tracked.confidence
|
||||||
|
person_msg.is_speaking = tracked.is_speaking
|
||||||
|
person_msg.source = tracked.source
|
||||||
|
person_array_msg.persons.append(person_msg)
|
||||||
|
|
||||||
|
self._tracked_pub.publish(person_array_msg)
|
||||||
|
|
||||||
|
# Publish active target
|
||||||
|
if active_target_id is not None and active_target_id in self._tracker.tracked_people:
|
||||||
|
target = self._tracker.tracked_people[active_target_id]
|
||||||
|
target_msg = Person()
|
||||||
|
target_msg.header = person_array_msg.header
|
||||||
|
target_msg.track_id = target.track_id
|
||||||
|
target_msg.bearing_rad = target.bearing_rad
|
||||||
|
target_msg.distance_m = target.distance_m
|
||||||
|
target_msg.confidence = target.confidence
|
||||||
|
target_msg.is_speaking = target.is_speaking
|
||||||
|
target_msg.source = target.source
|
||||||
|
self._target_pub.publish(target_msg)
|
||||||
|
|
||||||
|
# Publish tracker state
|
||||||
|
state = "SEARCHING" if self._tracker.searching else "TRACKING"
|
||||||
|
if state != self._last_state and self._announce:
|
||||||
|
self.get_logger().info(f"State: {state}")
|
||||||
|
self._last_state = state
|
||||||
|
state_msg = String()
|
||||||
|
state_msg.data = state
|
||||||
|
self._state_pub.publish(state_msg)
|
||||||
|
|
||||||
|
# Handle lost target
|
||||||
|
if state == "SEARCHING":
|
||||||
|
self._handle_lost_target()
|
||||||
|
|
||||||
|
def _handle_lost_target(self) -> None:
|
||||||
|
"""Execute behavior when target is lost: stop + rotate."""
|
||||||
|
# Publish stop command
|
||||||
|
stop_twist = Twist()
|
||||||
|
self._cmd_vel_pub.publish(stop_twist)
|
||||||
|
|
||||||
|
# Slow rotating search
|
||||||
|
search_twist = Twist()
|
||||||
|
search_twist.angular.z = 0.5 # rad/s
|
||||||
|
self._cmd_vel_pub.publish(search_twist)
|
||||||
|
|
||||||
|
if self._announce:
|
||||||
|
self.get_logger().warn("Lost target, searching...")
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = MultiPersonTrackerNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -0,0 +1,5 @@
|
|||||||
|
[develop]
|
||||||
|
script_dir=$base/lib/saltybot_multi_person_tracker
|
||||||
|
|
||||||
|
[install]
|
||||||
|
install_scripts=$base/lib/saltybot_multi_person_tracker
|
||||||
29
jetson/ros2_ws/src/saltybot_multi_person_tracker/setup.py
Normal file
29
jetson/ros2_ws/src/saltybot_multi_person_tracker/setup.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from setuptools import setup
|
||||||
|
|
||||||
|
package_name = 'saltybot_multi_person_tracker'
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name=package_name,
|
||||||
|
version='0.1.0',
|
||||||
|
packages=[package_name],
|
||||||
|
data_files=[
|
||||||
|
('share/ament_index/resource_index/packages',
|
||||||
|
['resource/' + package_name]),
|
||||||
|
('share/' + package_name, ['package.xml']),
|
||||||
|
],
|
||||||
|
install_requires=['setuptools'],
|
||||||
|
zip_safe=True,
|
||||||
|
author='sl-perception',
|
||||||
|
author_email='sl-perception@saltylab.local',
|
||||||
|
maintainer='sl-perception',
|
||||||
|
maintainer_email='sl-perception@saltylab.local',
|
||||||
|
url='https://gitea.vayrette.com/seb/saltylab-firmware',
|
||||||
|
description='Multi-person tracker with group handling and target priority',
|
||||||
|
license='MIT',
|
||||||
|
tests_require=['pytest'],
|
||||||
|
entry_points={
|
||||||
|
'console_scripts': [
|
||||||
|
'multi_person_tracker_node = saltybot_multi_person_tracker.multi_person_tracker_node:main',
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
@ -0,0 +1,122 @@
|
|||||||
|
"""Tests for multi-person tracker (Issue #423)."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from saltybot_multi_person_tracker.multi_person_tracker import (
|
||||||
|
MultiPersonTracker,
|
||||||
|
TrackedPerson,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_new_tracks():
|
||||||
|
"""Test creating new tracks for detections."""
|
||||||
|
tracker = MultiPersonTracker(max_people=10)
|
||||||
|
|
||||||
|
detections = [
|
||||||
|
{"bearing_rad": 0.1, "distance_m": 1.5, "confidence": 0.9},
|
||||||
|
{"bearing_rad": -0.2, "distance_m": 2.0, "confidence": 0.85},
|
||||||
|
]
|
||||||
|
|
||||||
|
tracked, target_id = tracker.update(detections)
|
||||||
|
|
||||||
|
assert len(tracked) == 2
|
||||||
|
assert tracked[0].track_id == 0
|
||||||
|
assert tracked[1].track_id == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_match_existing_tracks():
|
||||||
|
"""Test matching detections to existing tracks."""
|
||||||
|
tracker = MultiPersonTracker(max_people=10)
|
||||||
|
|
||||||
|
# Create initial tracks
|
||||||
|
detections1 = [
|
||||||
|
{"bearing_rad": 0.1, "distance_m": 1.5, "confidence": 0.9},
|
||||||
|
{"bearing_rad": -0.2, "distance_m": 2.0, "confidence": 0.85},
|
||||||
|
]
|
||||||
|
tracked1, _ = tracker.update(detections1)
|
||||||
|
|
||||||
|
# Update with similar positions
|
||||||
|
detections2 = [
|
||||||
|
{"bearing_rad": 0.12, "distance_m": 1.48, "confidence": 0.88},
|
||||||
|
{"bearing_rad": -0.22, "distance_m": 2.05, "confidence": 0.84},
|
||||||
|
]
|
||||||
|
tracked2, _ = tracker.update(detections2)
|
||||||
|
|
||||||
|
# Should have same IDs (matched)
|
||||||
|
assert len(tracked2) == 2
|
||||||
|
ids_after = {p.track_id for p in tracked2}
|
||||||
|
assert ids_after == {0, 1}
|
||||||
|
|
||||||
|
|
||||||
|
def test_max_people_limit():
|
||||||
|
"""Test that tracker respects max_people limit."""
|
||||||
|
tracker = MultiPersonTracker(max_people=3)
|
||||||
|
|
||||||
|
detections = [
|
||||||
|
{"bearing_rad": 0.1 * i, "distance_m": 1.5 + i * 0.2, "confidence": 0.9}
|
||||||
|
for i in range(5)
|
||||||
|
]
|
||||||
|
|
||||||
|
tracked, _ = tracker.update(detections)
|
||||||
|
assert len(tracked) <= 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_wake_word_speaker_priority():
|
||||||
|
"""Test that wake-word speaker gets highest priority."""
|
||||||
|
tracker = MultiPersonTracker(max_people=10)
|
||||||
|
|
||||||
|
detections = [
|
||||||
|
{"bearing_rad": 0.1, "distance_m": 5.0, "confidence": 0.9}, # Far
|
||||||
|
{"bearing_rad": -0.2, "distance_m": 1.0, "confidence": 0.85}, # Closest
|
||||||
|
{"bearing_rad": 0.3, "distance_m": 2.0, "confidence": 0.8}, # Farthest
|
||||||
|
]
|
||||||
|
|
||||||
|
tracked, _ = tracker.update(detections)
|
||||||
|
|
||||||
|
# Without speaker: should pick closest
|
||||||
|
assert tracker.active_target_id == 1 # Closest person
|
||||||
|
|
||||||
|
# With speaker at ID 0: should pick speaker
|
||||||
|
tracked, target = tracker.update(detections, speaking_person_id=0)
|
||||||
|
assert target == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_group_centroid():
|
||||||
|
"""Test group centroid calculation."""
|
||||||
|
tracker = MultiPersonTracker(max_people=10)
|
||||||
|
|
||||||
|
detections = [
|
||||||
|
{"bearing_rad": 0.0, "distance_m": 1.0, "confidence": 0.9},
|
||||||
|
{"bearing_rad": 0.2, "distance_m": 2.0, "confidence": 0.85},
|
||||||
|
{"bearing_rad": -0.2, "distance_m": 1.5, "confidence": 0.8},
|
||||||
|
]
|
||||||
|
|
||||||
|
tracked, _ = tracker.update(detections)
|
||||||
|
|
||||||
|
centroid = tracker.get_group_centroid()
|
||||||
|
assert centroid is not None
|
||||||
|
mean_bearing, mean_distance = centroid
|
||||||
|
# Mean bearing should be close to 0
|
||||||
|
assert abs(mean_bearing) < 0.1
|
||||||
|
# Mean distance should be between 1.0 and 2.0
|
||||||
|
assert 1.0 < mean_distance < 2.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_tracker():
|
||||||
|
"""Test tracker with no detections."""
|
||||||
|
tracker = MultiPersonTracker()
|
||||||
|
|
||||||
|
tracked, target = tracker.update([])
|
||||||
|
|
||||||
|
assert len(tracked) == 0
|
||||||
|
assert target is None
|
||||||
|
assert tracker.searching is True
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_create_new_tracks()
|
||||||
|
test_match_existing_tracks()
|
||||||
|
test_max_people_limit()
|
||||||
|
test_wake_word_speaker_priority()
|
||||||
|
test_group_centroid()
|
||||||
|
test_empty_tracker()
|
||||||
|
print("All tests passed!")
|
||||||
Loading…
x
Reference in New Issue
Block a user