Compare commits
3 Commits
94902f918b
...
c7dd07f9ed
| Author | SHA1 | Date | |
|---|---|---|---|
| c7dd07f9ed | |||
| 01ee02f837 | |||
| f0e11fe7ca |
@ -0,0 +1,141 @@
|
||||
"""
|
||||
_depth_hole_fill.py — Depth image hole filling via bilateral interpolation (no ROS2 deps).
|
||||
|
||||
Algorithm
|
||||
---------
|
||||
A "hole" is any pixel where depth == 0, depth is NaN, or depth is outside the
|
||||
valid range [d_min, d_max].
|
||||
|
||||
Each pass replaces every hole pixel with the spatial-Gaussian-weighted mean of
|
||||
valid pixels in a (kernel_size × kernel_size) neighbourhood:
|
||||
|
||||
filled[x,y] = Σ G(||p - q||; σ) · d[q] / Σ G(||p - q||; σ)
|
||||
q ∈ valid neighbours of (x,y)
|
||||
|
||||
The denominator (sum of spatial weights over valid pixels) normalises correctly
|
||||
even at image borders and around isolated valid pixels.
|
||||
|
||||
Multiple passes with geometrically growing kernels are applied so that:
|
||||
Pass 1 kernel_size — fills small holes (≤ kernel_size/2 px radius)
|
||||
Pass 2 kernel_size × 2.5 — fills medium holes
|
||||
Pass 3 kernel_size × 6.0 — fills large holes / fronto-parallel surfaces
|
||||
|
||||
After all passes any remaining zeros are left as-is (no valid neighbourhood data).
|
||||
|
||||
Because only the spatial Gaussian (not a depth range term) is used as the weighting
|
||||
function, this is equivalent to a bilateral filter with σ_range → ∞. In practice
|
||||
this produces smooth, physically plausible fills in the depth domain.
|
||||
|
||||
Public API
|
||||
----------
|
||||
fill_holes(depth, kernel_size=5, d_min=0.1, d_max=10.0, max_passes=3) → ndarray
|
||||
valid_mask(depth, d_min=0.1, d_max=10.0) → bool ndarray
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Kernel size multipliers for successive passes
|
||||
_PASS_SCALE = [1.0, 2.5, 6.0]
|
||||
|
||||
|
||||
def valid_mask(
|
||||
depth: np.ndarray,
|
||||
d_min: float = 0.1,
|
||||
d_max: float = 10.0,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Return a boolean mask of valid (non-hole) pixels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
depth : (H, W) float32 ndarray, depth in metres
|
||||
d_min : minimum valid depth (m)
|
||||
d_max : maximum valid depth (m)
|
||||
|
||||
Returns
|
||||
-------
|
||||
(H, W) bool ndarray — True where depth is finite and in [d_min, d_max]
|
||||
"""
|
||||
return np.isfinite(depth) & (depth >= d_min) & (depth <= d_max)
|
||||
|
||||
|
||||
def fill_holes(
|
||||
depth: np.ndarray,
|
||||
kernel_size: int = 5,
|
||||
d_min: float = 0.1,
|
||||
d_max: float = 10.0,
|
||||
max_passes: int = 3,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Fill zero/NaN depth pixels using multi-pass spatial Gaussian interpolation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
depth : (H, W) float32 ndarray, depth in metres
|
||||
kernel_size : initial kernel side length (pixels, forced odd, ≥ 3)
|
||||
d_min : minimum valid depth — pixels below this are treated as holes
|
||||
d_max : maximum valid depth — pixels above this are treated as holes
|
||||
max_passes : number of fill passes (1–3); each uses a larger kernel
|
||||
|
||||
Returns
|
||||
-------
|
||||
(H, W) float32 ndarray — same as input, with holes filled where possible.
|
||||
Pixels with no valid neighbours after all passes remain 0.0.
|
||||
Original valid pixels are never modified.
|
||||
"""
|
||||
import cv2
|
||||
|
||||
depth = np.asarray(depth, dtype=np.float32)
|
||||
# Replace NaN with 0 so arithmetic is clean
|
||||
depth = np.where(np.isfinite(depth), depth, 0.0).astype(np.float32)
|
||||
|
||||
mask = valid_mask(depth, d_min, d_max) # True where already valid
|
||||
result = depth.copy()
|
||||
n_passes = max(1, min(max_passes, len(_PASS_SCALE)))
|
||||
|
||||
for i in range(n_passes):
|
||||
if mask.all():
|
||||
break # no holes left
|
||||
|
||||
ks = _odd_kernel_size(kernel_size, _PASS_SCALE[i])
|
||||
half = ks // 2
|
||||
sigma = max(half / 2.0, 0.5)
|
||||
|
||||
gk = cv2.getGaussianKernel(ks, sigma).astype(np.float32)
|
||||
kernel = (gk @ gk.T)
|
||||
|
||||
# Multiply depth by mask so invalid pixels contribute 0 weight
|
||||
d_valid = np.where(mask, result, 0.0).astype(np.float32)
|
||||
w_valid = mask.astype(np.float32)
|
||||
|
||||
sum_d = cv2.filter2D(d_valid, ddepth=-1, kernel=kernel,
|
||||
borderType=cv2.BORDER_REFLECT)
|
||||
sum_w = cv2.filter2D(w_valid, ddepth=-1, kernel=kernel,
|
||||
borderType=cv2.BORDER_REFLECT)
|
||||
|
||||
# Where we have enough weight, compute the weighted mean
|
||||
has_data = sum_w > 1e-6
|
||||
interp = np.where(has_data, sum_d / np.where(has_data, sum_w, 1.0), 0.0)
|
||||
|
||||
# Only fill holes — never overwrite original valid pixels
|
||||
result = np.where(mask, result, interp.astype(np.float32))
|
||||
|
||||
# Update mask with newly filled pixels (for the next pass)
|
||||
newly_filled = (~mask) & (result > 0.0)
|
||||
mask = mask | newly_filled
|
||||
|
||||
return result.astype(np.float32)
|
||||
|
||||
|
||||
# ── Internal helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
def _odd_kernel_size(base: int, scale: float) -> int:
|
||||
"""Return the nearest odd integer to base * scale, minimum 3."""
|
||||
raw = max(3, int(round(base * scale)))
|
||||
return raw if raw % 2 == 1 else raw + 1
|
||||
@ -0,0 +1,128 @@
|
||||
"""
|
||||
depth_hole_fill_node.py — D435i depth image hole filler (Issue #268).
|
||||
|
||||
Subscribes to the raw D435i depth stream, fills zero/NaN pixels using
|
||||
multi-pass spatial-Gaussian bilateral interpolation, and republishes the
|
||||
filled image at camera rate.
|
||||
|
||||
Subscribes (BEST_EFFORT):
|
||||
/camera/depth/image_rect_raw sensor_msgs/Image float32 depth (m)
|
||||
|
||||
Publishes:
|
||||
/camera/depth/filled sensor_msgs/Image float32 depth (m), holes filled
|
||||
|
||||
The filled image preserves all original valid pixels exactly and only
|
||||
modifies pixels that had no return (0 or NaN). The output is suitable
|
||||
for all downstream consumers that expect a dense depth map (VO, RTAB-Map,
|
||||
collision avoidance, floor classifier).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_topic str /camera/depth/image_rect_raw Input depth topic
|
||||
output_topic str /camera/depth/filled Output depth topic
|
||||
kernel_size int 5 Initial Gaussian kernel side length (pixels)
|
||||
d_min float 0.1 Minimum valid depth (m)
|
||||
d_max float 10.0 Maximum valid depth (m)
|
||||
max_passes int 3 Fill passes (growing kernel per pass)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||
|
||||
import numpy as np
|
||||
from cv_bridge import CvBridge
|
||||
|
||||
from sensor_msgs.msg import Image
|
||||
|
||||
from ._depth_hole_fill import fill_holes
|
||||
|
||||
|
||||
_SENSOR_QOS = QoSProfile(
|
||||
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||
history=HistoryPolicy.KEEP_LAST,
|
||||
depth=4,
|
||||
)
|
||||
|
||||
|
||||
class DepthHoleFillNode(Node):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__('depth_hole_fill_node')
|
||||
|
||||
self.declare_parameter('input_topic', '/camera/depth/image_rect_raw')
|
||||
self.declare_parameter('output_topic', '/camera/depth/filled')
|
||||
self.declare_parameter('kernel_size', 5)
|
||||
self.declare_parameter('d_min', 0.1)
|
||||
self.declare_parameter('d_max', 10.0)
|
||||
self.declare_parameter('max_passes', 3)
|
||||
|
||||
input_topic = self.get_parameter('input_topic').value
|
||||
output_topic = self.get_parameter('output_topic').value
|
||||
self._ks = int(self.get_parameter('kernel_size').value)
|
||||
self._d_min = self.get_parameter('d_min').value
|
||||
self._d_max = self.get_parameter('d_max').value
|
||||
self._passes = int(self.get_parameter('max_passes').value)
|
||||
|
||||
self._bridge = CvBridge()
|
||||
|
||||
self._sub = self.create_subscription(
|
||||
Image, input_topic, self._on_depth, _SENSOR_QOS)
|
||||
self._pub = self.create_publisher(Image, output_topic, 10)
|
||||
|
||||
self.get_logger().info(
|
||||
f'depth_hole_fill_node ready — '
|
||||
f'{input_topic} → {output_topic} '
|
||||
f'kernel={self._ks} passes={self._passes} '
|
||||
f'd=[{self._d_min},{self._d_max}]m'
|
||||
)
|
||||
|
||||
# ── Callback ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _on_depth(self, msg: Image) -> None:
|
||||
try:
|
||||
depth = self._bridge.imgmsg_to_cv2(msg, desired_encoding='passthrough')
|
||||
except Exception as exc:
|
||||
self.get_logger().error(
|
||||
f'cv_bridge: {exc}', throttle_duration_sec=5.0)
|
||||
return
|
||||
|
||||
depth = depth.astype(np.float32)
|
||||
|
||||
# Handle uint16 mm → float32 m conversion (D435i raw stream)
|
||||
if depth.max() > 100.0:
|
||||
depth /= 1000.0
|
||||
|
||||
filled = fill_holes(
|
||||
depth,
|
||||
kernel_size=self._ks,
|
||||
d_min=self._d_min,
|
||||
d_max=self._d_max,
|
||||
max_passes=self._passes,
|
||||
)
|
||||
|
||||
try:
|
||||
out_msg = self._bridge.cv2_to_imgmsg(filled, encoding='32FC1')
|
||||
except Exception as exc:
|
||||
self.get_logger().error(
|
||||
f'cv2_to_imgmsg: {exc}', throttle_duration_sec=5.0)
|
||||
return
|
||||
|
||||
out_msg.header = msg.header
|
||||
self._pub.publish(out_msg)
|
||||
|
||||
|
||||
def main(args=None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = DepthHoleFillNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -37,6 +37,8 @@ setup(
|
||||
'floor_classifier = saltybot_bringup.floor_classifier_node:main',
|
||||
# Visual odometry drift detector (Issue #260)
|
||||
'vo_drift_detector = saltybot_bringup.vo_drift_node:main',
|
||||
# Depth image hole filler (Issue #268)
|
||||
'depth_hole_fill = saltybot_bringup.depth_hole_fill_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
281
jetson/ros2_ws/src/saltybot_bringup/test/test_depth_hole_fill.py
Normal file
281
jetson/ros2_ws/src/saltybot_bringup/test/test_depth_hole_fill.py
Normal file
@ -0,0 +1,281 @@
|
||||
"""
|
||||
test_depth_hole_fill.py — Unit tests for depth hole fill helpers (no ROS2 required).
|
||||
|
||||
Covers:
|
||||
valid_mask:
|
||||
- valid range returns True
|
||||
- zero / below d_min returns False
|
||||
- NaN returns False
|
||||
- above d_max returns False
|
||||
- mixed array has correct mask
|
||||
|
||||
_odd_kernel_size:
|
||||
- result is always odd
|
||||
- result >= 3
|
||||
- scales correctly
|
||||
|
||||
fill_holes — no-hole cases:
|
||||
- fully valid image is returned unchanged
|
||||
- output dtype is float32
|
||||
- output shape matches input
|
||||
|
||||
fill_holes — basic fills:
|
||||
- single centre hole in uniform depth → filled with correct depth
|
||||
- single centre hole in uniform depth → original valid pixels unchanged
|
||||
- NaN pixel treated as hole and filled
|
||||
- row of zeros within uniform depth → filled
|
||||
|
||||
fill_holes — fill quality:
|
||||
- linear gradient: centre hole filled with interpolated value
|
||||
- multi-pass fills larger holes than single pass
|
||||
- all-zero image stays zero (no valid neighbours)
|
||||
- border hole (edge pixel) is handled without crash
|
||||
- depth range: pixel above d_max treated as hole
|
||||
|
||||
fill_holes — valid pixel preservation:
|
||||
- original valid pixels are never modified
|
||||
- max_passes=1 still fills small holes
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from saltybot_bringup._depth_hole_fill import (
|
||||
fill_holes,
|
||||
valid_mask,
|
||||
_odd_kernel_size,
|
||||
)
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _uniform(val=2.0, h=64, w=64) -> np.ndarray:
|
||||
return np.full((h, w), val, dtype=np.float32)
|
||||
|
||||
|
||||
def _poke_hole(arr, r, c) -> np.ndarray:
|
||||
arr = arr.copy()
|
||||
arr[r, c] = 0.0
|
||||
return arr
|
||||
|
||||
|
||||
def _poke_nan(arr, r, c) -> np.ndarray:
|
||||
arr = arr.copy()
|
||||
arr[r, c] = float('nan')
|
||||
return arr
|
||||
|
||||
|
||||
# ── valid_mask ────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestValidMask:
|
||||
|
||||
def test_valid_pixel_is_true(self):
|
||||
d = np.array([[1.0]], dtype=np.float32)
|
||||
assert valid_mask(d, 0.1, 10.0)[0, 0]
|
||||
|
||||
def test_zero_is_false(self):
|
||||
d = np.array([[0.0]], dtype=np.float32)
|
||||
assert not valid_mask(d, 0.1, 10.0)[0, 0]
|
||||
|
||||
def test_below_dmin_is_false(self):
|
||||
d = np.array([[0.05]], dtype=np.float32)
|
||||
assert not valid_mask(d, 0.1, 10.0)[0, 0]
|
||||
|
||||
def test_nan_is_false(self):
|
||||
d = np.array([[float('nan')]], dtype=np.float32)
|
||||
assert not valid_mask(d, 0.1, 10.0)[0, 0]
|
||||
|
||||
def test_above_dmax_is_false(self):
|
||||
d = np.array([[15.0]], dtype=np.float32)
|
||||
assert not valid_mask(d, 0.1, 10.0)[0, 0]
|
||||
|
||||
def test_at_dmin_is_true(self):
|
||||
d = np.array([[0.1]], dtype=np.float32)
|
||||
assert valid_mask(d, 0.1, 10.0)[0, 0]
|
||||
|
||||
def test_at_dmax_is_true(self):
|
||||
d = np.array([[10.0]], dtype=np.float32)
|
||||
assert valid_mask(d, 0.1, 10.0)[0, 0]
|
||||
|
||||
def test_mixed_array(self):
|
||||
d = np.array([[0.0, 1.0, float('nan'), 5.0, 11.0]], dtype=np.float32)
|
||||
m = valid_mask(d, 0.1, 10.0)
|
||||
np.testing.assert_array_equal(m, [[False, True, False, True, False]])
|
||||
|
||||
|
||||
# ── _odd_kernel_size ──────────────────────────────────────────────────────────
|
||||
|
||||
class TestOddKernelSize:
|
||||
|
||||
@pytest.mark.parametrize('base,scale', [
|
||||
(5, 1.0), (5, 2.5), (5, 6.0),
|
||||
(3, 1.0), (7, 2.0), (9, 3.0),
|
||||
(4, 1.0), # even base → must become odd
|
||||
])
|
||||
def test_result_is_odd(self, base, scale):
|
||||
ks = _odd_kernel_size(base, scale)
|
||||
assert ks % 2 == 1
|
||||
|
||||
@pytest.mark.parametrize('base,scale', [(3, 1.0), (1, 5.0), (2, 0.5)])
|
||||
def test_result_at_least_3(self, base, scale):
|
||||
assert _odd_kernel_size(base, scale) >= 3
|
||||
|
||||
def test_scale_1_returns_base_or_nearby_odd(self):
|
||||
ks = _odd_kernel_size(5, 1.0)
|
||||
assert ks == 5
|
||||
|
||||
def test_large_scale_gives_large_kernel(self):
|
||||
ks = _odd_kernel_size(5, 6.0)
|
||||
assert ks >= 25 # 5 * 6 = 30 → 31
|
||||
|
||||
|
||||
# ── fill_holes — output contract ──────────────────────────────────────────────
|
||||
|
||||
class TestFillHolesOutputContract:
|
||||
|
||||
def test_output_dtype_float32(self):
|
||||
out = fill_holes(_uniform(2.0))
|
||||
assert out.dtype == np.float32
|
||||
|
||||
def test_output_shape_preserved(self):
|
||||
img = _uniform(2.0, h=48, w=64)
|
||||
out = fill_holes(img)
|
||||
assert out.shape == img.shape
|
||||
|
||||
def test_fully_valid_image_unchanged(self):
|
||||
img = _uniform(2.0)
|
||||
out = fill_holes(img)
|
||||
np.testing.assert_allclose(out, img, atol=1e-6)
|
||||
|
||||
def test_valid_pixels_never_modified(self):
|
||||
"""Any pixel valid in the input must be identical in the output."""
|
||||
img = _uniform(3.0, h=32, w=32)
|
||||
img[16, 16] = 0.0 # one hole
|
||||
mask_before = valid_mask(img)
|
||||
out = fill_holes(img)
|
||||
np.testing.assert_allclose(out[mask_before], img[mask_before], atol=1e-6)
|
||||
|
||||
|
||||
# ── fill_holes — basic hole filling ──────────────────────────────────────────
|
||||
|
||||
class TestFillHolesBasic:
|
||||
|
||||
def test_centre_zero_filled_uniform(self):
|
||||
"""Single zero pixel in uniform depth → filled with that depth."""
|
||||
img = _poke_hole(_uniform(2.0, 32, 32), 16, 16)
|
||||
out = fill_holes(img, kernel_size=5, max_passes=1)
|
||||
assert out[16, 16] == pytest.approx(2.0, abs=0.05)
|
||||
|
||||
def test_centre_nan_filled_uniform(self):
|
||||
"""Single NaN pixel in uniform depth → filled."""
|
||||
img = _poke_nan(_uniform(2.0, 32, 32), 16, 16)
|
||||
out = fill_holes(img, kernel_size=5, max_passes=1)
|
||||
assert out[16, 16] == pytest.approx(2.0, abs=0.05)
|
||||
|
||||
def test_filled_value_is_positive(self):
|
||||
img = _poke_hole(_uniform(1.5, 32, 32), 16, 16)
|
||||
out = fill_holes(img)
|
||||
assert out[16, 16] > 0.0
|
||||
|
||||
def test_row_of_holes_filled(self):
|
||||
"""Entire middle row zeroed → should be filled from neighbours above/below."""
|
||||
img = _uniform(3.0, 32, 32)
|
||||
img[16, :] = 0.0
|
||||
out = fill_holes(img, kernel_size=7, max_passes=1)
|
||||
# All pixels in the row should be non-zero after filling
|
||||
assert (out[16, :] > 0.0).all()
|
||||
|
||||
def test_all_zero_stays_zero(self):
|
||||
"""Image with no valid pixels → stays zero (nothing to interpolate from)."""
|
||||
img = np.zeros((32, 32), dtype=np.float32)
|
||||
out = fill_holes(img, d_min=0.1)
|
||||
assert (out == 0.0).all()
|
||||
|
||||
def test_border_hole_no_crash(self):
|
||||
"""Holes at image corners must not raise exceptions."""
|
||||
img = _uniform(2.0, 32, 32)
|
||||
img[0, 0] = 0.0
|
||||
img[0, -1] = 0.0
|
||||
img[-1, 0] = 0.0
|
||||
img[-1, -1] = 0.0
|
||||
out = fill_holes(img) # must not raise
|
||||
assert out.shape == img.shape
|
||||
|
||||
def test_border_holes_filled(self):
|
||||
"""Corner holes should be filled from their neighbours."""
|
||||
img = _uniform(2.0, 32, 32)
|
||||
img[0, 0] = 0.0
|
||||
out = fill_holes(img, kernel_size=5, max_passes=1)
|
||||
assert out[0, 0] == pytest.approx(2.0, abs=0.1)
|
||||
|
||||
|
||||
# ── fill_holes — fill quality ─────────────────────────────────────────────────
|
||||
|
||||
class TestFillHolesQuality:
|
||||
|
||||
def test_linear_gradient_centre_hole_interpolated(self):
|
||||
"""
|
||||
Depth linearly increasing from 1.0 (left) to 3.0 (right).
|
||||
Centre hole should be filled near the midpoint (~2.0).
|
||||
"""
|
||||
h, w = 32, 32
|
||||
img = np.tile(np.linspace(1.0, 3.0, w, dtype=np.float32), (h, 1))
|
||||
cx = w // 2
|
||||
img[:, cx] = 0.0
|
||||
out = fill_holes(img, kernel_size=5, max_passes=1)
|
||||
mid = out[h // 2, cx]
|
||||
assert 1.5 <= mid <= 2.5, f'interpolated value {mid:.3f} not in [1.5, 2.5]'
|
||||
|
||||
def test_large_hole_filled_with_more_passes(self):
|
||||
"""A 9×9 hole in uniform depth: single pass may not fully fill it,
|
||||
but 3 passes should."""
|
||||
img = _uniform(2.0, 64, 64)
|
||||
# Create a 9×9 hole
|
||||
img[28:37, 28:37] = 0.0
|
||||
out1 = fill_holes(img, kernel_size=5, max_passes=1)
|
||||
out3 = fill_holes(img, kernel_size=5, max_passes=3)
|
||||
# More passes → fewer remaining holes
|
||||
holes1 = (out1 == 0.0).sum()
|
||||
holes3 = (out3 == 0.0).sum()
|
||||
assert holes3 <= holes1, f'more passes should reduce holes: {holes3} vs {holes1}'
|
||||
|
||||
def test_3pass_fills_9x9_hole_completely(self):
|
||||
img = _uniform(2.0, 64, 64)
|
||||
img[28:37, 28:37] = 0.0
|
||||
out = fill_holes(img, kernel_size=5, max_passes=3)
|
||||
assert (out[28:37, 28:37] > 0.0).all()
|
||||
|
||||
def test_filled_depth_within_valid_range(self):
|
||||
"""Filled pixels should have depth within [d_min, d_max]."""
|
||||
img = _uniform(2.0, 32, 32)
|
||||
img[10:15, 10:15] = 0.0
|
||||
out = fill_holes(img, d_min=0.1, d_max=10.0, max_passes=3)
|
||||
# Only check pixels that were actually filled
|
||||
was_hole = (img == 0.0)
|
||||
filled = out[was_hole]
|
||||
positive = filled[filled > 0.0]
|
||||
assert (positive >= 0.1).all()
|
||||
assert (positive <= 10.0).all()
|
||||
|
||||
def test_above_dmax_treated_as_hole(self):
|
||||
"""Pixels above d_max should be treated as holes and filled."""
|
||||
img = _uniform(2.0, 32, 32)
|
||||
img[16, 16] = 15.0 # out of range
|
||||
out = fill_holes(img, d_max=10.0, max_passes=1)
|
||||
assert out[16, 16] == pytest.approx(2.0, abs=0.1)
|
||||
|
||||
def test_max_passes_1_works(self):
|
||||
img = _poke_hole(_uniform(2.0, 32, 32), 16, 16)
|
||||
out = fill_holes(img, max_passes=1)
|
||||
assert out.shape == img.shape
|
||||
assert out[16, 16] > 0.0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
@ -0,0 +1,8 @@
|
||||
greeting_trigger_node:
|
||||
ros__parameters:
|
||||
proximity_m: 2.0 # Trigger when person is within this distance (m)
|
||||
cooldown_s: 300.0 # Re-greeting suppression window per face_id (s)
|
||||
unknown_distance: 0.0 # Distance assumed when PersonState not yet available
|
||||
# 0.0 → always greet faces with no state yet
|
||||
faces_topic: "/social/faces/detected"
|
||||
states_topic: "/social/person_states"
|
||||
@ -0,0 +1,39 @@
|
||||
"""greeting_trigger.launch.py -- Launch proximity-based greeting trigger (Issue #270).
|
||||
|
||||
Usage:
|
||||
ros2 launch saltybot_social greeting_trigger.launch.py
|
||||
ros2 launch saltybot_social greeting_trigger.launch.py proximity_m:=1.5 cooldown_s:=120.0
|
||||
"""
|
||||
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg = get_package_share_directory("saltybot_social")
|
||||
cfg = os.path.join(pkg, "config", "greeting_trigger_params.yaml")
|
||||
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument("proximity_m", default_value="2.0",
|
||||
description="Greeting proximity threshold (m)"),
|
||||
DeclareLaunchArgument("cooldown_s", default_value="300.0",
|
||||
description="Per-face_id re-greeting cooldown (s)"),
|
||||
|
||||
Node(
|
||||
package="saltybot_social",
|
||||
executable="greeting_trigger_node",
|
||||
name="greeting_trigger_node",
|
||||
output="screen",
|
||||
parameters=[
|
||||
cfg,
|
||||
{
|
||||
"proximity_m": LaunchConfiguration("proximity_m"),
|
||||
"cooldown_s": LaunchConfiguration("cooldown_s"),
|
||||
},
|
||||
],
|
||||
),
|
||||
])
|
||||
@ -0,0 +1,150 @@
|
||||
"""greeting_trigger_node.py -- Proximity-based greeting trigger.
|
||||
Issue #270
|
||||
|
||||
Monitors face detections and person states. When a new face_id is seen
|
||||
within ``proximity_m`` metres (default 2 m) and has not been greeted within
|
||||
``cooldown_s`` seconds, publishes a JSON greeting trigger on
|
||||
/saltybot/greeting_trigger.
|
||||
|
||||
Distance is looked up from the /social/person_states topic which carries a
|
||||
face_id → distance mapping. When no state is available for a face the node
|
||||
applies a configurable default distance so it can still fire on face-only
|
||||
pipelines.
|
||||
|
||||
Subscriptions:
|
||||
/social/faces/detected saltybot_social_msgs/FaceDetectionArray
|
||||
/social/person_states saltybot_social_msgs/PersonStateArray
|
||||
|
||||
Publication:
|
||||
/saltybot/greeting_trigger std_msgs/String (JSON)
|
||||
{"face_id": <int>, "person_name": <str>, "distance_m": <float>,
|
||||
"ts": <float unix epoch>}
|
||||
|
||||
Parameters:
|
||||
proximity_m (float, 2.0) -- trigger when distance <= this
|
||||
cooldown_s (float, 300.0) -- suppress re-greeting same face_id
|
||||
unknown_distance (float, 0.0) -- distance assumed when PersonState
|
||||
is not yet available (0.0 → always
|
||||
trigger for unknown faces)
|
||||
faces_topic (str, "/social/faces/detected")
|
||||
states_topic (str, "/social/person_states")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile
|
||||
from std_msgs.msg import String
|
||||
|
||||
# Custom messages — imported at runtime so offline tests can stub them
|
||||
try:
|
||||
from saltybot_social_msgs.msg import FaceDetectionArray, PersonStateArray
|
||||
_MSGS = True
|
||||
except ImportError:
|
||||
_MSGS = False
|
||||
|
||||
|
||||
class GreetingTriggerNode(Node):
|
||||
"""Publishes greeting trigger when a person enters proximity."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("greeting_trigger_node")
|
||||
|
||||
self.declare_parameter("proximity_m", 2.0)
|
||||
self.declare_parameter("cooldown_s", 300.0)
|
||||
self.declare_parameter("unknown_distance", 0.0)
|
||||
self.declare_parameter("faces_topic", "/social/faces/detected")
|
||||
self.declare_parameter("states_topic", "/social/person_states")
|
||||
|
||||
self._proximity = self.get_parameter("proximity_m").value
|
||||
self._cooldown = self.get_parameter("cooldown_s").value
|
||||
self._unknown_dist = self.get_parameter("unknown_distance").value
|
||||
faces_topic = self.get_parameter("faces_topic").value
|
||||
states_topic = self.get_parameter("states_topic").value
|
||||
|
||||
# face_id → last known distance (m); updated from PersonStateArray
|
||||
self._distance_cache: Dict[int, float] = {}
|
||||
# face_id → unix timestamp of last greeting
|
||||
self._last_greeted: Dict[int, float] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
qos = QoSProfile(depth=10)
|
||||
self._pub = self.create_publisher(String, "/saltybot/greeting_trigger", qos)
|
||||
|
||||
if _MSGS:
|
||||
self._states_sub = self.create_subscription(
|
||||
PersonStateArray, states_topic, self._on_person_states, qos
|
||||
)
|
||||
self._faces_sub = self.create_subscription(
|
||||
FaceDetectionArray, faces_topic, self._on_faces, qos
|
||||
)
|
||||
else:
|
||||
self.get_logger().warn(
|
||||
"saltybot_social_msgs not available — node is passive (no subscriptions)"
|
||||
)
|
||||
|
||||
self.get_logger().info(
|
||||
f"GreetingTriggerNode ready "
|
||||
f"(proximity={self._proximity}m, cooldown={self._cooldown}s)"
|
||||
)
|
||||
|
||||
# ── Callbacks ──────────────────────────────────────────────────────────
|
||||
|
||||
def _on_person_states(self, msg: "PersonStateArray") -> None:
|
||||
"""Cache face_id → distance from incoming PersonState array."""
|
||||
with self._lock:
|
||||
for ps in msg.persons:
|
||||
if ps.face_id >= 0:
|
||||
self._distance_cache[ps.face_id] = float(ps.distance)
|
||||
|
||||
def _on_faces(self, msg: "FaceDetectionArray") -> None:
|
||||
"""Evaluate each detected face; fire greeting if conditions met."""
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
for face in msg.faces:
|
||||
fid = int(face.face_id)
|
||||
dist = self._distance_cache.get(fid, self._unknown_dist)
|
||||
|
||||
if dist > self._proximity:
|
||||
continue # too far
|
||||
|
||||
last = self._last_greeted.get(fid, 0.0)
|
||||
if now - last < self._cooldown:
|
||||
continue # still in cooldown
|
||||
|
||||
# Fire!
|
||||
self._last_greeted[fid] = now
|
||||
self._fire(fid, str(face.person_name), dist)
|
||||
|
||||
def _fire(self, face_id: int, person_name: str, distance_m: float) -> None:
|
||||
payload = {
|
||||
"face_id": face_id,
|
||||
"person_name": person_name,
|
||||
"distance_m": round(distance_m, 3),
|
||||
"ts": time.time(),
|
||||
}
|
||||
msg = String()
|
||||
msg.data = json.dumps(payload)
|
||||
self._pub.publish(msg)
|
||||
self.get_logger().info(
|
||||
f"Greeting trigger: face_id={face_id} name={person_name!r} "
|
||||
f"dist={distance_m:.2f}m"
|
||||
)
|
||||
|
||||
|
||||
def main(args=None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = GreetingTriggerNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
@ -47,6 +47,8 @@ setup(
|
||||
'vad_node = saltybot_social.vad_node:main',
|
||||
# Ambient sound classifier — mel-spectrogram (Issue #252)
|
||||
'ambient_sound_node = saltybot_social.ambient_sound_node:main',
|
||||
# Proximity-based greeting trigger (Issue #270)
|
||||
'greeting_trigger_node = saltybot_social.greeting_trigger_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
472
jetson/ros2_ws/src/saltybot_social/test/test_greeting_trigger.py
Normal file
472
jetson/ros2_ws/src/saltybot_social/test/test_greeting_trigger.py
Normal file
@ -0,0 +1,472 @@
|
||||
"""test_greeting_trigger.py -- Offline tests for greeting_trigger_node (Issue #270).
|
||||
|
||||
Stubs out rclpy and saltybot_social_msgs so tests run without a ROS install.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import types
|
||||
import unittest
|
||||
|
||||
|
||||
# ── ROS2 / message stubs ──────────────────────────────────────────────────────
|
||||
|
||||
def _make_ros_stubs():
|
||||
"""Install minimal stubs for rclpy and message packages."""
|
||||
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
|
||||
"std_msgs", "std_msgs.msg",
|
||||
"saltybot_social_msgs", "saltybot_social_msgs.msg"):
|
||||
sys.modules[mod_name] = types.ModuleType(mod_name)
|
||||
|
||||
class _Node:
|
||||
def __init__(self, name):
|
||||
self._name = name
|
||||
# Preserve _params if pre-set by _make_node (super().__init__() is
|
||||
# called from GreetingTriggerNode.__init__, so don't reset here)
|
||||
if not hasattr(self, '_params'):
|
||||
self._params = {}
|
||||
self._pubs = {}
|
||||
self._subs = {}
|
||||
self._logs = []
|
||||
|
||||
def declare_parameter(self, name, default):
|
||||
# Don't overwrite values pre-set by _make_node
|
||||
if name not in self._params:
|
||||
self._params[name] = default
|
||||
|
||||
def get_parameter(self, name):
|
||||
class _P:
|
||||
def __init__(self, v):
|
||||
self.value = v
|
||||
return _P(self._params[name])
|
||||
|
||||
def create_publisher(self, msg_type, topic, qos):
|
||||
pub = _FakePub()
|
||||
self._pubs[topic] = pub
|
||||
return pub
|
||||
|
||||
def create_subscription(self, msg_type, topic, cb, qos):
|
||||
self._subs[topic] = cb
|
||||
return object()
|
||||
|
||||
def get_logger(self):
|
||||
node = self
|
||||
class _L:
|
||||
def info(self, m): node._logs.append(("INFO", m))
|
||||
def warn(self, m): node._logs.append(("WARN", m))
|
||||
def error(self, m): node._logs.append(("ERROR", m))
|
||||
return _L()
|
||||
|
||||
def destroy_node(self): pass
|
||||
|
||||
class _FakePub:
|
||||
def __init__(self):
|
||||
self.msgs = []
|
||||
def publish(self, msg):
|
||||
self.msgs.append(msg)
|
||||
|
||||
class _QoSProfile:
|
||||
def __init__(self, depth=10): self.depth = depth
|
||||
|
||||
class _String:
|
||||
def __init__(self): self.data = ""
|
||||
|
||||
# rclpy
|
||||
rclpy_mod = sys.modules["rclpy"]
|
||||
rclpy_mod.init = lambda args=None: None
|
||||
rclpy_mod.spin = lambda node: None
|
||||
rclpy_mod.shutdown = lambda: None
|
||||
|
||||
# rclpy.node
|
||||
sys.modules["rclpy.node"].Node = _Node
|
||||
|
||||
# rclpy.qos
|
||||
sys.modules["rclpy.qos"].QoSProfile = _QoSProfile
|
||||
|
||||
# std_msgs.msg
|
||||
sys.modules["std_msgs.msg"].String = _String
|
||||
|
||||
# saltybot_social_msgs.msg (FaceDetectionArray + PersonStateArray)
|
||||
class _FaceDetection:
|
||||
def __init__(self, face_id=0, person_name="", confidence=1.0):
|
||||
self.face_id = face_id
|
||||
self.person_name = person_name
|
||||
self.confidence = confidence
|
||||
|
||||
class _FaceDetectionArray:
|
||||
def __init__(self, faces=None):
|
||||
self.faces = faces or []
|
||||
|
||||
class _PersonState:
|
||||
def __init__(self, face_id=0, distance=0.0):
|
||||
self.face_id = face_id
|
||||
self.distance = distance
|
||||
|
||||
class _PersonStateArray:
|
||||
def __init__(self, persons=None):
|
||||
self.persons = persons or []
|
||||
|
||||
msgs = sys.modules["saltybot_social_msgs.msg"]
|
||||
msgs.FaceDetection = _FaceDetection
|
||||
msgs.FaceDetectionArray = _FaceDetectionArray
|
||||
msgs.PersonState = _PersonState
|
||||
msgs.PersonStateArray = _PersonStateArray
|
||||
|
||||
return _Node, _FakePub, _QoSProfile, _String, _FaceDetection, _FaceDetectionArray, _PersonState, _PersonStateArray
|
||||
|
||||
|
||||
_Node, _FakePub, _QoSProfile, _String, _FaceDetection, _FaceDetectionArray, _PersonState, _PersonStateArray = _make_ros_stubs()
|
||||
|
||||
|
||||
# ── Load module under test ────────────────────────────────────────────────────
|
||||
|
||||
_SRC = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/saltybot_social/greeting_trigger_node.py"
|
||||
)
|
||||
|
||||
def _load_mod():
|
||||
spec = importlib.util.spec_from_file_location("greeting_trigger_node_testmod", _SRC)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _make_node(mod, **kwargs):
|
||||
"""Instantiate GreetingTriggerNode with overridden parameters."""
|
||||
node = mod.GreetingTriggerNode.__new__(mod.GreetingTriggerNode)
|
||||
|
||||
# Pre-populate _params BEFORE __init__ so super().__init__() (which calls
|
||||
# _Node.__init__) sees them and skips reset due to hasattr guard.
|
||||
defaults = {
|
||||
"proximity_m": 2.0,
|
||||
"cooldown_s": 300.0,
|
||||
"unknown_distance": 0.0,
|
||||
"faces_topic": "/social/faces/detected",
|
||||
"states_topic": "/social/person_states",
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
node._params = dict(defaults)
|
||||
|
||||
mod.GreetingTriggerNode.__init__(node)
|
||||
return node
|
||||
|
||||
|
||||
def _face_msg(faces):
|
||||
return _FaceDetectionArray(faces=faces)
|
||||
|
||||
def _state_msg(persons):
|
||||
return _PersonStateArray(persons=persons)
|
||||
|
||||
|
||||
# ── Test suites ───────────────────────────────────────────────────────────────
|
||||
|
||||
class TestNodeInit(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mod = _load_mod()
|
||||
|
||||
def test_imports_cleanly(self):
|
||||
self.assertTrue(hasattr(self.mod, "GreetingTriggerNode"))
|
||||
|
||||
def test_default_proximity(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertEqual(node._proximity, 2.0)
|
||||
|
||||
def test_default_cooldown(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertEqual(node._cooldown, 300.0)
|
||||
|
||||
def test_default_unknown_distance(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertEqual(node._unknown_dist, 0.0)
|
||||
|
||||
def test_pub_topic(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIn("/saltybot/greeting_trigger", node._pubs)
|
||||
|
||||
def test_subs_registered(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIn("/social/faces/detected", node._subs)
|
||||
self.assertIn("/social/person_states", node._subs)
|
||||
|
||||
def test_initial_caches_empty(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertEqual(node._distance_cache, {})
|
||||
self.assertEqual(node._last_greeted, {})
|
||||
|
||||
|
||||
class TestDistanceCache(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mod = _load_mod()
|
||||
|
||||
def setUp(self):
|
||||
self.node = _make_node(self.mod)
|
||||
|
||||
def test_state_updates_cache(self):
|
||||
ps = _PersonState(face_id=1, distance=1.5)
|
||||
self.node._on_person_states(_state_msg([ps]))
|
||||
self.assertAlmostEqual(self.node._distance_cache[1], 1.5)
|
||||
|
||||
def test_multiple_states_cached(self):
|
||||
persons = [_PersonState(face_id=i, distance=float(i)) for i in range(5)]
|
||||
self.node._on_person_states(_state_msg(persons))
|
||||
for i in range(5):
|
||||
self.assertAlmostEqual(self.node._distance_cache[i], float(i))
|
||||
|
||||
def test_state_update_overwrites(self):
|
||||
self.node._on_person_states(_state_msg([_PersonState(face_id=1, distance=3.0)]))
|
||||
self.node._on_person_states(_state_msg([_PersonState(face_id=1, distance=1.0)]))
|
||||
self.assertAlmostEqual(self.node._distance_cache[1], 1.0)
|
||||
|
||||
def test_negative_face_id_ignored(self):
|
||||
self.node._on_person_states(_state_msg([_PersonState(face_id=-1, distance=1.0)]))
|
||||
self.assertNotIn(-1, self.node._distance_cache)
|
||||
|
||||
def test_zero_distance_cached(self):
|
||||
self.node._on_person_states(_state_msg([_PersonState(face_id=5, distance=0.0)]))
|
||||
self.assertAlmostEqual(self.node._distance_cache[5], 0.0)
|
||||
|
||||
|
||||
class TestGreetingTrigger(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mod = _load_mod()
|
||||
|
||||
def setUp(self):
|
||||
self.node = _make_node(self.mod, proximity_m=2.0, cooldown_s=300.0)
|
||||
self.pub = self.node._pubs["/saltybot/greeting_trigger"]
|
||||
|
||||
def _inject_distance(self, face_id, distance):
|
||||
self.node._on_person_states(_state_msg([_PersonState(face_id=face_id, distance=distance)]))
|
||||
|
||||
def test_triggers_within_proximity(self):
|
||||
self._inject_distance(1, 1.5)
|
||||
self.node._on_faces(_face_msg([_FaceDetection(face_id=1, person_name="alice")]))
|
||||
self.assertEqual(len(self.pub.msgs), 1)
|
||||
|
||||
def test_no_trigger_beyond_proximity(self):
|
||||
self._inject_distance(2, 3.0)
|
||||
self.node._on_faces(_face_msg([_FaceDetection(face_id=2, person_name="bob")]))
|
||||
self.assertEqual(len(self.pub.msgs), 0)
|
||||
|
||||
def test_trigger_at_exact_proximity(self):
|
||||
self._inject_distance(3, 2.0)
|
||||
self.node._on_faces(_face_msg([_FaceDetection(face_id=3, person_name="carol")]))
|
||||
self.assertEqual(len(self.pub.msgs), 1)
|
||||
|
||||
def test_no_trigger_just_beyond(self):
|
||||
self._inject_distance(4, 2.001)
|
||||
self.node._on_faces(_face_msg([_FaceDetection(face_id=4, person_name="dave")]))
|
||||
self.assertEqual(len(self.pub.msgs), 0)
|
||||
|
||||
def test_cooldown_suppresses_retrigger(self):
|
||||
self._inject_distance(5, 1.0)
|
||||
face = _FaceDetection(face_id=5, person_name="eve")
|
||||
self.node._on_faces(_face_msg([face]))
|
||||
self.node._on_faces(_face_msg([face])) # second call in cooldown
|
||||
self.assertEqual(len(self.pub.msgs), 1)
|
||||
|
||||
def test_cooldown_per_face_id(self):
|
||||
self._inject_distance(6, 1.0)
|
||||
self._inject_distance(7, 1.0)
|
||||
self.node._on_faces(_face_msg([_FaceDetection(face_id=6, person_name="f")]))
|
||||
self.node._on_faces(_face_msg([_FaceDetection(face_id=7, person_name="g")]))
|
||||
self.assertEqual(len(self.pub.msgs), 2)
|
||||
|
||||
def test_expired_cooldown_retrigers(self):
|
||||
self._inject_distance(8, 1.0)
|
||||
face = _FaceDetection(face_id=8, person_name="hank")
|
||||
self.node._on_faces(_face_msg([face]))
|
||||
# Manually expire the cooldown
|
||||
self.node._last_greeted[8] = time.monotonic() - 400.0
|
||||
self.node._on_faces(_face_msg([face]))
|
||||
self.assertEqual(len(self.pub.msgs), 2)
|
||||
|
||||
def test_unknown_face_uses_unknown_distance(self):
|
||||
# unknown_distance=0.0 → should trigger (0.0 <= 2.0)
|
||||
node = _make_node(self.mod, unknown_distance=0.0)
|
||||
pub = node._pubs["/saltybot/greeting_trigger"]
|
||||
node._on_faces(_face_msg([_FaceDetection(face_id=99, person_name="stranger")]))
|
||||
self.assertEqual(len(pub.msgs), 1)
|
||||
|
||||
def test_unknown_face_large_distance_no_trigger(self):
|
||||
# unknown_distance=10.0 → should NOT trigger
|
||||
node = _make_node(self.mod, unknown_distance=10.0)
|
||||
pub = node._pubs["/saltybot/greeting_trigger"]
|
||||
node._on_faces(_face_msg([_FaceDetection(face_id=100, person_name="far")]))
|
||||
self.assertEqual(len(pub.msgs), 0)
|
||||
|
||||
def test_multiple_faces_triggers_each_within_range(self):
|
||||
self._inject_distance(10, 1.0)
|
||||
self._inject_distance(11, 3.0) # out of range
|
||||
faces = [
|
||||
_FaceDetection(face_id=10, person_name="near"),
|
||||
_FaceDetection(face_id=11, person_name="far"),
|
||||
]
|
||||
self.node._on_faces(_face_msg(faces))
|
||||
self.assertEqual(len(self.pub.msgs), 1)
|
||||
|
||||
def test_empty_face_array_no_trigger(self):
|
||||
self.node._on_faces(_face_msg([]))
|
||||
self.assertEqual(len(self.pub.msgs), 0)
|
||||
|
||||
|
||||
class TestPayload(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mod = _load_mod()
|
||||
|
||||
def setUp(self):
|
||||
self.node = _make_node(self.mod)
|
||||
self.pub = self.node._pubs["/saltybot/greeting_trigger"]
|
||||
|
||||
def _trigger(self, face_id=1, person_name="alice", distance=1.5):
|
||||
self.node._on_person_states(_state_msg([_PersonState(face_id=face_id, distance=distance)]))
|
||||
self.node._on_faces(_face_msg([_FaceDetection(face_id=face_id, person_name=person_name)]))
|
||||
|
||||
def test_payload_is_json(self):
|
||||
self._trigger()
|
||||
payload = json.loads(self.pub.msgs[0].data)
|
||||
self.assertIsInstance(payload, dict)
|
||||
|
||||
def test_payload_face_id(self):
|
||||
self._trigger(face_id=42)
|
||||
payload = json.loads(self.pub.msgs[0].data)
|
||||
self.assertEqual(payload["face_id"], 42)
|
||||
|
||||
def test_payload_person_name(self):
|
||||
self._trigger(person_name="zara")
|
||||
payload = json.loads(self.pub.msgs[0].data)
|
||||
self.assertEqual(payload["person_name"], "zara")
|
||||
|
||||
def test_payload_distance(self):
|
||||
self._trigger(distance=1.234)
|
||||
payload = json.loads(self.pub.msgs[0].data)
|
||||
self.assertAlmostEqual(payload["distance_m"], 1.234, places=2)
|
||||
|
||||
def test_payload_has_ts(self):
|
||||
self._trigger()
|
||||
payload = json.loads(self.pub.msgs[0].data)
|
||||
self.assertIn("ts", payload)
|
||||
self.assertIsInstance(payload["ts"], float)
|
||||
|
||||
def test_ts_is_recent(self):
|
||||
before = time.time()
|
||||
self._trigger()
|
||||
after = time.time()
|
||||
payload = json.loads(self.pub.msgs[0].data)
|
||||
self.assertGreaterEqual(payload["ts"], before)
|
||||
self.assertLessEqual(payload["ts"], after + 1.0)
|
||||
|
||||
|
||||
class TestNodeSrc(unittest.TestCase):
|
||||
"""Source-level checks — verify node structure without instantiation."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
with open(_SRC) as f:
|
||||
cls.src = f.read()
|
||||
|
||||
def test_issue_tag(self):
|
||||
self.assertIn("#270", self.src)
|
||||
|
||||
def test_pub_topic(self):
|
||||
self.assertIn("/saltybot/greeting_trigger", self.src)
|
||||
|
||||
def test_faces_topic(self):
|
||||
self.assertIn("/social/faces/detected", self.src)
|
||||
|
||||
def test_states_topic(self):
|
||||
self.assertIn("/social/person_states", self.src)
|
||||
|
||||
def test_proximity_param(self):
|
||||
self.assertIn("proximity_m", self.src)
|
||||
|
||||
def test_cooldown_param(self):
|
||||
self.assertIn("cooldown_s", self.src)
|
||||
|
||||
def test_unknown_distance_param(self):
|
||||
self.assertIn("unknown_distance", self.src)
|
||||
|
||||
def test_json_output(self):
|
||||
self.assertIn("json", self.src)
|
||||
|
||||
def test_face_id_in_payload(self):
|
||||
self.assertIn("face_id", self.src)
|
||||
|
||||
def test_person_name_in_payload(self):
|
||||
self.assertIn("person_name", self.src)
|
||||
|
||||
def test_distance_in_payload(self):
|
||||
self.assertIn("distance_m", self.src)
|
||||
|
||||
def test_main_defined(self):
|
||||
self.assertIn("def main", self.src)
|
||||
|
||||
def test_threading_lock(self):
|
||||
self.assertIn("threading.Lock", self.src)
|
||||
|
||||
|
||||
class TestConfig(unittest.TestCase):
|
||||
"""Checks on config/launch/setup files."""
|
||||
|
||||
_CONFIG = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/config/greeting_trigger_params.yaml"
|
||||
)
|
||||
_LAUNCH = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/launch/greeting_trigger.launch.py"
|
||||
)
|
||||
_SETUP = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/setup.py"
|
||||
)
|
||||
|
||||
def test_config_exists(self):
|
||||
import os
|
||||
self.assertTrue(os.path.exists(self._CONFIG))
|
||||
|
||||
def test_config_proximity(self):
|
||||
with open(self._CONFIG) as f:
|
||||
content = f.read()
|
||||
self.assertIn("proximity_m", content)
|
||||
|
||||
def test_config_cooldown(self):
|
||||
with open(self._CONFIG) as f:
|
||||
content = f.read()
|
||||
self.assertIn("cooldown_s", content)
|
||||
|
||||
def test_config_node_name(self):
|
||||
with open(self._CONFIG) as f:
|
||||
content = f.read()
|
||||
self.assertIn("greeting_trigger_node", content)
|
||||
|
||||
def test_launch_exists(self):
|
||||
import os
|
||||
self.assertTrue(os.path.exists(self._LAUNCH))
|
||||
|
||||
def test_launch_proximity_arg(self):
|
||||
with open(self._LAUNCH) as f:
|
||||
content = f.read()
|
||||
self.assertIn("proximity_m", content)
|
||||
|
||||
def test_launch_cooldown_arg(self):
|
||||
with open(self._LAUNCH) as f:
|
||||
content = f.read()
|
||||
self.assertIn("cooldown_s", content)
|
||||
|
||||
def test_entry_point(self):
|
||||
with open(self._SETUP) as f:
|
||||
content = f.read()
|
||||
self.assertIn("greeting_trigger_node", content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
x
Reference in New Issue
Block a user