feat: person detection + tracking (YOLOv8n TensorRT) #54
11
jetson/ros2_ws/src/saltybot_perception/.gitignore
vendored
Normal file
11
jetson/ros2_ws/src/saltybot_perception/.gitignore
vendored
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# TensorRT engines are hardware-specific — don't commit them
|
||||||
|
models/*.engine
|
||||||
|
models/*.onnx
|
||||||
|
|
||||||
|
# Python bytecode
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
|
||||||
|
# Test cache
|
||||||
|
.pytest_cache/
|
||||||
@ -0,0 +1,37 @@
|
|||||||
|
person_detector:
|
||||||
|
ros__parameters:
|
||||||
|
# ── Model paths ──────────────────────────────────────────────────────────
|
||||||
|
# TensorRT FP16 engine (built with scripts/build_trt_engine.py)
|
||||||
|
# Stored on NVMe for fast load and hardware-specific optimisation.
|
||||||
|
engine_path: "/mnt/nvme/saltybot/models/yolov8n.engine"
|
||||||
|
|
||||||
|
# ONNX fallback — used when engine_path not found (dev / CI environments)
|
||||||
|
onnx_path: "/mnt/nvme/saltybot/models/yolov8n.onnx"
|
||||||
|
|
||||||
|
# ── Detection thresholds ─────────────────────────────────────────────────
|
||||||
|
confidence_threshold: 0.40 # YOLO class confidence (0–1). Lower → more detections but more FP.
|
||||||
|
nms_iou_threshold: 0.45 # NMS IoU threshold. Higher → fewer suppressed boxes.
|
||||||
|
|
||||||
|
# ── Depth filtering ───────────────────────────────────────────────────────
|
||||||
|
# Only consider persons within this depth range (metres).
|
||||||
|
# RealSense D435i reliable range: 0.3–5.0 m
|
||||||
|
min_depth: 0.5 # ignore very close objects (robot body artefacts)
|
||||||
|
max_depth: 5.0 # ignore persons beyond following range
|
||||||
|
|
||||||
|
# ── Tracker settings ──────────────────────────────────────────────────────
|
||||||
|
# Hold last known track position for this many seconds after losing detection.
|
||||||
|
# Handles brief occlusion (person walks behind furniture).
|
||||||
|
track_hold_duration: 2.0 # seconds
|
||||||
|
|
||||||
|
# Minimum IoU between current detection and existing track to re-associate.
|
||||||
|
track_iou_threshold: 0.25
|
||||||
|
|
||||||
|
# ── Output coordinate frame ───────────────────────────────────────────────
|
||||||
|
# Frame for /person/target PoseStamped. Must be reachable via TF.
|
||||||
|
# sl-controls follow loop expects base_link.
|
||||||
|
target_frame: "base_link"
|
||||||
|
|
||||||
|
# ── Debug ─────────────────────────────────────────────────────────────────
|
||||||
|
# Publish annotated RGB image to /person/debug_image.
|
||||||
|
# Adds ~5ms overhead. Disable on production hardware.
|
||||||
|
publish_debug_image: false
|
||||||
@ -0,0 +1,98 @@
|
|||||||
|
"""
|
||||||
|
person_detection.launch.py — Launch person detection node with config.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ros2 launch saltybot_perception person_detection.launch.py
|
||||||
|
|
||||||
|
# Override engine path:
|
||||||
|
ros2 launch saltybot_perception person_detection.launch.py \\
|
||||||
|
engine_path:=/mnt/nvme/saltybot/models/yolov8n.engine
|
||||||
|
|
||||||
|
# Use ONNX fallback (dev/CI):
|
||||||
|
ros2 launch saltybot_perception person_detection.launch.py \\
|
||||||
|
onnx_path:=/mnt/nvme/saltybot/models/yolov8n.onnx
|
||||||
|
|
||||||
|
# Enable debug image stream:
|
||||||
|
ros2 launch saltybot_perception person_detection.launch.py \\
|
||||||
|
publish_debug_image:=true
|
||||||
|
|
||||||
|
Prerequisites:
|
||||||
|
- RealSense D435i node running and publishing:
|
||||||
|
/camera/color/image_raw
|
||||||
|
/camera/depth/image_rect_raw
|
||||||
|
/camera/color/camera_info
|
||||||
|
- TF tree containing base_link ← camera_color_optical_frame
|
||||||
|
- YOLOv8n TensorRT engine (build with scripts/build_trt_engine.py)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from launch import LaunchDescription
|
||||||
|
from launch.actions import DeclareLaunchArgument
|
||||||
|
from launch.substitutions import LaunchConfiguration
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
from ament_index_python.packages import get_package_share_directory
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launch_description():
|
||||||
|
pkg_dir = get_package_share_directory('saltybot_perception')
|
||||||
|
default_config = os.path.join(pkg_dir, 'config', 'person_detection_params.yaml')
|
||||||
|
default_engine = '/mnt/nvme/saltybot/models/yolov8n.engine'
|
||||||
|
default_onnx = '/mnt/nvme/saltybot/models/yolov8n.onnx'
|
||||||
|
|
||||||
|
return LaunchDescription([
|
||||||
|
# ── Launch arguments ───────────────────────────────────────────────
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
'engine_path',
|
||||||
|
default_value=default_engine,
|
||||||
|
description='Path to TensorRT .engine file (built by build_trt_engine.py)',
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
'onnx_path',
|
||||||
|
default_value=default_onnx,
|
||||||
|
description='Path to ONNX model (fallback when engine_path not found)',
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
'publish_debug_image',
|
||||||
|
default_value='false',
|
||||||
|
description='Publish annotated debug image on /person/debug_image',
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
'target_frame',
|
||||||
|
default_value='base_link',
|
||||||
|
description='TF frame for /person/target PoseStamped output',
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
'confidence_threshold',
|
||||||
|
default_value='0.4',
|
||||||
|
description='Minimum YOLO detection confidence',
|
||||||
|
),
|
||||||
|
DeclareLaunchArgument(
|
||||||
|
'max_depth',
|
||||||
|
default_value='5.0',
|
||||||
|
description='Maximum person tracking distance in metres',
|
||||||
|
),
|
||||||
|
|
||||||
|
# ── Person detector node ───────────────────────────────────────────
|
||||||
|
Node(
|
||||||
|
package='saltybot_perception',
|
||||||
|
executable='person_detector',
|
||||||
|
name='person_detector',
|
||||||
|
output='screen',
|
||||||
|
parameters=[
|
||||||
|
default_config,
|
||||||
|
{
|
||||||
|
'engine_path': LaunchConfiguration('engine_path'),
|
||||||
|
'onnx_path': LaunchConfiguration('onnx_path'),
|
||||||
|
'publish_debug_image': LaunchConfiguration('publish_debug_image'),
|
||||||
|
'target_frame': LaunchConfiguration('target_frame'),
|
||||||
|
'confidence_threshold': LaunchConfiguration('confidence_threshold'),
|
||||||
|
'max_depth': LaunchConfiguration('max_depth'),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
remappings=[
|
||||||
|
# Standard RealSense topic names — no remapping needed by default
|
||||||
|
# ('/camera/color/image_raw', '/camera/color/image_raw'),
|
||||||
|
# ('/camera/depth/image_rect_raw', '/camera/depth/image_rect_raw'),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
])
|
||||||
42
jetson/ros2_ws/src/saltybot_perception/package.xml
Normal file
42
jetson/ros2_ws/src/saltybot_perception/package.xml
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
<?xml version="1.0"?>
|
||||||
|
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||||
|
<package format="3">
|
||||||
|
<name>saltybot_perception</name>
|
||||||
|
<version>0.1.0</version>
|
||||||
|
<description>
|
||||||
|
Person detection and tracking for saltybot person-following mode.
|
||||||
|
Uses YOLOv8n with TensorRT FP16 on Jetson Orin Nano Super (67 TOPS).
|
||||||
|
Publishes person bounding boxes and 3D target position for the follow loop.
|
||||||
|
</description>
|
||||||
|
<maintainer email="seb@vayrette.com">seb</maintainer>
|
||||||
|
<license>MIT</license>
|
||||||
|
|
||||||
|
<depend>rclpy</depend>
|
||||||
|
<depend>sensor_msgs</depend>
|
||||||
|
<depend>geometry_msgs</depend>
|
||||||
|
<depend>vision_msgs</depend>
|
||||||
|
<depend>tf2_ros</depend>
|
||||||
|
<depend>tf2_geometry_msgs</depend>
|
||||||
|
<depend>cv_bridge</depend>
|
||||||
|
<depend>image_transport</depend>
|
||||||
|
|
||||||
|
<exec_depend>python3-numpy</exec_depend>
|
||||||
|
<exec_depend>python3-opencv</exec_depend>
|
||||||
|
<exec_depend>python3-launch-ros</exec_depend>
|
||||||
|
|
||||||
|
<!-- TensorRT (Jetson) — optional, falls back to onnxruntime -->
|
||||||
|
<!-- exec_depend>python3-tensorrt</exec_depend -->
|
||||||
|
<!-- exec_depend>python3-pycuda</exec_depend -->
|
||||||
|
|
||||||
|
<!-- ONNX Runtime fallback -->
|
||||||
|
<!-- exec_depend>python3-onnxruntime</exec_depend -->
|
||||||
|
|
||||||
|
<test_depend>ament_copyright</test_depend>
|
||||||
|
<test_depend>ament_flake8</test_depend>
|
||||||
|
<test_depend>ament_pep257</test_depend>
|
||||||
|
<test_depend>python3-pytest</test_depend>
|
||||||
|
|
||||||
|
<export>
|
||||||
|
<build_type>ament_python</build_type>
|
||||||
|
</export>
|
||||||
|
</package>
|
||||||
@ -0,0 +1,124 @@
|
|||||||
|
"""
|
||||||
|
detection_utils.py — Pure-Python helpers with no ROS2 dependencies.
|
||||||
|
Importable in tests without a running ROS2 environment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def nms(boxes: np.ndarray, scores: np.ndarray,
|
||||||
|
iou_threshold: float = 0.45) -> list[int]:
|
||||||
|
"""
|
||||||
|
Non-maximum suppression.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
boxes: [N, 4] float array of (x1, y1, x2, y2) boxes
|
||||||
|
scores: [N] float array of confidence scores
|
||||||
|
iou_threshold: suppress boxes with IoU > this value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of kept indices (sorted by descending score).
|
||||||
|
"""
|
||||||
|
if len(boxes) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
|
||||||
|
areas = (x2 - x1) * (y2 - y1)
|
||||||
|
order = scores.argsort()[::-1]
|
||||||
|
keep = []
|
||||||
|
|
||||||
|
while len(order) > 0:
|
||||||
|
i = order[0]
|
||||||
|
keep.append(int(i))
|
||||||
|
if len(order) == 1:
|
||||||
|
break
|
||||||
|
ix1 = np.maximum(x1[i], x1[order[1:]])
|
||||||
|
iy1 = np.maximum(y1[i], y1[order[1:]])
|
||||||
|
ix2 = np.minimum(x2[i], x2[order[1:]])
|
||||||
|
iy2 = np.minimum(y2[i], y2[order[1:]])
|
||||||
|
iw = np.maximum(0.0, ix2 - ix1)
|
||||||
|
ih = np.maximum(0.0, iy2 - iy1)
|
||||||
|
inter = iw * ih
|
||||||
|
iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-6)
|
||||||
|
order = order[1:][iou < iou_threshold]
|
||||||
|
|
||||||
|
return keep
|
||||||
|
|
||||||
|
|
||||||
|
def letterbox(image: np.ndarray, size: int = 640, pad_value: int = 114):
|
||||||
|
"""
|
||||||
|
Letterbox-resize `image` to `size`×`size`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(canvas, scale, pad_w, pad_h)
|
||||||
|
canvas: uint8 [size, size, 3]
|
||||||
|
scale: float — resize scale factor
|
||||||
|
pad_w: int — horizontal padding applied
|
||||||
|
pad_h: int — vertical padding applied
|
||||||
|
"""
|
||||||
|
import cv2
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
scale = min(size / w, size / h)
|
||||||
|
new_w = int(round(w * scale))
|
||||||
|
new_h = int(round(h * scale))
|
||||||
|
pad_w = (size - new_w) // 2
|
||||||
|
pad_h = (size - new_h) // 2
|
||||||
|
|
||||||
|
resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
||||||
|
canvas = np.full((size, size, image.shape[2] if image.ndim == 3 else 1),
|
||||||
|
pad_value, dtype=np.uint8)
|
||||||
|
canvas[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = resized
|
||||||
|
return canvas, scale, pad_w, pad_h
|
||||||
|
|
||||||
|
|
||||||
|
def remap_bbox(x1, y1, x2, y2, scale, pad_w, pad_h, orig_w, orig_h):
|
||||||
|
"""Map bbox from letterboxed-image space back to original image space."""
|
||||||
|
x1 = float(np.clip((x1 - pad_w) / scale, 0, orig_w))
|
||||||
|
y1 = float(np.clip((y1 - pad_h) / scale, 0, orig_h))
|
||||||
|
x2 = float(np.clip((x2 - pad_w) / scale, 0, orig_w))
|
||||||
|
y2 = float(np.clip((y2 - pad_h) / scale, 0, orig_h))
|
||||||
|
return x1, y1, x2, y2
|
||||||
|
|
||||||
|
|
||||||
|
def get_depth_at(depth_img: np.ndarray, u: float, v: float,
|
||||||
|
window: int = 7, min_d: float = 0.3, max_d: float = 6.0) -> float:
|
||||||
|
"""
|
||||||
|
Median depth in a `window`×`window` region around pixel (u, v).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
depth_img: float32 depth image in metres
|
||||||
|
u, v: pixel coordinates
|
||||||
|
window: patch side length
|
||||||
|
min_d, max_d: valid depth range
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Median depth in metres, or 0.0 if no valid pixels.
|
||||||
|
"""
|
||||||
|
h, w = depth_img.shape
|
||||||
|
u, v = int(u), int(v)
|
||||||
|
half = window // 2
|
||||||
|
u1, u2 = max(0, u - half), min(w, u + half + 1)
|
||||||
|
v1, v2 = max(0, v - half), min(h, v + half + 1)
|
||||||
|
patch = depth_img[v1:v2, u1:u2]
|
||||||
|
valid = patch[(patch > min_d) & (patch < max_d)]
|
||||||
|
return float(np.median(valid)) if len(valid) > 0 else 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def pixel_to_3d(u: float, v: float, depth_m: float, K) -> tuple[float, float, float]:
|
||||||
|
"""
|
||||||
|
Back-project pixel (u, v) at depth_m to 3D point in camera frame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
u, v: pixel coordinates
|
||||||
|
depth_m: depth in metres
|
||||||
|
K: camera intrinsic matrix (row-major, 9 elements or 3×3 array)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(X, Y, Z) in camera optical frame
|
||||||
|
"""
|
||||||
|
K = np.asarray(K).ravel()
|
||||||
|
fx, fy = K[0], K[4]
|
||||||
|
cx, cy = K[2], K[5]
|
||||||
|
X = (u - cx) * depth_m / fx
|
||||||
|
Y = (v - cy) * depth_m / fy
|
||||||
|
return X, Y, depth_m
|
||||||
@ -0,0 +1,471 @@
|
|||||||
|
"""
|
||||||
|
person_detector_node.py — Person detection + tracking for saltybot.
|
||||||
|
|
||||||
|
Pipeline:
|
||||||
|
1. Receive synchronized color + depth frames from RealSense D435i.
|
||||||
|
2. Run YOLOv8n (TensorRT FP16 on Orin Nano Super, ONNX fallback elsewhere).
|
||||||
|
3. Filter detections for class 'person' (COCO class 0).
|
||||||
|
4. Estimate 3D position from aligned depth at bounding box centre.
|
||||||
|
5. Track target person across frames (SimplePersonTracker).
|
||||||
|
6. Publish:
|
||||||
|
/person/detections — vision_msgs/Detection2DArray (all detected persons)
|
||||||
|
/person/target — geometry_msgs/PoseStamped (tracked person 3D pos)
|
||||||
|
/person/debug_image — sensor_msgs/Image (annotated RGB, lazy)
|
||||||
|
|
||||||
|
TensorRT engine:
|
||||||
|
- Build once with scripts/build_trt_engine.py
|
||||||
|
- Place engine at path specified by `engine_path` param
|
||||||
|
- Falls back to ONNX Runtime (onnxruntime[-gpu]) if engine not found
|
||||||
|
|
||||||
|
Coordinate frame:
|
||||||
|
/person/target is published in `base_link` frame.
|
||||||
|
If TF unavailable, falls back to `camera_color_optical_frame`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.duration import Duration
|
||||||
|
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy, DurabilityPolicy
|
||||||
|
|
||||||
|
import message_filters
|
||||||
|
import cv2
|
||||||
|
from cv_bridge import CvBridge
|
||||||
|
|
||||||
|
from sensor_msgs.msg import Image, CameraInfo
|
||||||
|
from geometry_msgs.msg import PoseStamped, PointStamped
|
||||||
|
from vision_msgs.msg import (
|
||||||
|
Detection2D,
|
||||||
|
Detection2DArray,
|
||||||
|
ObjectHypothesisWithPose,
|
||||||
|
BoundingBox2D,
|
||||||
|
)
|
||||||
|
import tf2_ros
|
||||||
|
import tf2_geometry_msgs # noqa: F401 — registers PointStamped transform support
|
||||||
|
|
||||||
|
from .tracker import SimplePersonTracker
|
||||||
|
from .detection_utils import nms, letterbox, remap_bbox, get_depth_at, pixel_to_3d
|
||||||
|
|
||||||
|
_PERSON_CLASS_ID = 0 # COCO class index for 'person'
|
||||||
|
_YOLO_INPUT_SIZE = 640
|
||||||
|
|
||||||
|
|
||||||
|
# ── Inference backends ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _TRTBackend:
|
||||||
|
"""TensorRT inference engine (Jetson Orin)."""
|
||||||
|
|
||||||
|
def __init__(self, engine_path: str):
|
||||||
|
import tensorrt as trt
|
||||||
|
import pycuda.driver as cuda
|
||||||
|
import pycuda.autoinit # noqa: F401
|
||||||
|
|
||||||
|
self._cuda = cuda
|
||||||
|
logger = trt.Logger(trt.Logger.WARNING)
|
||||||
|
with open(engine_path, 'rb') as f, trt.Runtime(logger) as runtime:
|
||||||
|
self._engine = runtime.deserialize_cuda_engine(f.read())
|
||||||
|
self._context = self._engine.create_execution_context()
|
||||||
|
|
||||||
|
self._inputs = []
|
||||||
|
self._outputs = []
|
||||||
|
self._bindings = []
|
||||||
|
for i in range(self._engine.num_io_tensors):
|
||||||
|
name = self._engine.get_tensor_name(i)
|
||||||
|
dtype = trt.nptype(self._engine.get_tensor_dtype(name))
|
||||||
|
shape = tuple(self._engine.get_tensor_shape(name))
|
||||||
|
nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize
|
||||||
|
host_mem = cuda.pagelocked_empty(shape, dtype)
|
||||||
|
device_mem = cuda.mem_alloc(nbytes)
|
||||||
|
self._bindings.append(int(device_mem))
|
||||||
|
if self._engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
|
||||||
|
self._inputs.append({'host': host_mem, 'device': device_mem})
|
||||||
|
else:
|
||||||
|
self._outputs.append({'host': host_mem, 'device': device_mem})
|
||||||
|
self._stream = cuda.Stream()
|
||||||
|
|
||||||
|
def infer(self, input_data: np.ndarray) -> list[np.ndarray]:
|
||||||
|
np.copyto(self._inputs[0]['host'], input_data.ravel())
|
||||||
|
self._cuda.memcpy_htod_async(
|
||||||
|
self._inputs[0]['device'], self._inputs[0]['host'], self._stream)
|
||||||
|
self._context.execute_async_v2(self._bindings, self._stream.handle)
|
||||||
|
for out in self._outputs:
|
||||||
|
self._cuda.memcpy_dtoh_async(out['host'], out['device'], self._stream)
|
||||||
|
self._stream.synchronize()
|
||||||
|
return [out['host'].copy() for out in self._outputs]
|
||||||
|
|
||||||
|
|
||||||
|
class _ONNXBackend:
|
||||||
|
"""ONNX Runtime inference (CPU / CUDA EP — fallback for non-Jetson)."""
|
||||||
|
|
||||||
|
def __init__(self, onnx_path: str):
|
||||||
|
import onnxruntime as ort
|
||||||
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||||
|
self._session = ort.InferenceSession(onnx_path, providers=providers)
|
||||||
|
self._input_name = self._session.get_inputs()[0].name
|
||||||
|
|
||||||
|
def infer(self, input_data: np.ndarray) -> list[np.ndarray]:
|
||||||
|
return self._session.run(None, {self._input_name: input_data})
|
||||||
|
|
||||||
|
|
||||||
|
# ── Node ──────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class PersonDetectorNode(Node):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('person_detector')
|
||||||
|
self._bridge = CvBridge()
|
||||||
|
self._camera_info: CameraInfo | None = None
|
||||||
|
self._backend = None
|
||||||
|
|
||||||
|
# ── Parameters ──────────────────────────────────────────────────────
|
||||||
|
self.declare_parameter('engine_path', '')
|
||||||
|
self.declare_parameter('onnx_path', '')
|
||||||
|
self.declare_parameter('confidence_threshold', 0.4)
|
||||||
|
self.declare_parameter('nms_iou_threshold', 0.45)
|
||||||
|
self.declare_parameter('min_depth', 0.5)
|
||||||
|
self.declare_parameter('max_depth', 5.0)
|
||||||
|
self.declare_parameter('track_hold_duration', 2.0)
|
||||||
|
self.declare_parameter('track_iou_threshold', 0.25)
|
||||||
|
self.declare_parameter('target_frame', 'base_link')
|
||||||
|
self.declare_parameter('publish_debug_image', False)
|
||||||
|
|
||||||
|
self._conf_thresh = self.get_parameter('confidence_threshold').value
|
||||||
|
self._nms_thresh = self.get_parameter('nms_iou_threshold').value
|
||||||
|
self._min_depth = self.get_parameter('min_depth').value
|
||||||
|
self._max_depth = self.get_parameter('max_depth').value
|
||||||
|
self._target_frame = self.get_parameter('target_frame').value
|
||||||
|
self._pub_debug = self.get_parameter('publish_debug_image').value
|
||||||
|
|
||||||
|
hold_dur = self.get_parameter('track_hold_duration').value
|
||||||
|
track_iou = self.get_parameter('track_iou_threshold').value
|
||||||
|
self._tracker = SimplePersonTracker(
|
||||||
|
hold_duration=hold_dur,
|
||||||
|
iou_threshold=track_iou,
|
||||||
|
min_depth=self._min_depth,
|
||||||
|
max_depth=self._max_depth,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Letterbox state (set during preprocessing)
|
||||||
|
self._scale = 1.0
|
||||||
|
self._pad_w = 0
|
||||||
|
self._pad_h = 0
|
||||||
|
self._orig_w = 0
|
||||||
|
self._orig_h = 0
|
||||||
|
|
||||||
|
# ── TF ──────────────────────────────────────────────────────────────
|
||||||
|
self._tf_buffer = tf2_ros.Buffer()
|
||||||
|
self._tf_listener = tf2_ros.TransformListener(self._tf_buffer, self)
|
||||||
|
|
||||||
|
# ── Publishers ──────────────────────────────────────────────────────
|
||||||
|
best_effort_qos = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=1,
|
||||||
|
)
|
||||||
|
self._pub_detections = self.create_publisher(
|
||||||
|
Detection2DArray, '/person/detections', best_effort_qos)
|
||||||
|
self._pub_target = self.create_publisher(
|
||||||
|
PoseStamped, '/person/target', best_effort_qos)
|
||||||
|
if self._pub_debug:
|
||||||
|
self._pub_debug_img = self.create_publisher(
|
||||||
|
Image, '/person/debug_image', best_effort_qos)
|
||||||
|
|
||||||
|
# ── Camera info subscriber ───────────────────────────────────────────
|
||||||
|
self.create_subscription(
|
||||||
|
CameraInfo,
|
||||||
|
'/camera/color/camera_info',
|
||||||
|
self._on_camera_info,
|
||||||
|
QoSProfile(reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
history=HistoryPolicy.KEEP_LAST, depth=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Synchronized color + depth subscribers ───────────────────────────
|
||||||
|
color_sub = message_filters.Subscriber(
|
||||||
|
self, Image, '/camera/color/image_raw',
|
||||||
|
qos_profile=QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
history=HistoryPolicy.KEEP_LAST, depth=4))
|
||||||
|
depth_sub = message_filters.Subscriber(
|
||||||
|
self, Image, '/camera/depth/image_rect_raw',
|
||||||
|
qos_profile=QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
history=HistoryPolicy.KEEP_LAST, depth=4))
|
||||||
|
self._sync = message_filters.ApproximateTimeSynchronizer(
|
||||||
|
[color_sub, depth_sub], queue_size=4, slop=0.05)
|
||||||
|
self._sync.registerCallback(self._on_frame)
|
||||||
|
|
||||||
|
# ── Load model ───────────────────────────────────────────────────────
|
||||||
|
self._load_backend()
|
||||||
|
self.get_logger().info('PersonDetectorNode ready.')
|
||||||
|
|
||||||
|
# ── Model loading ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _load_backend(self):
|
||||||
|
engine_path = self.get_parameter('engine_path').value
|
||||||
|
onnx_path = self.get_parameter('onnx_path').value
|
||||||
|
|
||||||
|
if engine_path and os.path.isfile(engine_path):
|
||||||
|
try:
|
||||||
|
self._backend = _TRTBackend(engine_path)
|
||||||
|
self.get_logger().info(f'TensorRT backend loaded: {engine_path}')
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().warn(f'TRT load failed ({e}), falling back to ONNX')
|
||||||
|
|
||||||
|
if onnx_path and os.path.isfile(onnx_path):
|
||||||
|
try:
|
||||||
|
self._backend = _ONNXBackend(onnx_path)
|
||||||
|
self.get_logger().info(f'ONNX backend loaded: {onnx_path}')
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f'ONNX load failed: {e}')
|
||||||
|
|
||||||
|
self.get_logger().error(
|
||||||
|
'No model found. Set engine_path or onnx_path parameter. '
|
||||||
|
'Detection disabled — node spinning without publishing.'
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Callbacks ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _on_camera_info(self, msg: CameraInfo):
|
||||||
|
self._camera_info = msg
|
||||||
|
|
||||||
|
def _on_frame(self, color_msg: Image, depth_msg: Image):
|
||||||
|
if self._backend is None or self._camera_info is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
t0 = time.monotonic()
|
||||||
|
|
||||||
|
# Decode images
|
||||||
|
try:
|
||||||
|
bgr = self._bridge.imgmsg_to_cv2(color_msg, desired_encoding='bgr8')
|
||||||
|
depth = self._bridge.imgmsg_to_cv2(depth_msg, desired_encoding='passthrough')
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f'Image decode error: {e}', throttle_duration_sec=5.0)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Depth image should be float32 metres (realsense2_camera default)
|
||||||
|
if depth.dtype != np.float32:
|
||||||
|
depth = depth.astype(np.float32)
|
||||||
|
if depth.max() > 100.0: # uint16 mm → float32 m
|
||||||
|
depth /= 1000.0
|
||||||
|
|
||||||
|
# Run detection
|
||||||
|
tensor = self._preprocess(bgr)
|
||||||
|
try:
|
||||||
|
raw_outputs = self._backend.infer(tensor)
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f'Inference error: {e}', throttle_duration_sec=5.0)
|
||||||
|
return
|
||||||
|
|
||||||
|
detections_px = self._postprocess_yolov8(raw_outputs[0])
|
||||||
|
|
||||||
|
# Get depth for each detection
|
||||||
|
detections = []
|
||||||
|
for x1, y1, x2, y2, conf in detections_px:
|
||||||
|
cx = (x1 + x2) / 2.0
|
||||||
|
cy = (y1 + y2) / 2.0
|
||||||
|
d = self._get_depth_at(depth, cx, cy)
|
||||||
|
detections.append(((x1, y1, x2, y2), d, conf))
|
||||||
|
|
||||||
|
# Update tracker
|
||||||
|
track = self._tracker.update(detections)
|
||||||
|
|
||||||
|
# Publish Detection2DArray
|
||||||
|
det_array = Detection2DArray()
|
||||||
|
det_array.header = color_msg.header
|
||||||
|
for (x1, y1, x2, y2), d, conf in detections:
|
||||||
|
det = self._make_detection2d(
|
||||||
|
color_msg.header, x1, y1, x2, y2, conf)
|
||||||
|
det_array.detections.append(det)
|
||||||
|
self._pub_detections.publish(det_array)
|
||||||
|
|
||||||
|
# Publish target PoseStamped
|
||||||
|
if track is not None:
|
||||||
|
x1, y1, x2, y2 = track.bbox
|
||||||
|
cx = (x1 + x2) / 2.0
|
||||||
|
cy = (y1 + y2) / 2.0
|
||||||
|
d = track.depth
|
||||||
|
|
||||||
|
if d > 0:
|
||||||
|
X, Y, Z = self._pixel_to_3d(cx, cy, d)
|
||||||
|
track.position_3d = (X, Y, Z)
|
||||||
|
|
||||||
|
pose = PoseStamped()
|
||||||
|
pose.header = color_msg.header
|
||||||
|
pose.header.frame_id = 'camera_color_optical_frame'
|
||||||
|
pose.pose.position.x = X
|
||||||
|
pose.pose.position.y = Y
|
||||||
|
pose.pose.position.z = Z
|
||||||
|
pose.pose.orientation.w = 1.0
|
||||||
|
|
||||||
|
# Transform to target_frame
|
||||||
|
if self._target_frame != 'camera_color_optical_frame':
|
||||||
|
pose = self._transform_pose(pose)
|
||||||
|
|
||||||
|
self._pub_target.publish(pose)
|
||||||
|
|
||||||
|
# Debug image
|
||||||
|
if self._pub_debug and hasattr(self, '_pub_debug_img'):
|
||||||
|
debug = self._draw_debug(bgr, detections, track)
|
||||||
|
self._pub_debug_img.publish(
|
||||||
|
self._bridge.cv2_to_imgmsg(debug, encoding='bgr8'))
|
||||||
|
|
||||||
|
dt = (time.monotonic() - t0) * 1000
|
||||||
|
self.get_logger().debug(
|
||||||
|
f'Frame: {len(detections)} persons, track={track is not None}, {dt:.1f}ms',
|
||||||
|
throttle_duration_sec=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Preprocessing ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _preprocess(self, bgr: np.ndarray) -> np.ndarray:
|
||||||
|
"""Letterbox resize to 640×640, normalise, HWC→CHW, add batch dim."""
|
||||||
|
h, w = bgr.shape[:2]
|
||||||
|
canvas, scale, pad_w, pad_h = letterbox(bgr, _YOLO_INPUT_SIZE)
|
||||||
|
self._scale = scale
|
||||||
|
self._pad_w = pad_w
|
||||||
|
self._pad_h = pad_h
|
||||||
|
self._orig_w = w
|
||||||
|
self._orig_h = h
|
||||||
|
|
||||||
|
rgb = cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)
|
||||||
|
tensor = rgb.astype(np.float32) / 255.0
|
||||||
|
tensor = tensor.transpose(2, 0, 1) # HWC → CHW
|
||||||
|
return np.ascontiguousarray(tensor[np.newaxis]) # [1, 3, H, W]
|
||||||
|
|
||||||
|
def _remap_bbox(self, x1, y1, x2, y2):
|
||||||
|
"""Map bbox from 640×640 space back to original image space."""
|
||||||
|
return remap_bbox(x1, y1, x2, y2,
|
||||||
|
self._scale, self._pad_w, self._pad_h,
|
||||||
|
self._orig_w, self._orig_h)
|
||||||
|
|
||||||
|
# ── Post-processing ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _postprocess_yolov8(self, raw: np.ndarray) -> list:
|
||||||
|
"""
|
||||||
|
Parse YOLOv8n output tensor and return person detections.
|
||||||
|
|
||||||
|
YOLOv8n output shape: [1, 84, 8400] or [84, 8400]
|
||||||
|
rows 0-3: cx, cy, w, h (in 640×640 input space)
|
||||||
|
rows 4-83: class scores (no objectness score in v8)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of (x1, y1, x2, y2, confidence) in original image space
|
||||||
|
"""
|
||||||
|
pred = raw.squeeze() # [84, 8400]
|
||||||
|
if pred.ndim != 2 or pred.shape[0] < 5:
|
||||||
|
return []
|
||||||
|
|
||||||
|
person_scores = pred[4 + _PERSON_CLASS_ID, :] # [8400]
|
||||||
|
mask = person_scores > self._conf_thresh
|
||||||
|
if not mask.any():
|
||||||
|
return []
|
||||||
|
|
||||||
|
scores = person_scores[mask]
|
||||||
|
boxes_raw = pred[:4, mask] # cx, cy, w, h — [4, N]
|
||||||
|
|
||||||
|
# cx,cy,w,h → x1,y1,x2,y2
|
||||||
|
cx, cy = boxes_raw[0], boxes_raw[1]
|
||||||
|
w2, h2 = boxes_raw[2] / 2.0, boxes_raw[3] / 2.0
|
||||||
|
x1, y1 = cx - w2, cy - h2
|
||||||
|
x2, y2 = cx + w2, cy + h2
|
||||||
|
|
||||||
|
boxes = np.stack([x1, y1, x2, y2], axis=1) # [N, 4]
|
||||||
|
keep = nms(boxes, scores, self._nms_thresh)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i in keep:
|
||||||
|
rx1, ry1, rx2, ry2 = self._remap_bbox(
|
||||||
|
boxes[i, 0], boxes[i, 1], boxes[i, 2], boxes[i, 3])
|
||||||
|
# Skip degenerate boxes
|
||||||
|
if rx2 - rx1 < 4 or ry2 - ry1 < 4:
|
||||||
|
continue
|
||||||
|
results.append((rx1, ry1, rx2, ry2, float(scores[i])))
|
||||||
|
return results
|
||||||
|
|
||||||
|
# ── Depth & 3D ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _get_depth_at(self, depth_img: np.ndarray, u: float, v: float,
|
||||||
|
window: int = 7) -> float:
|
||||||
|
"""Median depth in a window around pixel (u, v). Returns 0 if invalid."""
|
||||||
|
return get_depth_at(depth_img, u, v, window,
|
||||||
|
self._min_depth, self._max_depth)
|
||||||
|
|
||||||
|
def _pixel_to_3d(self, u: float, v: float, depth_m: float):
|
||||||
|
"""Back-project pixel (u, v) at depth_m to 3D point in camera frame."""
|
||||||
|
return pixel_to_3d(u, v, depth_m, self._camera_info.k)
|
||||||
|
|
||||||
|
# ── TF transform ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _transform_pose(self, pose_in: PoseStamped) -> PoseStamped:
|
||||||
|
try:
|
||||||
|
return self._tf_buffer.transform(
|
||||||
|
pose_in, self._target_frame,
|
||||||
|
timeout=Duration(seconds=0.05))
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().warn(
|
||||||
|
f'TF {pose_in.header.frame_id}→{self._target_frame} failed: {e}',
|
||||||
|
throttle_duration_sec=5.0)
|
||||||
|
return pose_in # publish in camera frame as fallback
|
||||||
|
|
||||||
|
# ── Message builders ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _make_detection2d(self, header, x1, y1, x2, y2, conf) -> Detection2D:
|
||||||
|
det = Detection2D()
|
||||||
|
det.header = header
|
||||||
|
|
||||||
|
hyp = ObjectHypothesisWithPose()
|
||||||
|
hyp.hypothesis.class_id = 'person'
|
||||||
|
hyp.hypothesis.score = conf
|
||||||
|
det.results.append(hyp)
|
||||||
|
|
||||||
|
det.bbox.center.position.x = (x1 + x2) / 2.0
|
||||||
|
det.bbox.center.position.y = (y1 + y2) / 2.0
|
||||||
|
det.bbox.center.theta = 0.0
|
||||||
|
det.bbox.size_x = x2 - x1
|
||||||
|
det.bbox.size_y = y2 - y1
|
||||||
|
return det
|
||||||
|
|
||||||
|
# ── Debug visualisation ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _draw_debug(self, bgr, detections, track):
|
||||||
|
vis = bgr.copy()
|
||||||
|
for (x1, y1, x2, y2), d, conf in detections:
|
||||||
|
cv2.rectangle(vis, (int(x1), int(y1)), (int(x2), int(y2)),
|
||||||
|
(0, 255, 0), 2)
|
||||||
|
cv2.putText(vis, f'{conf:.2f} {d:.1f}m',
|
||||||
|
(int(x1), int(y1) - 6),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
|
||||||
|
if track is not None:
|
||||||
|
x1, y1, x2, y2 = track.bbox
|
||||||
|
cv2.rectangle(vis, (int(x1), int(y1)), (int(x2), int(y2)),
|
||||||
|
(0, 0, 255), 3)
|
||||||
|
label = f'ID:{track.track_id} {track.depth:.1f}m'
|
||||||
|
if track.age > 0.05:
|
||||||
|
label += f' (held {track.age:.1f}s)'
|
||||||
|
cv2.putText(vis, label, (int(x1), int(y1) - 6),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
|
||||||
|
return vis
|
||||||
|
|
||||||
|
|
||||||
|
# ── Entry point ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = PersonDetectorNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -0,0 +1,179 @@
|
|||||||
|
"""
|
||||||
|
tracker.py — Single-target person tracker for saltybot person-following mode.
|
||||||
|
|
||||||
|
Strategy:
|
||||||
|
- Select closest valid detection each frame (smallest depth within range)
|
||||||
|
- Re-associate across frames using IoU matching with existing track
|
||||||
|
- Hold last known track for `hold_duration` seconds when detections are lost
|
||||||
|
- Assign monotonically increasing track IDs
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
tracker = SimplePersonTracker(hold_duration=2.0)
|
||||||
|
track = tracker.update(detections) # detections: list of (bbox, depth, conf)
|
||||||
|
if track is not None:
|
||||||
|
print(track.bbox, track.depth, track.track_id)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class PersonTrack:
|
||||||
|
"""Single tracked person."""
|
||||||
|
|
||||||
|
def __init__(self, bbox, depth, confidence, track_id):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
bbox: (x1, y1, x2, y2) in pixels
|
||||||
|
depth: distance to person in metres
|
||||||
|
confidence: detection confidence 0–1
|
||||||
|
track_id: unique integer ID
|
||||||
|
"""
|
||||||
|
self.bbox = bbox
|
||||||
|
self.depth = depth
|
||||||
|
self.confidence = confidence
|
||||||
|
self.track_id = track_id
|
||||||
|
self.position_3d = None # (X, Y, Z) in camera frame — set by node
|
||||||
|
self._first_seen = time.monotonic()
|
||||||
|
self._last_seen = time.monotonic()
|
||||||
|
|
||||||
|
def touch(self, bbox, depth, confidence):
|
||||||
|
self.bbox = bbox
|
||||||
|
self.depth = depth
|
||||||
|
self.confidence = confidence
|
||||||
|
self._last_seen = time.monotonic()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def age(self):
|
||||||
|
"""Seconds since last detection."""
|
||||||
|
return time.monotonic() - self._last_seen
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_stale(self):
|
||||||
|
return self.age > 0 # Always check externally against hold_duration
|
||||||
|
|
||||||
|
@property
|
||||||
|
def center(self):
|
||||||
|
x1, y1, x2, y2 = self.bbox
|
||||||
|
return ((x1 + x2) / 2.0, (y1 + y2) / 2.0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def area(self):
|
||||||
|
x1, y1, x2, y2 = self.bbox
|
||||||
|
return max(0.0, (x2 - x1) * (y2 - y1))
|
||||||
|
|
||||||
|
|
||||||
|
class SimplePersonTracker:
|
||||||
|
"""
|
||||||
|
Lightweight single-target person tracker.
|
||||||
|
|
||||||
|
Maintains one active PersonTrack at a time. On each update:
|
||||||
|
1. Filter detections to valid depth range.
|
||||||
|
2. If a track is active, attempt IoU re-association.
|
||||||
|
3. If re-association fails and track is within hold_duration, keep stale.
|
||||||
|
4. If no active track, initialise from closest detection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hold_duration: float = 2.0,
|
||||||
|
iou_threshold: float = 0.25,
|
||||||
|
min_depth: float = 0.3,
|
||||||
|
max_depth: float = 5.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hold_duration: seconds to hold last known position after losing track
|
||||||
|
iou_threshold: minimum IoU to accept a re-association
|
||||||
|
min_depth: minimum valid depth in metres
|
||||||
|
max_depth: maximum valid depth in metres
|
||||||
|
"""
|
||||||
|
self._hold_duration = hold_duration
|
||||||
|
self._iou_threshold = iou_threshold
|
||||||
|
self._min_depth = min_depth
|
||||||
|
self._max_depth = max_depth
|
||||||
|
self._track: PersonTrack | None = None
|
||||||
|
self._next_id = 1
|
||||||
|
|
||||||
|
def update(self, detections):
|
||||||
|
"""
|
||||||
|
Update tracker with new detections.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detections: list of (bbox, depth, confidence) where
|
||||||
|
bbox = (x1, y1, x2, y2) pixels,
|
||||||
|
depth = float metres (0 = invalid),
|
||||||
|
confidence = float 0–1
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PersonTrack or None.
|
||||||
|
Returns a stale PersonTrack (track.age > 0) during hold period.
|
||||||
|
Returns None after hold_duration or if never seen a person.
|
||||||
|
"""
|
||||||
|
valid = [
|
||||||
|
(b, d, c) for b, d, c in detections
|
||||||
|
if self._min_depth < d < self._max_depth
|
||||||
|
]
|
||||||
|
|
||||||
|
if not valid:
|
||||||
|
# No valid detections this frame
|
||||||
|
if self._track is not None:
|
||||||
|
if self._track.age <= self._hold_duration:
|
||||||
|
return self._track # hold last known
|
||||||
|
else:
|
||||||
|
self._track = None
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self._track is not None:
|
||||||
|
# Attempt re-association by IoU
|
||||||
|
best_iou = self._iou_threshold
|
||||||
|
best = None
|
||||||
|
for bbox, depth, conf in valid:
|
||||||
|
iou = _iou(bbox, self._track.bbox)
|
||||||
|
if iou > best_iou:
|
||||||
|
best_iou = iou
|
||||||
|
best = (bbox, depth, conf)
|
||||||
|
|
||||||
|
if best is not None:
|
||||||
|
self._track.touch(*best)
|
||||||
|
return self._track
|
||||||
|
|
||||||
|
# No IoU match — keep stale track within hold window
|
||||||
|
if self._track.age <= self._hold_duration:
|
||||||
|
return self._track
|
||||||
|
else:
|
||||||
|
self._track = None # lost — start fresh
|
||||||
|
|
||||||
|
# No active track: pick closest valid detection
|
||||||
|
bbox, depth, conf = min(valid, key=lambda x: x[1])
|
||||||
|
self._track = PersonTrack(bbox, depth, conf, self._next_id)
|
||||||
|
self._next_id += 1
|
||||||
|
return self._track
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Drop current track."""
|
||||||
|
self._track = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active(self):
|
||||||
|
return self._track is not None
|
||||||
|
|
||||||
|
|
||||||
|
def _iou(bbox_a, bbox_b):
|
||||||
|
"""Compute Intersection-over-Union of two bounding boxes (x1,y1,x2,y2)."""
|
||||||
|
ax1, ay1, ax2, ay2 = bbox_a
|
||||||
|
bx1, by1, bx2, by2 = bbox_b
|
||||||
|
|
||||||
|
ix1 = max(ax1, bx1)
|
||||||
|
iy1 = max(ay1, by1)
|
||||||
|
ix2 = min(ax2, bx2)
|
||||||
|
iy2 = min(ay2, by2)
|
||||||
|
|
||||||
|
if ix2 <= ix1 or iy2 <= iy1:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
intersection = (ix2 - ix1) * (iy2 - iy1)
|
||||||
|
area_a = (ax2 - ax1) * (ay2 - ay1)
|
||||||
|
area_b = (bx2 - bx1) * (by2 - by1)
|
||||||
|
union = area_a + area_b - intersection
|
||||||
|
return intersection / union if union > 0 else 0.0
|
||||||
@ -0,0 +1,162 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
build_trt_engine.py — Convert ONNX model to TensorRT .engine file.
|
||||||
|
|
||||||
|
Run this ONCE on the Jetson Orin Nano Super to build the optimised engine.
|
||||||
|
The engine is hardware-specific and cannot be shared between GPU families.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 build_trt_engine.py --onnx yolov8n.onnx --engine yolov8n.engine
|
||||||
|
python3 build_trt_engine.py --onnx yolov8n.onnx --engine yolov8n.engine --fp16
|
||||||
|
python3 build_trt_engine.py --onnx yolov8n.onnx --engine yolov8n.engine --fp16 --batch 1
|
||||||
|
|
||||||
|
Alternatively, use the trtexec CLI tool (ships with JetPack):
|
||||||
|
/usr/src/tensorrt/bin/trtexec \\
|
||||||
|
--onnx=yolov8n.onnx \\
|
||||||
|
--fp16 \\
|
||||||
|
--saveEngine=yolov8n.engine \\
|
||||||
|
--workspace=2048
|
||||||
|
|
||||||
|
Model download (YOLOv8n):
|
||||||
|
pip3 install ultralytics
|
||||||
|
python3 -c "from ultralytics import YOLO; YOLO('yolov8n.pt').export(format='onnx', imgsz=640)"
|
||||||
|
|
||||||
|
Model download (YOLOv5s):
|
||||||
|
python3 -c "
|
||||||
|
import torch
|
||||||
|
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
|
||||||
|
model.export(format='onnx')
|
||||||
|
"
|
||||||
|
|
||||||
|
Engine location:
|
||||||
|
Place the built .engine file at the path specified by `engine_path`
|
||||||
|
in person_detection_params.yaml (default: models/yolov8n.engine)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def build_engine(onnx_path: str, engine_path: str, fp16: bool, batch_size: int,
|
||||||
|
workspace_mb: int) -> bool:
|
||||||
|
"""Build TensorRT engine from ONNX model."""
|
||||||
|
try:
|
||||||
|
import tensorrt as trt
|
||||||
|
except ImportError:
|
||||||
|
print('ERROR: tensorrt not found. Run on Jetson with JetPack installed.')
|
||||||
|
print(' Alternatively use trtexec (see script header).')
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger = trt.Logger(trt.Logger.VERBOSE)
|
||||||
|
builder = trt.Builder(logger)
|
||||||
|
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||||
|
network = builder.create_network(network_flags)
|
||||||
|
parser = trt.OnnxParser(network, logger)
|
||||||
|
|
||||||
|
print(f'[+] Parsing ONNX: {onnx_path}')
|
||||||
|
with open(onnx_path, 'rb') as f:
|
||||||
|
if not parser.parse(f.read()):
|
||||||
|
for i in range(parser.num_errors):
|
||||||
|
print(f' Parse error {i}: {parser.get_error(i)}')
|
||||||
|
return False
|
||||||
|
print(f' Network inputs: {network.num_inputs}')
|
||||||
|
print(f' Network outputs: {network.num_outputs}')
|
||||||
|
|
||||||
|
config = builder.create_builder_config()
|
||||||
|
config.set_memory_pool_limit(
|
||||||
|
trt.MemoryPoolType.WORKSPACE, workspace_mb * 1024 * 1024)
|
||||||
|
|
||||||
|
if fp16 and builder.platform_has_fast_fp16:
|
||||||
|
config.set_flag(trt.BuilderFlag.FP16)
|
||||||
|
print('[+] FP16 mode enabled')
|
||||||
|
elif fp16:
|
||||||
|
print('[!] FP16 not supported on this platform, using FP32')
|
||||||
|
|
||||||
|
# Dynamic batch profile (batch_size = 1 for real-time inference)
|
||||||
|
profile = builder.create_optimization_profile()
|
||||||
|
input_name = network.get_input(0).name
|
||||||
|
input_shape = network.get_input(0).shape # e.g., [-1, 3, 640, 640]
|
||||||
|
min_shape = (1, input_shape[1], input_shape[2], input_shape[3])
|
||||||
|
opt_shape = (batch_size, input_shape[1], input_shape[2], input_shape[3])
|
||||||
|
max_shape = (batch_size, input_shape[1], input_shape[2], input_shape[3])
|
||||||
|
profile.set_shape(input_name, min_shape, opt_shape, max_shape)
|
||||||
|
config.add_optimization_profile(profile)
|
||||||
|
|
||||||
|
print(f'[+] Building engine (this may take 5–15 minutes on first run)...')
|
||||||
|
serialized = builder.build_serialized_network(network, config)
|
||||||
|
if serialized is None:
|
||||||
|
print('ERROR: Engine build failed.')
|
||||||
|
return False
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(os.path.abspath(engine_path)), exist_ok=True)
|
||||||
|
with open(engine_path, 'wb') as f:
|
||||||
|
f.write(serialized)
|
||||||
|
|
||||||
|
size_mb = os.path.getsize(engine_path) / (1024 * 1024)
|
||||||
|
print(f'[+] Engine saved: {engine_path} ({size_mb:.1f} MB)')
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def verify_engine(engine_path: str) -> bool:
|
||||||
|
"""Quick sanity check: deserialise and print binding info."""
|
||||||
|
try:
|
||||||
|
import tensorrt as trt
|
||||||
|
import numpy as np
|
||||||
|
logger = trt.Logger(trt.Logger.WARNING)
|
||||||
|
with open(engine_path, 'rb') as f, trt.Runtime(logger) as rt:
|
||||||
|
engine = rt.deserialize_cuda_engine(f.read())
|
||||||
|
print(f'\n[+] Engine verified: {engine.num_io_tensors} tensors')
|
||||||
|
for i in range(engine.num_io_tensors):
|
||||||
|
name = engine.get_tensor_name(i)
|
||||||
|
shape = engine.get_tensor_shape(name)
|
||||||
|
dtype = engine.get_tensor_dtype(name)
|
||||||
|
mode = 'IN ' if engine.get_tensor_mode(name).name == 'INPUT' else 'OUT'
|
||||||
|
print(f' {mode} [{i}] {name}: {list(shape)} {dtype}')
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f'[!] Engine verification failed: {e}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Build TensorRT engine from ONNX model')
|
||||||
|
parser.add_argument('--onnx', required=True,
|
||||||
|
help='Path to input ONNX model')
|
||||||
|
parser.add_argument('--engine', required=True,
|
||||||
|
help='Path for output .engine file')
|
||||||
|
parser.add_argument('--fp16', action='store_true', default=True,
|
||||||
|
help='Enable FP16 precision (default: True)')
|
||||||
|
parser.add_argument('--no-fp16', dest='fp16', action='store_false',
|
||||||
|
help='Disable FP16, use FP32')
|
||||||
|
parser.add_argument('--batch', type=int, default=1,
|
||||||
|
help='Batch size (default: 1)')
|
||||||
|
parser.add_argument('--workspace', type=int, default=2048,
|
||||||
|
help='Builder workspace in MB (default: 2048)')
|
||||||
|
parser.add_argument('--verify', action='store_true',
|
||||||
|
help='Verify engine after build')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not os.path.isfile(args.onnx):
|
||||||
|
print(f'ERROR: ONNX model not found: {args.onnx}')
|
||||||
|
print()
|
||||||
|
print('Download YOLOv8n ONNX:')
|
||||||
|
print(' pip3 install ultralytics')
|
||||||
|
print(' python3 -c "from ultralytics import YOLO; '
|
||||||
|
'YOLO(\'yolov8n.pt\').export(format=\'onnx\', imgsz=640)"')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
ok = build_engine(args.onnx, args.engine, args.fp16, args.batch, args.workspace)
|
||||||
|
if not ok:
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if args.verify:
|
||||||
|
verify_engine(args.engine)
|
||||||
|
|
||||||
|
print('\nNext step: update person_detection_params.yaml:')
|
||||||
|
print(f' engine_path: "{os.path.abspath(args.engine)}"')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
4
jetson/ros2_ws/src/saltybot_perception/setup.cfg
Normal file
4
jetson/ros2_ws/src/saltybot_perception/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
[develop]
|
||||||
|
script_dir=$base/lib/saltybot_perception
|
||||||
|
[install]
|
||||||
|
install_scripts=$base/lib/saltybot_perception
|
||||||
32
jetson/ros2_ws/src/saltybot_perception/setup.py
Normal file
32
jetson/ros2_ws/src/saltybot_perception/setup.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from setuptools import setup
|
||||||
|
import os
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
package_name = 'saltybot_perception'
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name=package_name,
|
||||||
|
version='0.1.0',
|
||||||
|
packages=[package_name],
|
||||||
|
data_files=[
|
||||||
|
('share/ament_index/resource_index/packages',
|
||||||
|
['resource/' + package_name]),
|
||||||
|
('share/' + package_name, ['package.xml']),
|
||||||
|
(os.path.join('share', package_name, 'launch'),
|
||||||
|
glob('launch/*.py')),
|
||||||
|
(os.path.join('share', package_name, 'config'),
|
||||||
|
glob('config/*.yaml')),
|
||||||
|
],
|
||||||
|
install_requires=['setuptools'],
|
||||||
|
zip_safe=True,
|
||||||
|
maintainer='seb',
|
||||||
|
maintainer_email='seb@vayrette.com',
|
||||||
|
description='Person detection and tracking for saltybot (YOLOv8n + TensorRT)',
|
||||||
|
license='MIT',
|
||||||
|
tests_require=['pytest'],
|
||||||
|
entry_points={
|
||||||
|
'console_scripts': [
|
||||||
|
'person_detector = saltybot_perception.person_detector_node:main',
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
@ -0,0 +1,67 @@
|
|||||||
|
"""
|
||||||
|
test_postprocess.py — Tests for YOLOv8 post-processing and NMS.
|
||||||
|
|
||||||
|
Tests the _nms() helper and validate post-processing logic without
|
||||||
|
requiring a GPU, TRT, or running ROS2 node.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from saltybot_perception.detection_utils import nms as _nms
|
||||||
|
|
||||||
|
|
||||||
|
class TestNMS:
|
||||||
|
|
||||||
|
def test_single_box_kept(self):
|
||||||
|
boxes = np.array([[0, 0, 10, 10]], dtype=float)
|
||||||
|
scores = np.array([0.9])
|
||||||
|
assert _nms(boxes, scores) == [0]
|
||||||
|
|
||||||
|
def test_empty_input(self):
|
||||||
|
assert _nms(np.zeros((0, 4)), np.array([])) == []
|
||||||
|
|
||||||
|
def test_suppresses_overlapping_box(self):
|
||||||
|
# Two heavily overlapping boxes — keep highest score
|
||||||
|
boxes = np.array([
|
||||||
|
[0, 0, 10, 10], # score 0.9 — keep
|
||||||
|
[1, 1, 11, 11], # score 0.8 — suppress (high IoU with first)
|
||||||
|
], dtype=float)
|
||||||
|
scores = np.array([0.9, 0.8])
|
||||||
|
keep = _nms(boxes, scores, iou_threshold=0.45)
|
||||||
|
assert keep == [0]
|
||||||
|
|
||||||
|
def test_keeps_non_overlapping_boxes(self):
|
||||||
|
boxes = np.array([
|
||||||
|
[0, 0, 10, 10],
|
||||||
|
[50, 50, 60, 60],
|
||||||
|
[100, 100, 110, 110],
|
||||||
|
], dtype=float)
|
||||||
|
scores = np.array([0.9, 0.85, 0.8])
|
||||||
|
keep = _nms(boxes, scores, iou_threshold=0.45)
|
||||||
|
assert sorted(keep) == [0, 1, 2]
|
||||||
|
|
||||||
|
def test_score_ordering(self):
|
||||||
|
# Lower score box overlaps with higher — higher should be kept
|
||||||
|
boxes = np.array([
|
||||||
|
[1, 1, 11, 11], # score 0.6
|
||||||
|
[0, 0, 10, 10], # score 0.95 — should be kept
|
||||||
|
], dtype=float)
|
||||||
|
scores = np.array([0.6, 0.95])
|
||||||
|
keep = _nms(boxes, scores, iou_threshold=0.45)
|
||||||
|
assert 1 in keep # higher score (index 1) kept
|
||||||
|
assert 0 not in keep # lower score (index 0) suppressed
|
||||||
|
|
||||||
|
def test_iou_threshold_controls_suppression(self):
|
||||||
|
# Two boxes with ~0.5 IoU
|
||||||
|
boxes = np.array([
|
||||||
|
[0, 0, 10, 10],
|
||||||
|
[5, 5, 15, 15],
|
||||||
|
], dtype=float)
|
||||||
|
scores = np.array([0.9, 0.8])
|
||||||
|
# High threshold — both boxes kept (IoU ~0.14 < 0.5)
|
||||||
|
keep_high = _nms(boxes, scores, iou_threshold=0.5)
|
||||||
|
assert sorted(keep_high) == [0, 1]
|
||||||
|
# Low threshold — only first kept
|
||||||
|
keep_low = _nms(boxes, scores, iou_threshold=0.0)
|
||||||
|
assert keep_low == [0]
|
||||||
156
jetson/ros2_ws/src/saltybot_perception/test/test_tracker.py
Normal file
156
jetson/ros2_ws/src/saltybot_perception/test/test_tracker.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
"""
|
||||||
|
test_tracker.py — Unit tests for SimplePersonTracker.
|
||||||
|
|
||||||
|
Run with: pytest test/test_tracker.py -v
|
||||||
|
No ROS2 runtime required.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import pytest
|
||||||
|
from saltybot_perception.tracker import SimplePersonTracker, PersonTrack, _iou
|
||||||
|
|
||||||
|
|
||||||
|
# ── IoU helper ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestIoU:
|
||||||
|
def test_identical_boxes(self):
|
||||||
|
assert _iou((0, 0, 10, 10), (0, 0, 10, 10)) == pytest.approx(1.0)
|
||||||
|
|
||||||
|
def test_no_overlap(self):
|
||||||
|
assert _iou((0, 0, 5, 5), (10, 10, 15, 15)) == pytest.approx(0.0)
|
||||||
|
|
||||||
|
def test_partial_overlap(self):
|
||||||
|
iou = _iou((0, 0, 10, 10), (5, 5, 15, 15))
|
||||||
|
# Intersection 5×5=25, each area=100, union=175
|
||||||
|
assert iou == pytest.approx(25 / 175)
|
||||||
|
|
||||||
|
def test_contained_box(self):
|
||||||
|
# Inner box fully inside outer box
|
||||||
|
outer = (0, 0, 10, 10)
|
||||||
|
inner = (2, 2, 8, 8)
|
||||||
|
# intersection = 6×6=36, area_outer=100, area_inner=36, union=100
|
||||||
|
assert _iou(outer, inner) == pytest.approx(36 / 100)
|
||||||
|
|
||||||
|
def test_touching_edges(self):
|
||||||
|
# Boxes touch at a single edge — no area overlap
|
||||||
|
assert _iou((0, 0, 5, 5), (5, 0, 10, 5)) == pytest.approx(0.0)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tracker ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestSimplePersonTracker:
|
||||||
|
|
||||||
|
def _make_det(self, x1=10, y1=10, x2=60, y2=160, depth=2.0, conf=0.85):
|
||||||
|
return ((x1, y1, x2, y2), depth, conf)
|
||||||
|
|
||||||
|
# ── Basic update ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_no_detections_returns_none(self):
|
||||||
|
t = SimplePersonTracker()
|
||||||
|
assert t.update([]) is None
|
||||||
|
|
||||||
|
def test_single_detection_creates_track(self):
|
||||||
|
t = SimplePersonTracker()
|
||||||
|
track = t.update([self._make_det()])
|
||||||
|
assert track is not None
|
||||||
|
assert track.track_id == 1
|
||||||
|
assert track.depth == pytest.approx(2.0)
|
||||||
|
|
||||||
|
def test_track_id_increments(self):
|
||||||
|
t = SimplePersonTracker(hold_duration=0.0)
|
||||||
|
t.update([self._make_det()])
|
||||||
|
t.update([]) # lose track immediately
|
||||||
|
track2 = t.update([self._make_det(x1=100, y1=100, x2=150, y2=250)])
|
||||||
|
assert track2.track_id == 2
|
||||||
|
|
||||||
|
# ── Closest-first selection ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_picks_closest_person(self):
|
||||||
|
t = SimplePersonTracker()
|
||||||
|
dets = [
|
||||||
|
((10, 10, 60, 160), 4.0, 0.8), # far
|
||||||
|
((70, 10, 120, 160), 1.5, 0.9), # closest
|
||||||
|
((130, 10, 180, 160), 3.0, 0.75), # mid
|
||||||
|
]
|
||||||
|
track = t.update(dets)
|
||||||
|
assert track.depth == pytest.approx(1.5)
|
||||||
|
|
||||||
|
# ── Depth filtering ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_rejects_beyond_max_depth(self):
|
||||||
|
t = SimplePersonTracker(max_depth=5.0)
|
||||||
|
assert t.update([self._make_det(depth=6.0)]) is None
|
||||||
|
|
||||||
|
def test_rejects_below_min_depth(self):
|
||||||
|
t = SimplePersonTracker(min_depth=0.3)
|
||||||
|
assert t.update([self._make_det(depth=0.1)]) is None
|
||||||
|
|
||||||
|
def test_accepts_within_depth_range(self):
|
||||||
|
t = SimplePersonTracker(min_depth=0.3, max_depth=5.0)
|
||||||
|
track = t.update([self._make_det(depth=2.5)])
|
||||||
|
assert track is not None
|
||||||
|
|
||||||
|
# ── Re-association ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_iou_reassociation_keeps_same_id(self):
|
||||||
|
t = SimplePersonTracker(iou_threshold=0.2)
|
||||||
|
# Frame 1
|
||||||
|
track1 = t.update([self._make_det(x1=10, y1=10, x2=60, y2=160)])
|
||||||
|
id1 = track1.track_id
|
||||||
|
# Frame 2 — slightly shifted (good IoU)
|
||||||
|
track2 = t.update([self._make_det(x1=12, y1=12, x2=62, y2=162)])
|
||||||
|
assert track2.track_id == id1
|
||||||
|
|
||||||
|
def test_poor_iou_loses_track_after_hold(self):
|
||||||
|
t = SimplePersonTracker(hold_duration=0.0, iou_threshold=0.5)
|
||||||
|
# Frame 1 — track person at left
|
||||||
|
t.update([self._make_det(x1=0, y1=0, x2=50, y2=150)])
|
||||||
|
# Frame 2 — completely different position, bad IoU
|
||||||
|
# hold_duration=0, so old track expires, new track started
|
||||||
|
track = t.update([self._make_det(x1=500, y1=0, x2=550, y2=150)])
|
||||||
|
assert track.track_id == 2 # new track
|
||||||
|
|
||||||
|
# ── Hold duration ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_holds_last_known_within_duration(self):
|
||||||
|
t = SimplePersonTracker(hold_duration=10.0)
|
||||||
|
track = t.update([self._make_det()])
|
||||||
|
track_id = track.track_id
|
||||||
|
# No detections — should hold
|
||||||
|
held = t.update([])
|
||||||
|
assert held is not None
|
||||||
|
assert held.track_id == track_id
|
||||||
|
assert held.age >= 0
|
||||||
|
|
||||||
|
def test_releases_track_after_hold_duration(self):
|
||||||
|
t = SimplePersonTracker(hold_duration=0.0)
|
||||||
|
t.update([self._make_det()])
|
||||||
|
# Immediately lose
|
||||||
|
result = t.update([])
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
# ── Reset ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_reset_clears_track(self):
|
||||||
|
t = SimplePersonTracker()
|
||||||
|
t.update([self._make_det()])
|
||||||
|
t.reset()
|
||||||
|
assert t.active is False
|
||||||
|
assert t.update([]) is None
|
||||||
|
|
||||||
|
# ── PersonTrack properties ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_track_center(self):
|
||||||
|
track = PersonTrack((10, 20, 50, 100), 2.0, 0.9, 1)
|
||||||
|
cx, cy = track.center
|
||||||
|
assert cx == pytest.approx(30.0)
|
||||||
|
assert cy == pytest.approx(60.0)
|
||||||
|
|
||||||
|
def test_track_area(self):
|
||||||
|
track = PersonTrack((10, 20, 50, 100), 2.0, 0.9, 1)
|
||||||
|
assert track.area == pytest.approx(40 * 80)
|
||||||
|
|
||||||
|
def test_track_age_increases(self):
|
||||||
|
track = PersonTrack((0, 0, 50, 100), 2.0, 0.9, 1)
|
||||||
|
time.sleep(0.05)
|
||||||
|
assert track.age >= 0.04
|
||||||
Loading…
x
Reference in New Issue
Block a user