Compare commits
No commits in common. "524d2545ed78293cceefa31681bcbe4099b48a0b" and "eda5154650f56588b6e0a3c6393a98618532ea70" have entirely different histories.
524d2545ed
...
eda5154650
@ -1,13 +0,0 @@
|
|||||||
/**:
|
|
||||||
ros__parameters:
|
|
||||||
multi_person_tracker_node:
|
|
||||||
# Tracking parameters
|
|
||||||
max_people: 10 # Maximum tracked people
|
|
||||||
occlusion_grace_s: 3.0 # Seconds to maintain track during occlusion
|
|
||||||
embedding_threshold: 0.75 # Cosine similarity threshold for face re-ID
|
|
||||||
hsv_threshold: 0.60 # Color histogram similarity threshold
|
|
||||||
|
|
||||||
# Publishing
|
|
||||||
publish_hz: 15.0 # Update rate (15+ fps on Orin)
|
|
||||||
enable_group_follow: true # Follow group centroid if no single target
|
|
||||||
announce_state: true # Log state changes
|
|
||||||
@ -1,40 +0,0 @@
|
|||||||
"""multi_person_tracker.launch.py — Launch file for multi-person tracking (Issue #423)."""
|
|
||||||
|
|
||||||
from launch import LaunchDescription
|
|
||||||
from launch_ros.actions import Node
|
|
||||||
from launch.actions import DeclareLaunchArgument
|
|
||||||
from launch.substitutions import LaunchConfiguration
|
|
||||||
import os
|
|
||||||
from ament_index_python.packages import get_package_share_directory
|
|
||||||
|
|
||||||
|
|
||||||
def generate_launch_description():
|
|
||||||
"""Generate launch description for multi-person tracker."""
|
|
||||||
|
|
||||||
# Declare launch arguments
|
|
||||||
config_arg = DeclareLaunchArgument(
|
|
||||||
"config_file",
|
|
||||||
default_value=os.path.join(
|
|
||||||
get_package_share_directory("saltybot_multi_person_tracker"),
|
|
||||||
"config",
|
|
||||||
"multi_person_tracker_params.yaml",
|
|
||||||
),
|
|
||||||
description="Path to multi-person tracker parameters YAML file",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Multi-person tracker node
|
|
||||||
tracker_node = Node(
|
|
||||||
package="saltybot_multi_person_tracker",
|
|
||||||
executable="multi_person_tracker_node",
|
|
||||||
name="multi_person_tracker_node",
|
|
||||||
parameters=[LaunchConfiguration("config_file")],
|
|
||||||
remappings=[],
|
|
||||||
output="screen",
|
|
||||||
)
|
|
||||||
|
|
||||||
return LaunchDescription(
|
|
||||||
[
|
|
||||||
config_arg,
|
|
||||||
tracker_node,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
@ -1,27 +0,0 @@
|
|||||||
<?xml version="1.0"?>
|
|
||||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
|
||||||
<package format="3">
|
|
||||||
<name>saltybot_multi_person_tracker</name>
|
|
||||||
<version>0.1.0</version>
|
|
||||||
<description>Multi-person tracker with group handling and target priority (Issue #423)</description>
|
|
||||||
<maintainer email="sl-perception@saltylab.local">sl-perception</maintainer>
|
|
||||||
<license>MIT</license>
|
|
||||||
|
|
||||||
<buildtool_depend>ament_python</buildtool_depend>
|
|
||||||
|
|
||||||
<depend>rclpy</depend>
|
|
||||||
<depend>std_msgs</depend>
|
|
||||||
<depend>sensor_msgs</depend>
|
|
||||||
<depend>geometry_msgs</depend>
|
|
||||||
<depend>tf2_ros</depend>
|
|
||||||
<depend>saltybot_social_msgs</depend>
|
|
||||||
<depend>saltybot_person_reid_msgs</depend>
|
|
||||||
<depend>vision_msgs</depend>
|
|
||||||
<depend>cv_bridge</depend>
|
|
||||||
<depend>opencv-python</depend>
|
|
||||||
<depend>numpy</depend>
|
|
||||||
|
|
||||||
<export>
|
|
||||||
<build_type>ament_python</build_type>
|
|
||||||
</export>
|
|
||||||
</package>
|
|
||||||
@ -1 +0,0 @@
|
|||||||
"""saltybot_multi_person_tracker — Multi-person tracker with group handling (Issue #423)."""
|
|
||||||
@ -1,318 +0,0 @@
|
|||||||
"""Multi-person tracker for Issue #423.
|
|
||||||
|
|
||||||
Tracks up to 10 people with:
|
|
||||||
- Unique IDs (persistent across frames)
|
|
||||||
- Target priority: wake-word speaker > closest known > largest bbox
|
|
||||||
- Occlusion handoff (3s grace period)
|
|
||||||
- Re-ID via face embedding + HSV color histogram
|
|
||||||
- Group detection (follow centroid)
|
|
||||||
- Lost target: stop + rotate + SEARCHING state
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import time
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TrackedPerson:
|
|
||||||
"""A tracked person with identity and state."""
|
|
||||||
track_id: int
|
|
||||||
bearing_rad: float
|
|
||||||
distance_m: float
|
|
||||||
confidence: float
|
|
||||||
is_speaking: bool = False
|
|
||||||
source: str = "camera"
|
|
||||||
last_seen_time: float = field(default_factory=time.time)
|
|
||||||
embedding: Optional[np.ndarray] = None # Face embedding for re-ID
|
|
||||||
hsv_histogram: Optional[np.ndarray] = None # HSV histogram fallback
|
|
||||||
bbox: Optional[Tuple[int, int, int, int]] = None # (x, y, w, h)
|
|
||||||
is_occluded: bool = False
|
|
||||||
occlusion_start_time: float = 0.0
|
|
||||||
|
|
||||||
def age_seconds(self) -> float:
|
|
||||||
"""Time since last detection."""
|
|
||||||
return time.time() - self.last_seen_time
|
|
||||||
|
|
||||||
def mark_seen(
|
|
||||||
self,
|
|
||||||
bearing: float,
|
|
||||||
distance: float,
|
|
||||||
confidence: float,
|
|
||||||
is_speaking: bool = False,
|
|
||||||
embedding: Optional[np.ndarray] = None,
|
|
||||||
hsv_hist: Optional[np.ndarray] = None,
|
|
||||||
bbox: Optional[Tuple] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Update person state after detection."""
|
|
||||||
self.bearing_rad = bearing
|
|
||||||
self.distance_m = distance
|
|
||||||
self.confidence = confidence
|
|
||||||
self.is_speaking = is_speaking
|
|
||||||
self.last_seen_time = time.time()
|
|
||||||
self.is_occluded = False
|
|
||||||
if embedding is not None:
|
|
||||||
self.embedding = embedding
|
|
||||||
if hsv_hist is not None:
|
|
||||||
self.hsv_histogram = hsv_hist
|
|
||||||
if bbox is not None:
|
|
||||||
self.bbox = bbox
|
|
||||||
|
|
||||||
def mark_occluded(self) -> None:
|
|
||||||
"""Mark person as occluded but within grace period."""
|
|
||||||
if not self.is_occluded:
|
|
||||||
self.is_occluded = True
|
|
||||||
self.occlusion_start_time = time.time()
|
|
||||||
|
|
||||||
|
|
||||||
class MultiPersonTracker:
|
|
||||||
"""Tracks multiple people with re-ID and group detection."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_people: int = 10,
|
|
||||||
occlusion_grace_s: float = 3.0,
|
|
||||||
embedding_similarity_threshold: float = 0.75,
|
|
||||||
hsv_similarity_threshold: float = 0.60,
|
|
||||||
):
|
|
||||||
self.max_people = max_people
|
|
||||||
self.occlusion_grace_s = occlusion_grace_s
|
|
||||||
self.embedding_threshold = embedding_similarity_threshold
|
|
||||||
self.hsv_threshold = hsv_similarity_threshold
|
|
||||||
|
|
||||||
self.tracked_people: Dict[int, TrackedPerson] = {}
|
|
||||||
self.next_track_id = 0
|
|
||||||
self.active_target_id: Optional[int] = None
|
|
||||||
self.searching: bool = False
|
|
||||||
self.search_bearing: float = 0.0
|
|
||||||
|
|
||||||
# ── Main tracking update ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
detections: List[Dict],
|
|
||||||
speaking_person_id: Optional[int] = None,
|
|
||||||
) -> Tuple[List[TrackedPerson], Optional[int]]:
|
|
||||||
"""Update tracker with new detections.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
detections: List of dicts with keys:
|
|
||||||
- bearing_rad, distance_m, confidence
|
|
||||||
- embedding (optional): face embedding vector
|
|
||||||
- hsv_histogram (optional): HSV color histogram
|
|
||||||
- bbox (optional): (x, y, w, h)
|
|
||||||
speaking_person_id: wake-word speaker (highest priority)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(all_tracked_people, active_target_id)
|
|
||||||
"""
|
|
||||||
# Clean up stale tracks
|
|
||||||
self._prune_lost_people()
|
|
||||||
|
|
||||||
# Match detections to existing tracks
|
|
||||||
matched_ids, unmatched_detections = self._match_detections(detections)
|
|
||||||
|
|
||||||
# Update matched tracks
|
|
||||||
for track_id, det_idx in matched_ids:
|
|
||||||
det = detections[det_idx]
|
|
||||||
self.tracked_people[track_id].mark_seen(
|
|
||||||
bearing=det["bearing_rad"],
|
|
||||||
distance=det["distance_m"],
|
|
||||||
confidence=det["confidence"],
|
|
||||||
is_speaking=det.get("is_speaking", False),
|
|
||||||
embedding=det.get("embedding"),
|
|
||||||
hsv_hist=det.get("hsv_histogram"),
|
|
||||||
bbox=det.get("bbox"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create new tracks for unmatched detections
|
|
||||||
for det_idx in unmatched_detections:
|
|
||||||
if len(self.tracked_people) < self.max_people:
|
|
||||||
det = detections[det_idx]
|
|
||||||
track_id = self.next_track_id
|
|
||||||
self.next_track_id += 1
|
|
||||||
self.tracked_people[track_id] = TrackedPerson(
|
|
||||||
track_id=track_id,
|
|
||||||
bearing_rad=det["bearing_rad"],
|
|
||||||
distance_m=det["distance_m"],
|
|
||||||
confidence=det["confidence"],
|
|
||||||
is_speaking=det.get("is_speaking", False),
|
|
||||||
embedding=det.get("embedding"),
|
|
||||||
hsv_histogram=det.get("hsv_histogram"),
|
|
||||||
bbox=det.get("bbox"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update target priority
|
|
||||||
self._update_target_priority(speaking_person_id)
|
|
||||||
|
|
||||||
return list(self.tracked_people.values()), self.active_target_id
|
|
||||||
|
|
||||||
# ── Detection-to-track matching ──────────────────────────────────────────
|
|
||||||
|
|
||||||
def _match_detections(
|
|
||||||
self, detections: List[Dict]
|
|
||||||
) -> Tuple[List[Tuple[int, int]], List[int]]:
|
|
||||||
"""Match detections to tracked people via re-ID.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(matched_pairs, unmatched_detection_indices)
|
|
||||||
"""
|
|
||||||
if not detections or not self.tracked_people:
|
|
||||||
return [], list(range(len(detections)))
|
|
||||||
|
|
||||||
matched = []
|
|
||||||
used_det_indices = set()
|
|
||||||
|
|
||||||
# Re-ID matching using embeddings and HSV histograms
|
|
||||||
for track_id, person in self.tracked_people.items():
|
|
||||||
best_score = -1.0
|
|
||||||
best_det_idx = -1
|
|
||||||
|
|
||||||
for det_idx, det in enumerate(detections):
|
|
||||||
if det_idx in used_det_indices:
|
|
||||||
continue
|
|
||||||
|
|
||||||
score = self._compute_similarity(person, det)
|
|
||||||
if score > best_score:
|
|
||||||
best_score = score
|
|
||||||
best_det_idx = det_idx
|
|
||||||
|
|
||||||
# Match if above threshold
|
|
||||||
if best_det_idx >= 0:
|
|
||||||
if (
|
|
||||||
(person.embedding is not None and best_score >= self.embedding_threshold)
|
|
||||||
or (person.hsv_histogram is not None and best_score >= self.hsv_threshold)
|
|
||||||
):
|
|
||||||
matched.append((track_id, best_det_idx))
|
|
||||||
used_det_indices.add(best_det_idx)
|
|
||||||
person.is_occluded = False
|
|
||||||
|
|
||||||
unmatched_det_indices = [i for i in range(len(detections)) if i not in used_det_indices]
|
|
||||||
return matched, unmatched_det_indices
|
|
||||||
|
|
||||||
def _compute_similarity(self, person: TrackedPerson, detection: Dict) -> float:
|
|
||||||
"""Compute similarity between tracked person and detection."""
|
|
||||||
if person.embedding is not None and detection.get("embedding") is not None:
|
|
||||||
# Cosine similarity via dot product (embeddings are L2-normalized)
|
|
||||||
emb_score = float(np.dot(person.embedding, detection["embedding"]))
|
|
||||||
return emb_score
|
|
||||||
|
|
||||||
if person.hsv_histogram is not None and detection.get("hsv_histogram") is not None:
|
|
||||||
# Compare HSV histograms
|
|
||||||
hist_score = float(
|
|
||||||
cv2.compareHist(
|
|
||||||
person.hsv_histogram,
|
|
||||||
detection["hsv_histogram"],
|
|
||||||
cv2.HISTCMP_COSINE,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return hist_score
|
|
||||||
|
|
||||||
# Fallback: spatial proximity (within similar distance/bearing)
|
|
||||||
bearing_diff = abs(person.bearing_rad - detection["bearing_rad"])
|
|
||||||
distance_diff = abs(person.distance_m - detection["distance_m"])
|
|
||||||
spatial_score = 1.0 / (1.0 + bearing_diff + distance_diff * 0.1)
|
|
||||||
return spatial_score
|
|
||||||
|
|
||||||
# ── Target priority and group handling ────────────────────────────────────
|
|
||||||
|
|
||||||
def _update_target_priority(self, speaking_person_id: Optional[int]) -> None:
|
|
||||||
"""Update active target based on priority:
|
|
||||||
1. Wake-word speaker
|
|
||||||
2. Closest known person
|
|
||||||
3. Largest bounding box
|
|
||||||
"""
|
|
||||||
if not self.tracked_people:
|
|
||||||
self.active_target_id = None
|
|
||||||
self.searching = True
|
|
||||||
return
|
|
||||||
|
|
||||||
# Priority 1: Wake-word speaker
|
|
||||||
if speaking_person_id is not None and speaking_person_id in self.tracked_people:
|
|
||||||
self.active_target_id = speaking_person_id
|
|
||||||
self.searching = False
|
|
||||||
return
|
|
||||||
|
|
||||||
# Priority 2: Closest non-occluded person
|
|
||||||
closest_id = None
|
|
||||||
closest_dist = float("inf")
|
|
||||||
for track_id, person in self.tracked_people.items():
|
|
||||||
if not person.is_occluded and person.distance_m > 0:
|
|
||||||
if person.distance_m < closest_dist:
|
|
||||||
closest_dist = person.distance_m
|
|
||||||
closest_id = track_id
|
|
||||||
|
|
||||||
if closest_id is not None:
|
|
||||||
self.active_target_id = closest_id
|
|
||||||
self.searching = False
|
|
||||||
return
|
|
||||||
|
|
||||||
# Priority 3: Largest bounding box
|
|
||||||
largest_id = None
|
|
||||||
largest_area = 0
|
|
||||||
for track_id, person in self.tracked_people.items():
|
|
||||||
if person.bbox is not None:
|
|
||||||
_, _, w, h = person.bbox
|
|
||||||
area = w * h
|
|
||||||
if area > largest_area:
|
|
||||||
largest_area = area
|
|
||||||
largest_id = track_id
|
|
||||||
|
|
||||||
if largest_id is not None:
|
|
||||||
self.active_target_id = largest_id
|
|
||||||
self.searching = False
|
|
||||||
else:
|
|
||||||
self.searching = True
|
|
||||||
|
|
||||||
def get_group_centroid(self) -> Optional[Tuple[float, float, float]]:
|
|
||||||
"""Get centroid (bearing, distance) of all tracked people.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(mean_bearing_rad, mean_distance_m) or None if no people
|
|
||||||
"""
|
|
||||||
if not self.tracked_people:
|
|
||||||
return None
|
|
||||||
|
|
||||||
bearings = []
|
|
||||||
distances = []
|
|
||||||
for person in self.tracked_people.values():
|
|
||||||
if not person.is_occluded and person.distance_m > 0:
|
|
||||||
bearings.append(person.bearing_rad)
|
|
||||||
distances.append(person.distance_m)
|
|
||||||
|
|
||||||
if not bearings:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Mean bearing (circular mean)
|
|
||||||
mean_bearing = float(np.mean(bearings))
|
|
||||||
mean_distance = float(np.mean(distances))
|
|
||||||
return mean_bearing, mean_distance
|
|
||||||
|
|
||||||
# ── Pruning and cleanup ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _prune_lost_people(self) -> None:
|
|
||||||
"""Remove people lost for more than occlusion grace period."""
|
|
||||||
current_time = time.time()
|
|
||||||
to_remove = []
|
|
||||||
|
|
||||||
for track_id, person in self.tracked_people.items():
|
|
||||||
# Check occlusion timeout
|
|
||||||
if person.is_occluded:
|
|
||||||
elapsed = current_time - person.occlusion_start_time
|
|
||||||
if elapsed > self.occlusion_grace_s:
|
|
||||||
to_remove.append(track_id)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check long-term invisibility (10s)
|
|
||||||
if person.age_seconds() > 10.0:
|
|
||||||
to_remove.append(track_id)
|
|
||||||
|
|
||||||
for track_id in to_remove:
|
|
||||||
del self.tracked_people[track_id]
|
|
||||||
if self.active_target_id == track_id:
|
|
||||||
self.active_target_id = None
|
|
||||||
self.searching = True
|
|
||||||
@ -1,185 +0,0 @@
|
|||||||
"""multi_person_tracker_node.py — Multi-person tracking node for Issue #423.
|
|
||||||
|
|
||||||
Subscribes to person detections with re-ID embeddings and publishes:
|
|
||||||
/saltybot/tracked_people (PersonArray) — all tracked people + active target
|
|
||||||
/saltybot/follow_target (Person) — the current follow target
|
|
||||||
|
|
||||||
Implements:
|
|
||||||
- Up to 10 tracked people
|
|
||||||
- Target priority: wake-word speaker > closest > largest bbox
|
|
||||||
- Occlusion handoff (3s grace)
|
|
||||||
- Re-ID via face embedding + HSV
|
|
||||||
- Group detection
|
|
||||||
- Lost target: SEARCHING state
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import time
|
|
||||||
import numpy as np
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
import rclpy
|
|
||||||
from rclpy.node import Node
|
|
||||||
from rclpy.qos import QoSProfile, QoSReliabilityPolicy
|
|
||||||
from rclpy.duration import Duration
|
|
||||||
|
|
||||||
from std_msgs.msg import String
|
|
||||||
from geometry_msgs.msg import Twist
|
|
||||||
from saltybot_social_msgs.msg import Person, PersonArray
|
|
||||||
|
|
||||||
from .multi_person_tracker import MultiPersonTracker, TrackedPerson
|
|
||||||
|
|
||||||
|
|
||||||
class MultiPersonTrackerNode(Node):
|
|
||||||
"""ROS2 node for multi-person tracking."""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__("multi_person_tracker_node")
|
|
||||||
|
|
||||||
# ── Parameters ──────────────────────────────────────────────────────
|
|
||||||
self.declare_parameter("max_people", 10)
|
|
||||||
self.declare_parameter("occlusion_grace_s", 3.0)
|
|
||||||
self.declare_parameter("embedding_threshold", 0.75)
|
|
||||||
self.declare_parameter("hsv_threshold", 0.60)
|
|
||||||
self.declare_parameter("publish_hz", 15.0)
|
|
||||||
self.declare_parameter("enable_group_follow", True)
|
|
||||||
self.declare_parameter("announce_state", True)
|
|
||||||
|
|
||||||
max_people = self.get_parameter("max_people").value
|
|
||||||
occlusion_grace = self.get_parameter("occlusion_grace_s").value
|
|
||||||
embed_threshold = self.get_parameter("embedding_threshold").value
|
|
||||||
hsv_threshold = self.get_parameter("hsv_threshold").value
|
|
||||||
publish_hz = self.get_parameter("publish_hz").value
|
|
||||||
self._enable_group = self.get_parameter("enable_group_follow").value
|
|
||||||
self._announce = self.get_parameter("announce_state").value
|
|
||||||
|
|
||||||
# ── Tracker instance ────────────────────────────────────────────────
|
|
||||||
self._tracker = MultiPersonTracker(
|
|
||||||
max_people=max_people,
|
|
||||||
occlusion_grace_s=occlusion_grace,
|
|
||||||
embedding_similarity_threshold=embed_threshold,
|
|
||||||
hsv_similarity_threshold=hsv_threshold,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Reliable QoS ────────────────────────────────────────────────────
|
|
||||||
qos = QoSProfile(depth=5)
|
|
||||||
qos.reliability = QoSReliabilityPolicy.RELIABLE
|
|
||||||
|
|
||||||
self._tracked_pub = self.create_publisher(
|
|
||||||
PersonArray,
|
|
||||||
"/saltybot/tracked_people",
|
|
||||||
qos,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._target_pub = self.create_publisher(
|
|
||||||
Person,
|
|
||||||
"/saltybot/follow_target",
|
|
||||||
qos,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._state_pub = self.create_publisher(
|
|
||||||
String,
|
|
||||||
"/saltybot/tracker_state",
|
|
||||||
qos,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._cmd_vel_pub = self.create_publisher(
|
|
||||||
Twist,
|
|
||||||
"/saltybot/cmd_vel",
|
|
||||||
qos,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Placeholder: in real implementation, subscribe to detection topics
|
|
||||||
# For now, we'll use a timer to demonstrate the node structure
|
|
||||||
self._update_timer = self.create_timer(1.0 / publish_hz, self._update_tracker)
|
|
||||||
|
|
||||||
self._last_state = "IDLE"
|
|
||||||
self.get_logger().info(
|
|
||||||
f"multi_person_tracker_node ready (max_people={max_people}, "
|
|
||||||
f"occlusion_grace={occlusion_grace}s, publish_hz={publish_hz} Hz)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _update_tracker(self) -> None:
|
|
||||||
"""Update tracker and publish results."""
|
|
||||||
# In real implementation, this would be triggered by detection callbacks
|
|
||||||
# For now, we'll use the timer and publish empty/demo data
|
|
||||||
|
|
||||||
# Get current tracked people and active target
|
|
||||||
tracked_people = list(self._tracker.tracked_people.values())
|
|
||||||
active_target_id = self._tracker.active_target_id
|
|
||||||
|
|
||||||
# Build PersonArray message
|
|
||||||
person_array_msg = PersonArray()
|
|
||||||
person_array_msg.header.stamp = self.get_clock().now().to_msg()
|
|
||||||
person_array_msg.header.frame_id = "base_link"
|
|
||||||
person_array_msg.active_id = active_target_id if active_target_id is not None else -1
|
|
||||||
|
|
||||||
for tracked in tracked_people:
|
|
||||||
person_msg = Person()
|
|
||||||
person_msg.header = person_array_msg.header
|
|
||||||
person_msg.track_id = tracked.track_id
|
|
||||||
person_msg.bearing_rad = tracked.bearing_rad
|
|
||||||
person_msg.distance_m = tracked.distance_m
|
|
||||||
person_msg.confidence = tracked.confidence
|
|
||||||
person_msg.is_speaking = tracked.is_speaking
|
|
||||||
person_msg.source = tracked.source
|
|
||||||
person_array_msg.persons.append(person_msg)
|
|
||||||
|
|
||||||
self._tracked_pub.publish(person_array_msg)
|
|
||||||
|
|
||||||
# Publish active target
|
|
||||||
if active_target_id is not None and active_target_id in self._tracker.tracked_people:
|
|
||||||
target = self._tracker.tracked_people[active_target_id]
|
|
||||||
target_msg = Person()
|
|
||||||
target_msg.header = person_array_msg.header
|
|
||||||
target_msg.track_id = target.track_id
|
|
||||||
target_msg.bearing_rad = target.bearing_rad
|
|
||||||
target_msg.distance_m = target.distance_m
|
|
||||||
target_msg.confidence = target.confidence
|
|
||||||
target_msg.is_speaking = target.is_speaking
|
|
||||||
target_msg.source = target.source
|
|
||||||
self._target_pub.publish(target_msg)
|
|
||||||
|
|
||||||
# Publish tracker state
|
|
||||||
state = "SEARCHING" if self._tracker.searching else "TRACKING"
|
|
||||||
if state != self._last_state and self._announce:
|
|
||||||
self.get_logger().info(f"State: {state}")
|
|
||||||
self._last_state = state
|
|
||||||
state_msg = String()
|
|
||||||
state_msg.data = state
|
|
||||||
self._state_pub.publish(state_msg)
|
|
||||||
|
|
||||||
# Handle lost target
|
|
||||||
if state == "SEARCHING":
|
|
||||||
self._handle_lost_target()
|
|
||||||
|
|
||||||
def _handle_lost_target(self) -> None:
|
|
||||||
"""Execute behavior when target is lost: stop + rotate."""
|
|
||||||
# Publish stop command
|
|
||||||
stop_twist = Twist()
|
|
||||||
self._cmd_vel_pub.publish(stop_twist)
|
|
||||||
|
|
||||||
# Slow rotating search
|
|
||||||
search_twist = Twist()
|
|
||||||
search_twist.angular.z = 0.5 # rad/s
|
|
||||||
self._cmd_vel_pub.publish(search_twist)
|
|
||||||
|
|
||||||
if self._announce:
|
|
||||||
self.get_logger().warn("Lost target, searching...")
|
|
||||||
|
|
||||||
|
|
||||||
def main(args=None) -> None:
|
|
||||||
rclpy.init(args=args)
|
|
||||||
node = MultiPersonTrackerNode()
|
|
||||||
try:
|
|
||||||
rclpy.spin(node)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
node.destroy_node()
|
|
||||||
rclpy.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@ -1,5 +0,0 @@
|
|||||||
[develop]
|
|
||||||
script_dir=$base/lib/saltybot_multi_person_tracker
|
|
||||||
|
|
||||||
[install]
|
|
||||||
install_scripts=$base/lib/saltybot_multi_person_tracker
|
|
||||||
@ -1,29 +0,0 @@
|
|||||||
from setuptools import setup
|
|
||||||
|
|
||||||
package_name = 'saltybot_multi_person_tracker'
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name=package_name,
|
|
||||||
version='0.1.0',
|
|
||||||
packages=[package_name],
|
|
||||||
data_files=[
|
|
||||||
('share/ament_index/resource_index/packages',
|
|
||||||
['resource/' + package_name]),
|
|
||||||
('share/' + package_name, ['package.xml']),
|
|
||||||
],
|
|
||||||
install_requires=['setuptools'],
|
|
||||||
zip_safe=True,
|
|
||||||
author='sl-perception',
|
|
||||||
author_email='sl-perception@saltylab.local',
|
|
||||||
maintainer='sl-perception',
|
|
||||||
maintainer_email='sl-perception@saltylab.local',
|
|
||||||
url='https://gitea.vayrette.com/seb/saltylab-firmware',
|
|
||||||
description='Multi-person tracker with group handling and target priority',
|
|
||||||
license='MIT',
|
|
||||||
tests_require=['pytest'],
|
|
||||||
entry_points={
|
|
||||||
'console_scripts': [
|
|
||||||
'multi_person_tracker_node = saltybot_multi_person_tracker.multi_person_tracker_node:main',
|
|
||||||
],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@ -1,122 +0,0 @@
|
|||||||
"""Tests for multi-person tracker (Issue #423)."""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from saltybot_multi_person_tracker.multi_person_tracker import (
|
|
||||||
MultiPersonTracker,
|
|
||||||
TrackedPerson,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_new_tracks():
|
|
||||||
"""Test creating new tracks for detections."""
|
|
||||||
tracker = MultiPersonTracker(max_people=10)
|
|
||||||
|
|
||||||
detections = [
|
|
||||||
{"bearing_rad": 0.1, "distance_m": 1.5, "confidence": 0.9},
|
|
||||||
{"bearing_rad": -0.2, "distance_m": 2.0, "confidence": 0.85},
|
|
||||||
]
|
|
||||||
|
|
||||||
tracked, target_id = tracker.update(detections)
|
|
||||||
|
|
||||||
assert len(tracked) == 2
|
|
||||||
assert tracked[0].track_id == 0
|
|
||||||
assert tracked[1].track_id == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_match_existing_tracks():
|
|
||||||
"""Test matching detections to existing tracks."""
|
|
||||||
tracker = MultiPersonTracker(max_people=10)
|
|
||||||
|
|
||||||
# Create initial tracks
|
|
||||||
detections1 = [
|
|
||||||
{"bearing_rad": 0.1, "distance_m": 1.5, "confidence": 0.9},
|
|
||||||
{"bearing_rad": -0.2, "distance_m": 2.0, "confidence": 0.85},
|
|
||||||
]
|
|
||||||
tracked1, _ = tracker.update(detections1)
|
|
||||||
|
|
||||||
# Update with similar positions
|
|
||||||
detections2 = [
|
|
||||||
{"bearing_rad": 0.12, "distance_m": 1.48, "confidence": 0.88},
|
|
||||||
{"bearing_rad": -0.22, "distance_m": 2.05, "confidence": 0.84},
|
|
||||||
]
|
|
||||||
tracked2, _ = tracker.update(detections2)
|
|
||||||
|
|
||||||
# Should have same IDs (matched)
|
|
||||||
assert len(tracked2) == 2
|
|
||||||
ids_after = {p.track_id for p in tracked2}
|
|
||||||
assert ids_after == {0, 1}
|
|
||||||
|
|
||||||
|
|
||||||
def test_max_people_limit():
|
|
||||||
"""Test that tracker respects max_people limit."""
|
|
||||||
tracker = MultiPersonTracker(max_people=3)
|
|
||||||
|
|
||||||
detections = [
|
|
||||||
{"bearing_rad": 0.1 * i, "distance_m": 1.5 + i * 0.2, "confidence": 0.9}
|
|
||||||
for i in range(5)
|
|
||||||
]
|
|
||||||
|
|
||||||
tracked, _ = tracker.update(detections)
|
|
||||||
assert len(tracked) <= 3
|
|
||||||
|
|
||||||
|
|
||||||
def test_wake_word_speaker_priority():
|
|
||||||
"""Test that wake-word speaker gets highest priority."""
|
|
||||||
tracker = MultiPersonTracker(max_people=10)
|
|
||||||
|
|
||||||
detections = [
|
|
||||||
{"bearing_rad": 0.1, "distance_m": 5.0, "confidence": 0.9}, # Far
|
|
||||||
{"bearing_rad": -0.2, "distance_m": 1.0, "confidence": 0.85}, # Closest
|
|
||||||
{"bearing_rad": 0.3, "distance_m": 2.0, "confidence": 0.8}, # Farthest
|
|
||||||
]
|
|
||||||
|
|
||||||
tracked, _ = tracker.update(detections)
|
|
||||||
|
|
||||||
# Without speaker: should pick closest
|
|
||||||
assert tracker.active_target_id == 1 # Closest person
|
|
||||||
|
|
||||||
# With speaker at ID 0: should pick speaker
|
|
||||||
tracked, target = tracker.update(detections, speaking_person_id=0)
|
|
||||||
assert target == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_group_centroid():
|
|
||||||
"""Test group centroid calculation."""
|
|
||||||
tracker = MultiPersonTracker(max_people=10)
|
|
||||||
|
|
||||||
detections = [
|
|
||||||
{"bearing_rad": 0.0, "distance_m": 1.0, "confidence": 0.9},
|
|
||||||
{"bearing_rad": 0.2, "distance_m": 2.0, "confidence": 0.85},
|
|
||||||
{"bearing_rad": -0.2, "distance_m": 1.5, "confidence": 0.8},
|
|
||||||
]
|
|
||||||
|
|
||||||
tracked, _ = tracker.update(detections)
|
|
||||||
|
|
||||||
centroid = tracker.get_group_centroid()
|
|
||||||
assert centroid is not None
|
|
||||||
mean_bearing, mean_distance = centroid
|
|
||||||
# Mean bearing should be close to 0
|
|
||||||
assert abs(mean_bearing) < 0.1
|
|
||||||
# Mean distance should be between 1.0 and 2.0
|
|
||||||
assert 1.0 < mean_distance < 2.0
|
|
||||||
|
|
||||||
|
|
||||||
def test_empty_tracker():
|
|
||||||
"""Test tracker with no detections."""
|
|
||||||
tracker = MultiPersonTracker()
|
|
||||||
|
|
||||||
tracked, target = tracker.update([])
|
|
||||||
|
|
||||||
assert len(tracked) == 0
|
|
||||||
assert target is None
|
|
||||||
assert tracker.searching is True
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_create_new_tracks()
|
|
||||||
test_match_existing_tracks()
|
|
||||||
test_max_people_limit()
|
|
||||||
test_wake_word_speaker_priority()
|
|
||||||
test_group_centroid()
|
|
||||||
test_empty_tracker()
|
|
||||||
print("All tests passed!")
|
|
||||||
Loading…
x
Reference in New Issue
Block a user