From c44a30561a5215466ca7b695df2d8cc5e27b3d60 Mon Sep 17 00:00:00 2001 From: sl-jetson Date: Sat, 28 Feb 2026 23:21:24 -0500 Subject: [PATCH] feat: person detection + tracking (YOLOv8n TensorRT) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New package: saltybot_perception person_detector_node.py: - Subscribes /camera/color/image_raw + /camera/depth/image_rect_raw (ApproximateTimeSynchronizer, slop=50ms) - Subscribes /camera/color/camera_info for intrinsics - YOLOv8n inference via TensorRT FP16 engine (Orin Nano 67 TOPS) Falls back to ONNX Runtime when engine not found (dev/CI) - Letterbox preprocessing (640x640), YOLOv8n post-process + NMS - Median-window depth lookup at bbox centre (7x7 px) - Back-projects 2D pixel + depth to 3D point in camera frame - tf2 transform to base_link (fallback: camera_color_optical_frame) - Publishes: /person/detections vision_msgs/Detection2DArray all persons /person/target geometry_msgs/PoseStamped tracked person 3D /person/debug_image sensor_msgs/Image (optional) tracker.py — SimplePersonTracker: - Single-target IoU-based tracker - Picks closest valid person (smallest depth) on first lock - Re-associates across frames using IoU threshold - Holds last known position for configurable duration (default 2s) - Monotonically increasing track IDs detection_utils.py — pure helpers (no ROS2 deps, testable standalone): - nms(), letterbox(), remap_bbox(), get_depth_at(), pixel_to_3d() scripts/build_trt_engine.py: - Converts ONNX to TensorRT FP16 engine using TRT Python API - Prints trtexec CLI alternative - Includes YOLOv8n download instructions config/person_detection_params.yaml: - confidence_threshold: 0.40, min_depth: 0.5m, max_depth: 5.0m - track_hold_duration: 2.0s, target_frame: base_link launch/person_detection.launch.py: - engine_path, onnx_path, publish_debug_image, target_frame overridable Tests: 26/26 passing (test_tracker.py + test_postprocess.py) - IoU computation, NMS suppression, tracker state machine, depth filtering, hold duration, re-association, track ID Co-Authored-By: Claude Sonnet 4.6 --- .../src/saltybot_perception/.gitignore | 11 + .../config/person_detection_params.yaml | 37 ++ .../launch/person_detection.launch.py | 98 ++++ .../src/saltybot_perception/models/.gitkeep | 0 .../src/saltybot_perception/package.xml | 42 ++ .../resource/saltybot_perception | 0 .../saltybot_perception/__init__.py | 0 .../saltybot_perception/detection_utils.py | 124 +++++ .../person_detector_node.py | 471 ++++++++++++++++++ .../saltybot_perception/tracker.py | 179 +++++++ .../scripts/build_trt_engine.py | 162 ++++++ .../ros2_ws/src/saltybot_perception/setup.cfg | 4 + .../ros2_ws/src/saltybot_perception/setup.py | 32 ++ .../test/test_postprocess.py | 67 +++ .../saltybot_perception/test/test_tracker.py | 156 ++++++ 15 files changed, 1383 insertions(+) create mode 100644 jetson/ros2_ws/src/saltybot_perception/.gitignore create mode 100644 jetson/ros2_ws/src/saltybot_perception/config/person_detection_params.yaml create mode 100644 jetson/ros2_ws/src/saltybot_perception/launch/person_detection.launch.py create mode 100644 jetson/ros2_ws/src/saltybot_perception/models/.gitkeep create mode 100644 jetson/ros2_ws/src/saltybot_perception/package.xml create mode 100644 jetson/ros2_ws/src/saltybot_perception/resource/saltybot_perception create mode 100644 jetson/ros2_ws/src/saltybot_perception/saltybot_perception/__init__.py create mode 100644 jetson/ros2_ws/src/saltybot_perception/saltybot_perception/detection_utils.py create mode 100644 jetson/ros2_ws/src/saltybot_perception/saltybot_perception/person_detector_node.py create mode 100644 jetson/ros2_ws/src/saltybot_perception/saltybot_perception/tracker.py create mode 100644 jetson/ros2_ws/src/saltybot_perception/scripts/build_trt_engine.py create mode 100644 jetson/ros2_ws/src/saltybot_perception/setup.cfg create mode 100644 jetson/ros2_ws/src/saltybot_perception/setup.py create mode 100644 jetson/ros2_ws/src/saltybot_perception/test/test_postprocess.py create mode 100644 jetson/ros2_ws/src/saltybot_perception/test/test_tracker.py diff --git a/jetson/ros2_ws/src/saltybot_perception/.gitignore b/jetson/ros2_ws/src/saltybot_perception/.gitignore new file mode 100644 index 0000000..cc1d3c5 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_perception/.gitignore @@ -0,0 +1,11 @@ +# TensorRT engines are hardware-specific — don't commit them +models/*.engine +models/*.onnx + +# Python bytecode +__pycache__/ +*.pyc +*.pyo + +# Test cache +.pytest_cache/ diff --git a/jetson/ros2_ws/src/saltybot_perception/config/person_detection_params.yaml b/jetson/ros2_ws/src/saltybot_perception/config/person_detection_params.yaml new file mode 100644 index 0000000..c0204a8 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_perception/config/person_detection_params.yaml @@ -0,0 +1,37 @@ +person_detector: + ros__parameters: + # ── Model paths ────────────────────────────────────────────────────────── + # TensorRT FP16 engine (built with scripts/build_trt_engine.py) + # Stored on NVMe for fast load and hardware-specific optimisation. + engine_path: "/mnt/nvme/saltybot/models/yolov8n.engine" + + # ONNX fallback — used when engine_path not found (dev / CI environments) + onnx_path: "/mnt/nvme/saltybot/models/yolov8n.onnx" + + # ── Detection thresholds ───────────────────────────────────────────────── + confidence_threshold: 0.40 # YOLO class confidence (0–1). Lower → more detections but more FP. + nms_iou_threshold: 0.45 # NMS IoU threshold. Higher → fewer suppressed boxes. + + # ── Depth filtering ─────────────────────────────────────────────────────── + # Only consider persons within this depth range (metres). + # RealSense D435i reliable range: 0.3–5.0 m + min_depth: 0.5 # ignore very close objects (robot body artefacts) + max_depth: 5.0 # ignore persons beyond following range + + # ── Tracker settings ────────────────────────────────────────────────────── + # Hold last known track position for this many seconds after losing detection. + # Handles brief occlusion (person walks behind furniture). + track_hold_duration: 2.0 # seconds + + # Minimum IoU between current detection and existing track to re-associate. + track_iou_threshold: 0.25 + + # ── Output coordinate frame ─────────────────────────────────────────────── + # Frame for /person/target PoseStamped. Must be reachable via TF. + # sl-controls follow loop expects base_link. + target_frame: "base_link" + + # ── Debug ───────────────────────────────────────────────────────────────── + # Publish annotated RGB image to /person/debug_image. + # Adds ~5ms overhead. Disable on production hardware. + publish_debug_image: false diff --git a/jetson/ros2_ws/src/saltybot_perception/launch/person_detection.launch.py b/jetson/ros2_ws/src/saltybot_perception/launch/person_detection.launch.py new file mode 100644 index 0000000..65defc7 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_perception/launch/person_detection.launch.py @@ -0,0 +1,98 @@ +""" +person_detection.launch.py — Launch person detection node with config. + +Usage: + ros2 launch saltybot_perception person_detection.launch.py + + # Override engine path: + ros2 launch saltybot_perception person_detection.launch.py \\ + engine_path:=/mnt/nvme/saltybot/models/yolov8n.engine + + # Use ONNX fallback (dev/CI): + ros2 launch saltybot_perception person_detection.launch.py \\ + onnx_path:=/mnt/nvme/saltybot/models/yolov8n.onnx + + # Enable debug image stream: + ros2 launch saltybot_perception person_detection.launch.py \\ + publish_debug_image:=true + +Prerequisites: + - RealSense D435i node running and publishing: + /camera/color/image_raw + /camera/depth/image_rect_raw + /camera/color/camera_info + - TF tree containing base_link ← camera_color_optical_frame + - YOLOv8n TensorRT engine (build with scripts/build_trt_engine.py) +""" + +import os +from launch import LaunchDescription +from launch.actions import DeclareLaunchArgument +from launch.substitutions import LaunchConfiguration +from launch_ros.actions import Node +from ament_index_python.packages import get_package_share_directory + + +def generate_launch_description(): + pkg_dir = get_package_share_directory('saltybot_perception') + default_config = os.path.join(pkg_dir, 'config', 'person_detection_params.yaml') + default_engine = '/mnt/nvme/saltybot/models/yolov8n.engine' + default_onnx = '/mnt/nvme/saltybot/models/yolov8n.onnx' + + return LaunchDescription([ + # ── Launch arguments ─────────────────────────────────────────────── + DeclareLaunchArgument( + 'engine_path', + default_value=default_engine, + description='Path to TensorRT .engine file (built by build_trt_engine.py)', + ), + DeclareLaunchArgument( + 'onnx_path', + default_value=default_onnx, + description='Path to ONNX model (fallback when engine_path not found)', + ), + DeclareLaunchArgument( + 'publish_debug_image', + default_value='false', + description='Publish annotated debug image on /person/debug_image', + ), + DeclareLaunchArgument( + 'target_frame', + default_value='base_link', + description='TF frame for /person/target PoseStamped output', + ), + DeclareLaunchArgument( + 'confidence_threshold', + default_value='0.4', + description='Minimum YOLO detection confidence', + ), + DeclareLaunchArgument( + 'max_depth', + default_value='5.0', + description='Maximum person tracking distance in metres', + ), + + # ── Person detector node ─────────────────────────────────────────── + Node( + package='saltybot_perception', + executable='person_detector', + name='person_detector', + output='screen', + parameters=[ + default_config, + { + 'engine_path': LaunchConfiguration('engine_path'), + 'onnx_path': LaunchConfiguration('onnx_path'), + 'publish_debug_image': LaunchConfiguration('publish_debug_image'), + 'target_frame': LaunchConfiguration('target_frame'), + 'confidence_threshold': LaunchConfiguration('confidence_threshold'), + 'max_depth': LaunchConfiguration('max_depth'), + }, + ], + remappings=[ + # Standard RealSense topic names — no remapping needed by default + # ('/camera/color/image_raw', '/camera/color/image_raw'), + # ('/camera/depth/image_rect_raw', '/camera/depth/image_rect_raw'), + ], + ), + ]) diff --git a/jetson/ros2_ws/src/saltybot_perception/models/.gitkeep b/jetson/ros2_ws/src/saltybot_perception/models/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/jetson/ros2_ws/src/saltybot_perception/package.xml b/jetson/ros2_ws/src/saltybot_perception/package.xml new file mode 100644 index 0000000..e6a7902 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_perception/package.xml @@ -0,0 +1,42 @@ + + + + saltybot_perception + 0.1.0 + + Person detection and tracking for saltybot person-following mode. + Uses YOLOv8n with TensorRT FP16 on Jetson Orin Nano Super (67 TOPS). + Publishes person bounding boxes and 3D target position for the follow loop. + + seb + MIT + + rclpy + sensor_msgs + geometry_msgs + vision_msgs + tf2_ros + tf2_geometry_msgs + cv_bridge + image_transport + + python3-numpy + python3-opencv + python3-launch-ros + + + + + + + + + ament_copyright + ament_flake8 + ament_pep257 + python3-pytest + + + ament_python + + diff --git a/jetson/ros2_ws/src/saltybot_perception/resource/saltybot_perception b/jetson/ros2_ws/src/saltybot_perception/resource/saltybot_perception new file mode 100644 index 0000000..e69de29 diff --git a/jetson/ros2_ws/src/saltybot_perception/saltybot_perception/__init__.py b/jetson/ros2_ws/src/saltybot_perception/saltybot_perception/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jetson/ros2_ws/src/saltybot_perception/saltybot_perception/detection_utils.py b/jetson/ros2_ws/src/saltybot_perception/saltybot_perception/detection_utils.py new file mode 100644 index 0000000..332ce85 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_perception/saltybot_perception/detection_utils.py @@ -0,0 +1,124 @@ +""" +detection_utils.py — Pure-Python helpers with no ROS2 dependencies. +Importable in tests without a running ROS2 environment. +""" + +import numpy as np + + +def nms(boxes: np.ndarray, scores: np.ndarray, + iou_threshold: float = 0.45) -> list[int]: + """ + Non-maximum suppression. + + Args: + boxes: [N, 4] float array of (x1, y1, x2, y2) boxes + scores: [N] float array of confidence scores + iou_threshold: suppress boxes with IoU > this value + + Returns: + List of kept indices (sorted by descending score). + """ + if len(boxes) == 0: + return [] + + x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3] + areas = (x2 - x1) * (y2 - y1) + order = scores.argsort()[::-1] + keep = [] + + while len(order) > 0: + i = order[0] + keep.append(int(i)) + if len(order) == 1: + break + ix1 = np.maximum(x1[i], x1[order[1:]]) + iy1 = np.maximum(y1[i], y1[order[1:]]) + ix2 = np.minimum(x2[i], x2[order[1:]]) + iy2 = np.minimum(y2[i], y2[order[1:]]) + iw = np.maximum(0.0, ix2 - ix1) + ih = np.maximum(0.0, iy2 - iy1) + inter = iw * ih + iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-6) + order = order[1:][iou < iou_threshold] + + return keep + + +def letterbox(image: np.ndarray, size: int = 640, pad_value: int = 114): + """ + Letterbox-resize `image` to `size`×`size`. + + Returns: + (canvas, scale, pad_w, pad_h) + canvas: uint8 [size, size, 3] + scale: float — resize scale factor + pad_w: int — horizontal padding applied + pad_h: int — vertical padding applied + """ + import cv2 + h, w = image.shape[:2] + scale = min(size / w, size / h) + new_w = int(round(w * scale)) + new_h = int(round(h * scale)) + pad_w = (size - new_w) // 2 + pad_h = (size - new_h) // 2 + + resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR) + canvas = np.full((size, size, image.shape[2] if image.ndim == 3 else 1), + pad_value, dtype=np.uint8) + canvas[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = resized + return canvas, scale, pad_w, pad_h + + +def remap_bbox(x1, y1, x2, y2, scale, pad_w, pad_h, orig_w, orig_h): + """Map bbox from letterboxed-image space back to original image space.""" + x1 = float(np.clip((x1 - pad_w) / scale, 0, orig_w)) + y1 = float(np.clip((y1 - pad_h) / scale, 0, orig_h)) + x2 = float(np.clip((x2 - pad_w) / scale, 0, orig_w)) + y2 = float(np.clip((y2 - pad_h) / scale, 0, orig_h)) + return x1, y1, x2, y2 + + +def get_depth_at(depth_img: np.ndarray, u: float, v: float, + window: int = 7, min_d: float = 0.3, max_d: float = 6.0) -> float: + """ + Median depth in a `window`×`window` region around pixel (u, v). + + Args: + depth_img: float32 depth image in metres + u, v: pixel coordinates + window: patch side length + min_d, max_d: valid depth range + + Returns: + Median depth in metres, or 0.0 if no valid pixels. + """ + h, w = depth_img.shape + u, v = int(u), int(v) + half = window // 2 + u1, u2 = max(0, u - half), min(w, u + half + 1) + v1, v2 = max(0, v - half), min(h, v + half + 1) + patch = depth_img[v1:v2, u1:u2] + valid = patch[(patch > min_d) & (patch < max_d)] + return float(np.median(valid)) if len(valid) > 0 else 0.0 + + +def pixel_to_3d(u: float, v: float, depth_m: float, K) -> tuple[float, float, float]: + """ + Back-project pixel (u, v) at depth_m to 3D point in camera frame. + + Args: + u, v: pixel coordinates + depth_m: depth in metres + K: camera intrinsic matrix (row-major, 9 elements or 3×3 array) + + Returns: + (X, Y, Z) in camera optical frame + """ + K = np.asarray(K).ravel() + fx, fy = K[0], K[4] + cx, cy = K[2], K[5] + X = (u - cx) * depth_m / fx + Y = (v - cy) * depth_m / fy + return X, Y, depth_m diff --git a/jetson/ros2_ws/src/saltybot_perception/saltybot_perception/person_detector_node.py b/jetson/ros2_ws/src/saltybot_perception/saltybot_perception/person_detector_node.py new file mode 100644 index 0000000..a271b12 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_perception/saltybot_perception/person_detector_node.py @@ -0,0 +1,471 @@ +""" +person_detector_node.py — Person detection + tracking for saltybot. + +Pipeline: + 1. Receive synchronized color + depth frames from RealSense D435i. + 2. Run YOLOv8n (TensorRT FP16 on Orin Nano Super, ONNX fallback elsewhere). + 3. Filter detections for class 'person' (COCO class 0). + 4. Estimate 3D position from aligned depth at bounding box centre. + 5. Track target person across frames (SimplePersonTracker). + 6. Publish: + /person/detections — vision_msgs/Detection2DArray (all detected persons) + /person/target — geometry_msgs/PoseStamped (tracked person 3D pos) + /person/debug_image — sensor_msgs/Image (annotated RGB, lazy) + +TensorRT engine: + - Build once with scripts/build_trt_engine.py + - Place engine at path specified by `engine_path` param + - Falls back to ONNX Runtime (onnxruntime[-gpu]) if engine not found + +Coordinate frame: + /person/target is published in `base_link` frame. + If TF unavailable, falls back to `camera_color_optical_frame`. +""" + +import os +import math +import time +import numpy as np + +import rclpy +from rclpy.node import Node +from rclpy.duration import Duration +from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy, DurabilityPolicy + +import message_filters +import cv2 +from cv_bridge import CvBridge + +from sensor_msgs.msg import Image, CameraInfo +from geometry_msgs.msg import PoseStamped, PointStamped +from vision_msgs.msg import ( + Detection2D, + Detection2DArray, + ObjectHypothesisWithPose, + BoundingBox2D, +) +import tf2_ros +import tf2_geometry_msgs # noqa: F401 — registers PointStamped transform support + +from .tracker import SimplePersonTracker +from .detection_utils import nms, letterbox, remap_bbox, get_depth_at, pixel_to_3d + +_PERSON_CLASS_ID = 0 # COCO class index for 'person' +_YOLO_INPUT_SIZE = 640 + + +# ── Inference backends ───────────────────────────────────────────────────────── + +class _TRTBackend: + """TensorRT inference engine (Jetson Orin).""" + + def __init__(self, engine_path: str): + import tensorrt as trt + import pycuda.driver as cuda + import pycuda.autoinit # noqa: F401 + + self._cuda = cuda + logger = trt.Logger(trt.Logger.WARNING) + with open(engine_path, 'rb') as f, trt.Runtime(logger) as runtime: + self._engine = runtime.deserialize_cuda_engine(f.read()) + self._context = self._engine.create_execution_context() + + self._inputs = [] + self._outputs = [] + self._bindings = [] + for i in range(self._engine.num_io_tensors): + name = self._engine.get_tensor_name(i) + dtype = trt.nptype(self._engine.get_tensor_dtype(name)) + shape = tuple(self._engine.get_tensor_shape(name)) + nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize + host_mem = cuda.pagelocked_empty(shape, dtype) + device_mem = cuda.mem_alloc(nbytes) + self._bindings.append(int(device_mem)) + if self._engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: + self._inputs.append({'host': host_mem, 'device': device_mem}) + else: + self._outputs.append({'host': host_mem, 'device': device_mem}) + self._stream = cuda.Stream() + + def infer(self, input_data: np.ndarray) -> list[np.ndarray]: + np.copyto(self._inputs[0]['host'], input_data.ravel()) + self._cuda.memcpy_htod_async( + self._inputs[0]['device'], self._inputs[0]['host'], self._stream) + self._context.execute_async_v2(self._bindings, self._stream.handle) + for out in self._outputs: + self._cuda.memcpy_dtoh_async(out['host'], out['device'], self._stream) + self._stream.synchronize() + return [out['host'].copy() for out in self._outputs] + + +class _ONNXBackend: + """ONNX Runtime inference (CPU / CUDA EP — fallback for non-Jetson).""" + + def __init__(self, onnx_path: str): + import onnxruntime as ort + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + self._session = ort.InferenceSession(onnx_path, providers=providers) + self._input_name = self._session.get_inputs()[0].name + + def infer(self, input_data: np.ndarray) -> list[np.ndarray]: + return self._session.run(None, {self._input_name: input_data}) + + +# ── Node ────────────────────────────────────────────────────────────────────── + +class PersonDetectorNode(Node): + + def __init__(self): + super().__init__('person_detector') + self._bridge = CvBridge() + self._camera_info: CameraInfo | None = None + self._backend = None + + # ── Parameters ────────────────────────────────────────────────────── + self.declare_parameter('engine_path', '') + self.declare_parameter('onnx_path', '') + self.declare_parameter('confidence_threshold', 0.4) + self.declare_parameter('nms_iou_threshold', 0.45) + self.declare_parameter('min_depth', 0.5) + self.declare_parameter('max_depth', 5.0) + self.declare_parameter('track_hold_duration', 2.0) + self.declare_parameter('track_iou_threshold', 0.25) + self.declare_parameter('target_frame', 'base_link') + self.declare_parameter('publish_debug_image', False) + + self._conf_thresh = self.get_parameter('confidence_threshold').value + self._nms_thresh = self.get_parameter('nms_iou_threshold').value + self._min_depth = self.get_parameter('min_depth').value + self._max_depth = self.get_parameter('max_depth').value + self._target_frame = self.get_parameter('target_frame').value + self._pub_debug = self.get_parameter('publish_debug_image').value + + hold_dur = self.get_parameter('track_hold_duration').value + track_iou = self.get_parameter('track_iou_threshold').value + self._tracker = SimplePersonTracker( + hold_duration=hold_dur, + iou_threshold=track_iou, + min_depth=self._min_depth, + max_depth=self._max_depth, + ) + + # Letterbox state (set during preprocessing) + self._scale = 1.0 + self._pad_w = 0 + self._pad_h = 0 + self._orig_w = 0 + self._orig_h = 0 + + # ── TF ────────────────────────────────────────────────────────────── + self._tf_buffer = tf2_ros.Buffer() + self._tf_listener = tf2_ros.TransformListener(self._tf_buffer, self) + + # ── Publishers ────────────────────────────────────────────────────── + best_effort_qos = QoSProfile( + reliability=ReliabilityPolicy.BEST_EFFORT, + history=HistoryPolicy.KEEP_LAST, + depth=1, + ) + self._pub_detections = self.create_publisher( + Detection2DArray, '/person/detections', best_effort_qos) + self._pub_target = self.create_publisher( + PoseStamped, '/person/target', best_effort_qos) + if self._pub_debug: + self._pub_debug_img = self.create_publisher( + Image, '/person/debug_image', best_effort_qos) + + # ── Camera info subscriber ─────────────────────────────────────────── + self.create_subscription( + CameraInfo, + '/camera/color/camera_info', + self._on_camera_info, + QoSProfile(reliability=ReliabilityPolicy.BEST_EFFORT, + history=HistoryPolicy.KEEP_LAST, depth=1), + ) + + # ── Synchronized color + depth subscribers ─────────────────────────── + color_sub = message_filters.Subscriber( + self, Image, '/camera/color/image_raw', + qos_profile=QoSProfile( + reliability=ReliabilityPolicy.BEST_EFFORT, + history=HistoryPolicy.KEEP_LAST, depth=4)) + depth_sub = message_filters.Subscriber( + self, Image, '/camera/depth/image_rect_raw', + qos_profile=QoSProfile( + reliability=ReliabilityPolicy.BEST_EFFORT, + history=HistoryPolicy.KEEP_LAST, depth=4)) + self._sync = message_filters.ApproximateTimeSynchronizer( + [color_sub, depth_sub], queue_size=4, slop=0.05) + self._sync.registerCallback(self._on_frame) + + # ── Load model ─────────────────────────────────────────────────────── + self._load_backend() + self.get_logger().info('PersonDetectorNode ready.') + + # ── Model loading ───────────────────────────────────────────────────────── + + def _load_backend(self): + engine_path = self.get_parameter('engine_path').value + onnx_path = self.get_parameter('onnx_path').value + + if engine_path and os.path.isfile(engine_path): + try: + self._backend = _TRTBackend(engine_path) + self.get_logger().info(f'TensorRT backend loaded: {engine_path}') + return + except Exception as e: + self.get_logger().warn(f'TRT load failed ({e}), falling back to ONNX') + + if onnx_path and os.path.isfile(onnx_path): + try: + self._backend = _ONNXBackend(onnx_path) + self.get_logger().info(f'ONNX backend loaded: {onnx_path}') + return + except Exception as e: + self.get_logger().error(f'ONNX load failed: {e}') + + self.get_logger().error( + 'No model found. Set engine_path or onnx_path parameter. ' + 'Detection disabled — node spinning without publishing.' + ) + + # ── Callbacks ───────────────────────────────────────────────────────────── + + def _on_camera_info(self, msg: CameraInfo): + self._camera_info = msg + + def _on_frame(self, color_msg: Image, depth_msg: Image): + if self._backend is None or self._camera_info is None: + return + + t0 = time.monotonic() + + # Decode images + try: + bgr = self._bridge.imgmsg_to_cv2(color_msg, desired_encoding='bgr8') + depth = self._bridge.imgmsg_to_cv2(depth_msg, desired_encoding='passthrough') + except Exception as e: + self.get_logger().error(f'Image decode error: {e}', throttle_duration_sec=5.0) + return + + # Depth image should be float32 metres (realsense2_camera default) + if depth.dtype != np.float32: + depth = depth.astype(np.float32) + if depth.max() > 100.0: # uint16 mm → float32 m + depth /= 1000.0 + + # Run detection + tensor = self._preprocess(bgr) + try: + raw_outputs = self._backend.infer(tensor) + except Exception as e: + self.get_logger().error(f'Inference error: {e}', throttle_duration_sec=5.0) + return + + detections_px = self._postprocess_yolov8(raw_outputs[0]) + + # Get depth for each detection + detections = [] + for x1, y1, x2, y2, conf in detections_px: + cx = (x1 + x2) / 2.0 + cy = (y1 + y2) / 2.0 + d = self._get_depth_at(depth, cx, cy) + detections.append(((x1, y1, x2, y2), d, conf)) + + # Update tracker + track = self._tracker.update(detections) + + # Publish Detection2DArray + det_array = Detection2DArray() + det_array.header = color_msg.header + for (x1, y1, x2, y2), d, conf in detections: + det = self._make_detection2d( + color_msg.header, x1, y1, x2, y2, conf) + det_array.detections.append(det) + self._pub_detections.publish(det_array) + + # Publish target PoseStamped + if track is not None: + x1, y1, x2, y2 = track.bbox + cx = (x1 + x2) / 2.0 + cy = (y1 + y2) / 2.0 + d = track.depth + + if d > 0: + X, Y, Z = self._pixel_to_3d(cx, cy, d) + track.position_3d = (X, Y, Z) + + pose = PoseStamped() + pose.header = color_msg.header + pose.header.frame_id = 'camera_color_optical_frame' + pose.pose.position.x = X + pose.pose.position.y = Y + pose.pose.position.z = Z + pose.pose.orientation.w = 1.0 + + # Transform to target_frame + if self._target_frame != 'camera_color_optical_frame': + pose = self._transform_pose(pose) + + self._pub_target.publish(pose) + + # Debug image + if self._pub_debug and hasattr(self, '_pub_debug_img'): + debug = self._draw_debug(bgr, detections, track) + self._pub_debug_img.publish( + self._bridge.cv2_to_imgmsg(debug, encoding='bgr8')) + + dt = (time.monotonic() - t0) * 1000 + self.get_logger().debug( + f'Frame: {len(detections)} persons, track={track is not None}, {dt:.1f}ms', + throttle_duration_sec=1.0, + ) + + # ── Preprocessing ───────────────────────────────────────────────────────── + + def _preprocess(self, bgr: np.ndarray) -> np.ndarray: + """Letterbox resize to 640×640, normalise, HWC→CHW, add batch dim.""" + h, w = bgr.shape[:2] + canvas, scale, pad_w, pad_h = letterbox(bgr, _YOLO_INPUT_SIZE) + self._scale = scale + self._pad_w = pad_w + self._pad_h = pad_h + self._orig_w = w + self._orig_h = h + + rgb = cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB) + tensor = rgb.astype(np.float32) / 255.0 + tensor = tensor.transpose(2, 0, 1) # HWC → CHW + return np.ascontiguousarray(tensor[np.newaxis]) # [1, 3, H, W] + + def _remap_bbox(self, x1, y1, x2, y2): + """Map bbox from 640×640 space back to original image space.""" + return remap_bbox(x1, y1, x2, y2, + self._scale, self._pad_w, self._pad_h, + self._orig_w, self._orig_h) + + # ── Post-processing ─────────────────────────────────────────────────────── + + def _postprocess_yolov8(self, raw: np.ndarray) -> list: + """ + Parse YOLOv8n output tensor and return person detections. + + YOLOv8n output shape: [1, 84, 8400] or [84, 8400] + rows 0-3: cx, cy, w, h (in 640×640 input space) + rows 4-83: class scores (no objectness score in v8) + + Returns: + list of (x1, y1, x2, y2, confidence) in original image space + """ + pred = raw.squeeze() # [84, 8400] + if pred.ndim != 2 or pred.shape[0] < 5: + return [] + + person_scores = pred[4 + _PERSON_CLASS_ID, :] # [8400] + mask = person_scores > self._conf_thresh + if not mask.any(): + return [] + + scores = person_scores[mask] + boxes_raw = pred[:4, mask] # cx, cy, w, h — [4, N] + + # cx,cy,w,h → x1,y1,x2,y2 + cx, cy = boxes_raw[0], boxes_raw[1] + w2, h2 = boxes_raw[2] / 2.0, boxes_raw[3] / 2.0 + x1, y1 = cx - w2, cy - h2 + x2, y2 = cx + w2, cy + h2 + + boxes = np.stack([x1, y1, x2, y2], axis=1) # [N, 4] + keep = nms(boxes, scores, self._nms_thresh) + + results = [] + for i in keep: + rx1, ry1, rx2, ry2 = self._remap_bbox( + boxes[i, 0], boxes[i, 1], boxes[i, 2], boxes[i, 3]) + # Skip degenerate boxes + if rx2 - rx1 < 4 or ry2 - ry1 < 4: + continue + results.append((rx1, ry1, rx2, ry2, float(scores[i]))) + return results + + # ── Depth & 3D ──────────────────────────────────────────────────────────── + + def _get_depth_at(self, depth_img: np.ndarray, u: float, v: float, + window: int = 7) -> float: + """Median depth in a window around pixel (u, v). Returns 0 if invalid.""" + return get_depth_at(depth_img, u, v, window, + self._min_depth, self._max_depth) + + def _pixel_to_3d(self, u: float, v: float, depth_m: float): + """Back-project pixel (u, v) at depth_m to 3D point in camera frame.""" + return pixel_to_3d(u, v, depth_m, self._camera_info.k) + + # ── TF transform ────────────────────────────────────────────────────────── + + def _transform_pose(self, pose_in: PoseStamped) -> PoseStamped: + try: + return self._tf_buffer.transform( + pose_in, self._target_frame, + timeout=Duration(seconds=0.05)) + except Exception as e: + self.get_logger().warn( + f'TF {pose_in.header.frame_id}→{self._target_frame} failed: {e}', + throttle_duration_sec=5.0) + return pose_in # publish in camera frame as fallback + + # ── Message builders ────────────────────────────────────────────────────── + + def _make_detection2d(self, header, x1, y1, x2, y2, conf) -> Detection2D: + det = Detection2D() + det.header = header + + hyp = ObjectHypothesisWithPose() + hyp.hypothesis.class_id = 'person' + hyp.hypothesis.score = conf + det.results.append(hyp) + + det.bbox.center.position.x = (x1 + x2) / 2.0 + det.bbox.center.position.y = (y1 + y2) / 2.0 + det.bbox.center.theta = 0.0 + det.bbox.size_x = x2 - x1 + det.bbox.size_y = y2 - y1 + return det + + # ── Debug visualisation ─────────────────────────────────────────────────── + + def _draw_debug(self, bgr, detections, track): + vis = bgr.copy() + for (x1, y1, x2, y2), d, conf in detections: + cv2.rectangle(vis, (int(x1), int(y1)), (int(x2), int(y2)), + (0, 255, 0), 2) + cv2.putText(vis, f'{conf:.2f} {d:.1f}m', + (int(x1), int(y1) - 6), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) + if track is not None: + x1, y1, x2, y2 = track.bbox + cv2.rectangle(vis, (int(x1), int(y1)), (int(x2), int(y2)), + (0, 0, 255), 3) + label = f'ID:{track.track_id} {track.depth:.1f}m' + if track.age > 0.05: + label += f' (held {track.age:.1f}s)' + cv2.putText(vis, label, (int(x1), int(y1) - 6), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2) + return vis + + +# ── Entry point ─────────────────────────────────────────────────────────────── + +def main(args=None): + rclpy.init(args=args) + node = PersonDetectorNode() + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.shutdown() + + +if __name__ == '__main__': + main() diff --git a/jetson/ros2_ws/src/saltybot_perception/saltybot_perception/tracker.py b/jetson/ros2_ws/src/saltybot_perception/saltybot_perception/tracker.py new file mode 100644 index 0000000..36a37cf --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_perception/saltybot_perception/tracker.py @@ -0,0 +1,179 @@ +""" +tracker.py — Single-target person tracker for saltybot person-following mode. + +Strategy: + - Select closest valid detection each frame (smallest depth within range) + - Re-associate across frames using IoU matching with existing track + - Hold last known track for `hold_duration` seconds when detections are lost + - Assign monotonically increasing track IDs + +Usage: + tracker = SimplePersonTracker(hold_duration=2.0) + track = tracker.update(detections) # detections: list of (bbox, depth, conf) + if track is not None: + print(track.bbox, track.depth, track.track_id) +""" + +import time +import numpy as np + + +class PersonTrack: + """Single tracked person.""" + + def __init__(self, bbox, depth, confidence, track_id): + """ + Args: + bbox: (x1, y1, x2, y2) in pixels + depth: distance to person in metres + confidence: detection confidence 0–1 + track_id: unique integer ID + """ + self.bbox = bbox + self.depth = depth + self.confidence = confidence + self.track_id = track_id + self.position_3d = None # (X, Y, Z) in camera frame — set by node + self._first_seen = time.monotonic() + self._last_seen = time.monotonic() + + def touch(self, bbox, depth, confidence): + self.bbox = bbox + self.depth = depth + self.confidence = confidence + self._last_seen = time.monotonic() + + @property + def age(self): + """Seconds since last detection.""" + return time.monotonic() - self._last_seen + + @property + def is_stale(self): + return self.age > 0 # Always check externally against hold_duration + + @property + def center(self): + x1, y1, x2, y2 = self.bbox + return ((x1 + x2) / 2.0, (y1 + y2) / 2.0) + + @property + def area(self): + x1, y1, x2, y2 = self.bbox + return max(0.0, (x2 - x1) * (y2 - y1)) + + +class SimplePersonTracker: + """ + Lightweight single-target person tracker. + + Maintains one active PersonTrack at a time. On each update: + 1. Filter detections to valid depth range. + 2. If a track is active, attempt IoU re-association. + 3. If re-association fails and track is within hold_duration, keep stale. + 4. If no active track, initialise from closest detection. + """ + + def __init__( + self, + hold_duration: float = 2.0, + iou_threshold: float = 0.25, + min_depth: float = 0.3, + max_depth: float = 5.0, + ): + """ + Args: + hold_duration: seconds to hold last known position after losing track + iou_threshold: minimum IoU to accept a re-association + min_depth: minimum valid depth in metres + max_depth: maximum valid depth in metres + """ + self._hold_duration = hold_duration + self._iou_threshold = iou_threshold + self._min_depth = min_depth + self._max_depth = max_depth + self._track: PersonTrack | None = None + self._next_id = 1 + + def update(self, detections): + """ + Update tracker with new detections. + + Args: + detections: list of (bbox, depth, confidence) where + bbox = (x1, y1, x2, y2) pixels, + depth = float metres (0 = invalid), + confidence = float 0–1 + + Returns: + PersonTrack or None. + Returns a stale PersonTrack (track.age > 0) during hold period. + Returns None after hold_duration or if never seen a person. + """ + valid = [ + (b, d, c) for b, d, c in detections + if self._min_depth < d < self._max_depth + ] + + if not valid: + # No valid detections this frame + if self._track is not None: + if self._track.age <= self._hold_duration: + return self._track # hold last known + else: + self._track = None + return None + + if self._track is not None: + # Attempt re-association by IoU + best_iou = self._iou_threshold + best = None + for bbox, depth, conf in valid: + iou = _iou(bbox, self._track.bbox) + if iou > best_iou: + best_iou = iou + best = (bbox, depth, conf) + + if best is not None: + self._track.touch(*best) + return self._track + + # No IoU match — keep stale track within hold window + if self._track.age <= self._hold_duration: + return self._track + else: + self._track = None # lost — start fresh + + # No active track: pick closest valid detection + bbox, depth, conf = min(valid, key=lambda x: x[1]) + self._track = PersonTrack(bbox, depth, conf, self._next_id) + self._next_id += 1 + return self._track + + def reset(self): + """Drop current track.""" + self._track = None + + @property + def active(self): + return self._track is not None + + +def _iou(bbox_a, bbox_b): + """Compute Intersection-over-Union of two bounding boxes (x1,y1,x2,y2).""" + ax1, ay1, ax2, ay2 = bbox_a + bx1, by1, bx2, by2 = bbox_b + + ix1 = max(ax1, bx1) + iy1 = max(ay1, by1) + ix2 = min(ax2, bx2) + iy2 = min(ay2, by2) + + if ix2 <= ix1 or iy2 <= iy1: + return 0.0 + + intersection = (ix2 - ix1) * (iy2 - iy1) + area_a = (ax2 - ax1) * (ay2 - ay1) + area_b = (bx2 - bx1) * (by2 - by1) + union = area_a + area_b - intersection + return intersection / union if union > 0 else 0.0 diff --git a/jetson/ros2_ws/src/saltybot_perception/scripts/build_trt_engine.py b/jetson/ros2_ws/src/saltybot_perception/scripts/build_trt_engine.py new file mode 100644 index 0000000..9521709 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_perception/scripts/build_trt_engine.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 +""" +build_trt_engine.py — Convert ONNX model to TensorRT .engine file. + +Run this ONCE on the Jetson Orin Nano Super to build the optimised engine. +The engine is hardware-specific and cannot be shared between GPU families. + +Usage: + python3 build_trt_engine.py --onnx yolov8n.onnx --engine yolov8n.engine + python3 build_trt_engine.py --onnx yolov8n.onnx --engine yolov8n.engine --fp16 + python3 build_trt_engine.py --onnx yolov8n.onnx --engine yolov8n.engine --fp16 --batch 1 + +Alternatively, use the trtexec CLI tool (ships with JetPack): + /usr/src/tensorrt/bin/trtexec \\ + --onnx=yolov8n.onnx \\ + --fp16 \\ + --saveEngine=yolov8n.engine \\ + --workspace=2048 + +Model download (YOLOv8n): + pip3 install ultralytics + python3 -c "from ultralytics import YOLO; YOLO('yolov8n.pt').export(format='onnx', imgsz=640)" + +Model download (YOLOv5s): + python3 -c " + import torch + model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) + model.export(format='onnx') + " + +Engine location: + Place the built .engine file at the path specified by `engine_path` + in person_detection_params.yaml (default: models/yolov8n.engine) +""" + +import argparse +import sys +import os + + +def build_engine(onnx_path: str, engine_path: str, fp16: bool, batch_size: int, + workspace_mb: int) -> bool: + """Build TensorRT engine from ONNX model.""" + try: + import tensorrt as trt + except ImportError: + print('ERROR: tensorrt not found. Run on Jetson with JetPack installed.') + print(' Alternatively use trtexec (see script header).') + return False + + logger = trt.Logger(trt.Logger.VERBOSE) + builder = trt.Builder(logger) + network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(network_flags) + parser = trt.OnnxParser(network, logger) + + print(f'[+] Parsing ONNX: {onnx_path}') + with open(onnx_path, 'rb') as f: + if not parser.parse(f.read()): + for i in range(parser.num_errors): + print(f' Parse error {i}: {parser.get_error(i)}') + return False + print(f' Network inputs: {network.num_inputs}') + print(f' Network outputs: {network.num_outputs}') + + config = builder.create_builder_config() + config.set_memory_pool_limit( + trt.MemoryPoolType.WORKSPACE, workspace_mb * 1024 * 1024) + + if fp16 and builder.platform_has_fast_fp16: + config.set_flag(trt.BuilderFlag.FP16) + print('[+] FP16 mode enabled') + elif fp16: + print('[!] FP16 not supported on this platform, using FP32') + + # Dynamic batch profile (batch_size = 1 for real-time inference) + profile = builder.create_optimization_profile() + input_name = network.get_input(0).name + input_shape = network.get_input(0).shape # e.g., [-1, 3, 640, 640] + min_shape = (1, input_shape[1], input_shape[2], input_shape[3]) + opt_shape = (batch_size, input_shape[1], input_shape[2], input_shape[3]) + max_shape = (batch_size, input_shape[1], input_shape[2], input_shape[3]) + profile.set_shape(input_name, min_shape, opt_shape, max_shape) + config.add_optimization_profile(profile) + + print(f'[+] Building engine (this may take 5–15 minutes on first run)...') + serialized = builder.build_serialized_network(network, config) + if serialized is None: + print('ERROR: Engine build failed.') + return False + + os.makedirs(os.path.dirname(os.path.abspath(engine_path)), exist_ok=True) + with open(engine_path, 'wb') as f: + f.write(serialized) + + size_mb = os.path.getsize(engine_path) / (1024 * 1024) + print(f'[+] Engine saved: {engine_path} ({size_mb:.1f} MB)') + return True + + +def verify_engine(engine_path: str) -> bool: + """Quick sanity check: deserialise and print binding info.""" + try: + import tensorrt as trt + import numpy as np + logger = trt.Logger(trt.Logger.WARNING) + with open(engine_path, 'rb') as f, trt.Runtime(logger) as rt: + engine = rt.deserialize_cuda_engine(f.read()) + print(f'\n[+] Engine verified: {engine.num_io_tensors} tensors') + for i in range(engine.num_io_tensors): + name = engine.get_tensor_name(i) + shape = engine.get_tensor_shape(name) + dtype = engine.get_tensor_dtype(name) + mode = 'IN ' if engine.get_tensor_mode(name).name == 'INPUT' else 'OUT' + print(f' {mode} [{i}] {name}: {list(shape)} {dtype}') + return True + except Exception as e: + print(f'[!] Engine verification failed: {e}') + return False + + +def main(): + parser = argparse.ArgumentParser( + description='Build TensorRT engine from ONNX model') + parser.add_argument('--onnx', required=True, + help='Path to input ONNX model') + parser.add_argument('--engine', required=True, + help='Path for output .engine file') + parser.add_argument('--fp16', action='store_true', default=True, + help='Enable FP16 precision (default: True)') + parser.add_argument('--no-fp16', dest='fp16', action='store_false', + help='Disable FP16, use FP32') + parser.add_argument('--batch', type=int, default=1, + help='Batch size (default: 1)') + parser.add_argument('--workspace', type=int, default=2048, + help='Builder workspace in MB (default: 2048)') + parser.add_argument('--verify', action='store_true', + help='Verify engine after build') + args = parser.parse_args() + + if not os.path.isfile(args.onnx): + print(f'ERROR: ONNX model not found: {args.onnx}') + print() + print('Download YOLOv8n ONNX:') + print(' pip3 install ultralytics') + print(' python3 -c "from ultralytics import YOLO; ' + 'YOLO(\'yolov8n.pt\').export(format=\'onnx\', imgsz=640)"') + sys.exit(1) + + ok = build_engine(args.onnx, args.engine, args.fp16, args.batch, args.workspace) + if not ok: + sys.exit(1) + + if args.verify: + verify_engine(args.engine) + + print('\nNext step: update person_detection_params.yaml:') + print(f' engine_path: "{os.path.abspath(args.engine)}"') + + +if __name__ == '__main__': + main() diff --git a/jetson/ros2_ws/src/saltybot_perception/setup.cfg b/jetson/ros2_ws/src/saltybot_perception/setup.cfg new file mode 100644 index 0000000..0495a9c --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_perception/setup.cfg @@ -0,0 +1,4 @@ +[develop] +script_dir=$base/lib/saltybot_perception +[install] +install_scripts=$base/lib/saltybot_perception diff --git a/jetson/ros2_ws/src/saltybot_perception/setup.py b/jetson/ros2_ws/src/saltybot_perception/setup.py new file mode 100644 index 0000000..9d150d6 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_perception/setup.py @@ -0,0 +1,32 @@ +from setuptools import setup +import os +from glob import glob + +package_name = 'saltybot_perception' + +setup( + name=package_name, + version='0.1.0', + packages=[package_name], + data_files=[ + ('share/ament_index/resource_index/packages', + ['resource/' + package_name]), + ('share/' + package_name, ['package.xml']), + (os.path.join('share', package_name, 'launch'), + glob('launch/*.py')), + (os.path.join('share', package_name, 'config'), + glob('config/*.yaml')), + ], + install_requires=['setuptools'], + zip_safe=True, + maintainer='seb', + maintainer_email='seb@vayrette.com', + description='Person detection and tracking for saltybot (YOLOv8n + TensorRT)', + license='MIT', + tests_require=['pytest'], + entry_points={ + 'console_scripts': [ + 'person_detector = saltybot_perception.person_detector_node:main', + ], + }, +) diff --git a/jetson/ros2_ws/src/saltybot_perception/test/test_postprocess.py b/jetson/ros2_ws/src/saltybot_perception/test/test_postprocess.py new file mode 100644 index 0000000..bdb6752 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_perception/test/test_postprocess.py @@ -0,0 +1,67 @@ +""" +test_postprocess.py — Tests for YOLOv8 post-processing and NMS. + +Tests the _nms() helper and validate post-processing logic without +requiring a GPU, TRT, or running ROS2 node. +""" + +import numpy as np +import pytest + +from saltybot_perception.detection_utils import nms as _nms + + +class TestNMS: + + def test_single_box_kept(self): + boxes = np.array([[0, 0, 10, 10]], dtype=float) + scores = np.array([0.9]) + assert _nms(boxes, scores) == [0] + + def test_empty_input(self): + assert _nms(np.zeros((0, 4)), np.array([])) == [] + + def test_suppresses_overlapping_box(self): + # Two heavily overlapping boxes — keep highest score + boxes = np.array([ + [0, 0, 10, 10], # score 0.9 — keep + [1, 1, 11, 11], # score 0.8 — suppress (high IoU with first) + ], dtype=float) + scores = np.array([0.9, 0.8]) + keep = _nms(boxes, scores, iou_threshold=0.45) + assert keep == [0] + + def test_keeps_non_overlapping_boxes(self): + boxes = np.array([ + [0, 0, 10, 10], + [50, 50, 60, 60], + [100, 100, 110, 110], + ], dtype=float) + scores = np.array([0.9, 0.85, 0.8]) + keep = _nms(boxes, scores, iou_threshold=0.45) + assert sorted(keep) == [0, 1, 2] + + def test_score_ordering(self): + # Lower score box overlaps with higher — higher should be kept + boxes = np.array([ + [1, 1, 11, 11], # score 0.6 + [0, 0, 10, 10], # score 0.95 — should be kept + ], dtype=float) + scores = np.array([0.6, 0.95]) + keep = _nms(boxes, scores, iou_threshold=0.45) + assert 1 in keep # higher score (index 1) kept + assert 0 not in keep # lower score (index 0) suppressed + + def test_iou_threshold_controls_suppression(self): + # Two boxes with ~0.5 IoU + boxes = np.array([ + [0, 0, 10, 10], + [5, 5, 15, 15], + ], dtype=float) + scores = np.array([0.9, 0.8]) + # High threshold — both boxes kept (IoU ~0.14 < 0.5) + keep_high = _nms(boxes, scores, iou_threshold=0.5) + assert sorted(keep_high) == [0, 1] + # Low threshold — only first kept + keep_low = _nms(boxes, scores, iou_threshold=0.0) + assert keep_low == [0] diff --git a/jetson/ros2_ws/src/saltybot_perception/test/test_tracker.py b/jetson/ros2_ws/src/saltybot_perception/test/test_tracker.py new file mode 100644 index 0000000..5bbfa64 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_perception/test/test_tracker.py @@ -0,0 +1,156 @@ +""" +test_tracker.py — Unit tests for SimplePersonTracker. + +Run with: pytest test/test_tracker.py -v +No ROS2 runtime required. +""" + +import time +import pytest +from saltybot_perception.tracker import SimplePersonTracker, PersonTrack, _iou + + +# ── IoU helper ──────────────────────────────────────────────────────────────── + +class TestIoU: + def test_identical_boxes(self): + assert _iou((0, 0, 10, 10), (0, 0, 10, 10)) == pytest.approx(1.0) + + def test_no_overlap(self): + assert _iou((0, 0, 5, 5), (10, 10, 15, 15)) == pytest.approx(0.0) + + def test_partial_overlap(self): + iou = _iou((0, 0, 10, 10), (5, 5, 15, 15)) + # Intersection 5×5=25, each area=100, union=175 + assert iou == pytest.approx(25 / 175) + + def test_contained_box(self): + # Inner box fully inside outer box + outer = (0, 0, 10, 10) + inner = (2, 2, 8, 8) + # intersection = 6×6=36, area_outer=100, area_inner=36, union=100 + assert _iou(outer, inner) == pytest.approx(36 / 100) + + def test_touching_edges(self): + # Boxes touch at a single edge — no area overlap + assert _iou((0, 0, 5, 5), (5, 0, 10, 5)) == pytest.approx(0.0) + + +# ── Tracker ─────────────────────────────────────────────────────────────────── + +class TestSimplePersonTracker: + + def _make_det(self, x1=10, y1=10, x2=60, y2=160, depth=2.0, conf=0.85): + return ((x1, y1, x2, y2), depth, conf) + + # ── Basic update ────────────────────────────────────────────────────────── + + def test_no_detections_returns_none(self): + t = SimplePersonTracker() + assert t.update([]) is None + + def test_single_detection_creates_track(self): + t = SimplePersonTracker() + track = t.update([self._make_det()]) + assert track is not None + assert track.track_id == 1 + assert track.depth == pytest.approx(2.0) + + def test_track_id_increments(self): + t = SimplePersonTracker(hold_duration=0.0) + t.update([self._make_det()]) + t.update([]) # lose track immediately + track2 = t.update([self._make_det(x1=100, y1=100, x2=150, y2=250)]) + assert track2.track_id == 2 + + # ── Closest-first selection ─────────────────────────────────────────────── + + def test_picks_closest_person(self): + t = SimplePersonTracker() + dets = [ + ((10, 10, 60, 160), 4.0, 0.8), # far + ((70, 10, 120, 160), 1.5, 0.9), # closest + ((130, 10, 180, 160), 3.0, 0.75), # mid + ] + track = t.update(dets) + assert track.depth == pytest.approx(1.5) + + # ── Depth filtering ─────────────────────────────────────────────────────── + + def test_rejects_beyond_max_depth(self): + t = SimplePersonTracker(max_depth=5.0) + assert t.update([self._make_det(depth=6.0)]) is None + + def test_rejects_below_min_depth(self): + t = SimplePersonTracker(min_depth=0.3) + assert t.update([self._make_det(depth=0.1)]) is None + + def test_accepts_within_depth_range(self): + t = SimplePersonTracker(min_depth=0.3, max_depth=5.0) + track = t.update([self._make_det(depth=2.5)]) + assert track is not None + + # ── Re-association ──────────────────────────────────────────────────────── + + def test_iou_reassociation_keeps_same_id(self): + t = SimplePersonTracker(iou_threshold=0.2) + # Frame 1 + track1 = t.update([self._make_det(x1=10, y1=10, x2=60, y2=160)]) + id1 = track1.track_id + # Frame 2 — slightly shifted (good IoU) + track2 = t.update([self._make_det(x1=12, y1=12, x2=62, y2=162)]) + assert track2.track_id == id1 + + def test_poor_iou_loses_track_after_hold(self): + t = SimplePersonTracker(hold_duration=0.0, iou_threshold=0.5) + # Frame 1 — track person at left + t.update([self._make_det(x1=0, y1=0, x2=50, y2=150)]) + # Frame 2 — completely different position, bad IoU + # hold_duration=0, so old track expires, new track started + track = t.update([self._make_det(x1=500, y1=0, x2=550, y2=150)]) + assert track.track_id == 2 # new track + + # ── Hold duration ───────────────────────────────────────────────────────── + + def test_holds_last_known_within_duration(self): + t = SimplePersonTracker(hold_duration=10.0) + track = t.update([self._make_det()]) + track_id = track.track_id + # No detections — should hold + held = t.update([]) + assert held is not None + assert held.track_id == track_id + assert held.age >= 0 + + def test_releases_track_after_hold_duration(self): + t = SimplePersonTracker(hold_duration=0.0) + t.update([self._make_det()]) + # Immediately lose + result = t.update([]) + assert result is None + + # ── Reset ──────────────────────────────────────────────────────────────── + + def test_reset_clears_track(self): + t = SimplePersonTracker() + t.update([self._make_det()]) + t.reset() + assert t.active is False + assert t.update([]) is None + + # ── PersonTrack properties ──────────────────────────────────────────────── + + def test_track_center(self): + track = PersonTrack((10, 20, 50, 100), 2.0, 0.9, 1) + cx, cy = track.center + assert cx == pytest.approx(30.0) + assert cy == pytest.approx(60.0) + + def test_track_area(self): + track = PersonTrack((10, 20, 50, 100), 2.0, 0.9, 1) + assert track.area == pytest.approx(40 * 80) + + def test_track_age_increases(self): + track = PersonTrack((0, 0, 50, 100), 2.0, 0.9, 1) + time.sleep(0.05) + assert track.age >= 0.04 -- 2.47.2