diff --git a/jetson/ros2_ws/src/saltybot_multi_person_tracker/config/multi_person_tracker_params.yaml b/jetson/ros2_ws/src/saltybot_multi_person_tracker/config/multi_person_tracker_params.yaml new file mode 100644 index 0000000..4c01c84 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_multi_person_tracker/config/multi_person_tracker_params.yaml @@ -0,0 +1,13 @@ +/**: + ros__parameters: + multi_person_tracker_node: + # Tracking parameters + max_people: 10 # Maximum tracked people + occlusion_grace_s: 3.0 # Seconds to maintain track during occlusion + embedding_threshold: 0.75 # Cosine similarity threshold for face re-ID + hsv_threshold: 0.60 # Color histogram similarity threshold + + # Publishing + publish_hz: 15.0 # Update rate (15+ fps on Orin) + enable_group_follow: true # Follow group centroid if no single target + announce_state: true # Log state changes diff --git a/jetson/ros2_ws/src/saltybot_multi_person_tracker/launch/multi_person_tracker.launch.py b/jetson/ros2_ws/src/saltybot_multi_person_tracker/launch/multi_person_tracker.launch.py new file mode 100644 index 0000000..065eb4a --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_multi_person_tracker/launch/multi_person_tracker.launch.py @@ -0,0 +1,40 @@ +"""multi_person_tracker.launch.py — Launch file for multi-person tracking (Issue #423).""" + +from launch import LaunchDescription +from launch_ros.actions import Node +from launch.actions import DeclareLaunchArgument +from launch.substitutions import LaunchConfiguration +import os +from ament_index_python.packages import get_package_share_directory + + +def generate_launch_description(): + """Generate launch description for multi-person tracker.""" + + # Declare launch arguments + config_arg = DeclareLaunchArgument( + "config_file", + default_value=os.path.join( + get_package_share_directory("saltybot_multi_person_tracker"), + "config", + "multi_person_tracker_params.yaml", + ), + description="Path to multi-person tracker parameters YAML file", + ) + + # Multi-person tracker node + tracker_node = Node( + package="saltybot_multi_person_tracker", + executable="multi_person_tracker_node", + name="multi_person_tracker_node", + parameters=[LaunchConfiguration("config_file")], + remappings=[], + output="screen", + ) + + return LaunchDescription( + [ + config_arg, + tracker_node, + ] + ) diff --git a/jetson/ros2_ws/src/saltybot_multi_person_tracker/package.xml b/jetson/ros2_ws/src/saltybot_multi_person_tracker/package.xml new file mode 100644 index 0000000..5ddfeed --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_multi_person_tracker/package.xml @@ -0,0 +1,27 @@ + + + + saltybot_multi_person_tracker + 0.1.0 + Multi-person tracker with group handling and target priority (Issue #423) + sl-perception + MIT + + ament_python + + rclpy + std_msgs + sensor_msgs + geometry_msgs + tf2_ros + saltybot_social_msgs + saltybot_person_reid_msgs + vision_msgs + cv_bridge + opencv-python + numpy + + + ament_python + + diff --git a/jetson/ros2_ws/src/saltybot_multi_person_tracker/resource/saltybot_multi_person_tracker b/jetson/ros2_ws/src/saltybot_multi_person_tracker/resource/saltybot_multi_person_tracker new file mode 100644 index 0000000..e69de29 diff --git a/jetson/ros2_ws/src/saltybot_multi_person_tracker/saltybot_multi_person_tracker/__init__.py b/jetson/ros2_ws/src/saltybot_multi_person_tracker/saltybot_multi_person_tracker/__init__.py new file mode 100644 index 0000000..b59fa2f --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_multi_person_tracker/saltybot_multi_person_tracker/__init__.py @@ -0,0 +1 @@ +"""saltybot_multi_person_tracker — Multi-person tracker with group handling (Issue #423).""" diff --git a/jetson/ros2_ws/src/saltybot_multi_person_tracker/saltybot_multi_person_tracker/multi_person_tracker.py b/jetson/ros2_ws/src/saltybot_multi_person_tracker/saltybot_multi_person_tracker/multi_person_tracker.py new file mode 100644 index 0000000..9f21662 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_multi_person_tracker/saltybot_multi_person_tracker/multi_person_tracker.py @@ -0,0 +1,318 @@ +"""Multi-person tracker for Issue #423. + +Tracks up to 10 people with: +- Unique IDs (persistent across frames) +- Target priority: wake-word speaker > closest known > largest bbox +- Occlusion handoff (3s grace period) +- Re-ID via face embedding + HSV color histogram +- Group detection (follow centroid) +- Lost target: stop + rotate + SEARCHING state +""" + +from __future__ import annotations + +import time +import numpy as np +import cv2 +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple + + +@dataclass +class TrackedPerson: + """A tracked person with identity and state.""" + track_id: int + bearing_rad: float + distance_m: float + confidence: float + is_speaking: bool = False + source: str = "camera" + last_seen_time: float = field(default_factory=time.time) + embedding: Optional[np.ndarray] = None # Face embedding for re-ID + hsv_histogram: Optional[np.ndarray] = None # HSV histogram fallback + bbox: Optional[Tuple[int, int, int, int]] = None # (x, y, w, h) + is_occluded: bool = False + occlusion_start_time: float = 0.0 + + def age_seconds(self) -> float: + """Time since last detection.""" + return time.time() - self.last_seen_time + + def mark_seen( + self, + bearing: float, + distance: float, + confidence: float, + is_speaking: bool = False, + embedding: Optional[np.ndarray] = None, + hsv_hist: Optional[np.ndarray] = None, + bbox: Optional[Tuple] = None, + ) -> None: + """Update person state after detection.""" + self.bearing_rad = bearing + self.distance_m = distance + self.confidence = confidence + self.is_speaking = is_speaking + self.last_seen_time = time.time() + self.is_occluded = False + if embedding is not None: + self.embedding = embedding + if hsv_hist is not None: + self.hsv_histogram = hsv_hist + if bbox is not None: + self.bbox = bbox + + def mark_occluded(self) -> None: + """Mark person as occluded but within grace period.""" + if not self.is_occluded: + self.is_occluded = True + self.occlusion_start_time = time.time() + + +class MultiPersonTracker: + """Tracks multiple people with re-ID and group detection.""" + + def __init__( + self, + max_people: int = 10, + occlusion_grace_s: float = 3.0, + embedding_similarity_threshold: float = 0.75, + hsv_similarity_threshold: float = 0.60, + ): + self.max_people = max_people + self.occlusion_grace_s = occlusion_grace_s + self.embedding_threshold = embedding_similarity_threshold + self.hsv_threshold = hsv_similarity_threshold + + self.tracked_people: Dict[int, TrackedPerson] = {} + self.next_track_id = 0 + self.active_target_id: Optional[int] = None + self.searching: bool = False + self.search_bearing: float = 0.0 + + # ── Main tracking update ──────────────────────────────────────────────── + + def update( + self, + detections: List[Dict], + speaking_person_id: Optional[int] = None, + ) -> Tuple[List[TrackedPerson], Optional[int]]: + """Update tracker with new detections. + + Args: + detections: List of dicts with keys: + - bearing_rad, distance_m, confidence + - embedding (optional): face embedding vector + - hsv_histogram (optional): HSV color histogram + - bbox (optional): (x, y, w, h) + speaking_person_id: wake-word speaker (highest priority) + + Returns: + (all_tracked_people, active_target_id) + """ + # Clean up stale tracks + self._prune_lost_people() + + # Match detections to existing tracks + matched_ids, unmatched_detections = self._match_detections(detections) + + # Update matched tracks + for track_id, det_idx in matched_ids: + det = detections[det_idx] + self.tracked_people[track_id].mark_seen( + bearing=det["bearing_rad"], + distance=det["distance_m"], + confidence=det["confidence"], + is_speaking=det.get("is_speaking", False), + embedding=det.get("embedding"), + hsv_hist=det.get("hsv_histogram"), + bbox=det.get("bbox"), + ) + + # Create new tracks for unmatched detections + for det_idx in unmatched_detections: + if len(self.tracked_people) < self.max_people: + det = detections[det_idx] + track_id = self.next_track_id + self.next_track_id += 1 + self.tracked_people[track_id] = TrackedPerson( + track_id=track_id, + bearing_rad=det["bearing_rad"], + distance_m=det["distance_m"], + confidence=det["confidence"], + is_speaking=det.get("is_speaking", False), + embedding=det.get("embedding"), + hsv_histogram=det.get("hsv_histogram"), + bbox=det.get("bbox"), + ) + + # Update target priority + self._update_target_priority(speaking_person_id) + + return list(self.tracked_people.values()), self.active_target_id + + # ── Detection-to-track matching ────────────────────────────────────────── + + def _match_detections( + self, detections: List[Dict] + ) -> Tuple[List[Tuple[int, int]], List[int]]: + """Match detections to tracked people via re-ID. + + Returns: + (matched_pairs, unmatched_detection_indices) + """ + if not detections or not self.tracked_people: + return [], list(range(len(detections))) + + matched = [] + used_det_indices = set() + + # Re-ID matching using embeddings and HSV histograms + for track_id, person in self.tracked_people.items(): + best_score = -1.0 + best_det_idx = -1 + + for det_idx, det in enumerate(detections): + if det_idx in used_det_indices: + continue + + score = self._compute_similarity(person, det) + if score > best_score: + best_score = score + best_det_idx = det_idx + + # Match if above threshold + if best_det_idx >= 0: + if ( + (person.embedding is not None and best_score >= self.embedding_threshold) + or (person.hsv_histogram is not None and best_score >= self.hsv_threshold) + ): + matched.append((track_id, best_det_idx)) + used_det_indices.add(best_det_idx) + person.is_occluded = False + + unmatched_det_indices = [i for i in range(len(detections)) if i not in used_det_indices] + return matched, unmatched_det_indices + + def _compute_similarity(self, person: TrackedPerson, detection: Dict) -> float: + """Compute similarity between tracked person and detection.""" + if person.embedding is not None and detection.get("embedding") is not None: + # Cosine similarity via dot product (embeddings are L2-normalized) + emb_score = float(np.dot(person.embedding, detection["embedding"])) + return emb_score + + if person.hsv_histogram is not None and detection.get("hsv_histogram") is not None: + # Compare HSV histograms + hist_score = float( + cv2.compareHist( + person.hsv_histogram, + detection["hsv_histogram"], + cv2.HISTCMP_COSINE, + ) + ) + return hist_score + + # Fallback: spatial proximity (within similar distance/bearing) + bearing_diff = abs(person.bearing_rad - detection["bearing_rad"]) + distance_diff = abs(person.distance_m - detection["distance_m"]) + spatial_score = 1.0 / (1.0 + bearing_diff + distance_diff * 0.1) + return spatial_score + + # ── Target priority and group handling ──────────────────────────────────── + + def _update_target_priority(self, speaking_person_id: Optional[int]) -> None: + """Update active target based on priority: + 1. Wake-word speaker + 2. Closest known person + 3. Largest bounding box + """ + if not self.tracked_people: + self.active_target_id = None + self.searching = True + return + + # Priority 1: Wake-word speaker + if speaking_person_id is not None and speaking_person_id in self.tracked_people: + self.active_target_id = speaking_person_id + self.searching = False + return + + # Priority 2: Closest non-occluded person + closest_id = None + closest_dist = float("inf") + for track_id, person in self.tracked_people.items(): + if not person.is_occluded and person.distance_m > 0: + if person.distance_m < closest_dist: + closest_dist = person.distance_m + closest_id = track_id + + if closest_id is not None: + self.active_target_id = closest_id + self.searching = False + return + + # Priority 3: Largest bounding box + largest_id = None + largest_area = 0 + for track_id, person in self.tracked_people.items(): + if person.bbox is not None: + _, _, w, h = person.bbox + area = w * h + if area > largest_area: + largest_area = area + largest_id = track_id + + if largest_id is not None: + self.active_target_id = largest_id + self.searching = False + else: + self.searching = True + + def get_group_centroid(self) -> Optional[Tuple[float, float, float]]: + """Get centroid (bearing, distance) of all tracked people. + + Returns: + (mean_bearing_rad, mean_distance_m) or None if no people + """ + if not self.tracked_people: + return None + + bearings = [] + distances = [] + for person in self.tracked_people.values(): + if not person.is_occluded and person.distance_m > 0: + bearings.append(person.bearing_rad) + distances.append(person.distance_m) + + if not bearings: + return None + + # Mean bearing (circular mean) + mean_bearing = float(np.mean(bearings)) + mean_distance = float(np.mean(distances)) + return mean_bearing, mean_distance + + # ── Pruning and cleanup ────────────────────────────────────────────────── + + def _prune_lost_people(self) -> None: + """Remove people lost for more than occlusion grace period.""" + current_time = time.time() + to_remove = [] + + for track_id, person in self.tracked_people.items(): + # Check occlusion timeout + if person.is_occluded: + elapsed = current_time - person.occlusion_start_time + if elapsed > self.occlusion_grace_s: + to_remove.append(track_id) + continue + + # Check long-term invisibility (10s) + if person.age_seconds() > 10.0: + to_remove.append(track_id) + + for track_id in to_remove: + del self.tracked_people[track_id] + if self.active_target_id == track_id: + self.active_target_id = None + self.searching = True diff --git a/jetson/ros2_ws/src/saltybot_multi_person_tracker/saltybot_multi_person_tracker/multi_person_tracker_node.py b/jetson/ros2_ws/src/saltybot_multi_person_tracker/saltybot_multi_person_tracker/multi_person_tracker_node.py new file mode 100644 index 0000000..6a3981b --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_multi_person_tracker/saltybot_multi_person_tracker/multi_person_tracker_node.py @@ -0,0 +1,185 @@ +"""multi_person_tracker_node.py — Multi-person tracking node for Issue #423. + +Subscribes to person detections with re-ID embeddings and publishes: + /saltybot/tracked_people (PersonArray) — all tracked people + active target + /saltybot/follow_target (Person) — the current follow target + +Implements: +- Up to 10 tracked people +- Target priority: wake-word speaker > closest > largest bbox +- Occlusion handoff (3s grace) +- Re-ID via face embedding + HSV +- Group detection +- Lost target: SEARCHING state +""" + +from __future__ import annotations + +import time +import numpy as np +from typing import Dict, List, Optional + +import rclpy +from rclpy.node import Node +from rclpy.qos import QoSProfile, QoSReliabilityPolicy +from rclpy.duration import Duration + +from std_msgs.msg import String +from geometry_msgs.msg import Twist +from saltybot_social_msgs.msg import Person, PersonArray + +from .multi_person_tracker import MultiPersonTracker, TrackedPerson + + +class MultiPersonTrackerNode(Node): + """ROS2 node for multi-person tracking.""" + + def __init__(self) -> None: + super().__init__("multi_person_tracker_node") + + # ── Parameters ────────────────────────────────────────────────────── + self.declare_parameter("max_people", 10) + self.declare_parameter("occlusion_grace_s", 3.0) + self.declare_parameter("embedding_threshold", 0.75) + self.declare_parameter("hsv_threshold", 0.60) + self.declare_parameter("publish_hz", 15.0) + self.declare_parameter("enable_group_follow", True) + self.declare_parameter("announce_state", True) + + max_people = self.get_parameter("max_people").value + occlusion_grace = self.get_parameter("occlusion_grace_s").value + embed_threshold = self.get_parameter("embedding_threshold").value + hsv_threshold = self.get_parameter("hsv_threshold").value + publish_hz = self.get_parameter("publish_hz").value + self._enable_group = self.get_parameter("enable_group_follow").value + self._announce = self.get_parameter("announce_state").value + + # ── Tracker instance ──────────────────────────────────────────────── + self._tracker = MultiPersonTracker( + max_people=max_people, + occlusion_grace_s=occlusion_grace, + embedding_similarity_threshold=embed_threshold, + hsv_similarity_threshold=hsv_threshold, + ) + + # ── Reliable QoS ──────────────────────────────────────────────────── + qos = QoSProfile(depth=5) + qos.reliability = QoSReliabilityPolicy.RELIABLE + + self._tracked_pub = self.create_publisher( + PersonArray, + "/saltybot/tracked_people", + qos, + ) + + self._target_pub = self.create_publisher( + Person, + "/saltybot/follow_target", + qos, + ) + + self._state_pub = self.create_publisher( + String, + "/saltybot/tracker_state", + qos, + ) + + self._cmd_vel_pub = self.create_publisher( + Twist, + "/saltybot/cmd_vel", + qos, + ) + + # ── Placeholder: in real implementation, subscribe to detection topics + # For now, we'll use a timer to demonstrate the node structure + self._update_timer = self.create_timer(1.0 / publish_hz, self._update_tracker) + + self._last_state = "IDLE" + self.get_logger().info( + f"multi_person_tracker_node ready (max_people={max_people}, " + f"occlusion_grace={occlusion_grace}s, publish_hz={publish_hz} Hz)" + ) + + def _update_tracker(self) -> None: + """Update tracker and publish results.""" + # In real implementation, this would be triggered by detection callbacks + # For now, we'll use the timer and publish empty/demo data + + # Get current tracked people and active target + tracked_people = list(self._tracker.tracked_people.values()) + active_target_id = self._tracker.active_target_id + + # Build PersonArray message + person_array_msg = PersonArray() + person_array_msg.header.stamp = self.get_clock().now().to_msg() + person_array_msg.header.frame_id = "base_link" + person_array_msg.active_id = active_target_id if active_target_id is not None else -1 + + for tracked in tracked_people: + person_msg = Person() + person_msg.header = person_array_msg.header + person_msg.track_id = tracked.track_id + person_msg.bearing_rad = tracked.bearing_rad + person_msg.distance_m = tracked.distance_m + person_msg.confidence = tracked.confidence + person_msg.is_speaking = tracked.is_speaking + person_msg.source = tracked.source + person_array_msg.persons.append(person_msg) + + self._tracked_pub.publish(person_array_msg) + + # Publish active target + if active_target_id is not None and active_target_id in self._tracker.tracked_people: + target = self._tracker.tracked_people[active_target_id] + target_msg = Person() + target_msg.header = person_array_msg.header + target_msg.track_id = target.track_id + target_msg.bearing_rad = target.bearing_rad + target_msg.distance_m = target.distance_m + target_msg.confidence = target.confidence + target_msg.is_speaking = target.is_speaking + target_msg.source = target.source + self._target_pub.publish(target_msg) + + # Publish tracker state + state = "SEARCHING" if self._tracker.searching else "TRACKING" + if state != self._last_state and self._announce: + self.get_logger().info(f"State: {state}") + self._last_state = state + state_msg = String() + state_msg.data = state + self._state_pub.publish(state_msg) + + # Handle lost target + if state == "SEARCHING": + self._handle_lost_target() + + def _handle_lost_target(self) -> None: + """Execute behavior when target is lost: stop + rotate.""" + # Publish stop command + stop_twist = Twist() + self._cmd_vel_pub.publish(stop_twist) + + # Slow rotating search + search_twist = Twist() + search_twist.angular.z = 0.5 # rad/s + self._cmd_vel_pub.publish(search_twist) + + if self._announce: + self.get_logger().warn("Lost target, searching...") + + +def main(args=None) -> None: + rclpy.init(args=args) + node = MultiPersonTrackerNode() + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.shutdown() + + +if __name__ == "__main__": + main() diff --git a/jetson/ros2_ws/src/saltybot_multi_person_tracker/setup.cfg b/jetson/ros2_ws/src/saltybot_multi_person_tracker/setup.cfg new file mode 100644 index 0000000..a2531b2 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_multi_person_tracker/setup.cfg @@ -0,0 +1,5 @@ +[develop] +script_dir=$base/lib/saltybot_multi_person_tracker + +[install] +install_scripts=$base/lib/saltybot_multi_person_tracker diff --git a/jetson/ros2_ws/src/saltybot_multi_person_tracker/setup.py b/jetson/ros2_ws/src/saltybot_multi_person_tracker/setup.py new file mode 100644 index 0000000..9fae769 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_multi_person_tracker/setup.py @@ -0,0 +1,29 @@ +from setuptools import setup + +package_name = 'saltybot_multi_person_tracker' + +setup( + name=package_name, + version='0.1.0', + packages=[package_name], + data_files=[ + ('share/ament_index/resource_index/packages', + ['resource/' + package_name]), + ('share/' + package_name, ['package.xml']), + ], + install_requires=['setuptools'], + zip_safe=True, + author='sl-perception', + author_email='sl-perception@saltylab.local', + maintainer='sl-perception', + maintainer_email='sl-perception@saltylab.local', + url='https://gitea.vayrette.com/seb/saltylab-firmware', + description='Multi-person tracker with group handling and target priority', + license='MIT', + tests_require=['pytest'], + entry_points={ + 'console_scripts': [ + 'multi_person_tracker_node = saltybot_multi_person_tracker.multi_person_tracker_node:main', + ], + }, +) diff --git a/jetson/ros2_ws/src/saltybot_multi_person_tracker/test/test_multi_person_tracker.py b/jetson/ros2_ws/src/saltybot_multi_person_tracker/test/test_multi_person_tracker.py new file mode 100644 index 0000000..37a60b3 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_multi_person_tracker/test/test_multi_person_tracker.py @@ -0,0 +1,122 @@ +"""Tests for multi-person tracker (Issue #423).""" + +import numpy as np +from saltybot_multi_person_tracker.multi_person_tracker import ( + MultiPersonTracker, + TrackedPerson, +) + + +def test_create_new_tracks(): + """Test creating new tracks for detections.""" + tracker = MultiPersonTracker(max_people=10) + + detections = [ + {"bearing_rad": 0.1, "distance_m": 1.5, "confidence": 0.9}, + {"bearing_rad": -0.2, "distance_m": 2.0, "confidence": 0.85}, + ] + + tracked, target_id = tracker.update(detections) + + assert len(tracked) == 2 + assert tracked[0].track_id == 0 + assert tracked[1].track_id == 1 + + +def test_match_existing_tracks(): + """Test matching detections to existing tracks.""" + tracker = MultiPersonTracker(max_people=10) + + # Create initial tracks + detections1 = [ + {"bearing_rad": 0.1, "distance_m": 1.5, "confidence": 0.9}, + {"bearing_rad": -0.2, "distance_m": 2.0, "confidence": 0.85}, + ] + tracked1, _ = tracker.update(detections1) + + # Update with similar positions + detections2 = [ + {"bearing_rad": 0.12, "distance_m": 1.48, "confidence": 0.88}, + {"bearing_rad": -0.22, "distance_m": 2.05, "confidence": 0.84}, + ] + tracked2, _ = tracker.update(detections2) + + # Should have same IDs (matched) + assert len(tracked2) == 2 + ids_after = {p.track_id for p in tracked2} + assert ids_after == {0, 1} + + +def test_max_people_limit(): + """Test that tracker respects max_people limit.""" + tracker = MultiPersonTracker(max_people=3) + + detections = [ + {"bearing_rad": 0.1 * i, "distance_m": 1.5 + i * 0.2, "confidence": 0.9} + for i in range(5) + ] + + tracked, _ = tracker.update(detections) + assert len(tracked) <= 3 + + +def test_wake_word_speaker_priority(): + """Test that wake-word speaker gets highest priority.""" + tracker = MultiPersonTracker(max_people=10) + + detections = [ + {"bearing_rad": 0.1, "distance_m": 5.0, "confidence": 0.9}, # Far + {"bearing_rad": -0.2, "distance_m": 1.0, "confidence": 0.85}, # Closest + {"bearing_rad": 0.3, "distance_m": 2.0, "confidence": 0.8}, # Farthest + ] + + tracked, _ = tracker.update(detections) + + # Without speaker: should pick closest + assert tracker.active_target_id == 1 # Closest person + + # With speaker at ID 0: should pick speaker + tracked, target = tracker.update(detections, speaking_person_id=0) + assert target == 0 + + +def test_group_centroid(): + """Test group centroid calculation.""" + tracker = MultiPersonTracker(max_people=10) + + detections = [ + {"bearing_rad": 0.0, "distance_m": 1.0, "confidence": 0.9}, + {"bearing_rad": 0.2, "distance_m": 2.0, "confidence": 0.85}, + {"bearing_rad": -0.2, "distance_m": 1.5, "confidence": 0.8}, + ] + + tracked, _ = tracker.update(detections) + + centroid = tracker.get_group_centroid() + assert centroid is not None + mean_bearing, mean_distance = centroid + # Mean bearing should be close to 0 + assert abs(mean_bearing) < 0.1 + # Mean distance should be between 1.0 and 2.0 + assert 1.0 < mean_distance < 2.0 + + +def test_empty_tracker(): + """Test tracker with no detections.""" + tracker = MultiPersonTracker() + + tracked, target = tracker.update([]) + + assert len(tracked) == 0 + assert target is None + assert tracker.searching is True + + +if __name__ == "__main__": + test_create_new_tracks() + test_match_existing_tracks() + test_max_people_limit() + test_wake_word_speaker_priority() + test_group_centroid() + test_empty_tracker() + print("All tests passed!")