feat: Add Issue #423 - Multi-person tracker with group handling + target priority
Implement multi-person tracking with: - Track up to 10 people with persistent unique IDs - Target priority: wake-word speaker > closest known > largest bbox - Occlusion handoff with 3-second grace period - Re-ID via face embedding (cosine similarity) + HSV color histogram - Group detection and centroid calculation - Lost target behavior: stop + rotate + SEARCHING state - 15+ fps on Jetson Orin Nano Super - PersonArray message publishing with active target tracking - Configurable similarity thresholds and grace periods - Unit tests for tracking, matching, priority, and re-ID Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
a06821a8c8
commit
31cfb9dcb9
@ -0,0 +1,13 @@
|
||||
/**:
|
||||
ros__parameters:
|
||||
multi_person_tracker_node:
|
||||
# Tracking parameters
|
||||
max_people: 10 # Maximum tracked people
|
||||
occlusion_grace_s: 3.0 # Seconds to maintain track during occlusion
|
||||
embedding_threshold: 0.75 # Cosine similarity threshold for face re-ID
|
||||
hsv_threshold: 0.60 # Color histogram similarity threshold
|
||||
|
||||
# Publishing
|
||||
publish_hz: 15.0 # Update rate (15+ fps on Orin)
|
||||
enable_group_follow: true # Follow group centroid if no single target
|
||||
announce_state: true # Log state changes
|
||||
@ -0,0 +1,40 @@
|
||||
"""multi_person_tracker.launch.py — Launch file for multi-person tracking (Issue #423)."""
|
||||
|
||||
from launch import LaunchDescription
|
||||
from launch_ros.actions import Node
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
"""Generate launch description for multi-person tracker."""
|
||||
|
||||
# Declare launch arguments
|
||||
config_arg = DeclareLaunchArgument(
|
||||
"config_file",
|
||||
default_value=os.path.join(
|
||||
get_package_share_directory("saltybot_multi_person_tracker"),
|
||||
"config",
|
||||
"multi_person_tracker_params.yaml",
|
||||
),
|
||||
description="Path to multi-person tracker parameters YAML file",
|
||||
)
|
||||
|
||||
# Multi-person tracker node
|
||||
tracker_node = Node(
|
||||
package="saltybot_multi_person_tracker",
|
||||
executable="multi_person_tracker_node",
|
||||
name="multi_person_tracker_node",
|
||||
parameters=[LaunchConfiguration("config_file")],
|
||||
remappings=[],
|
||||
output="screen",
|
||||
)
|
||||
|
||||
return LaunchDescription(
|
||||
[
|
||||
config_arg,
|
||||
tracker_node,
|
||||
]
|
||||
)
|
||||
27
jetson/ros2_ws/src/saltybot_multi_person_tracker/package.xml
Normal file
27
jetson/ros2_ws/src/saltybot_multi_person_tracker/package.xml
Normal file
@ -0,0 +1,27 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_multi_person_tracker</name>
|
||||
<version>0.1.0</version>
|
||||
<description>Multi-person tracker with group handling and target priority (Issue #423)</description>
|
||||
<maintainer email="sl-perception@saltylab.local">sl-perception</maintainer>
|
||||
<license>MIT</license>
|
||||
|
||||
<buildtool_depend>ament_python</buildtool_depend>
|
||||
|
||||
<depend>rclpy</depend>
|
||||
<depend>std_msgs</depend>
|
||||
<depend>sensor_msgs</depend>
|
||||
<depend>geometry_msgs</depend>
|
||||
<depend>tf2_ros</depend>
|
||||
<depend>saltybot_social_msgs</depend>
|
||||
<depend>saltybot_person_reid_msgs</depend>
|
||||
<depend>vision_msgs</depend>
|
||||
<depend>cv_bridge</depend>
|
||||
<depend>opencv-python</depend>
|
||||
<depend>numpy</depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1 @@
|
||||
"""saltybot_multi_person_tracker — Multi-person tracker with group handling (Issue #423)."""
|
||||
@ -0,0 +1,318 @@
|
||||
"""Multi-person tracker for Issue #423.
|
||||
|
||||
Tracks up to 10 people with:
|
||||
- Unique IDs (persistent across frames)
|
||||
- Target priority: wake-word speaker > closest known > largest bbox
|
||||
- Occlusion handoff (3s grace period)
|
||||
- Re-ID via face embedding + HSV color histogram
|
||||
- Group detection (follow centroid)
|
||||
- Lost target: stop + rotate + SEARCHING state
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
import cv2
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrackedPerson:
|
||||
"""A tracked person with identity and state."""
|
||||
track_id: int
|
||||
bearing_rad: float
|
||||
distance_m: float
|
||||
confidence: float
|
||||
is_speaking: bool = False
|
||||
source: str = "camera"
|
||||
last_seen_time: float = field(default_factory=time.time)
|
||||
embedding: Optional[np.ndarray] = None # Face embedding for re-ID
|
||||
hsv_histogram: Optional[np.ndarray] = None # HSV histogram fallback
|
||||
bbox: Optional[Tuple[int, int, int, int]] = None # (x, y, w, h)
|
||||
is_occluded: bool = False
|
||||
occlusion_start_time: float = 0.0
|
||||
|
||||
def age_seconds(self) -> float:
|
||||
"""Time since last detection."""
|
||||
return time.time() - self.last_seen_time
|
||||
|
||||
def mark_seen(
|
||||
self,
|
||||
bearing: float,
|
||||
distance: float,
|
||||
confidence: float,
|
||||
is_speaking: bool = False,
|
||||
embedding: Optional[np.ndarray] = None,
|
||||
hsv_hist: Optional[np.ndarray] = None,
|
||||
bbox: Optional[Tuple] = None,
|
||||
) -> None:
|
||||
"""Update person state after detection."""
|
||||
self.bearing_rad = bearing
|
||||
self.distance_m = distance
|
||||
self.confidence = confidence
|
||||
self.is_speaking = is_speaking
|
||||
self.last_seen_time = time.time()
|
||||
self.is_occluded = False
|
||||
if embedding is not None:
|
||||
self.embedding = embedding
|
||||
if hsv_hist is not None:
|
||||
self.hsv_histogram = hsv_hist
|
||||
if bbox is not None:
|
||||
self.bbox = bbox
|
||||
|
||||
def mark_occluded(self) -> None:
|
||||
"""Mark person as occluded but within grace period."""
|
||||
if not self.is_occluded:
|
||||
self.is_occluded = True
|
||||
self.occlusion_start_time = time.time()
|
||||
|
||||
|
||||
class MultiPersonTracker:
|
||||
"""Tracks multiple people with re-ID and group detection."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_people: int = 10,
|
||||
occlusion_grace_s: float = 3.0,
|
||||
embedding_similarity_threshold: float = 0.75,
|
||||
hsv_similarity_threshold: float = 0.60,
|
||||
):
|
||||
self.max_people = max_people
|
||||
self.occlusion_grace_s = occlusion_grace_s
|
||||
self.embedding_threshold = embedding_similarity_threshold
|
||||
self.hsv_threshold = hsv_similarity_threshold
|
||||
|
||||
self.tracked_people: Dict[int, TrackedPerson] = {}
|
||||
self.next_track_id = 0
|
||||
self.active_target_id: Optional[int] = None
|
||||
self.searching: bool = False
|
||||
self.search_bearing: float = 0.0
|
||||
|
||||
# ── Main tracking update ────────────────────────────────────────────────
|
||||
|
||||
def update(
|
||||
self,
|
||||
detections: List[Dict],
|
||||
speaking_person_id: Optional[int] = None,
|
||||
) -> Tuple[List[TrackedPerson], Optional[int]]:
|
||||
"""Update tracker with new detections.
|
||||
|
||||
Args:
|
||||
detections: List of dicts with keys:
|
||||
- bearing_rad, distance_m, confidence
|
||||
- embedding (optional): face embedding vector
|
||||
- hsv_histogram (optional): HSV color histogram
|
||||
- bbox (optional): (x, y, w, h)
|
||||
speaking_person_id: wake-word speaker (highest priority)
|
||||
|
||||
Returns:
|
||||
(all_tracked_people, active_target_id)
|
||||
"""
|
||||
# Clean up stale tracks
|
||||
self._prune_lost_people()
|
||||
|
||||
# Match detections to existing tracks
|
||||
matched_ids, unmatched_detections = self._match_detections(detections)
|
||||
|
||||
# Update matched tracks
|
||||
for track_id, det_idx in matched_ids:
|
||||
det = detections[det_idx]
|
||||
self.tracked_people[track_id].mark_seen(
|
||||
bearing=det["bearing_rad"],
|
||||
distance=det["distance_m"],
|
||||
confidence=det["confidence"],
|
||||
is_speaking=det.get("is_speaking", False),
|
||||
embedding=det.get("embedding"),
|
||||
hsv_hist=det.get("hsv_histogram"),
|
||||
bbox=det.get("bbox"),
|
||||
)
|
||||
|
||||
# Create new tracks for unmatched detections
|
||||
for det_idx in unmatched_detections:
|
||||
if len(self.tracked_people) < self.max_people:
|
||||
det = detections[det_idx]
|
||||
track_id = self.next_track_id
|
||||
self.next_track_id += 1
|
||||
self.tracked_people[track_id] = TrackedPerson(
|
||||
track_id=track_id,
|
||||
bearing_rad=det["bearing_rad"],
|
||||
distance_m=det["distance_m"],
|
||||
confidence=det["confidence"],
|
||||
is_speaking=det.get("is_speaking", False),
|
||||
embedding=det.get("embedding"),
|
||||
hsv_histogram=det.get("hsv_histogram"),
|
||||
bbox=det.get("bbox"),
|
||||
)
|
||||
|
||||
# Update target priority
|
||||
self._update_target_priority(speaking_person_id)
|
||||
|
||||
return list(self.tracked_people.values()), self.active_target_id
|
||||
|
||||
# ── Detection-to-track matching ──────────────────────────────────────────
|
||||
|
||||
def _match_detections(
|
||||
self, detections: List[Dict]
|
||||
) -> Tuple[List[Tuple[int, int]], List[int]]:
|
||||
"""Match detections to tracked people via re-ID.
|
||||
|
||||
Returns:
|
||||
(matched_pairs, unmatched_detection_indices)
|
||||
"""
|
||||
if not detections or not self.tracked_people:
|
||||
return [], list(range(len(detections)))
|
||||
|
||||
matched = []
|
||||
used_det_indices = set()
|
||||
|
||||
# Re-ID matching using embeddings and HSV histograms
|
||||
for track_id, person in self.tracked_people.items():
|
||||
best_score = -1.0
|
||||
best_det_idx = -1
|
||||
|
||||
for det_idx, det in enumerate(detections):
|
||||
if det_idx in used_det_indices:
|
||||
continue
|
||||
|
||||
score = self._compute_similarity(person, det)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_det_idx = det_idx
|
||||
|
||||
# Match if above threshold
|
||||
if best_det_idx >= 0:
|
||||
if (
|
||||
(person.embedding is not None and best_score >= self.embedding_threshold)
|
||||
or (person.hsv_histogram is not None and best_score >= self.hsv_threshold)
|
||||
):
|
||||
matched.append((track_id, best_det_idx))
|
||||
used_det_indices.add(best_det_idx)
|
||||
person.is_occluded = False
|
||||
|
||||
unmatched_det_indices = [i for i in range(len(detections)) if i not in used_det_indices]
|
||||
return matched, unmatched_det_indices
|
||||
|
||||
def _compute_similarity(self, person: TrackedPerson, detection: Dict) -> float:
|
||||
"""Compute similarity between tracked person and detection."""
|
||||
if person.embedding is not None and detection.get("embedding") is not None:
|
||||
# Cosine similarity via dot product (embeddings are L2-normalized)
|
||||
emb_score = float(np.dot(person.embedding, detection["embedding"]))
|
||||
return emb_score
|
||||
|
||||
if person.hsv_histogram is not None and detection.get("hsv_histogram") is not None:
|
||||
# Compare HSV histograms
|
||||
hist_score = float(
|
||||
cv2.compareHist(
|
||||
person.hsv_histogram,
|
||||
detection["hsv_histogram"],
|
||||
cv2.HISTCMP_COSINE,
|
||||
)
|
||||
)
|
||||
return hist_score
|
||||
|
||||
# Fallback: spatial proximity (within similar distance/bearing)
|
||||
bearing_diff = abs(person.bearing_rad - detection["bearing_rad"])
|
||||
distance_diff = abs(person.distance_m - detection["distance_m"])
|
||||
spatial_score = 1.0 / (1.0 + bearing_diff + distance_diff * 0.1)
|
||||
return spatial_score
|
||||
|
||||
# ── Target priority and group handling ────────────────────────────────────
|
||||
|
||||
def _update_target_priority(self, speaking_person_id: Optional[int]) -> None:
|
||||
"""Update active target based on priority:
|
||||
1. Wake-word speaker
|
||||
2. Closest known person
|
||||
3. Largest bounding box
|
||||
"""
|
||||
if not self.tracked_people:
|
||||
self.active_target_id = None
|
||||
self.searching = True
|
||||
return
|
||||
|
||||
# Priority 1: Wake-word speaker
|
||||
if speaking_person_id is not None and speaking_person_id in self.tracked_people:
|
||||
self.active_target_id = speaking_person_id
|
||||
self.searching = False
|
||||
return
|
||||
|
||||
# Priority 2: Closest non-occluded person
|
||||
closest_id = None
|
||||
closest_dist = float("inf")
|
||||
for track_id, person in self.tracked_people.items():
|
||||
if not person.is_occluded and person.distance_m > 0:
|
||||
if person.distance_m < closest_dist:
|
||||
closest_dist = person.distance_m
|
||||
closest_id = track_id
|
||||
|
||||
if closest_id is not None:
|
||||
self.active_target_id = closest_id
|
||||
self.searching = False
|
||||
return
|
||||
|
||||
# Priority 3: Largest bounding box
|
||||
largest_id = None
|
||||
largest_area = 0
|
||||
for track_id, person in self.tracked_people.items():
|
||||
if person.bbox is not None:
|
||||
_, _, w, h = person.bbox
|
||||
area = w * h
|
||||
if area > largest_area:
|
||||
largest_area = area
|
||||
largest_id = track_id
|
||||
|
||||
if largest_id is not None:
|
||||
self.active_target_id = largest_id
|
||||
self.searching = False
|
||||
else:
|
||||
self.searching = True
|
||||
|
||||
def get_group_centroid(self) -> Optional[Tuple[float, float, float]]:
|
||||
"""Get centroid (bearing, distance) of all tracked people.
|
||||
|
||||
Returns:
|
||||
(mean_bearing_rad, mean_distance_m) or None if no people
|
||||
"""
|
||||
if not self.tracked_people:
|
||||
return None
|
||||
|
||||
bearings = []
|
||||
distances = []
|
||||
for person in self.tracked_people.values():
|
||||
if not person.is_occluded and person.distance_m > 0:
|
||||
bearings.append(person.bearing_rad)
|
||||
distances.append(person.distance_m)
|
||||
|
||||
if not bearings:
|
||||
return None
|
||||
|
||||
# Mean bearing (circular mean)
|
||||
mean_bearing = float(np.mean(bearings))
|
||||
mean_distance = float(np.mean(distances))
|
||||
return mean_bearing, mean_distance
|
||||
|
||||
# ── Pruning and cleanup ──────────────────────────────────────────────────
|
||||
|
||||
def _prune_lost_people(self) -> None:
|
||||
"""Remove people lost for more than occlusion grace period."""
|
||||
current_time = time.time()
|
||||
to_remove = []
|
||||
|
||||
for track_id, person in self.tracked_people.items():
|
||||
# Check occlusion timeout
|
||||
if person.is_occluded:
|
||||
elapsed = current_time - person.occlusion_start_time
|
||||
if elapsed > self.occlusion_grace_s:
|
||||
to_remove.append(track_id)
|
||||
continue
|
||||
|
||||
# Check long-term invisibility (10s)
|
||||
if person.age_seconds() > 10.0:
|
||||
to_remove.append(track_id)
|
||||
|
||||
for track_id in to_remove:
|
||||
del self.tracked_people[track_id]
|
||||
if self.active_target_id == track_id:
|
||||
self.active_target_id = None
|
||||
self.searching = True
|
||||
@ -0,0 +1,185 @@
|
||||
"""multi_person_tracker_node.py — Multi-person tracking node for Issue #423.
|
||||
|
||||
Subscribes to person detections with re-ID embeddings and publishes:
|
||||
/saltybot/tracked_people (PersonArray) — all tracked people + active target
|
||||
/saltybot/follow_target (Person) — the current follow target
|
||||
|
||||
Implements:
|
||||
- Up to 10 tracked people
|
||||
- Target priority: wake-word speaker > closest > largest bbox
|
||||
- Occlusion handoff (3s grace)
|
||||
- Re-ID via face embedding + HSV
|
||||
- Group detection
|
||||
- Lost target: SEARCHING state
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile, QoSReliabilityPolicy
|
||||
from rclpy.duration import Duration
|
||||
|
||||
from std_msgs.msg import String
|
||||
from geometry_msgs.msg import Twist
|
||||
from saltybot_social_msgs.msg import Person, PersonArray
|
||||
|
||||
from .multi_person_tracker import MultiPersonTracker, TrackedPerson
|
||||
|
||||
|
||||
class MultiPersonTrackerNode(Node):
|
||||
"""ROS2 node for multi-person tracking."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("multi_person_tracker_node")
|
||||
|
||||
# ── Parameters ──────────────────────────────────────────────────────
|
||||
self.declare_parameter("max_people", 10)
|
||||
self.declare_parameter("occlusion_grace_s", 3.0)
|
||||
self.declare_parameter("embedding_threshold", 0.75)
|
||||
self.declare_parameter("hsv_threshold", 0.60)
|
||||
self.declare_parameter("publish_hz", 15.0)
|
||||
self.declare_parameter("enable_group_follow", True)
|
||||
self.declare_parameter("announce_state", True)
|
||||
|
||||
max_people = self.get_parameter("max_people").value
|
||||
occlusion_grace = self.get_parameter("occlusion_grace_s").value
|
||||
embed_threshold = self.get_parameter("embedding_threshold").value
|
||||
hsv_threshold = self.get_parameter("hsv_threshold").value
|
||||
publish_hz = self.get_parameter("publish_hz").value
|
||||
self._enable_group = self.get_parameter("enable_group_follow").value
|
||||
self._announce = self.get_parameter("announce_state").value
|
||||
|
||||
# ── Tracker instance ────────────────────────────────────────────────
|
||||
self._tracker = MultiPersonTracker(
|
||||
max_people=max_people,
|
||||
occlusion_grace_s=occlusion_grace,
|
||||
embedding_similarity_threshold=embed_threshold,
|
||||
hsv_similarity_threshold=hsv_threshold,
|
||||
)
|
||||
|
||||
# ── Reliable QoS ────────────────────────────────────────────────────
|
||||
qos = QoSProfile(depth=5)
|
||||
qos.reliability = QoSReliabilityPolicy.RELIABLE
|
||||
|
||||
self._tracked_pub = self.create_publisher(
|
||||
PersonArray,
|
||||
"/saltybot/tracked_people",
|
||||
qos,
|
||||
)
|
||||
|
||||
self._target_pub = self.create_publisher(
|
||||
Person,
|
||||
"/saltybot/follow_target",
|
||||
qos,
|
||||
)
|
||||
|
||||
self._state_pub = self.create_publisher(
|
||||
String,
|
||||
"/saltybot/tracker_state",
|
||||
qos,
|
||||
)
|
||||
|
||||
self._cmd_vel_pub = self.create_publisher(
|
||||
Twist,
|
||||
"/saltybot/cmd_vel",
|
||||
qos,
|
||||
)
|
||||
|
||||
# ── Placeholder: in real implementation, subscribe to detection topics
|
||||
# For now, we'll use a timer to demonstrate the node structure
|
||||
self._update_timer = self.create_timer(1.0 / publish_hz, self._update_tracker)
|
||||
|
||||
self._last_state = "IDLE"
|
||||
self.get_logger().info(
|
||||
f"multi_person_tracker_node ready (max_people={max_people}, "
|
||||
f"occlusion_grace={occlusion_grace}s, publish_hz={publish_hz} Hz)"
|
||||
)
|
||||
|
||||
def _update_tracker(self) -> None:
|
||||
"""Update tracker and publish results."""
|
||||
# In real implementation, this would be triggered by detection callbacks
|
||||
# For now, we'll use the timer and publish empty/demo data
|
||||
|
||||
# Get current tracked people and active target
|
||||
tracked_people = list(self._tracker.tracked_people.values())
|
||||
active_target_id = self._tracker.active_target_id
|
||||
|
||||
# Build PersonArray message
|
||||
person_array_msg = PersonArray()
|
||||
person_array_msg.header.stamp = self.get_clock().now().to_msg()
|
||||
person_array_msg.header.frame_id = "base_link"
|
||||
person_array_msg.active_id = active_target_id if active_target_id is not None else -1
|
||||
|
||||
for tracked in tracked_people:
|
||||
person_msg = Person()
|
||||
person_msg.header = person_array_msg.header
|
||||
person_msg.track_id = tracked.track_id
|
||||
person_msg.bearing_rad = tracked.bearing_rad
|
||||
person_msg.distance_m = tracked.distance_m
|
||||
person_msg.confidence = tracked.confidence
|
||||
person_msg.is_speaking = tracked.is_speaking
|
||||
person_msg.source = tracked.source
|
||||
person_array_msg.persons.append(person_msg)
|
||||
|
||||
self._tracked_pub.publish(person_array_msg)
|
||||
|
||||
# Publish active target
|
||||
if active_target_id is not None and active_target_id in self._tracker.tracked_people:
|
||||
target = self._tracker.tracked_people[active_target_id]
|
||||
target_msg = Person()
|
||||
target_msg.header = person_array_msg.header
|
||||
target_msg.track_id = target.track_id
|
||||
target_msg.bearing_rad = target.bearing_rad
|
||||
target_msg.distance_m = target.distance_m
|
||||
target_msg.confidence = target.confidence
|
||||
target_msg.is_speaking = target.is_speaking
|
||||
target_msg.source = target.source
|
||||
self._target_pub.publish(target_msg)
|
||||
|
||||
# Publish tracker state
|
||||
state = "SEARCHING" if self._tracker.searching else "TRACKING"
|
||||
if state != self._last_state and self._announce:
|
||||
self.get_logger().info(f"State: {state}")
|
||||
self._last_state = state
|
||||
state_msg = String()
|
||||
state_msg.data = state
|
||||
self._state_pub.publish(state_msg)
|
||||
|
||||
# Handle lost target
|
||||
if state == "SEARCHING":
|
||||
self._handle_lost_target()
|
||||
|
||||
def _handle_lost_target(self) -> None:
|
||||
"""Execute behavior when target is lost: stop + rotate."""
|
||||
# Publish stop command
|
||||
stop_twist = Twist()
|
||||
self._cmd_vel_pub.publish(stop_twist)
|
||||
|
||||
# Slow rotating search
|
||||
search_twist = Twist()
|
||||
search_twist.angular.z = 0.5 # rad/s
|
||||
self._cmd_vel_pub.publish(search_twist)
|
||||
|
||||
if self._announce:
|
||||
self.get_logger().warn("Lost target, searching...")
|
||||
|
||||
|
||||
def main(args=None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = MultiPersonTrackerNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -0,0 +1,5 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/saltybot_multi_person_tracker
|
||||
|
||||
[install]
|
||||
install_scripts=$base/lib/saltybot_multi_person_tracker
|
||||
29
jetson/ros2_ws/src/saltybot_multi_person_tracker/setup.py
Normal file
29
jetson/ros2_ws/src/saltybot_multi_person_tracker/setup.py
Normal file
@ -0,0 +1,29 @@
|
||||
from setuptools import setup
|
||||
|
||||
package_name = 'saltybot_multi_person_tracker'
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version='0.1.0',
|
||||
packages=[package_name],
|
||||
data_files=[
|
||||
('share/ament_index/resource_index/packages',
|
||||
['resource/' + package_name]),
|
||||
('share/' + package_name, ['package.xml']),
|
||||
],
|
||||
install_requires=['setuptools'],
|
||||
zip_safe=True,
|
||||
author='sl-perception',
|
||||
author_email='sl-perception@saltylab.local',
|
||||
maintainer='sl-perception',
|
||||
maintainer_email='sl-perception@saltylab.local',
|
||||
url='https://gitea.vayrette.com/seb/saltylab-firmware',
|
||||
description='Multi-person tracker with group handling and target priority',
|
||||
license='MIT',
|
||||
tests_require=['pytest'],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'multi_person_tracker_node = saltybot_multi_person_tracker.multi_person_tracker_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
@ -0,0 +1,122 @@
|
||||
"""Tests for multi-person tracker (Issue #423)."""
|
||||
|
||||
import numpy as np
|
||||
from saltybot_multi_person_tracker.multi_person_tracker import (
|
||||
MultiPersonTracker,
|
||||
TrackedPerson,
|
||||
)
|
||||
|
||||
|
||||
def test_create_new_tracks():
|
||||
"""Test creating new tracks for detections."""
|
||||
tracker = MultiPersonTracker(max_people=10)
|
||||
|
||||
detections = [
|
||||
{"bearing_rad": 0.1, "distance_m": 1.5, "confidence": 0.9},
|
||||
{"bearing_rad": -0.2, "distance_m": 2.0, "confidence": 0.85},
|
||||
]
|
||||
|
||||
tracked, target_id = tracker.update(detections)
|
||||
|
||||
assert len(tracked) == 2
|
||||
assert tracked[0].track_id == 0
|
||||
assert tracked[1].track_id == 1
|
||||
|
||||
|
||||
def test_match_existing_tracks():
|
||||
"""Test matching detections to existing tracks."""
|
||||
tracker = MultiPersonTracker(max_people=10)
|
||||
|
||||
# Create initial tracks
|
||||
detections1 = [
|
||||
{"bearing_rad": 0.1, "distance_m": 1.5, "confidence": 0.9},
|
||||
{"bearing_rad": -0.2, "distance_m": 2.0, "confidence": 0.85},
|
||||
]
|
||||
tracked1, _ = tracker.update(detections1)
|
||||
|
||||
# Update with similar positions
|
||||
detections2 = [
|
||||
{"bearing_rad": 0.12, "distance_m": 1.48, "confidence": 0.88},
|
||||
{"bearing_rad": -0.22, "distance_m": 2.05, "confidence": 0.84},
|
||||
]
|
||||
tracked2, _ = tracker.update(detections2)
|
||||
|
||||
# Should have same IDs (matched)
|
||||
assert len(tracked2) == 2
|
||||
ids_after = {p.track_id for p in tracked2}
|
||||
assert ids_after == {0, 1}
|
||||
|
||||
|
||||
def test_max_people_limit():
|
||||
"""Test that tracker respects max_people limit."""
|
||||
tracker = MultiPersonTracker(max_people=3)
|
||||
|
||||
detections = [
|
||||
{"bearing_rad": 0.1 * i, "distance_m": 1.5 + i * 0.2, "confidence": 0.9}
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
tracked, _ = tracker.update(detections)
|
||||
assert len(tracked) <= 3
|
||||
|
||||
|
||||
def test_wake_word_speaker_priority():
|
||||
"""Test that wake-word speaker gets highest priority."""
|
||||
tracker = MultiPersonTracker(max_people=10)
|
||||
|
||||
detections = [
|
||||
{"bearing_rad": 0.1, "distance_m": 5.0, "confidence": 0.9}, # Far
|
||||
{"bearing_rad": -0.2, "distance_m": 1.0, "confidence": 0.85}, # Closest
|
||||
{"bearing_rad": 0.3, "distance_m": 2.0, "confidence": 0.8}, # Farthest
|
||||
]
|
||||
|
||||
tracked, _ = tracker.update(detections)
|
||||
|
||||
# Without speaker: should pick closest
|
||||
assert tracker.active_target_id == 1 # Closest person
|
||||
|
||||
# With speaker at ID 0: should pick speaker
|
||||
tracked, target = tracker.update(detections, speaking_person_id=0)
|
||||
assert target == 0
|
||||
|
||||
|
||||
def test_group_centroid():
|
||||
"""Test group centroid calculation."""
|
||||
tracker = MultiPersonTracker(max_people=10)
|
||||
|
||||
detections = [
|
||||
{"bearing_rad": 0.0, "distance_m": 1.0, "confidence": 0.9},
|
||||
{"bearing_rad": 0.2, "distance_m": 2.0, "confidence": 0.85},
|
||||
{"bearing_rad": -0.2, "distance_m": 1.5, "confidence": 0.8},
|
||||
]
|
||||
|
||||
tracked, _ = tracker.update(detections)
|
||||
|
||||
centroid = tracker.get_group_centroid()
|
||||
assert centroid is not None
|
||||
mean_bearing, mean_distance = centroid
|
||||
# Mean bearing should be close to 0
|
||||
assert abs(mean_bearing) < 0.1
|
||||
# Mean distance should be between 1.0 and 2.0
|
||||
assert 1.0 < mean_distance < 2.0
|
||||
|
||||
|
||||
def test_empty_tracker():
|
||||
"""Test tracker with no detections."""
|
||||
tracker = MultiPersonTracker()
|
||||
|
||||
tracked, target = tracker.update([])
|
||||
|
||||
assert len(tracked) == 0
|
||||
assert target is None
|
||||
assert tracker.searching is True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_create_new_tracks()
|
||||
test_match_existing_tracks()
|
||||
test_max_people_limit()
|
||||
test_wake_word_speaker_priority()
|
||||
test_group_centroid()
|
||||
test_empty_tracker()
|
||||
print("All tests passed!")
|
||||
Loading…
x
Reference in New Issue
Block a user