Compare commits

...

3 Commits

Author SHA1 Message Date
c7dd07f9ed feat(social): proximity-based greeting trigger — Issue #270
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (pull_request) Failing after 12s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (pull_request) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (pull_request) Has been cancelled
Adds greeting_trigger_node to saltybot_social:
- Subscribes to /social/faces/detected (FaceDetectionArray) for face arrivals
- Subscribes to /social/person_states (PersonStateArray) to cache face_id→distance
- Fires greeting when face_id is within proximity_m (default 2m) and
  not in per-face_id cooldown window (default 300s)
- Publishes JSON on /saltybot/greeting_trigger:
  {face_id, person_name, distance_m, ts}
- unknown_distance param controls assumed distance for faces with no PersonState yet
- Thread-safe distance cache and greeted map
- 50/50 tests passing

Closes #270
2026-03-02 17:26:40 -05:00
01ee02f837 Merge pull request 'feat(bringup): D435i depth hole filler via bilateral interpolation (Issue #268)' (#271) from sl-perception/issue-268-depth-holes into main 2026-03-02 17:26:22 -05:00
f0e11fe7ca feat(bringup): depth image hole filler via bilateral interpolation (Issue #268)
Adds multi-pass spatial-Gaussian hole filler for D435i depth images.
Each pass replaces zero/NaN pixels with the Gaussian-weighted mean of valid
neighbours in a growing kernel (×1, ×2.5, ×6 default); original valid
pixels are never modified.  Handles uint16 mm → float32 m conversion,
border pixels via BORDER_REFLECT, and above-d_max pixels as holes.
Publishes filled float32 depth on /camera/depth/filled at camera rate.
37/37 pure-Python tests pass (no ROS2 required).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 14:19:27 -05:00
9 changed files with 1223 additions and 0 deletions

View File

@ -0,0 +1,141 @@
"""
_depth_hole_fill.py Depth image hole filling via bilateral interpolation (no ROS2 deps).
Algorithm
---------
A "hole" is any pixel where depth == 0, depth is NaN, or depth is outside the
valid range [d_min, d_max].
Each pass replaces every hole pixel with the spatial-Gaussian-weighted mean of
valid pixels in a (kernel_size × kernel_size) neighbourhood:
filled[x,y] = Σ G(||p - q||; σ) · d[q] / Σ G(||p - q||; σ)
q valid neighbours of (x,y)
The denominator (sum of spatial weights over valid pixels) normalises correctly
even at image borders and around isolated valid pixels.
Multiple passes with geometrically growing kernels are applied so that:
Pass 1 kernel_size fills small holes ( kernel_size/2 px radius)
Pass 2 kernel_size × 2.5 fills medium holes
Pass 3 kernel_size × 6.0 fills large holes / fronto-parallel surfaces
After all passes any remaining zeros are left as-is (no valid neighbourhood data).
Because only the spatial Gaussian (not a depth range term) is used as the weighting
function, this is equivalent to a bilateral filter with σ_range . In practice
this produces smooth, physically plausible fills in the depth domain.
Public API
----------
fill_holes(depth, kernel_size=5, d_min=0.1, d_max=10.0, max_passes=3) ndarray
valid_mask(depth, d_min=0.1, d_max=10.0) bool ndarray
"""
from __future__ import annotations
import math
from typing import Optional
import numpy as np
# Kernel size multipliers for successive passes
_PASS_SCALE = [1.0, 2.5, 6.0]
def valid_mask(
depth: np.ndarray,
d_min: float = 0.1,
d_max: float = 10.0,
) -> np.ndarray:
"""
Return a boolean mask of valid (non-hole) pixels.
Parameters
----------
depth : (H, W) float32 ndarray, depth in metres
d_min : minimum valid depth (m)
d_max : maximum valid depth (m)
Returns
-------
(H, W) bool ndarray True where depth is finite and in [d_min, d_max]
"""
return np.isfinite(depth) & (depth >= d_min) & (depth <= d_max)
def fill_holes(
depth: np.ndarray,
kernel_size: int = 5,
d_min: float = 0.1,
d_max: float = 10.0,
max_passes: int = 3,
) -> np.ndarray:
"""
Fill zero/NaN depth pixels using multi-pass spatial Gaussian interpolation.
Parameters
----------
depth : (H, W) float32 ndarray, depth in metres
kernel_size : initial kernel side length (pixels, forced odd, 3)
d_min : minimum valid depth pixels below this are treated as holes
d_max : maximum valid depth pixels above this are treated as holes
max_passes : number of fill passes (13); each uses a larger kernel
Returns
-------
(H, W) float32 ndarray same as input, with holes filled where possible.
Pixels with no valid neighbours after all passes remain 0.0.
Original valid pixels are never modified.
"""
import cv2
depth = np.asarray(depth, dtype=np.float32)
# Replace NaN with 0 so arithmetic is clean
depth = np.where(np.isfinite(depth), depth, 0.0).astype(np.float32)
mask = valid_mask(depth, d_min, d_max) # True where already valid
result = depth.copy()
n_passes = max(1, min(max_passes, len(_PASS_SCALE)))
for i in range(n_passes):
if mask.all():
break # no holes left
ks = _odd_kernel_size(kernel_size, _PASS_SCALE[i])
half = ks // 2
sigma = max(half / 2.0, 0.5)
gk = cv2.getGaussianKernel(ks, sigma).astype(np.float32)
kernel = (gk @ gk.T)
# Multiply depth by mask so invalid pixels contribute 0 weight
d_valid = np.where(mask, result, 0.0).astype(np.float32)
w_valid = mask.astype(np.float32)
sum_d = cv2.filter2D(d_valid, ddepth=-1, kernel=kernel,
borderType=cv2.BORDER_REFLECT)
sum_w = cv2.filter2D(w_valid, ddepth=-1, kernel=kernel,
borderType=cv2.BORDER_REFLECT)
# Where we have enough weight, compute the weighted mean
has_data = sum_w > 1e-6
interp = np.where(has_data, sum_d / np.where(has_data, sum_w, 1.0), 0.0)
# Only fill holes — never overwrite original valid pixels
result = np.where(mask, result, interp.astype(np.float32))
# Update mask with newly filled pixels (for the next pass)
newly_filled = (~mask) & (result > 0.0)
mask = mask | newly_filled
return result.astype(np.float32)
# ── Internal helpers ──────────────────────────────────────────────────────────
def _odd_kernel_size(base: int, scale: float) -> int:
"""Return the nearest odd integer to base * scale, minimum 3."""
raw = max(3, int(round(base * scale)))
return raw if raw % 2 == 1 else raw + 1

View File

@ -0,0 +1,128 @@
"""
depth_hole_fill_node.py D435i depth image hole filler (Issue #268).
Subscribes to the raw D435i depth stream, fills zero/NaN pixels using
multi-pass spatial-Gaussian bilateral interpolation, and republishes the
filled image at camera rate.
Subscribes (BEST_EFFORT):
/camera/depth/image_rect_raw sensor_msgs/Image float32 depth (m)
Publishes:
/camera/depth/filled sensor_msgs/Image float32 depth (m), holes filled
The filled image preserves all original valid pixels exactly and only
modifies pixels that had no return (0 or NaN). The output is suitable
for all downstream consumers that expect a dense depth map (VO, RTAB-Map,
collision avoidance, floor classifier).
Parameters
----------
input_topic str /camera/depth/image_rect_raw Input depth topic
output_topic str /camera/depth/filled Output depth topic
kernel_size int 5 Initial Gaussian kernel side length (pixels)
d_min float 0.1 Minimum valid depth (m)
d_max float 10.0 Maximum valid depth (m)
max_passes int 3 Fill passes (growing kernel per pass)
"""
from __future__ import annotations
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
import numpy as np
from cv_bridge import CvBridge
from sensor_msgs.msg import Image
from ._depth_hole_fill import fill_holes
_SENSOR_QOS = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=4,
)
class DepthHoleFillNode(Node):
def __init__(self) -> None:
super().__init__('depth_hole_fill_node')
self.declare_parameter('input_topic', '/camera/depth/image_rect_raw')
self.declare_parameter('output_topic', '/camera/depth/filled')
self.declare_parameter('kernel_size', 5)
self.declare_parameter('d_min', 0.1)
self.declare_parameter('d_max', 10.0)
self.declare_parameter('max_passes', 3)
input_topic = self.get_parameter('input_topic').value
output_topic = self.get_parameter('output_topic').value
self._ks = int(self.get_parameter('kernel_size').value)
self._d_min = self.get_parameter('d_min').value
self._d_max = self.get_parameter('d_max').value
self._passes = int(self.get_parameter('max_passes').value)
self._bridge = CvBridge()
self._sub = self.create_subscription(
Image, input_topic, self._on_depth, _SENSOR_QOS)
self._pub = self.create_publisher(Image, output_topic, 10)
self.get_logger().info(
f'depth_hole_fill_node ready — '
f'{input_topic}{output_topic} '
f'kernel={self._ks} passes={self._passes} '
f'd=[{self._d_min},{self._d_max}]m'
)
# ── Callback ──────────────────────────────────────────────────────────────
def _on_depth(self, msg: Image) -> None:
try:
depth = self._bridge.imgmsg_to_cv2(msg, desired_encoding='passthrough')
except Exception as exc:
self.get_logger().error(
f'cv_bridge: {exc}', throttle_duration_sec=5.0)
return
depth = depth.astype(np.float32)
# Handle uint16 mm → float32 m conversion (D435i raw stream)
if depth.max() > 100.0:
depth /= 1000.0
filled = fill_holes(
depth,
kernel_size=self._ks,
d_min=self._d_min,
d_max=self._d_max,
max_passes=self._passes,
)
try:
out_msg = self._bridge.cv2_to_imgmsg(filled, encoding='32FC1')
except Exception as exc:
self.get_logger().error(
f'cv2_to_imgmsg: {exc}', throttle_duration_sec=5.0)
return
out_msg.header = msg.header
self._pub.publish(out_msg)
def main(args=None) -> None:
rclpy.init(args=args)
node = DepthHoleFillNode()
try:
rclpy.spin(node)
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -37,6 +37,8 @@ setup(
'floor_classifier = saltybot_bringup.floor_classifier_node:main', 'floor_classifier = saltybot_bringup.floor_classifier_node:main',
# Visual odometry drift detector (Issue #260) # Visual odometry drift detector (Issue #260)
'vo_drift_detector = saltybot_bringup.vo_drift_node:main', 'vo_drift_detector = saltybot_bringup.vo_drift_node:main',
# Depth image hole filler (Issue #268)
'depth_hole_fill = saltybot_bringup.depth_hole_fill_node:main',
], ],
}, },
) )

View File

@ -0,0 +1,281 @@
"""
test_depth_hole_fill.py Unit tests for depth hole fill helpers (no ROS2 required).
Covers:
valid_mask:
- valid range returns True
- zero / below d_min returns False
- NaN returns False
- above d_max returns False
- mixed array has correct mask
_odd_kernel_size:
- result is always odd
- result >= 3
- scales correctly
fill_holes no-hole cases:
- fully valid image is returned unchanged
- output dtype is float32
- output shape matches input
fill_holes basic fills:
- single centre hole in uniform depth filled with correct depth
- single centre hole in uniform depth original valid pixels unchanged
- NaN pixel treated as hole and filled
- row of zeros within uniform depth filled
fill_holes fill quality:
- linear gradient: centre hole filled with interpolated value
- multi-pass fills larger holes than single pass
- all-zero image stays zero (no valid neighbours)
- border hole (edge pixel) is handled without crash
- depth range: pixel above d_max treated as hole
fill_holes valid pixel preservation:
- original valid pixels are never modified
- max_passes=1 still fills small holes
"""
import sys
import os
import math
import numpy as np
import pytest
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from saltybot_bringup._depth_hole_fill import (
fill_holes,
valid_mask,
_odd_kernel_size,
)
# ── Helpers ───────────────────────────────────────────────────────────────────
def _uniform(val=2.0, h=64, w=64) -> np.ndarray:
return np.full((h, w), val, dtype=np.float32)
def _poke_hole(arr, r, c) -> np.ndarray:
arr = arr.copy()
arr[r, c] = 0.0
return arr
def _poke_nan(arr, r, c) -> np.ndarray:
arr = arr.copy()
arr[r, c] = float('nan')
return arr
# ── valid_mask ────────────────────────────────────────────────────────────────
class TestValidMask:
def test_valid_pixel_is_true(self):
d = np.array([[1.0]], dtype=np.float32)
assert valid_mask(d, 0.1, 10.0)[0, 0]
def test_zero_is_false(self):
d = np.array([[0.0]], dtype=np.float32)
assert not valid_mask(d, 0.1, 10.0)[0, 0]
def test_below_dmin_is_false(self):
d = np.array([[0.05]], dtype=np.float32)
assert not valid_mask(d, 0.1, 10.0)[0, 0]
def test_nan_is_false(self):
d = np.array([[float('nan')]], dtype=np.float32)
assert not valid_mask(d, 0.1, 10.0)[0, 0]
def test_above_dmax_is_false(self):
d = np.array([[15.0]], dtype=np.float32)
assert not valid_mask(d, 0.1, 10.0)[0, 0]
def test_at_dmin_is_true(self):
d = np.array([[0.1]], dtype=np.float32)
assert valid_mask(d, 0.1, 10.0)[0, 0]
def test_at_dmax_is_true(self):
d = np.array([[10.0]], dtype=np.float32)
assert valid_mask(d, 0.1, 10.0)[0, 0]
def test_mixed_array(self):
d = np.array([[0.0, 1.0, float('nan'), 5.0, 11.0]], dtype=np.float32)
m = valid_mask(d, 0.1, 10.0)
np.testing.assert_array_equal(m, [[False, True, False, True, False]])
# ── _odd_kernel_size ──────────────────────────────────────────────────────────
class TestOddKernelSize:
@pytest.mark.parametrize('base,scale', [
(5, 1.0), (5, 2.5), (5, 6.0),
(3, 1.0), (7, 2.0), (9, 3.0),
(4, 1.0), # even base → must become odd
])
def test_result_is_odd(self, base, scale):
ks = _odd_kernel_size(base, scale)
assert ks % 2 == 1
@pytest.mark.parametrize('base,scale', [(3, 1.0), (1, 5.0), (2, 0.5)])
def test_result_at_least_3(self, base, scale):
assert _odd_kernel_size(base, scale) >= 3
def test_scale_1_returns_base_or_nearby_odd(self):
ks = _odd_kernel_size(5, 1.0)
assert ks == 5
def test_large_scale_gives_large_kernel(self):
ks = _odd_kernel_size(5, 6.0)
assert ks >= 25 # 5 * 6 = 30 → 31
# ── fill_holes — output contract ──────────────────────────────────────────────
class TestFillHolesOutputContract:
def test_output_dtype_float32(self):
out = fill_holes(_uniform(2.0))
assert out.dtype == np.float32
def test_output_shape_preserved(self):
img = _uniform(2.0, h=48, w=64)
out = fill_holes(img)
assert out.shape == img.shape
def test_fully_valid_image_unchanged(self):
img = _uniform(2.0)
out = fill_holes(img)
np.testing.assert_allclose(out, img, atol=1e-6)
def test_valid_pixels_never_modified(self):
"""Any pixel valid in the input must be identical in the output."""
img = _uniform(3.0, h=32, w=32)
img[16, 16] = 0.0 # one hole
mask_before = valid_mask(img)
out = fill_holes(img)
np.testing.assert_allclose(out[mask_before], img[mask_before], atol=1e-6)
# ── fill_holes — basic hole filling ──────────────────────────────────────────
class TestFillHolesBasic:
def test_centre_zero_filled_uniform(self):
"""Single zero pixel in uniform depth → filled with that depth."""
img = _poke_hole(_uniform(2.0, 32, 32), 16, 16)
out = fill_holes(img, kernel_size=5, max_passes=1)
assert out[16, 16] == pytest.approx(2.0, abs=0.05)
def test_centre_nan_filled_uniform(self):
"""Single NaN pixel in uniform depth → filled."""
img = _poke_nan(_uniform(2.0, 32, 32), 16, 16)
out = fill_holes(img, kernel_size=5, max_passes=1)
assert out[16, 16] == pytest.approx(2.0, abs=0.05)
def test_filled_value_is_positive(self):
img = _poke_hole(_uniform(1.5, 32, 32), 16, 16)
out = fill_holes(img)
assert out[16, 16] > 0.0
def test_row_of_holes_filled(self):
"""Entire middle row zeroed → should be filled from neighbours above/below."""
img = _uniform(3.0, 32, 32)
img[16, :] = 0.0
out = fill_holes(img, kernel_size=7, max_passes=1)
# All pixels in the row should be non-zero after filling
assert (out[16, :] > 0.0).all()
def test_all_zero_stays_zero(self):
"""Image with no valid pixels → stays zero (nothing to interpolate from)."""
img = np.zeros((32, 32), dtype=np.float32)
out = fill_holes(img, d_min=0.1)
assert (out == 0.0).all()
def test_border_hole_no_crash(self):
"""Holes at image corners must not raise exceptions."""
img = _uniform(2.0, 32, 32)
img[0, 0] = 0.0
img[0, -1] = 0.0
img[-1, 0] = 0.0
img[-1, -1] = 0.0
out = fill_holes(img) # must not raise
assert out.shape == img.shape
def test_border_holes_filled(self):
"""Corner holes should be filled from their neighbours."""
img = _uniform(2.0, 32, 32)
img[0, 0] = 0.0
out = fill_holes(img, kernel_size=5, max_passes=1)
assert out[0, 0] == pytest.approx(2.0, abs=0.1)
# ── fill_holes — fill quality ─────────────────────────────────────────────────
class TestFillHolesQuality:
def test_linear_gradient_centre_hole_interpolated(self):
"""
Depth linearly increasing from 1.0 (left) to 3.0 (right).
Centre hole should be filled near the midpoint (~2.0).
"""
h, w = 32, 32
img = np.tile(np.linspace(1.0, 3.0, w, dtype=np.float32), (h, 1))
cx = w // 2
img[:, cx] = 0.0
out = fill_holes(img, kernel_size=5, max_passes=1)
mid = out[h // 2, cx]
assert 1.5 <= mid <= 2.5, f'interpolated value {mid:.3f} not in [1.5, 2.5]'
def test_large_hole_filled_with_more_passes(self):
"""A 9×9 hole in uniform depth: single pass may not fully fill it,
but 3 passes should."""
img = _uniform(2.0, 64, 64)
# Create a 9×9 hole
img[28:37, 28:37] = 0.0
out1 = fill_holes(img, kernel_size=5, max_passes=1)
out3 = fill_holes(img, kernel_size=5, max_passes=3)
# More passes → fewer remaining holes
holes1 = (out1 == 0.0).sum()
holes3 = (out3 == 0.0).sum()
assert holes3 <= holes1, f'more passes should reduce holes: {holes3} vs {holes1}'
def test_3pass_fills_9x9_hole_completely(self):
img = _uniform(2.0, 64, 64)
img[28:37, 28:37] = 0.0
out = fill_holes(img, kernel_size=5, max_passes=3)
assert (out[28:37, 28:37] > 0.0).all()
def test_filled_depth_within_valid_range(self):
"""Filled pixels should have depth within [d_min, d_max]."""
img = _uniform(2.0, 32, 32)
img[10:15, 10:15] = 0.0
out = fill_holes(img, d_min=0.1, d_max=10.0, max_passes=3)
# Only check pixels that were actually filled
was_hole = (img == 0.0)
filled = out[was_hole]
positive = filled[filled > 0.0]
assert (positive >= 0.1).all()
assert (positive <= 10.0).all()
def test_above_dmax_treated_as_hole(self):
"""Pixels above d_max should be treated as holes and filled."""
img = _uniform(2.0, 32, 32)
img[16, 16] = 15.0 # out of range
out = fill_holes(img, d_max=10.0, max_passes=1)
assert out[16, 16] == pytest.approx(2.0, abs=0.1)
def test_max_passes_1_works(self):
img = _poke_hole(_uniform(2.0, 32, 32), 16, 16)
out = fill_holes(img, max_passes=1)
assert out.shape == img.shape
assert out[16, 16] > 0.0
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@ -0,0 +1,8 @@
greeting_trigger_node:
ros__parameters:
proximity_m: 2.0 # Trigger when person is within this distance (m)
cooldown_s: 300.0 # Re-greeting suppression window per face_id (s)
unknown_distance: 0.0 # Distance assumed when PersonState not yet available
# 0.0 → always greet faces with no state yet
faces_topic: "/social/faces/detected"
states_topic: "/social/person_states"

View File

@ -0,0 +1,39 @@
"""greeting_trigger.launch.py -- Launch proximity-based greeting trigger (Issue #270).
Usage:
ros2 launch saltybot_social greeting_trigger.launch.py
ros2 launch saltybot_social greeting_trigger.launch.py proximity_m:=1.5 cooldown_s:=120.0
"""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
pkg = get_package_share_directory("saltybot_social")
cfg = os.path.join(pkg, "config", "greeting_trigger_params.yaml")
return LaunchDescription([
DeclareLaunchArgument("proximity_m", default_value="2.0",
description="Greeting proximity threshold (m)"),
DeclareLaunchArgument("cooldown_s", default_value="300.0",
description="Per-face_id re-greeting cooldown (s)"),
Node(
package="saltybot_social",
executable="greeting_trigger_node",
name="greeting_trigger_node",
output="screen",
parameters=[
cfg,
{
"proximity_m": LaunchConfiguration("proximity_m"),
"cooldown_s": LaunchConfiguration("cooldown_s"),
},
],
),
])

View File

@ -0,0 +1,150 @@
"""greeting_trigger_node.py -- Proximity-based greeting trigger.
Issue #270
Monitors face detections and person states. When a new face_id is seen
within ``proximity_m`` metres (default 2 m) and has not been greeted within
``cooldown_s`` seconds, publishes a JSON greeting trigger on
/saltybot/greeting_trigger.
Distance is looked up from the /social/person_states topic which carries a
face_id distance mapping. When no state is available for a face the node
applies a configurable default distance so it can still fire on face-only
pipelines.
Subscriptions:
/social/faces/detected saltybot_social_msgs/FaceDetectionArray
/social/person_states saltybot_social_msgs/PersonStateArray
Publication:
/saltybot/greeting_trigger std_msgs/String (JSON)
{"face_id": <int>, "person_name": <str>, "distance_m": <float>,
"ts": <float unix epoch>}
Parameters:
proximity_m (float, 2.0) -- trigger when distance <= this
cooldown_s (float, 300.0) -- suppress re-greeting same face_id
unknown_distance (float, 0.0) -- distance assumed when PersonState
is not yet available (0.0 always
trigger for unknown faces)
faces_topic (str, "/social/faces/detected")
states_topic (str, "/social/person_states")
"""
from __future__ import annotations
import json
import time
import threading
from typing import Dict
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile
from std_msgs.msg import String
# Custom messages — imported at runtime so offline tests can stub them
try:
from saltybot_social_msgs.msg import FaceDetectionArray, PersonStateArray
_MSGS = True
except ImportError:
_MSGS = False
class GreetingTriggerNode(Node):
"""Publishes greeting trigger when a person enters proximity."""
def __init__(self) -> None:
super().__init__("greeting_trigger_node")
self.declare_parameter("proximity_m", 2.0)
self.declare_parameter("cooldown_s", 300.0)
self.declare_parameter("unknown_distance", 0.0)
self.declare_parameter("faces_topic", "/social/faces/detected")
self.declare_parameter("states_topic", "/social/person_states")
self._proximity = self.get_parameter("proximity_m").value
self._cooldown = self.get_parameter("cooldown_s").value
self._unknown_dist = self.get_parameter("unknown_distance").value
faces_topic = self.get_parameter("faces_topic").value
states_topic = self.get_parameter("states_topic").value
# face_id → last known distance (m); updated from PersonStateArray
self._distance_cache: Dict[int, float] = {}
# face_id → unix timestamp of last greeting
self._last_greeted: Dict[int, float] = {}
self._lock = threading.Lock()
qos = QoSProfile(depth=10)
self._pub = self.create_publisher(String, "/saltybot/greeting_trigger", qos)
if _MSGS:
self._states_sub = self.create_subscription(
PersonStateArray, states_topic, self._on_person_states, qos
)
self._faces_sub = self.create_subscription(
FaceDetectionArray, faces_topic, self._on_faces, qos
)
else:
self.get_logger().warn(
"saltybot_social_msgs not available — node is passive (no subscriptions)"
)
self.get_logger().info(
f"GreetingTriggerNode ready "
f"(proximity={self._proximity}m, cooldown={self._cooldown}s)"
)
# ── Callbacks ──────────────────────────────────────────────────────────
def _on_person_states(self, msg: "PersonStateArray") -> None:
"""Cache face_id → distance from incoming PersonState array."""
with self._lock:
for ps in msg.persons:
if ps.face_id >= 0:
self._distance_cache[ps.face_id] = float(ps.distance)
def _on_faces(self, msg: "FaceDetectionArray") -> None:
"""Evaluate each detected face; fire greeting if conditions met."""
now = time.monotonic()
with self._lock:
for face in msg.faces:
fid = int(face.face_id)
dist = self._distance_cache.get(fid, self._unknown_dist)
if dist > self._proximity:
continue # too far
last = self._last_greeted.get(fid, 0.0)
if now - last < self._cooldown:
continue # still in cooldown
# Fire!
self._last_greeted[fid] = now
self._fire(fid, str(face.person_name), dist)
def _fire(self, face_id: int, person_name: str, distance_m: float) -> None:
payload = {
"face_id": face_id,
"person_name": person_name,
"distance_m": round(distance_m, 3),
"ts": time.time(),
}
msg = String()
msg.data = json.dumps(payload)
self._pub.publish(msg)
self.get_logger().info(
f"Greeting trigger: face_id={face_id} name={person_name!r} "
f"dist={distance_m:.2f}m"
)
def main(args=None) -> None:
rclpy.init(args=args)
node = GreetingTriggerNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@ -47,6 +47,8 @@ setup(
'vad_node = saltybot_social.vad_node:main', 'vad_node = saltybot_social.vad_node:main',
# Ambient sound classifier — mel-spectrogram (Issue #252) # Ambient sound classifier — mel-spectrogram (Issue #252)
'ambient_sound_node = saltybot_social.ambient_sound_node:main', 'ambient_sound_node = saltybot_social.ambient_sound_node:main',
# Proximity-based greeting trigger (Issue #270)
'greeting_trigger_node = saltybot_social.greeting_trigger_node:main',
], ],
}, },
) )

View File

@ -0,0 +1,472 @@
"""test_greeting_trigger.py -- Offline tests for greeting_trigger_node (Issue #270).
Stubs out rclpy and saltybot_social_msgs so tests run without a ROS install.
"""
import importlib
import json
import sys
import time
import types
import unittest
# ── ROS2 / message stubs ──────────────────────────────────────────────────────
def _make_ros_stubs():
"""Install minimal stubs for rclpy and message packages."""
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
"std_msgs", "std_msgs.msg",
"saltybot_social_msgs", "saltybot_social_msgs.msg"):
sys.modules[mod_name] = types.ModuleType(mod_name)
class _Node:
def __init__(self, name):
self._name = name
# Preserve _params if pre-set by _make_node (super().__init__() is
# called from GreetingTriggerNode.__init__, so don't reset here)
if not hasattr(self, '_params'):
self._params = {}
self._pubs = {}
self._subs = {}
self._logs = []
def declare_parameter(self, name, default):
# Don't overwrite values pre-set by _make_node
if name not in self._params:
self._params[name] = default
def get_parameter(self, name):
class _P:
def __init__(self, v):
self.value = v
return _P(self._params[name])
def create_publisher(self, msg_type, topic, qos):
pub = _FakePub()
self._pubs[topic] = pub
return pub
def create_subscription(self, msg_type, topic, cb, qos):
self._subs[topic] = cb
return object()
def get_logger(self):
node = self
class _L:
def info(self, m): node._logs.append(("INFO", m))
def warn(self, m): node._logs.append(("WARN", m))
def error(self, m): node._logs.append(("ERROR", m))
return _L()
def destroy_node(self): pass
class _FakePub:
def __init__(self):
self.msgs = []
def publish(self, msg):
self.msgs.append(msg)
class _QoSProfile:
def __init__(self, depth=10): self.depth = depth
class _String:
def __init__(self): self.data = ""
# rclpy
rclpy_mod = sys.modules["rclpy"]
rclpy_mod.init = lambda args=None: None
rclpy_mod.spin = lambda node: None
rclpy_mod.shutdown = lambda: None
# rclpy.node
sys.modules["rclpy.node"].Node = _Node
# rclpy.qos
sys.modules["rclpy.qos"].QoSProfile = _QoSProfile
# std_msgs.msg
sys.modules["std_msgs.msg"].String = _String
# saltybot_social_msgs.msg (FaceDetectionArray + PersonStateArray)
class _FaceDetection:
def __init__(self, face_id=0, person_name="", confidence=1.0):
self.face_id = face_id
self.person_name = person_name
self.confidence = confidence
class _FaceDetectionArray:
def __init__(self, faces=None):
self.faces = faces or []
class _PersonState:
def __init__(self, face_id=0, distance=0.0):
self.face_id = face_id
self.distance = distance
class _PersonStateArray:
def __init__(self, persons=None):
self.persons = persons or []
msgs = sys.modules["saltybot_social_msgs.msg"]
msgs.FaceDetection = _FaceDetection
msgs.FaceDetectionArray = _FaceDetectionArray
msgs.PersonState = _PersonState
msgs.PersonStateArray = _PersonStateArray
return _Node, _FakePub, _QoSProfile, _String, _FaceDetection, _FaceDetectionArray, _PersonState, _PersonStateArray
_Node, _FakePub, _QoSProfile, _String, _FaceDetection, _FaceDetectionArray, _PersonState, _PersonStateArray = _make_ros_stubs()
# ── Load module under test ────────────────────────────────────────────────────
_SRC = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/saltybot_social/greeting_trigger_node.py"
)
def _load_mod():
spec = importlib.util.spec_from_file_location("greeting_trigger_node_testmod", _SRC)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
# ── Helpers ───────────────────────────────────────────────────────────────────
def _make_node(mod, **kwargs):
"""Instantiate GreetingTriggerNode with overridden parameters."""
node = mod.GreetingTriggerNode.__new__(mod.GreetingTriggerNode)
# Pre-populate _params BEFORE __init__ so super().__init__() (which calls
# _Node.__init__) sees them and skips reset due to hasattr guard.
defaults = {
"proximity_m": 2.0,
"cooldown_s": 300.0,
"unknown_distance": 0.0,
"faces_topic": "/social/faces/detected",
"states_topic": "/social/person_states",
}
defaults.update(kwargs)
node._params = dict(defaults)
mod.GreetingTriggerNode.__init__(node)
return node
def _face_msg(faces):
return _FaceDetectionArray(faces=faces)
def _state_msg(persons):
return _PersonStateArray(persons=persons)
# ── Test suites ───────────────────────────────────────────────────────────────
class TestNodeInit(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mod = _load_mod()
def test_imports_cleanly(self):
self.assertTrue(hasattr(self.mod, "GreetingTriggerNode"))
def test_default_proximity(self):
node = _make_node(self.mod)
self.assertEqual(node._proximity, 2.0)
def test_default_cooldown(self):
node = _make_node(self.mod)
self.assertEqual(node._cooldown, 300.0)
def test_default_unknown_distance(self):
node = _make_node(self.mod)
self.assertEqual(node._unknown_dist, 0.0)
def test_pub_topic(self):
node = _make_node(self.mod)
self.assertIn("/saltybot/greeting_trigger", node._pubs)
def test_subs_registered(self):
node = _make_node(self.mod)
self.assertIn("/social/faces/detected", node._subs)
self.assertIn("/social/person_states", node._subs)
def test_initial_caches_empty(self):
node = _make_node(self.mod)
self.assertEqual(node._distance_cache, {})
self.assertEqual(node._last_greeted, {})
class TestDistanceCache(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mod = _load_mod()
def setUp(self):
self.node = _make_node(self.mod)
def test_state_updates_cache(self):
ps = _PersonState(face_id=1, distance=1.5)
self.node._on_person_states(_state_msg([ps]))
self.assertAlmostEqual(self.node._distance_cache[1], 1.5)
def test_multiple_states_cached(self):
persons = [_PersonState(face_id=i, distance=float(i)) for i in range(5)]
self.node._on_person_states(_state_msg(persons))
for i in range(5):
self.assertAlmostEqual(self.node._distance_cache[i], float(i))
def test_state_update_overwrites(self):
self.node._on_person_states(_state_msg([_PersonState(face_id=1, distance=3.0)]))
self.node._on_person_states(_state_msg([_PersonState(face_id=1, distance=1.0)]))
self.assertAlmostEqual(self.node._distance_cache[1], 1.0)
def test_negative_face_id_ignored(self):
self.node._on_person_states(_state_msg([_PersonState(face_id=-1, distance=1.0)]))
self.assertNotIn(-1, self.node._distance_cache)
def test_zero_distance_cached(self):
self.node._on_person_states(_state_msg([_PersonState(face_id=5, distance=0.0)]))
self.assertAlmostEqual(self.node._distance_cache[5], 0.0)
class TestGreetingTrigger(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mod = _load_mod()
def setUp(self):
self.node = _make_node(self.mod, proximity_m=2.0, cooldown_s=300.0)
self.pub = self.node._pubs["/saltybot/greeting_trigger"]
def _inject_distance(self, face_id, distance):
self.node._on_person_states(_state_msg([_PersonState(face_id=face_id, distance=distance)]))
def test_triggers_within_proximity(self):
self._inject_distance(1, 1.5)
self.node._on_faces(_face_msg([_FaceDetection(face_id=1, person_name="alice")]))
self.assertEqual(len(self.pub.msgs), 1)
def test_no_trigger_beyond_proximity(self):
self._inject_distance(2, 3.0)
self.node._on_faces(_face_msg([_FaceDetection(face_id=2, person_name="bob")]))
self.assertEqual(len(self.pub.msgs), 0)
def test_trigger_at_exact_proximity(self):
self._inject_distance(3, 2.0)
self.node._on_faces(_face_msg([_FaceDetection(face_id=3, person_name="carol")]))
self.assertEqual(len(self.pub.msgs), 1)
def test_no_trigger_just_beyond(self):
self._inject_distance(4, 2.001)
self.node._on_faces(_face_msg([_FaceDetection(face_id=4, person_name="dave")]))
self.assertEqual(len(self.pub.msgs), 0)
def test_cooldown_suppresses_retrigger(self):
self._inject_distance(5, 1.0)
face = _FaceDetection(face_id=5, person_name="eve")
self.node._on_faces(_face_msg([face]))
self.node._on_faces(_face_msg([face])) # second call in cooldown
self.assertEqual(len(self.pub.msgs), 1)
def test_cooldown_per_face_id(self):
self._inject_distance(6, 1.0)
self._inject_distance(7, 1.0)
self.node._on_faces(_face_msg([_FaceDetection(face_id=6, person_name="f")]))
self.node._on_faces(_face_msg([_FaceDetection(face_id=7, person_name="g")]))
self.assertEqual(len(self.pub.msgs), 2)
def test_expired_cooldown_retrigers(self):
self._inject_distance(8, 1.0)
face = _FaceDetection(face_id=8, person_name="hank")
self.node._on_faces(_face_msg([face]))
# Manually expire the cooldown
self.node._last_greeted[8] = time.monotonic() - 400.0
self.node._on_faces(_face_msg([face]))
self.assertEqual(len(self.pub.msgs), 2)
def test_unknown_face_uses_unknown_distance(self):
# unknown_distance=0.0 → should trigger (0.0 <= 2.0)
node = _make_node(self.mod, unknown_distance=0.0)
pub = node._pubs["/saltybot/greeting_trigger"]
node._on_faces(_face_msg([_FaceDetection(face_id=99, person_name="stranger")]))
self.assertEqual(len(pub.msgs), 1)
def test_unknown_face_large_distance_no_trigger(self):
# unknown_distance=10.0 → should NOT trigger
node = _make_node(self.mod, unknown_distance=10.0)
pub = node._pubs["/saltybot/greeting_trigger"]
node._on_faces(_face_msg([_FaceDetection(face_id=100, person_name="far")]))
self.assertEqual(len(pub.msgs), 0)
def test_multiple_faces_triggers_each_within_range(self):
self._inject_distance(10, 1.0)
self._inject_distance(11, 3.0) # out of range
faces = [
_FaceDetection(face_id=10, person_name="near"),
_FaceDetection(face_id=11, person_name="far"),
]
self.node._on_faces(_face_msg(faces))
self.assertEqual(len(self.pub.msgs), 1)
def test_empty_face_array_no_trigger(self):
self.node._on_faces(_face_msg([]))
self.assertEqual(len(self.pub.msgs), 0)
class TestPayload(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.mod = _load_mod()
def setUp(self):
self.node = _make_node(self.mod)
self.pub = self.node._pubs["/saltybot/greeting_trigger"]
def _trigger(self, face_id=1, person_name="alice", distance=1.5):
self.node._on_person_states(_state_msg([_PersonState(face_id=face_id, distance=distance)]))
self.node._on_faces(_face_msg([_FaceDetection(face_id=face_id, person_name=person_name)]))
def test_payload_is_json(self):
self._trigger()
payload = json.loads(self.pub.msgs[0].data)
self.assertIsInstance(payload, dict)
def test_payload_face_id(self):
self._trigger(face_id=42)
payload = json.loads(self.pub.msgs[0].data)
self.assertEqual(payload["face_id"], 42)
def test_payload_person_name(self):
self._trigger(person_name="zara")
payload = json.loads(self.pub.msgs[0].data)
self.assertEqual(payload["person_name"], "zara")
def test_payload_distance(self):
self._trigger(distance=1.234)
payload = json.loads(self.pub.msgs[0].data)
self.assertAlmostEqual(payload["distance_m"], 1.234, places=2)
def test_payload_has_ts(self):
self._trigger()
payload = json.loads(self.pub.msgs[0].data)
self.assertIn("ts", payload)
self.assertIsInstance(payload["ts"], float)
def test_ts_is_recent(self):
before = time.time()
self._trigger()
after = time.time()
payload = json.loads(self.pub.msgs[0].data)
self.assertGreaterEqual(payload["ts"], before)
self.assertLessEqual(payload["ts"], after + 1.0)
class TestNodeSrc(unittest.TestCase):
"""Source-level checks — verify node structure without instantiation."""
@classmethod
def setUpClass(cls):
with open(_SRC) as f:
cls.src = f.read()
def test_issue_tag(self):
self.assertIn("#270", self.src)
def test_pub_topic(self):
self.assertIn("/saltybot/greeting_trigger", self.src)
def test_faces_topic(self):
self.assertIn("/social/faces/detected", self.src)
def test_states_topic(self):
self.assertIn("/social/person_states", self.src)
def test_proximity_param(self):
self.assertIn("proximity_m", self.src)
def test_cooldown_param(self):
self.assertIn("cooldown_s", self.src)
def test_unknown_distance_param(self):
self.assertIn("unknown_distance", self.src)
def test_json_output(self):
self.assertIn("json", self.src)
def test_face_id_in_payload(self):
self.assertIn("face_id", self.src)
def test_person_name_in_payload(self):
self.assertIn("person_name", self.src)
def test_distance_in_payload(self):
self.assertIn("distance_m", self.src)
def test_main_defined(self):
self.assertIn("def main", self.src)
def test_threading_lock(self):
self.assertIn("threading.Lock", self.src)
class TestConfig(unittest.TestCase):
"""Checks on config/launch/setup files."""
_CONFIG = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/config/greeting_trigger_params.yaml"
)
_LAUNCH = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/launch/greeting_trigger.launch.py"
)
_SETUP = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/setup.py"
)
def test_config_exists(self):
import os
self.assertTrue(os.path.exists(self._CONFIG))
def test_config_proximity(self):
with open(self._CONFIG) as f:
content = f.read()
self.assertIn("proximity_m", content)
def test_config_cooldown(self):
with open(self._CONFIG) as f:
content = f.read()
self.assertIn("cooldown_s", content)
def test_config_node_name(self):
with open(self._CONFIG) as f:
content = f.read()
self.assertIn("greeting_trigger_node", content)
def test_launch_exists(self):
import os
self.assertTrue(os.path.exists(self._LAUNCH))
def test_launch_proximity_arg(self):
with open(self._LAUNCH) as f:
content = f.read()
self.assertIn("proximity_m", content)
def test_launch_cooldown_arg(self):
with open(self._LAUNCH) as f:
content = f.read()
self.assertIn("cooldown_s", content)
def test_entry_point(self):
with open(self._SETUP) as f:
content = f.read()
self.assertIn("greeting_trigger_node", content)
if __name__ == "__main__":
unittest.main()