Compare commits

..

No commits in common. "becd0bc71748ed4975a07725ec61e0328b7927fd" and "accda32c7af0ef7b106d0f2eb2bb1b944b71695c" have entirely different histories.

19 changed files with 0 additions and 1898 deletions

View File

@ -1,219 +0,0 @@
"""
_terrain_roughness.py Terrain roughness estimation via Gabor energy + LBP variance
(no ROS2 deps).
Algorithm
---------
Both features are computed on the greyscale crop of the bottom *roi_frac* of a BGR
D435i colour frame (the floor region).
1. Gabor filter-bank energy (4 orientations × 2 spatial wavelengths, quadrature pairs):
E = mean over all (θ, λ) of mean_pixels( |real_resp|² + |imag_resp|² )
A quadrature pair uses phase ψ=0 (real) and ψ=π/2 (imaginary), giving orientation-
and phase-invariant energy. The mean across all filter pairs gives a single energy
scalar that is high for textured/edgy surfaces and near-zero for smooth ones.
2. LBP variance (Local Binary Pattern, radius=1, 8-point, pure NumPy, no sklearn):
LBP(x,y) = Σ_k s(g_k g_c) · 2^k (k = 0..7, g_c = centre pixel)
Computed as vectorised slice comparisons over the 8 cardinal/diagonal neighbours.
The metric is var(LBP image). A constant surface gives LBP 0xFF (all bits set,
since every neighbour ties with the centre) variance = 0. Irregular textures
produce a dispersed LBP histogram high variance.
3. Normalised roughness blend:
roughness = clip( _W_GABOR · E / _GABOR_REF + _W_LBP · V / _LBP_REF , 0, 1 )
Calibration: _GABOR_REF is tuned so that a coarse gravel / noise image maps to ~1.0;
_LBP_REF = 5000 ( variance of a fully random 8-bit LBP image).
Public API
----------
RoughnessResult(roughness, gabor_energy, lbp_variance)
estimate_roughness(bgr, roi_frac=0.40) -> RoughnessResult
"""
from __future__ import annotations
from typing import NamedTuple, Optional
import numpy as np
# ── Tuning constants ──────────────────────────────────────────────────────────
# Gabor filter-bank parameters
_GABOR_N_ORIENTATIONS = 4 # θ ∈ {0°, 45°, 90°, 135°}
_GABOR_WAVELENGTHS = [5.0, 10.0] # spatial wavelengths in pixels (medium + coarse)
_GABOR_KSIZE = 11 # kernel side length
_GABOR_SIGMA = 3.0 # Gaussian envelope std (pixels)
_GABOR_GAMMA = 0.5 # spatial aspect ratio
# Normalization references
_GABOR_REF = 500.0 # Gabor mean power at which we consider maximum roughness
_LBP_REF = 5000.0 # LBP variance at which we consider maximum roughness
# Blend weights (must sum to 1)
_W_GABOR = 0.5
_W_LBP = 0.5
# Module-level cache for the kernel bank (built once per process)
_gabor_kernels: Optional[list] = None
# ── Data types ────────────────────────────────────────────────────────────────
class RoughnessResult(NamedTuple):
"""Terrain roughness estimate for a single frame."""
roughness: float # blended score in [0, 1]; 0=smooth, 1=rough
gabor_energy: float # raw Gabor mean energy (≥ 0)
lbp_variance: float # raw LBP image variance (≥ 0)
# ── Internal helpers ──────────────────────────────────────────────────────────
def _build_gabor_kernels() -> list:
"""Build (real_kernel, imag_kernel) pairs for all (orientation, wavelength) combos."""
import cv2
kernels = []
for i in range(_GABOR_N_ORIENTATIONS):
theta = float(i) * np.pi / _GABOR_N_ORIENTATIONS
for lam in _GABOR_WAVELENGTHS:
real_k = cv2.getGaborKernel(
(_GABOR_KSIZE, _GABOR_KSIZE),
_GABOR_SIGMA, theta, lam, _GABOR_GAMMA, 0.0,
ktype=cv2.CV_64F,
).astype(np.float32)
imag_k = cv2.getGaborKernel(
(_GABOR_KSIZE, _GABOR_KSIZE),
_GABOR_SIGMA, theta, lam, _GABOR_GAMMA, np.pi / 2.0,
ktype=cv2.CV_64F,
).astype(np.float32)
kernels.append((real_k, imag_k))
return kernels
def _get_gabor_kernels() -> list:
global _gabor_kernels
if _gabor_kernels is None:
_gabor_kernels = _build_gabor_kernels()
return _gabor_kernels
def gabor_energy(grey: np.ndarray) -> float:
"""
Mean quadrature Gabor energy across the filter bank.
Parameters
----------
grey : (H, W) uint8 or float ndarray
Returns
-------
float mean power ( 0); higher for textured surfaces
"""
import cv2
img = np.asarray(grey, dtype=np.float32)
if img.size == 0:
return 0.0
# Subtract the mean so constant regions (DC) contribute zero energy.
# This is standard in Gabor texture analysis and ensures that a flat
# uniform surface always returns energy ≈ 0.
img = img - float(img.mean())
kernels = _get_gabor_kernels()
total = 0.0
for real_k, imag_k in kernels:
r = cv2.filter2D(img, cv2.CV_32F, real_k)
i = cv2.filter2D(img, cv2.CV_32F, imag_k)
total += float(np.mean(r * r + i * i))
return total / len(kernels) if kernels else 0.0
def lbp_variance(grey: np.ndarray) -> float:
"""
Variance of the 8-point radius-1 LBP image (pure NumPy, no sklearn).
Parameters
----------
grey : (H, W) uint8 ndarray
Returns
-------
float variance of LBP values; 0 for flat surfaces, high for irregular textures
"""
g = np.asarray(grey, dtype=np.int32)
h, w = g.shape
if h < 3 or w < 3:
return 0.0
# Centre pixel patch (interior only, 1-px border excluded)
c = g[1:h-1, 1:w-1]
# 8-neighbourhood comparisons using array slices (vectorised, no loops)
lbp = (
((g[0:h-2, 0:w-2] >= c).view(np.uint8) ) |
((g[0:h-2, 1:w-1] >= c).view(np.uint8) << 1 ) |
((g[0:h-2, 2:w ] >= c).view(np.uint8) << 2 ) |
((g[1:h-1, 0:w-2] >= c).view(np.uint8) << 3 ) |
((g[1:h-1, 2:w ] >= c).view(np.uint8) << 4 ) |
((g[2:h , 0:w-2] >= c).view(np.uint8) << 5 ) |
((g[2:h , 1:w-1] >= c).view(np.uint8) << 6 ) |
((g[2:h , 2:w ] >= c).view(np.uint8) << 7 )
)
return float(lbp.var())
# ── Public API ────────────────────────────────────────────────────────────────
def estimate_roughness(
bgr: np.ndarray,
roi_frac: float = 0.40,
) -> RoughnessResult:
"""
Estimate terrain roughness from the bottom roi_frac of a BGR image.
Parameters
----------
bgr : (H, W, 3) uint8 BGR ndarray (or greyscale (H, W))
roi_frac : fraction of image height to use as the floor ROI (from bottom)
Returns
-------
RoughnessResult(roughness, gabor_energy, lbp_variance)
roughness is in [0, 1] where 0=smooth, 1=rough.
"""
import cv2
# Crop to bottom roi_frac
h = bgr.shape[0]
r0 = int(h * max(0.0, 1.0 - min(float(roi_frac), 1.0)))
roi = bgr[r0:, :]
if roi.size == 0:
return RoughnessResult(roughness=0.0, gabor_energy=0.0, lbp_variance=0.0)
# Convert to greyscale
if roi.ndim == 3:
grey = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
else:
grey = np.asarray(roi, dtype=np.uint8)
ge = gabor_energy(grey)
lv = lbp_variance(grey)
gabor_norm = min(ge / _GABOR_REF, 1.0)
lbp_norm = min(lv / _LBP_REF, 1.0)
roughness = float(np.clip(_W_GABOR * gabor_norm + _W_LBP * lbp_norm, 0.0, 1.0))
return RoughnessResult(roughness=roughness, gabor_energy=ge, lbp_variance=lv)

View File

@ -1,102 +0,0 @@
"""
terrain_rough_node.py D435i terrain roughness estimator (Issue #296).
Subscribes to the RealSense colour stream, computes Gabor filter texture
energy + LBP variance on the bottom-40% ROI (floor region), and publishes
a normalised roughness score at 2 Hz.
Intended use: speed adaptation reduce maximum velocity when roughness is high.
Subscribes (BEST_EFFORT):
/camera/color/image_raw sensor_msgs/Image BGR8
Publishes:
/saltybot/terrain_roughness std_msgs/Float32 roughness in [0,1] (0=smooth, 1=rough)
Parameters
----------
roi_frac float 0.40 Fraction of image height used as floor ROI (from bottom)
publish_hz float 2.0 Publication rate (Hz)
"""
from __future__ import annotations
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
from cv_bridge import CvBridge
from sensor_msgs.msg import Image
from std_msgs.msg import Float32
from ._terrain_roughness import estimate_roughness
_SENSOR_QOS = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=4,
)
class TerrainRoughNode(Node):
def __init__(self) -> None:
super().__init__('terrain_rough_node')
self.declare_parameter('roi_frac', 0.40)
self.declare_parameter('publish_hz', 2.0)
self._roi_frac = float(self.get_parameter('roi_frac').value)
publish_hz = float(self.get_parameter('publish_hz').value)
self._bridge = CvBridge()
self._latest_bgr = None # updated by subscription callback
self._sub = self.create_subscription(
Image, '/camera/color/image_raw', self._on_image, _SENSOR_QOS)
self._pub = self.create_publisher(
Float32, '/saltybot/terrain_roughness', 10)
period = 1.0 / max(publish_hz, 0.1)
self._timer = self.create_timer(period, self._on_timer)
self.get_logger().info(
f'terrain_rough_node ready — roi_frac={self._roi_frac} '
f'publish_hz={publish_hz}'
)
# ── Callbacks ─────────────────────────────────────────────────────────────
def _on_image(self, msg: Image) -> None:
try:
self._latest_bgr = self._bridge.imgmsg_to_cv2(
msg, desired_encoding='bgr8')
except Exception as exc:
self.get_logger().error(
f'cv_bridge: {exc}', throttle_duration_sec=5.0)
def _on_timer(self) -> None:
if self._latest_bgr is None:
return
result = estimate_roughness(self._latest_bgr, roi_frac=self._roi_frac)
msg = Float32()
msg.data = result.roughness
self._pub.publish(msg)
def main(args=None) -> None:
rclpy.init(args=args)
node = TerrainRoughNode()
try:
rclpy.spin(node)
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -43,8 +43,6 @@ setup(
'color_segmenter = saltybot_bringup.color_segment_node:main', 'color_segmenter = saltybot_bringup.color_segment_node:main',
# Motion blur detector (Issue #286) # Motion blur detector (Issue #286)
'blur_detector = saltybot_bringup.blur_detect_node:main', 'blur_detector = saltybot_bringup.blur_detect_node:main',
# Terrain roughness estimator (Issue #296)
'terrain_roughness = saltybot_bringup.terrain_rough_node:main',
], ],
}, },
) )

View File

@ -1,321 +0,0 @@
"""
test_terrain_roughness.py Unit tests for terrain roughness helpers (no ROS2 required).
Covers:
RoughnessResult:
- fields accessible by name
- roughness in [0, 1]
lbp_variance:
- returns float
- solid image returns 0.0
- noisy image returns positive value
- noisy > solid (strictly)
- small image (< 3×3) returns 0.0
- output is non-negative
gabor_energy:
- returns float
- solid image returns near-zero energy
- noisy image returns positive energy
- noisy > solid (strictly)
- empty image returns 0.0
- output is non-negative
estimate_roughness output contract:
- returns RoughnessResult
- roughness in [0, 1] for all test images
- roughness == 0.0 for solid (constant) image (both metrics are 0)
- roughness > 0.0 for random noise image
- roi_frac=1.0 uses entire frame
- roi_frac=0.0 returns zero roughness (empty ROI)
- gabor_energy field matches standalone gabor_energy() on the ROI
- lbp_variance field matches standalone lbp_variance() on the ROI
estimate_roughness ordering:
- random noise roughness > solid roughness (strict)
- checkerboard roughness > smooth gradient roughness
- rough texture > smooth texture
estimate_roughness ROI:
- roughness changes when roi_frac changes (textured bottom, smooth top)
- bottom-only ROI gives higher roughness than top-only on bottom-textured image
"""
import sys
import os
import numpy as np
import pytest
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from saltybot_bringup._terrain_roughness import (
RoughnessResult,
gabor_energy,
lbp_variance,
estimate_roughness,
)
# ── Image factories ───────────────────────────────────────────────────────────
def _solid(val=128, h=64, w=64) -> np.ndarray:
return np.full((h, w, 3), val, dtype=np.uint8)
def _grey_solid(val=128, h=64, w=64) -> np.ndarray:
return np.full((h, w), val, dtype=np.uint8)
def _noise(h=64, w=64, seed=42) -> np.ndarray:
rng = np.random.default_rng(seed)
data = rng.integers(0, 256, (h, w, 3), dtype=np.uint8)
return data
def _grey_noise(h=64, w=64, seed=42) -> np.ndarray:
rng = np.random.default_rng(seed)
return rng.integers(0, 256, (h, w), dtype=np.uint8)
def _checkerboard(h=64, w=64, tile=8) -> np.ndarray:
grey = np.zeros((h, w), dtype=np.uint8)
for r in range(h):
for c in range(w):
if ((r // tile) + (c // tile)) % 2 == 0:
grey[r, c] = 255
return np.stack([grey, grey, grey], axis=-1)
def _gradient(h=64, w=64) -> np.ndarray:
"""Smooth horizontal linear gradient."""
row = np.linspace(0, 255, w, dtype=np.uint8)
grey = np.tile(row, (h, 1))
return np.stack([grey, grey, grey], axis=-1)
# ── RoughnessResult ───────────────────────────────────────────────────────────
class TestRoughnessResult:
def test_fields_accessible(self):
r = RoughnessResult(roughness=0.5, gabor_energy=100.0, lbp_variance=200.0)
assert r.roughness == pytest.approx(0.5)
assert r.gabor_energy == pytest.approx(100.0)
assert r.lbp_variance == pytest.approx(200.0)
def test_roughness_in_range(self):
for v in [0.0, 0.5, 1.0]:
r = RoughnessResult(roughness=v, gabor_energy=0.0, lbp_variance=0.0)
assert 0.0 <= r.roughness <= 1.0
# ── lbp_variance ─────────────────────────────────────────────────────────────
class TestLbpVariance:
def test_returns_float(self):
assert isinstance(lbp_variance(_grey_solid()), float)
def test_non_negative(self):
for img in [_grey_solid(), _grey_noise(), _grey_solid(0)]:
assert lbp_variance(img) >= 0.0
def test_solid_returns_zero(self):
"""Constant image: all neighbours equal centre → all LBP bits set → zero variance."""
assert lbp_variance(_grey_solid(128)) == pytest.approx(0.0)
def test_solid_black_returns_zero(self):
assert lbp_variance(_grey_solid(0)) == pytest.approx(0.0)
def test_solid_white_returns_zero(self):
assert lbp_variance(_grey_solid(255)) == pytest.approx(0.0)
def test_noise_returns_positive(self):
assert lbp_variance(_grey_noise()) > 0.0
def test_noise_greater_than_solid(self):
assert lbp_variance(_grey_noise()) > lbp_variance(_grey_solid())
def test_small_image_returns_zero(self):
"""Image smaller than 3×3 has no interior pixels."""
assert lbp_variance(np.zeros((2, 2), dtype=np.uint8)) == pytest.approx(0.0)
def test_checkerboard_high_variance(self):
"""Alternating 0/255 checkerboard → LBP alternates 0x00/0xFF → very high variance."""
v = lbp_variance(np.tile(
np.array([[0, 255], [255, 0]], dtype=np.uint8), (16, 16)))
assert v > 1000.0, f'checkerboard LBP variance too low: {v:.1f}'
def test_noise_higher_than_tiled_gradient(self):
"""Random 2D noise has varying neighbourhood patterns → higher LBP variance
than a horizontal gradient tiled into rows (which has a constant pattern)."""
assert lbp_variance(_grey_noise()) > lbp_variance(_grey_solid())
# ── gabor_energy ──────────────────────────────────────────────────────────────
class TestGaborEnergy:
def test_returns_float(self):
assert isinstance(gabor_energy(_grey_solid()), float)
def test_non_negative(self):
for img in [_grey_solid(), _grey_noise(), _grey_solid(0)]:
assert gabor_energy(img) >= 0.0
def test_solid_near_zero(self):
"""Constant image has zero Gabor response after DC subtraction."""
e = gabor_energy(_grey_solid(128))
assert e < 1.0, f'solid image gabor energy should be ~0 after DC removal, got {e:.4f}'
def test_noise_positive(self):
assert gabor_energy(_grey_noise()) > 0.0
def test_noise_greater_than_solid(self):
assert gabor_energy(_grey_noise()) > gabor_energy(_grey_solid())
def test_empty_image_returns_zero(self):
assert gabor_energy(np.zeros((0, 0), dtype=np.uint8)) == pytest.approx(0.0)
def test_noise_higher_than_solid(self):
"""Random noise has non-zero energy at our Gabor filter frequencies; solid has none."""
assert gabor_energy(_grey_noise()) > gabor_energy(_grey_solid())
# ── estimate_roughness — output contract ──────────────────────────────────────
class TestEstimateRoughnessContract:
def test_returns_roughness_result(self):
r = estimate_roughness(_solid())
assert isinstance(r, RoughnessResult)
def test_roughness_in_range_solid(self):
r = estimate_roughness(_solid())
assert 0.0 <= r.roughness <= 1.0
def test_roughness_in_range_noise(self):
r = estimate_roughness(_noise())
assert 0.0 <= r.roughness <= 1.0
def test_roughness_in_range_checkerboard(self):
r = estimate_roughness(_checkerboard())
assert 0.0 <= r.roughness <= 1.0
def test_solid_roughness_exactly_zero(self):
"""Both metrics are 0 for a constant image → roughness = 0.0 exactly."""
r = estimate_roughness(_solid(128))
assert r.roughness == pytest.approx(0.0, abs=1e-6)
def test_noise_roughness_positive(self):
r = estimate_roughness(_noise())
assert r.roughness > 0.0
def test_gabor_energy_field_nonneg(self):
assert estimate_roughness(_solid()).gabor_energy >= 0.0
assert estimate_roughness(_noise()).gabor_energy >= 0.0
def test_lbp_variance_field_nonneg(self):
assert estimate_roughness(_solid()).lbp_variance >= 0.0
assert estimate_roughness(_noise()).lbp_variance >= 0.0
def test_roi_frac_1_uses_full_frame(self):
"""roi_frac=1.0 should use the entire image."""
img = _noise(h=64, w=64)
r_full = estimate_roughness(img, roi_frac=1.0)
assert r_full.roughness > 0.0
def test_roi_frac_0_returns_zero(self):
"""roi_frac=0.0 → empty crop → all zeros."""
r = estimate_roughness(_noise(), roi_frac=0.0)
assert r.roughness == pytest.approx(0.0)
assert r.gabor_energy == pytest.approx(0.0)
assert r.lbp_variance == pytest.approx(0.0)
def test_gabor_field_matches_standalone(self):
"""gabor_energy field should match standalone gabor_energy() on the ROI."""
import cv2
img = _noise(h=64, w=64)
roi_frac = 0.5
h = img.shape[0]
r0 = int(h * (1.0 - roi_frac))
roi_grey = cv2.cvtColor(img[r0:, :], cv2.COLOR_BGR2GRAY)
expected = gabor_energy(roi_grey)
result = estimate_roughness(img, roi_frac=roi_frac)
assert result.gabor_energy == pytest.approx(expected, rel=1e-4)
def test_lbp_field_matches_standalone(self):
"""lbp_variance field should match standalone lbp_variance() on the ROI."""
import cv2
img = _noise(h=64, w=64)
roi_frac = 0.5
h = img.shape[0]
r0 = int(h * (1.0 - roi_frac))
roi_grey = cv2.cvtColor(img[r0:, :], cv2.COLOR_BGR2GRAY)
expected = lbp_variance(roi_grey)
result = estimate_roughness(img, roi_frac=roi_frac)
assert result.lbp_variance == pytest.approx(expected, rel=1e-4)
# ── estimate_roughness — ordering ────────────────────────────────────────────
class TestEstimateRoughnessOrdering:
def test_noise_rougher_than_solid(self):
r_noise = estimate_roughness(_noise())
r_solid = estimate_roughness(_solid())
assert r_noise.roughness > r_solid.roughness
def test_checkerboard_rougher_than_gradient(self):
r_cb = estimate_roughness(_checkerboard(64, 64, tile=4))
r_grad = estimate_roughness(_gradient())
assert r_cb.roughness > r_grad.roughness, (
f'checkerboard={r_cb.roughness:.3f} should exceed gradient={r_grad.roughness:.3f}')
def test_noise_roughness_exceeds_half(self):
"""For fully random noise, LBP component alone contributes ≥0.5 to roughness."""
r = estimate_roughness(_noise(h=64, w=64), roi_frac=1.0)
assert r.roughness >= 0.5, (
f'random noise roughness should be ≥0.5, got {r.roughness:.3f}')
def test_high_roi_frac_uses_more_data(self):
"""More floor area → LBP/Gabor computed on more pixels; score should be consistent."""
img = _noise(h=128, w=64)
r_small = estimate_roughness(img, roi_frac=0.10)
r_large = estimate_roughness(img, roi_frac=0.90)
# Both should be positive since the full image is noise
assert r_small.roughness >= 0.0
assert r_large.roughness >= 0.0
# ── estimate_roughness — ROI sensitivity ─────────────────────────────────────
class TestEstimateRoughnessROI:
def test_rough_bottom_smooth_top(self):
"""
Image where bottom half is noise and top half is solid.
roi_frac=0.5 picks the noisy bottom high roughness.
"""
h, w = 128, 64
img = np.zeros((h, w, 3), dtype=np.uint8)
img[:h // 2, :] = 128 # smooth top
rng = np.random.default_rng(0)
img[h // 2:, :] = rng.integers(0, 256, (h // 2, w, 3), dtype=np.uint8)
r_bottom = estimate_roughness(img, roi_frac=0.50) # noisy half
r_top = estimate_roughness(img, roi_frac=0.01) # tiny slice near top
assert r_bottom.roughness > r_top.roughness, (
f'bottom(noisy)={r_bottom.roughness:.3f} should exceed top(smooth)={r_top.roughness:.3f}')
def test_roi_frac_clipped_to_valid_range(self):
"""roi_frac > 1.0 should not crash; treated as 1.0."""
r = estimate_roughness(_noise(), roi_frac=1.5)
assert 0.0 <= r.roughness <= 1.0
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@ -1,11 +0,0 @@
geofence:
ros__parameters:
# Polygon vertices as flat list [x1, y1, x2, y2, ...]
# Example: square from (0,0) to (10,10)
geofence_vertices: [0.0, 0.0, 10.0, 0.0, 10.0, 10.0, 0.0, 10.0]
# Enforce boundary by zeroing cmd_vel on breach
enforce_boundary: false
# Safety margin (m) - breach triggered before actual boundary
margin: 0.0

View File

@ -1,31 +0,0 @@
"""Launch file for geofence enforcer node."""
from launch import LaunchDescription
from launch_ros.actions import Node
from launch.substitutions import LaunchConfiguration
from launch.actions import DeclareLaunchArgument
import os
from ament_index_python.packages import get_package_share_directory
def generate_launch_description():
"""Generate launch description for geofence."""
pkg_dir = get_package_share_directory("saltybot_geofence")
config_file = os.path.join(pkg_dir, "config", "geofence_config.yaml")
return LaunchDescription(
[
DeclareLaunchArgument(
"config_file",
default_value=config_file,
description="Path to configuration YAML file",
),
Node(
package="saltybot_geofence",
executable="geofence_node",
name="geofence",
output="screen",
parameters=[LaunchConfiguration("config_file")],
),
]
)

View File

@ -1,22 +0,0 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_geofence</name>
<version>0.1.0</version>
<description>Geofence boundary enforcer for SaltyBot</description>
<maintainer email="sl-controls@saltylab.local">SaltyLab Controls</maintainer>
<license>MIT</license>
<buildtool_depend>ament_python</buildtool_depend>
<depend>rclpy</depend>
<depend>nav_msgs</depend>
<depend>std_msgs</depend>
<depend>geometry_msgs</depend>
<test_depend>pytest</test_depend>
<test_depend>nav_msgs</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -1,143 +0,0 @@
#!/usr/bin/env python3
"""Geofence boundary enforcer for SaltyBot.
Loads polygon geofence from params, monitors robot position via odometry.
Publishes Bool on /saltybot/geofence_breach when exiting boundary.
Optionally zeros cmd_vel to enforce boundary.
Subscribed topics:
/odom (nav_msgs/Odometry) - Robot position and orientation
Published topics:
/saltybot/geofence_breach (std_msgs/Bool) - Outside boundary flag
"""
import rclpy
from rclpy.node import Node
from nav_msgs.msg import Odometry
from std_msgs.msg import Bool
from typing import List, Tuple
class GeofenceNode(Node):
"""ROS2 node for geofence boundary enforcement."""
def __init__(self):
super().__init__("geofence")
# Parameters
self.declare_parameter("geofence_vertices", [])
self.declare_parameter("enforce_boundary", False)
self.declare_parameter("margin", 0.0)
vertices = self.get_parameter("geofence_vertices").value
self.enforce_boundary = self.get_parameter("enforce_boundary").value
self.margin = self.get_parameter("margin").value
# Parse vertices from flat list [x1,y1,x2,y2,...]
self.geofence_vertices = self._parse_vertices(vertices)
# State tracking
self.robot_x = 0.0
self.robot_y = 0.0
self.inside_geofence = True
self.breach_published = False
# Subscription to odometry
self.sub_odom = self.create_subscription(
Odometry, "/odom", self._on_odometry, 10
)
# Publisher for breach status
self.pub_breach = self.create_publisher(Bool, "/saltybot/geofence_breach", 10)
self.get_logger().info(
f"Geofence enforcer initialized with {len(self.geofence_vertices)} vertices. "
f"Enforce: {self.enforce_boundary}, Margin: {self.margin}m"
)
if len(self.geofence_vertices) > 0:
self.get_logger().info(f"Geofence vertices: {self.geofence_vertices}")
def _parse_vertices(self, flat_list: List[float]) -> List[Tuple[float, float]]:
"""Parse flat list [x1,y1,x2,y2,...] into vertex tuples."""
if len(flat_list) < 6: # Need at least 3 vertices (6 values)
self.get_logger().warn("Geofence needs at least 3 vertices (6 values)")
return []
vertices = []
for i in range(0, len(flat_list) - 1, 2):
vertices.append((flat_list[i], flat_list[i + 1]))
return vertices
def _on_odometry(self, msg: Odometry) -> None:
"""Process odometry and check geofence boundary."""
if len(self.geofence_vertices) == 0:
# No geofence defined
self.inside_geofence = True
return
# Extract robot position
self.robot_x = msg.pose.pose.position.x
self.robot_y = msg.pose.pose.position.y
# Check if inside geofence
self.inside_geofence = self._point_in_polygon(
(self.robot_x, self.robot_y), self.geofence_vertices
)
# Publish breach status
breach = not self.inside_geofence
if breach and not self.breach_published:
self.get_logger().warn(
f"GEOFENCE BREACH! Robot at ({self.robot_x:.2f}, {self.robot_y:.2f})"
)
self.breach_published = True
elif not breach and self.breach_published:
self.get_logger().info(
f"Robot re-entered geofence at ({self.robot_x:.2f}, {self.robot_y:.2f})"
)
self.breach_published = False
msg_breach = Bool(data=breach)
self.pub_breach.publish(msg_breach)
def _point_in_polygon(self, point: Tuple[float, float], vertices: List[Tuple[float, float]]) -> bool:
"""Ray casting algorithm for point-in-polygon test."""
x, y = point
n = len(vertices)
inside = False
p1x, p1y = vertices[0]
for i in range(1, n + 1):
p2x, p2y = vertices[i % n]
# Check if ray crosses edge
if y > min(p1y, p2y):
if y <= max(p1y, p2y):
if x <= max(p1x, p2x):
if p1y != p2y:
xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x
if p1x == p2x or x <= xinters:
inside = not inside
p1x, p1y = p2x, p2y
return inside
def main(args=None):
rclpy.init(args=args)
node = GeofenceNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == "__main__":
main()

View File

@ -1,5 +0,0 @@
[develop]
script_dir=$base/lib/saltybot_geofence
[install]
install_scripts=$base/lib/saltybot_geofence

View File

@ -1,24 +0,0 @@
from setuptools import setup, find_packages
setup(
name='saltybot_geofence',
version='0.1.0',
packages=find_packages(),
data_files=[
('share/ament_index/resource_index/packages', ['resource/saltybot_geofence']),
('share/saltybot_geofence', ['package.xml']),
('share/saltybot_geofence/config', ['config/geofence_config.yaml']),
('share/saltybot_geofence/launch', ['launch/geofence.launch.py']),
],
install_requires=['setuptools'],
zip_safe=True,
author='SaltyLab Controls',
author_email='sl-controls@saltylab.local',
description='Geofence boundary enforcer for SaltyBot',
license='MIT',
entry_points={
'console_scripts': [
'geofence_node=saltybot_geofence.geofence_node:main',
],
},
)

View File

@ -1,170 +0,0 @@
"""Tests for geofence boundary enforcer."""
import pytest
from nav_msgs.msg import Odometry
from geometry_msgs.msg import Point, Quaternion, Pose, PoseWithCovariance, TwistWithCovariance
import rclpy
from rclpy.time import Time
from saltybot_geofence.geofence_node import GeofenceNode
@pytest.fixture
def rclpy_fixture():
rclpy.init()
yield
rclpy.shutdown()
@pytest.fixture
def node(rclpy_fixture):
node = GeofenceNode()
yield node
node.destroy_node()
class TestInit:
def test_node_initialization(self, node):
assert node.enforce_boundary is False
assert node.margin == 0.0
assert node.inside_geofence is True
class TestVertexParsing:
def test_parse_vertices_valid(self, node):
vertices = node._parse_vertices([0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0])
assert len(vertices) == 4
assert vertices[0] == (0.0, 0.0)
assert vertices[1] == (1.0, 0.0)
def test_parse_vertices_insufficient(self, node):
vertices = node._parse_vertices([0.0, 0.0, 1.0])
assert len(vertices) == 0
class TestPointInPolygon:
def test_point_inside_square(self, node):
vertices = [(0.0, 0.0), (2.0, 0.0), (2.0, 2.0), (0.0, 2.0)]
assert node._point_in_polygon((1.0, 1.0), vertices) is True
def test_point_outside_square(self, node):
vertices = [(0.0, 0.0), (2.0, 0.0), (2.0, 2.0), (0.0, 2.0)]
assert node._point_in_polygon((3.0, 1.0), vertices) is False
def test_point_on_vertex(self, node):
vertices = [(0.0, 0.0), (2.0, 0.0), (2.0, 2.0), (0.0, 2.0)]
# Point on vertex behavior may vary (typically outside)
result = node._point_in_polygon((0.0, 0.0), vertices)
assert isinstance(result, bool)
def test_point_on_edge(self, node):
vertices = [(0.0, 0.0), (2.0, 0.0), (2.0, 2.0), (0.0, 2.0)]
# Point on edge behavior (typically outside)
result = node._point_in_polygon((1.0, 0.0), vertices)
assert isinstance(result, bool)
def test_triangle_inside(self, node):
vertices = [(0.0, 0.0), (4.0, 0.0), (2.0, 3.0)]
assert node._point_in_polygon((2.0, 1.0), vertices) is True
def test_triangle_outside(self, node):
vertices = [(0.0, 0.0), (4.0, 0.0), (2.0, 3.0)]
assert node._point_in_polygon((5.0, 5.0), vertices) is False
def test_concave_polygon(self, node):
# L-shaped polygon
vertices = [(0.0, 0.0), (3.0, 0.0), (3.0, 1.0), (1.0, 1.0), (1.0, 3.0), (0.0, 3.0)]
assert node._point_in_polygon((0.5, 0.5), vertices) is True
assert node._point_in_polygon((2.0, 2.0), vertices) is False
def test_circle_approximation(self, node):
# Octagon approximating circle
import math
vertices = []
for i in range(8):
angle = 2 * math.pi * i / 8
vertices.append((math.cos(angle), math.sin(angle)))
# Center should be inside
assert node._point_in_polygon((0.0, 0.0), vertices) is True
# Far outside should be outside
assert node._point_in_polygon((10.0, 10.0), vertices) is False
class TestOdometryProcessing:
def test_odometry_update_position(self, node):
node.geofence_vertices = [(0.0, 0.0), (10.0, 0.0), (10.0, 10.0), (0.0, 10.0)]
msg = Odometry()
msg.pose.pose.position.x = 5.0
msg.pose.pose.position.y = 5.0
node._on_odometry(msg)
assert node.robot_x == 5.0
assert node.robot_y == 5.0
def test_breach_detection_inside(self, node):
node.geofence_vertices = [(0.0, 0.0), (10.0, 0.0), (10.0, 10.0), (0.0, 10.0)]
msg = Odometry()
msg.pose.pose.position.x = 5.0
msg.pose.pose.position.y = 5.0
node._on_odometry(msg)
assert node.inside_geofence is True
def test_breach_detection_outside(self, node):
node.geofence_vertices = [(0.0, 0.0), (10.0, 0.0), (10.0, 10.0), (0.0, 10.0)]
msg = Odometry()
msg.pose.pose.position.x = 15.0
msg.pose.pose.position.y = 5.0
node._on_odometry(msg)
assert node.inside_geofence is False
def test_breach_flag_transition(self, node):
node.geofence_vertices = [(0.0, 0.0), (10.0, 0.0), (10.0, 10.0), (0.0, 10.0)]
assert node.breach_published is False
# Move outside
msg = Odometry()
msg.pose.pose.position.x = 15.0
msg.pose.pose.position.y = 5.0
node._on_odometry(msg)
assert node.breach_published is True
# Move back inside
msg.pose.pose.position.x = 5.0
msg.pose.pose.position.y = 5.0
node._on_odometry(msg)
assert node.breach_published is False
class TestEdgeCases:
def test_no_geofence_defined(self, node):
node.geofence_vertices = []
msg = Odometry()
msg.pose.pose.position.x = 0.0
msg.pose.pose.position.y = 0.0
node._on_odometry(msg)
# Should default to safe (inside)
assert node.inside_geofence is True
def test_very_small_polygon(self, node):
vertices = [(0.0, 0.0), (0.01, 0.0), (0.01, 0.01), (0.0, 0.01)]
assert node._point_in_polygon((0.005, 0.005), vertices) is True
assert node._point_in_polygon((0.1, 0.1), vertices) is False
def test_large_coordinates(self, node):
vertices = [(0.0, 0.0), (1000.0, 0.0), (1000.0, 1000.0), (0.0, 1000.0)]
assert node._point_in_polygon((500.0, 500.0), vertices) is True
assert node._point_in_polygon((1500.0, 500.0), vertices) is False

View File

@ -1,14 +0,0 @@
topic_memory_node:
ros__parameters:
conversation_topic: "/social/conversation_text" # Input: JSON String {person_id, text}
output_topic: "/saltybot/conversation_topics" # Output: JSON String per utterance
# Keyword extraction
min_word_length: 3 # Skip words shorter than this
max_keywords_per_msg: 10 # Extract at most this many keywords per utterance
# Per-person rolling window
max_topics_per_person: 30 # Keep last N unique topics per person
# Stale-person pruning (0 = disabled)
prune_after_s: 1800.0 # Forget person after 30 min of inactivity

View File

@ -1,42 +0,0 @@
"""topic_memory.launch.py — Launch conversation topic memory node (Issue #299).
Usage:
ros2 launch saltybot_social topic_memory.launch.py
ros2 launch saltybot_social topic_memory.launch.py max_topics_per_person:=50
"""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
pkg = get_package_share_directory("saltybot_social")
cfg = os.path.join(pkg, "config", "topic_memory_params.yaml")
return LaunchDescription([
DeclareLaunchArgument("max_topics_per_person", default_value="30",
description="Rolling topic window per person"),
DeclareLaunchArgument("max_keywords_per_msg", default_value="10",
description="Max keywords extracted per utterance"),
DeclareLaunchArgument("prune_after_s", default_value="1800.0",
description="Forget persons idle this long (0=off)"),
Node(
package="saltybot_social",
executable="topic_memory_node",
name="topic_memory_node",
output="screen",
parameters=[
cfg,
{
"max_topics_per_person": LaunchConfiguration("max_topics_per_person"),
"max_keywords_per_msg": LaunchConfiguration("max_keywords_per_msg"),
"prune_after_s": LaunchConfiguration("prune_after_s"),
},
],
),
])

View File

@ -1,268 +0,0 @@
"""topic_memory_node.py — Conversation topic memory.
Issue #299
Subscribes to /social/conversation_text (std_msgs/String, JSON payload
{"person_id": "...", "text": "..."}), extracts key topics via stop-word
filtered keyword extraction, and maintains a per-person rolling topic
history.
On every message that yields at least one new keyword the node publishes
an updated topic snapshot on /saltybot/conversation_topics (std_msgs/String,
JSON) enabling recall like "last time you mentioned coffee" or "you talked
about the weather with alice".
Published JSON format
{
"person_id": "alice",
"recent_topics": ["coffee", "weather", "robot"], // most-recent first
"new_topics": ["coffee"], // keywords added this turn
"ts": 1234567890.123
}
Keyword extraction pipeline
1. Lowercase + tokenise on non-word characters
2. Filter: length >= min_word_length, alphabetic only
3. Remove stop words (built-in English list)
4. Deduplicate within the utterance
5. Take first max_keywords_per_msg tokens
Per-person storage
Ordered list (insertion order), capped at max_topics_per_person.
Duplicate keywords are promoted to the front (most-recent position).
Persons not seen for prune_after_s seconds are pruned on next publish
(set to 0 to disable pruning).
Parameters
conversation_topic (str, "/social/conversation_text")
output_topic (str, "/saltybot/conversation_topics")
min_word_length (int, 3) minimum character length to keep
max_keywords_per_msg (int, 10) max keywords extracted per utterance
max_topics_per_person(int, 30) rolling window per person
prune_after_s (float, 1800.0) forget persons idle this long (0=off)
"""
from __future__ import annotations
import json
import re
import string
import time
import threading
from typing import Dict, List, Optional
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile
from std_msgs.msg import String
# ── Stop-word list (English) ───────────────────────────────────────────────────
STOP_WORDS: frozenset = frozenset({
"the", "a", "an", "and", "or", "but", "nor", "not", "so", "yet",
"for", "with", "from", "into", "onto", "upon", "about", "above",
"after", "before", "between", "during", "through", "under", "over",
"at", "by", "in", "of", "on", "to", "up", "as",
"is", "are", "was", "were", "be", "been", "being",
"have", "has", "had", "do", "does", "did",
"will", "would", "could", "should", "may", "might", "shall", "can",
"it", "its", "this", "that", "these", "those",
"i", "me", "my", "myself", "we", "our", "ours",
"you", "your", "yours", "he", "him", "his", "she", "her", "hers",
"they", "them", "their", "theirs",
"what", "which", "who", "whom", "when", "where", "why", "how",
"all", "each", "every", "both", "few", "more", "most",
"other", "some", "such", "no", "only", "own", "same",
"than", "then", "too", "very", "just", "also",
"get", "got", "say", "said", "know", "think", "go", "going", "come",
"like", "want", "see", "take", "make", "give", "look",
"yes", "yeah", "okay", "ok", "oh", "ah", "um", "uh", "well",
"now", "here", "there", "hi", "hey", "hello",
})
_PUNCT_RE = re.compile(r"[^\w\s]")
# ── Keyword extraction ────────────────────────────────────────────────────────
def extract_keywords(text: str,
min_length: int = 3,
max_keywords: int = 10) -> List[str]:
"""Return a deduplicated list of meaningful keywords from *text*.
Steps: lowercase -> strip punctuation -> split -> filter stop words &
length -> deduplicate -> cap at max_keywords.
"""
cleaned = _PUNCT_RE.sub(" ", text.lower())
seen: dict = {} # ordered-set via insertion-order dict
for tok in cleaned.split():
tok = tok.strip(string.punctuation + "_")
if (len(tok) >= min_length
and tok.isalpha()
and tok not in STOP_WORDS
and tok not in seen):
seen[tok] = None
if len(seen) >= max_keywords:
break
return list(seen)
# ── Per-person topic memory ───────────────────────────────────────────────────
class PersonTopicMemory:
"""Rolling, deduplicated topic list for one person."""
def __init__(self, max_topics: int = 30) -> None:
self._max = max_topics
self._topics: List[str] = [] # oldest -> newest order
self._topic_set: set = set()
self.last_updated: float = 0.0
def add(self, keywords: List[str]) -> List[str]:
"""Add *keywords*; promote duplicates to front, evict oldest over cap.
Returns the list of newly added (previously unseen) keywords.
"""
added: List[str] = []
for kw in keywords:
if kw in self._topic_set:
# Promote to most-recent position
self._topics.remove(kw)
self._topics.append(kw)
else:
self._topics.append(kw)
self._topic_set.add(kw)
added.append(kw)
# Trim oldest if over cap
while len(self._topics) > self._max:
evicted = self._topics.pop(0)
self._topic_set.discard(evicted)
self.last_updated = time.monotonic()
return added
@property
def recent_topics(self) -> List[str]:
"""Most-recent topics first."""
return list(reversed(self._topics))
def __len__(self) -> int:
return len(self._topics)
# ── ROS2 node ─────────────────────────────────────────────────────────────────
class TopicMemoryNode(Node):
"""Extracts and remembers conversation topics per person."""
def __init__(self) -> None:
super().__init__("topic_memory_node")
self.declare_parameter("conversation_topic", "/social/conversation_text")
self.declare_parameter("output_topic", "/saltybot/conversation_topics")
self.declare_parameter("min_word_length", 3)
self.declare_parameter("max_keywords_per_msg", 10)
self.declare_parameter("max_topics_per_person", 30)
self.declare_parameter("prune_after_s", 1800.0)
conv_topic = self.get_parameter("conversation_topic").value
out_topic = self.get_parameter("output_topic").value
self._min_len = self.get_parameter("min_word_length").value
self._max_kw = self.get_parameter("max_keywords_per_msg").value
self._max_tp = self.get_parameter("max_topics_per_person").value
self._prune_s = self.get_parameter("prune_after_s").value
# person_id -> PersonTopicMemory
self._memory: Dict[str, PersonTopicMemory] = {}
self._lock = threading.Lock()
qos = QoSProfile(depth=10)
self._pub = self.create_publisher(String, out_topic, qos)
self._sub = self.create_subscription(
String, conv_topic, self._on_conversation, qos
)
self.get_logger().info(
f"TopicMemoryNode ready "
f"(min_len={self._min_len}, max_kw={self._max_kw}, "
f"max_topics={self._max_tp}, prune_after={self._prune_s}s)"
)
# ── Subscription ───────────────────────────────────────────────────────
def _on_conversation(self, msg: String) -> None:
try:
payload = json.loads(msg.data)
person_id: str = str(payload.get("person_id", "unknown"))
text: str = str(payload.get("text", ""))
except (json.JSONDecodeError, AttributeError) as exc:
self.get_logger().warn(f"Bad conversation_text payload: {exc}")
return
if not text.strip():
return
keywords = extract_keywords(text, self._min_len, self._max_kw)
if not keywords:
return
with self._lock:
self._prune_stale()
if person_id not in self._memory:
self._memory[person_id] = PersonTopicMemory(self._max_tp)
mem = self._memory[person_id]
new_topics = mem.add(keywords)
recent = mem.recent_topics
out = String()
out.data = json.dumps({
"person_id": person_id,
"recent_topics": recent,
"new_topics": new_topics,
"ts": time.time(),
})
self._pub.publish(out)
if new_topics:
self.get_logger().info(
f"[{person_id}] new topics: {new_topics} | "
f"memory: {recent[:5]}{'...' if len(recent) > 5 else ''}"
)
# ── Helpers ────────────────────────────────────────────────────────────
def _prune_stale(self) -> None:
"""Remove persons not seen for prune_after_s seconds (call under lock)."""
if self._prune_s <= 0:
return
now = time.monotonic()
stale = [pid for pid, m in self._memory.items()
if (now - m.last_updated) > self._prune_s]
for pid in stale:
del self._memory[pid]
self.get_logger().info(f"Pruned stale person: {pid}")
def get_memory(self, person_id: str) -> Optional[PersonTopicMemory]:
"""Return topic memory for a person (None if not seen)."""
with self._lock:
return self._memory.get(person_id)
def all_persons(self) -> Dict[str, List[str]]:
"""Return {person_id: recent_topics} snapshot for all known persons."""
with self._lock:
return {pid: m.recent_topics for pid, m in self._memory.items()}
def main(args=None) -> None:
rclpy.init(args=args)
node = TopicMemoryNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@ -53,8 +53,6 @@ setup(
'face_track_servo_node = saltybot_social.face_track_servo_node:main', 'face_track_servo_node = saltybot_social.face_track_servo_node:main',
# Speech volume auto-adjuster (Issue #289) # Speech volume auto-adjuster (Issue #289)
'volume_adjust_node = saltybot_social.volume_adjust_node:main', 'volume_adjust_node = saltybot_social.volume_adjust_node:main',
# Conversation topic memory (Issue #299)
'topic_memory_node = saltybot_social.topic_memory_node:main',
], ],
}, },
) )

View File

@ -1,522 +0,0 @@
"""test_topic_memory.py — Offline tests for topic_memory_node (Issue #299).
Stubs out rclpy so tests run without a ROS install.
"""
import importlib.util
import json
import sys
import time
import types
import unittest
# ── ROS2 stubs ────────────────────────────────────────────────────────────────
def _make_ros_stubs():
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
"std_msgs", "std_msgs.msg"):
if mod_name not in sys.modules:
sys.modules[mod_name] = types.ModuleType(mod_name)
class _Node:
def __init__(self, name="node"):
self._name = name
if not hasattr(self, "_params"):
self._params = {}
self._pubs = {}
self._subs = {}
self._logs = []
def declare_parameter(self, name, default):
if name not in self._params:
self._params[name] = default
def get_parameter(self, name):
class _P:
def __init__(self, v): self.value = v
return _P(self._params.get(name))
def create_publisher(self, msg_type, topic, qos):
pub = _FakePub()
self._pubs[topic] = pub
return pub
def create_subscription(self, msg_type, topic, cb, qos):
self._subs[topic] = cb
return object()
def get_logger(self):
node = self
class _L:
def info(self, m): node._logs.append(("INFO", m))
def warn(self, m): node._logs.append(("WARN", m))
def error(self, m): node._logs.append(("ERROR", m))
return _L()
def destroy_node(self): pass
class _FakePub:
def __init__(self):
self.msgs = []
def publish(self, msg):
self.msgs.append(msg)
class _QoSProfile:
def __init__(self, depth=10): self.depth = depth
class _String:
def __init__(self): self.data = ""
rclpy_mod = sys.modules["rclpy"]
rclpy_mod.init = lambda args=None: None
rclpy_mod.spin = lambda node: None
rclpy_mod.shutdown = lambda: None
sys.modules["rclpy.node"].Node = _Node
sys.modules["rclpy.qos"].QoSProfile = _QoSProfile
sys.modules["std_msgs.msg"].String = _String
return _Node, _FakePub, _String
_Node, _FakePub, _String = _make_ros_stubs()
# ── Module loader ─────────────────────────────────────────────────────────────
_SRC = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/saltybot_social/topic_memory_node.py"
)
def _load_mod():
spec = importlib.util.spec_from_file_location("topic_memory_testmod", _SRC)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
def _make_node(mod, **kwargs):
node = mod.TopicMemoryNode.__new__(mod.TopicMemoryNode)
defaults = {
"conversation_topic": "/social/conversation_text",
"output_topic": "/saltybot/conversation_topics",
"min_word_length": 3,
"max_keywords_per_msg": 10,
"max_topics_per_person": 30,
"prune_after_s": 1800.0,
}
defaults.update(kwargs)
node._params = dict(defaults)
mod.TopicMemoryNode.__init__(node)
return node
def _msg(person_id, text):
m = _String()
m.data = json.dumps({"person_id": person_id, "text": text})
return m
def _send(node, person_id, text):
"""Deliver a conversation message to the node."""
cb = node._subs["/social/conversation_text"]
cb(_msg(person_id, text))
# ── Tests: extract_keywords ───────────────────────────────────────────────────
class TestExtractKeywords(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def _kw(self, text, min_len=3, max_kw=10):
return self.mod.extract_keywords(text, min_len, max_kw)
def test_basic(self):
kws = self._kw("I love drinking coffee")
self.assertIn("love", kws)
self.assertIn("drinking", kws)
self.assertIn("coffee", kws)
def test_stop_words_removed(self):
kws = self._kw("the quick brown fox")
self.assertNotIn("the", kws)
def test_short_words_removed(self):
kws = self._kw("go to the museum now", min_len=4)
self.assertNotIn("go", kws)
self.assertNotIn("to", kws)
self.assertNotIn("the", kws)
def test_deduplication(self):
kws = self._kw("coffee coffee coffee")
self.assertEqual(kws.count("coffee"), 1)
def test_max_keywords_cap(self):
text = " ".join(f"word{i}" for i in range(20))
kws = self._kw(text, max_kw=5)
self.assertLessEqual(len(kws), 5)
def test_punctuation_stripped(self):
kws = self._kw("Hello, world! How's weather?")
# "world" and "weather" should be found, punctuation removed
self.assertIn("world", kws)
self.assertIn("weather", kws)
def test_case_insensitive(self):
kws = self._kw("Robot ROBOT robot")
self.assertEqual(kws.count("robot"), 1)
def test_empty_text(self):
self.assertEqual(self._kw(""), [])
def test_all_stop_words(self):
self.assertEqual(self._kw("the is a and"), [])
def test_non_alpha_excluded(self):
kws = self._kw("model42 123 price500")
# alphanumeric tokens like "model42" contain digits → excluded
self.assertEqual(kws, [])
def test_preserves_order(self):
kws = self._kw("zebra apple mango")
self.assertEqual(kws, ["zebra", "apple", "mango"])
def test_min_length_three(self):
kws = self._kw("cat dog elephant", min_len=3)
self.assertIn("cat", kws)
self.assertIn("dog", kws)
self.assertIn("elephant", kws)
def test_min_length_four_excludes_short(self):
kws = self._kw("cat dog elephant", min_len=4)
self.assertNotIn("cat", kws)
self.assertNotIn("dog", kws)
self.assertIn("elephant", kws)
# ── Tests: PersonTopicMemory ──────────────────────────────────────────────────
class TestPersonTopicMemory(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def _mem(self, max_topics=10):
return self.mod.PersonTopicMemory(max_topics)
def test_empty_initially(self):
self.assertEqual(len(self._mem()), 0)
def test_add_returns_new(self):
m = self._mem()
added = m.add(["coffee", "weather"])
self.assertEqual(added, ["coffee", "weather"])
def test_duplicate_not_in_added(self):
m = self._mem()
m.add(["coffee"])
added = m.add(["coffee", "weather"])
self.assertNotIn("coffee", added)
self.assertIn("weather", added)
def test_recent_topics_most_recent_first(self):
m = self._mem()
m.add(["coffee"])
m.add(["weather"])
topics = m.recent_topics
self.assertEqual(topics[0], "weather")
self.assertEqual(topics[1], "coffee")
def test_duplicate_promoted_to_front(self):
m = self._mem()
m.add(["coffee", "weather", "robot"])
m.add(["coffee"]) # promote coffee to front
topics = m.recent_topics
self.assertEqual(topics[0], "coffee")
def test_cap_evicts_oldest(self):
m = self._mem(max_topics=3)
m.add(["alpha", "beta", "gamma"])
m.add(["delta"]) # alpha should be evicted
topics = m.recent_topics
self.assertNotIn("alpha", topics)
self.assertIn("delta", topics)
self.assertEqual(len(topics), 3)
def test_len(self):
m = self._mem()
m.add(["coffee", "weather", "robot"])
self.assertEqual(len(m), 3)
def test_last_updated_set(self):
m = self._mem()
before = time.monotonic()
m.add(["test"])
self.assertGreaterEqual(m.last_updated, before)
def test_empty_add(self):
m = self._mem()
added = m.add([])
self.assertEqual(added, [])
self.assertEqual(len(m), 0)
def test_many_duplicates_stay_within_cap(self):
m = self._mem(max_topics=5)
for _ in range(10):
m.add(["coffee"])
self.assertEqual(len(m), 1)
self.assertIn("coffee", m.recent_topics)
# ── Tests: node init ──────────────────────────────────────────────────────────
class TestNodeInit(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def test_instantiates(self):
self.assertIsNotNone(_make_node(self.mod))
def test_pub_registered(self):
node = _make_node(self.mod)
self.assertIn("/saltybot/conversation_topics", node._pubs)
def test_sub_registered(self):
node = _make_node(self.mod)
self.assertIn("/social/conversation_text", node._subs)
def test_memory_empty(self):
node = _make_node(self.mod)
self.assertEqual(node.all_persons(), {})
def test_custom_topics(self):
node = _make_node(self.mod,
conversation_topic="/my/conv",
output_topic="/my/topics")
self.assertIn("/my/conv", node._subs)
self.assertIn("/my/topics", node._pubs)
# ── Tests: on_conversation callback ──────────────────────────────────────────
class TestOnConversation(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def setUp(self):
self.node = _make_node(self.mod)
self.pub = self.node._pubs["/saltybot/conversation_topics"]
def test_publishes_on_keyword(self):
_send(self.node, "alice", "I love drinking coffee")
self.assertEqual(len(self.pub.msgs), 1)
def test_payload_is_json(self):
_send(self.node, "alice", "I love drinking coffee")
payload = json.loads(self.pub.msgs[-1].data)
self.assertIsInstance(payload, dict)
def test_payload_person_id(self):
_send(self.node, "bob", "I enjoy hiking mountains")
payload = json.loads(self.pub.msgs[-1].data)
self.assertEqual(payload["person_id"], "bob")
def test_payload_recent_topics(self):
_send(self.node, "alice", "I love coffee weather robots")
payload = json.loads(self.pub.msgs[-1].data)
self.assertIsInstance(payload["recent_topics"], list)
self.assertGreater(len(payload["recent_topics"]), 0)
def test_payload_new_topics(self):
_send(self.node, "alice", "I love coffee")
payload = json.loads(self.pub.msgs[-1].data)
self.assertIn("coffee", payload["new_topics"])
def test_payload_has_ts(self):
_send(self.node, "alice", "I love coffee")
payload = json.loads(self.pub.msgs[-1].data)
self.assertIn("ts", payload)
def test_all_stop_words_no_publish(self):
_send(self.node, "alice", "the is and or")
self.assertEqual(len(self.pub.msgs), 0)
def test_empty_text_no_publish(self):
_send(self.node, "alice", "")
self.assertEqual(len(self.pub.msgs), 0)
def test_bad_json_no_crash(self):
m = _String(); m.data = "not json at all"
self.node._subs["/social/conversation_text"](m)
self.assertEqual(len(self.pub.msgs), 0)
warns = [l for l in self.node._logs if l[0] == "WARN"]
self.assertEqual(len(warns), 1)
def test_missing_person_id_defaults_unknown(self):
m = _String(); m.data = json.dumps({"text": "coffee weather hiking"})
self.node._subs["/social/conversation_text"](m)
payload = json.loads(self.pub.msgs[-1].data)
self.assertEqual(payload["person_id"], "unknown")
def test_duplicate_topic_not_in_new(self):
_send(self.node, "alice", "coffee mountains weather")
_send(self.node, "alice", "coffee") # coffee already known
payload = json.loads(self.pub.msgs[-1].data)
self.assertNotIn("coffee", payload["new_topics"])
def test_topics_accumulate_across_turns(self):
_send(self.node, "alice", "coffee weather")
_send(self.node, "alice", "mountains hiking")
mem = self.node.get_memory("alice")
topics = mem.recent_topics
self.assertIn("coffee", topics)
self.assertIn("mountains", topics)
def test_separate_persons_independent(self):
_send(self.node, "alice", "coffee weather")
_send(self.node, "bob", "robots motors")
alice = self.node.get_memory("alice").recent_topics
bob = self.node.get_memory("bob").recent_topics
self.assertIn("coffee", alice)
self.assertNotIn("robots", alice)
self.assertIn("robots", bob)
self.assertNotIn("coffee", bob)
def test_all_persons_snapshot(self):
_send(self.node, "alice", "coffee weather")
_send(self.node, "bob", "robots motors")
persons = self.node.all_persons()
self.assertIn("alice", persons)
self.assertIn("bob", persons)
def test_recent_topics_most_recent_first(self):
_send(self.node, "alice", "alpha")
_send(self.node, "alice", "beta")
_send(self.node, "alice", "gamma")
topics = self.node.get_memory("alice").recent_topics
self.assertEqual(topics[0], "gamma")
def test_stop_words_not_stored(self):
_send(self.node, "alice", "the weather and coffee")
topics = self.node.get_memory("alice").recent_topics
self.assertNotIn("the", topics)
self.assertNotIn("and", topics)
def test_max_keywords_respected(self):
node = _make_node(self.mod, max_keywords_per_msg=3)
_send(node, "alice",
"coffee weather hiking mountains ocean desert forest lake")
mem = node.get_memory("alice")
# Only 3 keywords should have been extracted per message
self.assertLessEqual(len(mem), 3)
def test_max_topics_cap(self):
node = _make_node(self.mod, max_topics_per_person=5, max_keywords_per_msg=20)
words = ["alpha", "beta", "gamma", "delta", "epsilon",
"zeta", "eta", "theta"]
for i, w in enumerate(words):
_send(node, "alice", w)
mem = node.get_memory("alice")
self.assertLessEqual(len(mem), 5)
# ── Tests: prune ──────────────────────────────────────────────────────────────
class TestPrune(unittest.TestCase):
@classmethod
def setUpClass(cls): cls.mod = _load_mod()
def test_no_prune_when_disabled(self):
node = _make_node(self.mod, prune_after_s=0.0)
_send(node, "alice", "coffee weather")
# Manually expire timestamp
node._memory["alice"].last_updated = time.monotonic() - 9999
_send(node, "bob", "robots motors") # triggers prune check
# alice should still be present (prune disabled)
self.assertIn("alice", node.all_persons())
def test_prune_stale_person(self):
node = _make_node(self.mod, prune_after_s=1.0)
_send(node, "alice", "coffee weather")
node._memory["alice"].last_updated = time.monotonic() - 10.0 # stale
_send(node, "bob", "robots motors") # triggers prune
self.assertNotIn("alice", node.all_persons())
def test_fresh_person_not_pruned(self):
node = _make_node(self.mod, prune_after_s=1.0)
_send(node, "alice", "coffee weather")
# alice is fresh (just added)
_send(node, "bob", "robots motors")
self.assertIn("alice", node.all_persons())
# ── Tests: source and config ──────────────────────────────────────────────────
class TestNodeSrc(unittest.TestCase):
@classmethod
def setUpClass(cls):
with open(_SRC) as f: cls.src = f.read()
def test_issue_tag(self): self.assertIn("#299", self.src)
def test_input_topic(self): self.assertIn("/social/conversation_text", self.src)
def test_output_topic(self): self.assertIn("/saltybot/conversation_topics", self.src)
def test_extract_keywords(self): self.assertIn("extract_keywords", self.src)
def test_person_topic_memory(self):self.assertIn("PersonTopicMemory", self.src)
def test_stop_words(self): self.assertIn("STOP_WORDS", self.src)
def test_json_output(self): self.assertIn("json.dumps", self.src)
def test_person_id_in_output(self):self.assertIn("person_id", self.src)
def test_recent_topics_key(self): self.assertIn("recent_topics", self.src)
def test_new_topics_key(self): self.assertIn("new_topics", self.src)
def test_threading_lock(self): self.assertIn("threading.Lock", self.src)
def test_prune_method(self): self.assertIn("_prune_stale", self.src)
def test_main_defined(self): self.assertIn("def main", self.src)
def test_min_word_length_param(self):self.assertIn("min_word_length", self.src)
def test_max_topics_param(self): self.assertIn("max_topics_per_person", self.src)
class TestConfig(unittest.TestCase):
_CONFIG = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/config/topic_memory_params.yaml"
)
_LAUNCH = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/launch/topic_memory.launch.py"
)
_SETUP = (
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
"saltybot_social/setup.py"
)
def test_config_exists(self):
import os; self.assertTrue(os.path.exists(self._CONFIG))
def test_config_min_word_length(self):
with open(self._CONFIG) as f: c = f.read()
self.assertIn("min_word_length", c)
def test_config_max_topics(self):
with open(self._CONFIG) as f: c = f.read()
self.assertIn("max_topics_per_person", c)
def test_config_prune(self):
with open(self._CONFIG) as f: c = f.read()
self.assertIn("prune_after_s", c)
def test_launch_exists(self):
import os; self.assertTrue(os.path.exists(self._LAUNCH))
def test_launch_max_topics_arg(self):
with open(self._LAUNCH) as f: c = f.read()
self.assertIn("max_topics_per_person", c)
def test_entry_point(self):
with open(self._SETUP) as f: c = f.read()
self.assertIn("topic_memory_node", c)
if __name__ == "__main__":
unittest.main()