Compare commits
6 Commits
accda32c7a
...
becd0bc717
| Author | SHA1 | Date | |
|---|---|---|---|
| becd0bc717 | |||
| a8838cfbbd | |||
| 86d798afe7 | |||
| 3bf603f685 | |||
| 797ed711b9 | |||
| bfd291cbdd |
@ -0,0 +1,219 @@
|
||||
"""
|
||||
_terrain_roughness.py — Terrain roughness estimation via Gabor energy + LBP variance
|
||||
(no ROS2 deps).
|
||||
|
||||
Algorithm
|
||||
---------
|
||||
Both features are computed on the greyscale crop of the bottom *roi_frac* of a BGR
|
||||
D435i colour frame (the floor region).
|
||||
|
||||
1. Gabor filter-bank energy (4 orientations × 2 spatial wavelengths, quadrature pairs):
|
||||
|
||||
E = mean over all (θ, λ) of mean_pixels( |real_resp|² + |imag_resp|² )
|
||||
|
||||
A quadrature pair uses phase ψ=0 (real) and ψ=π/2 (imaginary), giving orientation-
|
||||
and phase-invariant energy. The mean across all filter pairs gives a single energy
|
||||
scalar that is high for textured/edgy surfaces and near-zero for smooth ones.
|
||||
|
||||
2. LBP variance (Local Binary Pattern, radius=1, 8-point, pure NumPy, no sklearn):
|
||||
|
||||
LBP(x,y) = Σ_k s(g_k − g_c) · 2^k (k = 0..7, g_c = centre pixel)
|
||||
|
||||
Computed as vectorised slice comparisons over the 8 cardinal/diagonal neighbours.
|
||||
The metric is var(LBP image). A constant surface gives LBP ≡ 0xFF (all bits set,
|
||||
since every neighbour ties with the centre) → variance = 0. Irregular textures
|
||||
produce a dispersed LBP histogram → high variance.
|
||||
|
||||
3. Normalised roughness blend:
|
||||
|
||||
roughness = clip( _W_GABOR · E / _GABOR_REF + _W_LBP · V / _LBP_REF , 0, 1 )
|
||||
|
||||
Calibration: _GABOR_REF is tuned so that a coarse gravel / noise image maps to ~1.0;
|
||||
_LBP_REF = 5000 (≈ variance of a fully random 8-bit LBP image).
|
||||
|
||||
Public API
|
||||
----------
|
||||
RoughnessResult(roughness, gabor_energy, lbp_variance)
|
||||
estimate_roughness(bgr, roi_frac=0.40) -> RoughnessResult
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ── Tuning constants ──────────────────────────────────────────────────────────
|
||||
|
||||
# Gabor filter-bank parameters
|
||||
_GABOR_N_ORIENTATIONS = 4 # θ ∈ {0°, 45°, 90°, 135°}
|
||||
_GABOR_WAVELENGTHS = [5.0, 10.0] # spatial wavelengths in pixels (medium + coarse)
|
||||
_GABOR_KSIZE = 11 # kernel side length
|
||||
_GABOR_SIGMA = 3.0 # Gaussian envelope std (pixels)
|
||||
_GABOR_GAMMA = 0.5 # spatial aspect ratio
|
||||
|
||||
# Normalization references
|
||||
_GABOR_REF = 500.0 # Gabor mean power at which we consider maximum roughness
|
||||
_LBP_REF = 5000.0 # LBP variance at which we consider maximum roughness
|
||||
|
||||
# Blend weights (must sum to 1)
|
||||
_W_GABOR = 0.5
|
||||
_W_LBP = 0.5
|
||||
|
||||
# Module-level cache for the kernel bank (built once per process)
|
||||
_gabor_kernels: Optional[list] = None
|
||||
|
||||
|
||||
# ── Data types ────────────────────────────────────────────────────────────────
|
||||
|
||||
class RoughnessResult(NamedTuple):
|
||||
"""Terrain roughness estimate for a single frame."""
|
||||
roughness: float # blended score in [0, 1]; 0=smooth, 1=rough
|
||||
gabor_energy: float # raw Gabor mean energy (≥ 0)
|
||||
lbp_variance: float # raw LBP image variance (≥ 0)
|
||||
|
||||
|
||||
# ── Internal helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
def _build_gabor_kernels() -> list:
|
||||
"""Build (real_kernel, imag_kernel) pairs for all (orientation, wavelength) combos."""
|
||||
import cv2
|
||||
|
||||
kernels = []
|
||||
for i in range(_GABOR_N_ORIENTATIONS):
|
||||
theta = float(i) * np.pi / _GABOR_N_ORIENTATIONS
|
||||
for lam in _GABOR_WAVELENGTHS:
|
||||
real_k = cv2.getGaborKernel(
|
||||
(_GABOR_KSIZE, _GABOR_KSIZE),
|
||||
_GABOR_SIGMA, theta, lam, _GABOR_GAMMA, 0.0,
|
||||
ktype=cv2.CV_64F,
|
||||
).astype(np.float32)
|
||||
imag_k = cv2.getGaborKernel(
|
||||
(_GABOR_KSIZE, _GABOR_KSIZE),
|
||||
_GABOR_SIGMA, theta, lam, _GABOR_GAMMA, np.pi / 2.0,
|
||||
ktype=cv2.CV_64F,
|
||||
).astype(np.float32)
|
||||
kernels.append((real_k, imag_k))
|
||||
return kernels
|
||||
|
||||
|
||||
def _get_gabor_kernels() -> list:
|
||||
global _gabor_kernels
|
||||
if _gabor_kernels is None:
|
||||
_gabor_kernels = _build_gabor_kernels()
|
||||
return _gabor_kernels
|
||||
|
||||
|
||||
def gabor_energy(grey: np.ndarray) -> float:
|
||||
"""
|
||||
Mean quadrature Gabor energy across the filter bank.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
grey : (H, W) uint8 or float ndarray
|
||||
|
||||
Returns
|
||||
-------
|
||||
float — mean power (≥ 0); higher for textured surfaces
|
||||
"""
|
||||
import cv2
|
||||
|
||||
img = np.asarray(grey, dtype=np.float32)
|
||||
if img.size == 0:
|
||||
return 0.0
|
||||
|
||||
# Subtract the mean so constant regions (DC) contribute zero energy.
|
||||
# This is standard in Gabor texture analysis and ensures that a flat
|
||||
# uniform surface always returns energy ≈ 0.
|
||||
img = img - float(img.mean())
|
||||
|
||||
kernels = _get_gabor_kernels()
|
||||
total = 0.0
|
||||
for real_k, imag_k in kernels:
|
||||
r = cv2.filter2D(img, cv2.CV_32F, real_k)
|
||||
i = cv2.filter2D(img, cv2.CV_32F, imag_k)
|
||||
total += float(np.mean(r * r + i * i))
|
||||
|
||||
return total / len(kernels) if kernels else 0.0
|
||||
|
||||
|
||||
def lbp_variance(grey: np.ndarray) -> float:
|
||||
"""
|
||||
Variance of the 8-point radius-1 LBP image (pure NumPy, no sklearn).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
grey : (H, W) uint8 ndarray
|
||||
|
||||
Returns
|
||||
-------
|
||||
float — variance of LBP values; 0 for flat surfaces, high for irregular textures
|
||||
"""
|
||||
g = np.asarray(grey, dtype=np.int32)
|
||||
h, w = g.shape
|
||||
if h < 3 or w < 3:
|
||||
return 0.0
|
||||
|
||||
# Centre pixel patch (interior only, 1-px border excluded)
|
||||
c = g[1:h-1, 1:w-1]
|
||||
|
||||
# 8-neighbourhood comparisons using array slices (vectorised, no loops)
|
||||
lbp = (
|
||||
((g[0:h-2, 0:w-2] >= c).view(np.uint8) ) |
|
||||
((g[0:h-2, 1:w-1] >= c).view(np.uint8) << 1 ) |
|
||||
((g[0:h-2, 2:w ] >= c).view(np.uint8) << 2 ) |
|
||||
((g[1:h-1, 0:w-2] >= c).view(np.uint8) << 3 ) |
|
||||
((g[1:h-1, 2:w ] >= c).view(np.uint8) << 4 ) |
|
||||
((g[2:h , 0:w-2] >= c).view(np.uint8) << 5 ) |
|
||||
((g[2:h , 1:w-1] >= c).view(np.uint8) << 6 ) |
|
||||
((g[2:h , 2:w ] >= c).view(np.uint8) << 7 )
|
||||
)
|
||||
|
||||
return float(lbp.var())
|
||||
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
def estimate_roughness(
|
||||
bgr: np.ndarray,
|
||||
roi_frac: float = 0.40,
|
||||
) -> RoughnessResult:
|
||||
"""
|
||||
Estimate terrain roughness from the bottom roi_frac of a BGR image.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bgr : (H, W, 3) uint8 BGR ndarray (or greyscale (H, W))
|
||||
roi_frac : fraction of image height to use as the floor ROI (from bottom)
|
||||
|
||||
Returns
|
||||
-------
|
||||
RoughnessResult(roughness, gabor_energy, lbp_variance)
|
||||
roughness is in [0, 1] where 0=smooth, 1=rough.
|
||||
"""
|
||||
import cv2
|
||||
|
||||
# Crop to bottom roi_frac
|
||||
h = bgr.shape[0]
|
||||
r0 = int(h * max(0.0, 1.0 - min(float(roi_frac), 1.0)))
|
||||
roi = bgr[r0:, :]
|
||||
|
||||
if roi.size == 0:
|
||||
return RoughnessResult(roughness=0.0, gabor_energy=0.0, lbp_variance=0.0)
|
||||
|
||||
# Convert to greyscale
|
||||
if roi.ndim == 3:
|
||||
grey = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
|
||||
else:
|
||||
grey = np.asarray(roi, dtype=np.uint8)
|
||||
|
||||
ge = gabor_energy(grey)
|
||||
lv = lbp_variance(grey)
|
||||
|
||||
gabor_norm = min(ge / _GABOR_REF, 1.0)
|
||||
lbp_norm = min(lv / _LBP_REF, 1.0)
|
||||
|
||||
roughness = float(np.clip(_W_GABOR * gabor_norm + _W_LBP * lbp_norm, 0.0, 1.0))
|
||||
|
||||
return RoughnessResult(roughness=roughness, gabor_energy=ge, lbp_variance=lv)
|
||||
@ -0,0 +1,102 @@
|
||||
"""
|
||||
terrain_rough_node.py — D435i terrain roughness estimator (Issue #296).
|
||||
|
||||
Subscribes to the RealSense colour stream, computes Gabor filter texture
|
||||
energy + LBP variance on the bottom-40% ROI (floor region), and publishes
|
||||
a normalised roughness score at 2 Hz.
|
||||
|
||||
Intended use: speed adaptation — reduce maximum velocity when roughness is high.
|
||||
|
||||
Subscribes (BEST_EFFORT):
|
||||
/camera/color/image_raw sensor_msgs/Image BGR8
|
||||
|
||||
Publishes:
|
||||
/saltybot/terrain_roughness std_msgs/Float32 roughness in [0,1] (0=smooth, 1=rough)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
roi_frac float 0.40 Fraction of image height used as floor ROI (from bottom)
|
||||
publish_hz float 2.0 Publication rate (Hz)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||
|
||||
from cv_bridge import CvBridge
|
||||
|
||||
from sensor_msgs.msg import Image
|
||||
from std_msgs.msg import Float32
|
||||
|
||||
from ._terrain_roughness import estimate_roughness
|
||||
|
||||
|
||||
_SENSOR_QOS = QoSProfile(
|
||||
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||
history=HistoryPolicy.KEEP_LAST,
|
||||
depth=4,
|
||||
)
|
||||
|
||||
|
||||
class TerrainRoughNode(Node):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__('terrain_rough_node')
|
||||
|
||||
self.declare_parameter('roi_frac', 0.40)
|
||||
self.declare_parameter('publish_hz', 2.0)
|
||||
|
||||
self._roi_frac = float(self.get_parameter('roi_frac').value)
|
||||
publish_hz = float(self.get_parameter('publish_hz').value)
|
||||
|
||||
self._bridge = CvBridge()
|
||||
self._latest_bgr = None # updated by subscription callback
|
||||
|
||||
self._sub = self.create_subscription(
|
||||
Image, '/camera/color/image_raw', self._on_image, _SENSOR_QOS)
|
||||
self._pub = self.create_publisher(
|
||||
Float32, '/saltybot/terrain_roughness', 10)
|
||||
|
||||
period = 1.0 / max(publish_hz, 0.1)
|
||||
self._timer = self.create_timer(period, self._on_timer)
|
||||
|
||||
self.get_logger().info(
|
||||
f'terrain_rough_node ready — roi_frac={self._roi_frac} '
|
||||
f'publish_hz={publish_hz}'
|
||||
)
|
||||
|
||||
# ── Callbacks ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _on_image(self, msg: Image) -> None:
|
||||
try:
|
||||
self._latest_bgr = self._bridge.imgmsg_to_cv2(
|
||||
msg, desired_encoding='bgr8')
|
||||
except Exception as exc:
|
||||
self.get_logger().error(
|
||||
f'cv_bridge: {exc}', throttle_duration_sec=5.0)
|
||||
|
||||
def _on_timer(self) -> None:
|
||||
if self._latest_bgr is None:
|
||||
return
|
||||
|
||||
result = estimate_roughness(self._latest_bgr, roi_frac=self._roi_frac)
|
||||
|
||||
msg = Float32()
|
||||
msg.data = result.roughness
|
||||
self._pub.publish(msg)
|
||||
|
||||
|
||||
def main(args=None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = TerrainRoughNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -43,6 +43,8 @@ setup(
|
||||
'color_segmenter = saltybot_bringup.color_segment_node:main',
|
||||
# Motion blur detector (Issue #286)
|
||||
'blur_detector = saltybot_bringup.blur_detect_node:main',
|
||||
# Terrain roughness estimator (Issue #296)
|
||||
'terrain_roughness = saltybot_bringup.terrain_rough_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@ -0,0 +1,321 @@
|
||||
"""
|
||||
test_terrain_roughness.py — Unit tests for terrain roughness helpers (no ROS2 required).
|
||||
|
||||
Covers:
|
||||
RoughnessResult:
|
||||
- fields accessible by name
|
||||
- roughness in [0, 1]
|
||||
|
||||
lbp_variance:
|
||||
- returns float
|
||||
- solid image returns 0.0
|
||||
- noisy image returns positive value
|
||||
- noisy > solid (strictly)
|
||||
- small image (< 3×3) returns 0.0
|
||||
- output is non-negative
|
||||
|
||||
gabor_energy:
|
||||
- returns float
|
||||
- solid image returns near-zero energy
|
||||
- noisy image returns positive energy
|
||||
- noisy > solid (strictly)
|
||||
- empty image returns 0.0
|
||||
- output is non-negative
|
||||
|
||||
estimate_roughness — output contract:
|
||||
- returns RoughnessResult
|
||||
- roughness in [0, 1] for all test images
|
||||
- roughness == 0.0 for solid (constant) image (both metrics are 0)
|
||||
- roughness > 0.0 for random noise image
|
||||
- roi_frac=1.0 uses entire frame
|
||||
- roi_frac=0.0 returns zero roughness (empty ROI)
|
||||
- gabor_energy field matches standalone gabor_energy() on the ROI
|
||||
- lbp_variance field matches standalone lbp_variance() on the ROI
|
||||
|
||||
estimate_roughness — ordering:
|
||||
- random noise roughness > solid roughness (strict)
|
||||
- checkerboard roughness > smooth gradient roughness
|
||||
- rough texture > smooth texture
|
||||
|
||||
estimate_roughness — ROI:
|
||||
- roughness changes when roi_frac changes (textured bottom, smooth top)
|
||||
- bottom-only ROI gives higher roughness than top-only on bottom-textured image
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from saltybot_bringup._terrain_roughness import (
|
||||
RoughnessResult,
|
||||
gabor_energy,
|
||||
lbp_variance,
|
||||
estimate_roughness,
|
||||
)
|
||||
|
||||
|
||||
# ── Image factories ───────────────────────────────────────────────────────────
|
||||
|
||||
def _solid(val=128, h=64, w=64) -> np.ndarray:
|
||||
return np.full((h, w, 3), val, dtype=np.uint8)
|
||||
|
||||
|
||||
def _grey_solid(val=128, h=64, w=64) -> np.ndarray:
|
||||
return np.full((h, w), val, dtype=np.uint8)
|
||||
|
||||
|
||||
def _noise(h=64, w=64, seed=42) -> np.ndarray:
|
||||
rng = np.random.default_rng(seed)
|
||||
data = rng.integers(0, 256, (h, w, 3), dtype=np.uint8)
|
||||
return data
|
||||
|
||||
|
||||
def _grey_noise(h=64, w=64, seed=42) -> np.ndarray:
|
||||
rng = np.random.default_rng(seed)
|
||||
return rng.integers(0, 256, (h, w), dtype=np.uint8)
|
||||
|
||||
|
||||
def _checkerboard(h=64, w=64, tile=8) -> np.ndarray:
|
||||
grey = np.zeros((h, w), dtype=np.uint8)
|
||||
for r in range(h):
|
||||
for c in range(w):
|
||||
if ((r // tile) + (c // tile)) % 2 == 0:
|
||||
grey[r, c] = 255
|
||||
return np.stack([grey, grey, grey], axis=-1)
|
||||
|
||||
|
||||
def _gradient(h=64, w=64) -> np.ndarray:
|
||||
"""Smooth horizontal linear gradient."""
|
||||
row = np.linspace(0, 255, w, dtype=np.uint8)
|
||||
grey = np.tile(row, (h, 1))
|
||||
return np.stack([grey, grey, grey], axis=-1)
|
||||
|
||||
|
||||
# ── RoughnessResult ───────────────────────────────────────────────────────────
|
||||
|
||||
class TestRoughnessResult:
|
||||
|
||||
def test_fields_accessible(self):
|
||||
r = RoughnessResult(roughness=0.5, gabor_energy=100.0, lbp_variance=200.0)
|
||||
assert r.roughness == pytest.approx(0.5)
|
||||
assert r.gabor_energy == pytest.approx(100.0)
|
||||
assert r.lbp_variance == pytest.approx(200.0)
|
||||
|
||||
def test_roughness_in_range(self):
|
||||
for v in [0.0, 0.5, 1.0]:
|
||||
r = RoughnessResult(roughness=v, gabor_energy=0.0, lbp_variance=0.0)
|
||||
assert 0.0 <= r.roughness <= 1.0
|
||||
|
||||
|
||||
# ── lbp_variance ─────────────────────────────────────────────────────────────
|
||||
|
||||
class TestLbpVariance:
|
||||
|
||||
def test_returns_float(self):
|
||||
assert isinstance(lbp_variance(_grey_solid()), float)
|
||||
|
||||
def test_non_negative(self):
|
||||
for img in [_grey_solid(), _grey_noise(), _grey_solid(0)]:
|
||||
assert lbp_variance(img) >= 0.0
|
||||
|
||||
def test_solid_returns_zero(self):
|
||||
"""Constant image: all neighbours equal centre → all LBP bits set → zero variance."""
|
||||
assert lbp_variance(_grey_solid(128)) == pytest.approx(0.0)
|
||||
|
||||
def test_solid_black_returns_zero(self):
|
||||
assert lbp_variance(_grey_solid(0)) == pytest.approx(0.0)
|
||||
|
||||
def test_solid_white_returns_zero(self):
|
||||
assert lbp_variance(_grey_solid(255)) == pytest.approx(0.0)
|
||||
|
||||
def test_noise_returns_positive(self):
|
||||
assert lbp_variance(_grey_noise()) > 0.0
|
||||
|
||||
def test_noise_greater_than_solid(self):
|
||||
assert lbp_variance(_grey_noise()) > lbp_variance(_grey_solid())
|
||||
|
||||
def test_small_image_returns_zero(self):
|
||||
"""Image smaller than 3×3 has no interior pixels."""
|
||||
assert lbp_variance(np.zeros((2, 2), dtype=np.uint8)) == pytest.approx(0.0)
|
||||
|
||||
def test_checkerboard_high_variance(self):
|
||||
"""Alternating 0/255 checkerboard → LBP alternates 0x00/0xFF → very high variance."""
|
||||
v = lbp_variance(np.tile(
|
||||
np.array([[0, 255], [255, 0]], dtype=np.uint8), (16, 16)))
|
||||
assert v > 1000.0, f'checkerboard LBP variance too low: {v:.1f}'
|
||||
|
||||
def test_noise_higher_than_tiled_gradient(self):
|
||||
"""Random 2D noise has varying neighbourhood patterns → higher LBP variance
|
||||
than a horizontal gradient tiled into rows (which has a constant pattern)."""
|
||||
assert lbp_variance(_grey_noise()) > lbp_variance(_grey_solid())
|
||||
|
||||
|
||||
# ── gabor_energy ──────────────────────────────────────────────────────────────
|
||||
|
||||
class TestGaborEnergy:
|
||||
|
||||
def test_returns_float(self):
|
||||
assert isinstance(gabor_energy(_grey_solid()), float)
|
||||
|
||||
def test_non_negative(self):
|
||||
for img in [_grey_solid(), _grey_noise(), _grey_solid(0)]:
|
||||
assert gabor_energy(img) >= 0.0
|
||||
|
||||
def test_solid_near_zero(self):
|
||||
"""Constant image has zero Gabor response after DC subtraction."""
|
||||
e = gabor_energy(_grey_solid(128))
|
||||
assert e < 1.0, f'solid image gabor energy should be ~0 after DC removal, got {e:.4f}'
|
||||
|
||||
def test_noise_positive(self):
|
||||
assert gabor_energy(_grey_noise()) > 0.0
|
||||
|
||||
def test_noise_greater_than_solid(self):
|
||||
assert gabor_energy(_grey_noise()) > gabor_energy(_grey_solid())
|
||||
|
||||
def test_empty_image_returns_zero(self):
|
||||
assert gabor_energy(np.zeros((0, 0), dtype=np.uint8)) == pytest.approx(0.0)
|
||||
|
||||
def test_noise_higher_than_solid(self):
|
||||
"""Random noise has non-zero energy at our Gabor filter frequencies; solid has none."""
|
||||
assert gabor_energy(_grey_noise()) > gabor_energy(_grey_solid())
|
||||
|
||||
|
||||
# ── estimate_roughness — output contract ──────────────────────────────────────
|
||||
|
||||
class TestEstimateRoughnessContract:
|
||||
|
||||
def test_returns_roughness_result(self):
|
||||
r = estimate_roughness(_solid())
|
||||
assert isinstance(r, RoughnessResult)
|
||||
|
||||
def test_roughness_in_range_solid(self):
|
||||
r = estimate_roughness(_solid())
|
||||
assert 0.0 <= r.roughness <= 1.0
|
||||
|
||||
def test_roughness_in_range_noise(self):
|
||||
r = estimate_roughness(_noise())
|
||||
assert 0.0 <= r.roughness <= 1.0
|
||||
|
||||
def test_roughness_in_range_checkerboard(self):
|
||||
r = estimate_roughness(_checkerboard())
|
||||
assert 0.0 <= r.roughness <= 1.0
|
||||
|
||||
def test_solid_roughness_exactly_zero(self):
|
||||
"""Both metrics are 0 for a constant image → roughness = 0.0 exactly."""
|
||||
r = estimate_roughness(_solid(128))
|
||||
assert r.roughness == pytest.approx(0.0, abs=1e-6)
|
||||
|
||||
def test_noise_roughness_positive(self):
|
||||
r = estimate_roughness(_noise())
|
||||
assert r.roughness > 0.0
|
||||
|
||||
def test_gabor_energy_field_nonneg(self):
|
||||
assert estimate_roughness(_solid()).gabor_energy >= 0.0
|
||||
assert estimate_roughness(_noise()).gabor_energy >= 0.0
|
||||
|
||||
def test_lbp_variance_field_nonneg(self):
|
||||
assert estimate_roughness(_solid()).lbp_variance >= 0.0
|
||||
assert estimate_roughness(_noise()).lbp_variance >= 0.0
|
||||
|
||||
def test_roi_frac_1_uses_full_frame(self):
|
||||
"""roi_frac=1.0 should use the entire image."""
|
||||
img = _noise(h=64, w=64)
|
||||
r_full = estimate_roughness(img, roi_frac=1.0)
|
||||
assert r_full.roughness > 0.0
|
||||
|
||||
def test_roi_frac_0_returns_zero(self):
|
||||
"""roi_frac=0.0 → empty crop → all zeros."""
|
||||
r = estimate_roughness(_noise(), roi_frac=0.0)
|
||||
assert r.roughness == pytest.approx(0.0)
|
||||
assert r.gabor_energy == pytest.approx(0.0)
|
||||
assert r.lbp_variance == pytest.approx(0.0)
|
||||
|
||||
def test_gabor_field_matches_standalone(self):
|
||||
"""gabor_energy field should match standalone gabor_energy() on the ROI."""
|
||||
import cv2
|
||||
img = _noise(h=64, w=64)
|
||||
roi_frac = 0.5
|
||||
h = img.shape[0]
|
||||
r0 = int(h * (1.0 - roi_frac))
|
||||
roi_grey = cv2.cvtColor(img[r0:, :], cv2.COLOR_BGR2GRAY)
|
||||
expected = gabor_energy(roi_grey)
|
||||
result = estimate_roughness(img, roi_frac=roi_frac)
|
||||
assert result.gabor_energy == pytest.approx(expected, rel=1e-4)
|
||||
|
||||
def test_lbp_field_matches_standalone(self):
|
||||
"""lbp_variance field should match standalone lbp_variance() on the ROI."""
|
||||
import cv2
|
||||
img = _noise(h=64, w=64)
|
||||
roi_frac = 0.5
|
||||
h = img.shape[0]
|
||||
r0 = int(h * (1.0 - roi_frac))
|
||||
roi_grey = cv2.cvtColor(img[r0:, :], cv2.COLOR_BGR2GRAY)
|
||||
expected = lbp_variance(roi_grey)
|
||||
result = estimate_roughness(img, roi_frac=roi_frac)
|
||||
assert result.lbp_variance == pytest.approx(expected, rel=1e-4)
|
||||
|
||||
|
||||
# ── estimate_roughness — ordering ────────────────────────────────────────────
|
||||
|
||||
class TestEstimateRoughnessOrdering:
|
||||
|
||||
def test_noise_rougher_than_solid(self):
|
||||
r_noise = estimate_roughness(_noise())
|
||||
r_solid = estimate_roughness(_solid())
|
||||
assert r_noise.roughness > r_solid.roughness
|
||||
|
||||
def test_checkerboard_rougher_than_gradient(self):
|
||||
r_cb = estimate_roughness(_checkerboard(64, 64, tile=4))
|
||||
r_grad = estimate_roughness(_gradient())
|
||||
assert r_cb.roughness > r_grad.roughness, (
|
||||
f'checkerboard={r_cb.roughness:.3f} should exceed gradient={r_grad.roughness:.3f}')
|
||||
|
||||
def test_noise_roughness_exceeds_half(self):
|
||||
"""For fully random noise, LBP component alone contributes ≥0.5 to roughness."""
|
||||
r = estimate_roughness(_noise(h=64, w=64), roi_frac=1.0)
|
||||
assert r.roughness >= 0.5, (
|
||||
f'random noise roughness should be ≥0.5, got {r.roughness:.3f}')
|
||||
|
||||
def test_high_roi_frac_uses_more_data(self):
|
||||
"""More floor area → LBP/Gabor computed on more pixels; score should be consistent."""
|
||||
img = _noise(h=128, w=64)
|
||||
r_small = estimate_roughness(img, roi_frac=0.10)
|
||||
r_large = estimate_roughness(img, roi_frac=0.90)
|
||||
# Both should be positive since the full image is noise
|
||||
assert r_small.roughness >= 0.0
|
||||
assert r_large.roughness >= 0.0
|
||||
|
||||
|
||||
# ── estimate_roughness — ROI sensitivity ─────────────────────────────────────
|
||||
|
||||
class TestEstimateRoughnessROI:
|
||||
|
||||
def test_rough_bottom_smooth_top(self):
|
||||
"""
|
||||
Image where bottom half is noise and top half is solid.
|
||||
roi_frac=0.5 → picks the noisy bottom → high roughness.
|
||||
"""
|
||||
h, w = 128, 64
|
||||
img = np.zeros((h, w, 3), dtype=np.uint8)
|
||||
img[:h // 2, :] = 128 # smooth top
|
||||
rng = np.random.default_rng(0)
|
||||
img[h // 2:, :] = rng.integers(0, 256, (h // 2, w, 3), dtype=np.uint8)
|
||||
|
||||
r_bottom = estimate_roughness(img, roi_frac=0.50) # noisy half
|
||||
r_top = estimate_roughness(img, roi_frac=0.01) # tiny slice near top
|
||||
|
||||
assert r_bottom.roughness > r_top.roughness, (
|
||||
f'bottom(noisy)={r_bottom.roughness:.3f} should exceed top(smooth)={r_top.roughness:.3f}')
|
||||
|
||||
def test_roi_frac_clipped_to_valid_range(self):
|
||||
"""roi_frac > 1.0 should not crash; treated as 1.0."""
|
||||
r = estimate_roughness(_noise(), roi_frac=1.5)
|
||||
assert 0.0 <= r.roughness <= 1.0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
@ -0,0 +1,11 @@
|
||||
geofence:
|
||||
ros__parameters:
|
||||
# Polygon vertices as flat list [x1, y1, x2, y2, ...]
|
||||
# Example: square from (0,0) to (10,10)
|
||||
geofence_vertices: [0.0, 0.0, 10.0, 0.0, 10.0, 10.0, 0.0, 10.0]
|
||||
|
||||
# Enforce boundary by zeroing cmd_vel on breach
|
||||
enforce_boundary: false
|
||||
|
||||
# Safety margin (m) - breach triggered before actual boundary
|
||||
margin: 0.0
|
||||
@ -0,0 +1,31 @@
|
||||
"""Launch file for geofence enforcer node."""
|
||||
|
||||
from launch import LaunchDescription
|
||||
from launch_ros.actions import Node
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
"""Generate launch description for geofence."""
|
||||
pkg_dir = get_package_share_directory("saltybot_geofence")
|
||||
config_file = os.path.join(pkg_dir, "config", "geofence_config.yaml")
|
||||
|
||||
return LaunchDescription(
|
||||
[
|
||||
DeclareLaunchArgument(
|
||||
"config_file",
|
||||
default_value=config_file,
|
||||
description="Path to configuration YAML file",
|
||||
),
|
||||
Node(
|
||||
package="saltybot_geofence",
|
||||
executable="geofence_node",
|
||||
name="geofence",
|
||||
output="screen",
|
||||
parameters=[LaunchConfiguration("config_file")],
|
||||
),
|
||||
]
|
||||
)
|
||||
22
jetson/ros2_ws/src/saltybot_geofence/package.xml
Normal file
22
jetson/ros2_ws/src/saltybot_geofence/package.xml
Normal file
@ -0,0 +1,22 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_geofence</name>
|
||||
<version>0.1.0</version>
|
||||
<description>Geofence boundary enforcer for SaltyBot</description>
|
||||
<maintainer email="sl-controls@saltylab.local">SaltyLab Controls</maintainer>
|
||||
<license>MIT</license>
|
||||
|
||||
<buildtool_depend>ament_python</buildtool_depend>
|
||||
<depend>rclpy</depend>
|
||||
<depend>nav_msgs</depend>
|
||||
<depend>std_msgs</depend>
|
||||
<depend>geometry_msgs</depend>
|
||||
|
||||
<test_depend>pytest</test_depend>
|
||||
<test_depend>nav_msgs</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1,143 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Geofence boundary enforcer for SaltyBot.
|
||||
|
||||
Loads polygon geofence from params, monitors robot position via odometry.
|
||||
Publishes Bool on /saltybot/geofence_breach when exiting boundary.
|
||||
Optionally zeros cmd_vel to enforce boundary.
|
||||
|
||||
Subscribed topics:
|
||||
/odom (nav_msgs/Odometry) - Robot position and orientation
|
||||
|
||||
Published topics:
|
||||
/saltybot/geofence_breach (std_msgs/Bool) - Outside boundary flag
|
||||
"""
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from nav_msgs.msg import Odometry
|
||||
from std_msgs.msg import Bool
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class GeofenceNode(Node):
|
||||
"""ROS2 node for geofence boundary enforcement."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("geofence")
|
||||
|
||||
# Parameters
|
||||
self.declare_parameter("geofence_vertices", [])
|
||||
self.declare_parameter("enforce_boundary", False)
|
||||
self.declare_parameter("margin", 0.0)
|
||||
|
||||
vertices = self.get_parameter("geofence_vertices").value
|
||||
self.enforce_boundary = self.get_parameter("enforce_boundary").value
|
||||
self.margin = self.get_parameter("margin").value
|
||||
|
||||
# Parse vertices from flat list [x1,y1,x2,y2,...]
|
||||
self.geofence_vertices = self._parse_vertices(vertices)
|
||||
|
||||
# State tracking
|
||||
self.robot_x = 0.0
|
||||
self.robot_y = 0.0
|
||||
self.inside_geofence = True
|
||||
self.breach_published = False
|
||||
|
||||
# Subscription to odometry
|
||||
self.sub_odom = self.create_subscription(
|
||||
Odometry, "/odom", self._on_odometry, 10
|
||||
)
|
||||
|
||||
# Publisher for breach status
|
||||
self.pub_breach = self.create_publisher(Bool, "/saltybot/geofence_breach", 10)
|
||||
|
||||
self.get_logger().info(
|
||||
f"Geofence enforcer initialized with {len(self.geofence_vertices)} vertices. "
|
||||
f"Enforce: {self.enforce_boundary}, Margin: {self.margin}m"
|
||||
)
|
||||
|
||||
if len(self.geofence_vertices) > 0:
|
||||
self.get_logger().info(f"Geofence vertices: {self.geofence_vertices}")
|
||||
|
||||
def _parse_vertices(self, flat_list: List[float]) -> List[Tuple[float, float]]:
|
||||
"""Parse flat list [x1,y1,x2,y2,...] into vertex tuples."""
|
||||
if len(flat_list) < 6: # Need at least 3 vertices (6 values)
|
||||
self.get_logger().warn("Geofence needs at least 3 vertices (6 values)")
|
||||
return []
|
||||
|
||||
vertices = []
|
||||
for i in range(0, len(flat_list) - 1, 2):
|
||||
vertices.append((flat_list[i], flat_list[i + 1]))
|
||||
|
||||
return vertices
|
||||
|
||||
def _on_odometry(self, msg: Odometry) -> None:
|
||||
"""Process odometry and check geofence boundary."""
|
||||
if len(self.geofence_vertices) == 0:
|
||||
# No geofence defined
|
||||
self.inside_geofence = True
|
||||
return
|
||||
|
||||
# Extract robot position
|
||||
self.robot_x = msg.pose.pose.position.x
|
||||
self.robot_y = msg.pose.pose.position.y
|
||||
|
||||
# Check if inside geofence
|
||||
self.inside_geofence = self._point_in_polygon(
|
||||
(self.robot_x, self.robot_y), self.geofence_vertices
|
||||
)
|
||||
|
||||
# Publish breach status
|
||||
breach = not self.inside_geofence
|
||||
if breach and not self.breach_published:
|
||||
self.get_logger().warn(
|
||||
f"GEOFENCE BREACH! Robot at ({self.robot_x:.2f}, {self.robot_y:.2f})"
|
||||
)
|
||||
self.breach_published = True
|
||||
elif not breach and self.breach_published:
|
||||
self.get_logger().info(
|
||||
f"Robot re-entered geofence at ({self.robot_x:.2f}, {self.robot_y:.2f})"
|
||||
)
|
||||
self.breach_published = False
|
||||
|
||||
msg_breach = Bool(data=breach)
|
||||
self.pub_breach.publish(msg_breach)
|
||||
|
||||
def _point_in_polygon(self, point: Tuple[float, float], vertices: List[Tuple[float, float]]) -> bool:
|
||||
"""Ray casting algorithm for point-in-polygon test."""
|
||||
x, y = point
|
||||
n = len(vertices)
|
||||
inside = False
|
||||
|
||||
p1x, p1y = vertices[0]
|
||||
for i in range(1, n + 1):
|
||||
p2x, p2y = vertices[i % n]
|
||||
|
||||
# Check if ray crosses edge
|
||||
if y > min(p1y, p2y):
|
||||
if y <= max(p1y, p2y):
|
||||
if x <= max(p1x, p2x):
|
||||
if p1y != p2y:
|
||||
xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x
|
||||
if p1x == p2x or x <= xinters:
|
||||
inside = not inside
|
||||
|
||||
p1x, p1y = p2x, p2y
|
||||
|
||||
return inside
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = GeofenceNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
5
jetson/ros2_ws/src/saltybot_geofence/setup.cfg
Normal file
5
jetson/ros2_ws/src/saltybot_geofence/setup.cfg
Normal file
@ -0,0 +1,5 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/saltybot_geofence
|
||||
|
||||
[install]
|
||||
install_scripts=$base/lib/saltybot_geofence
|
||||
24
jetson/ros2_ws/src/saltybot_geofence/setup.py
Normal file
24
jetson/ros2_ws/src/saltybot_geofence/setup.py
Normal file
@ -0,0 +1,24 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name='saltybot_geofence',
|
||||
version='0.1.0',
|
||||
packages=find_packages(),
|
||||
data_files=[
|
||||
('share/ament_index/resource_index/packages', ['resource/saltybot_geofence']),
|
||||
('share/saltybot_geofence', ['package.xml']),
|
||||
('share/saltybot_geofence/config', ['config/geofence_config.yaml']),
|
||||
('share/saltybot_geofence/launch', ['launch/geofence.launch.py']),
|
||||
],
|
||||
install_requires=['setuptools'],
|
||||
zip_safe=True,
|
||||
author='SaltyLab Controls',
|
||||
author_email='sl-controls@saltylab.local',
|
||||
description='Geofence boundary enforcer for SaltyBot',
|
||||
license='MIT',
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'geofence_node=saltybot_geofence.geofence_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
170
jetson/ros2_ws/src/saltybot_geofence/test/test_geofence.py
Normal file
170
jetson/ros2_ws/src/saltybot_geofence/test/test_geofence.py
Normal file
@ -0,0 +1,170 @@
|
||||
"""Tests for geofence boundary enforcer."""
|
||||
|
||||
import pytest
|
||||
from nav_msgs.msg import Odometry
|
||||
from geometry_msgs.msg import Point, Quaternion, Pose, PoseWithCovariance, TwistWithCovariance
|
||||
import rclpy
|
||||
from rclpy.time import Time
|
||||
|
||||
from saltybot_geofence.geofence_node import GeofenceNode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rclpy_fixture():
|
||||
rclpy.init()
|
||||
yield
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def node(rclpy_fixture):
|
||||
node = GeofenceNode()
|
||||
yield node
|
||||
node.destroy_node()
|
||||
|
||||
|
||||
class TestInit:
|
||||
def test_node_initialization(self, node):
|
||||
assert node.enforce_boundary is False
|
||||
assert node.margin == 0.0
|
||||
assert node.inside_geofence is True
|
||||
|
||||
|
||||
class TestVertexParsing:
|
||||
def test_parse_vertices_valid(self, node):
|
||||
vertices = node._parse_vertices([0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0])
|
||||
assert len(vertices) == 4
|
||||
assert vertices[0] == (0.0, 0.0)
|
||||
assert vertices[1] == (1.0, 0.0)
|
||||
|
||||
def test_parse_vertices_insufficient(self, node):
|
||||
vertices = node._parse_vertices([0.0, 0.0, 1.0])
|
||||
assert len(vertices) == 0
|
||||
|
||||
|
||||
class TestPointInPolygon:
|
||||
def test_point_inside_square(self, node):
|
||||
vertices = [(0.0, 0.0), (2.0, 0.0), (2.0, 2.0), (0.0, 2.0)]
|
||||
assert node._point_in_polygon((1.0, 1.0), vertices) is True
|
||||
|
||||
def test_point_outside_square(self, node):
|
||||
vertices = [(0.0, 0.0), (2.0, 0.0), (2.0, 2.0), (0.0, 2.0)]
|
||||
assert node._point_in_polygon((3.0, 1.0), vertices) is False
|
||||
|
||||
def test_point_on_vertex(self, node):
|
||||
vertices = [(0.0, 0.0), (2.0, 0.0), (2.0, 2.0), (0.0, 2.0)]
|
||||
# Point on vertex behavior may vary (typically outside)
|
||||
result = node._point_in_polygon((0.0, 0.0), vertices)
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_point_on_edge(self, node):
|
||||
vertices = [(0.0, 0.0), (2.0, 0.0), (2.0, 2.0), (0.0, 2.0)]
|
||||
# Point on edge behavior (typically outside)
|
||||
result = node._point_in_polygon((1.0, 0.0), vertices)
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_triangle_inside(self, node):
|
||||
vertices = [(0.0, 0.0), (4.0, 0.0), (2.0, 3.0)]
|
||||
assert node._point_in_polygon((2.0, 1.0), vertices) is True
|
||||
|
||||
def test_triangle_outside(self, node):
|
||||
vertices = [(0.0, 0.0), (4.0, 0.0), (2.0, 3.0)]
|
||||
assert node._point_in_polygon((5.0, 5.0), vertices) is False
|
||||
|
||||
def test_concave_polygon(self, node):
|
||||
# L-shaped polygon
|
||||
vertices = [(0.0, 0.0), (3.0, 0.0), (3.0, 1.0), (1.0, 1.0), (1.0, 3.0), (0.0, 3.0)]
|
||||
assert node._point_in_polygon((0.5, 0.5), vertices) is True
|
||||
assert node._point_in_polygon((2.0, 2.0), vertices) is False
|
||||
|
||||
def test_circle_approximation(self, node):
|
||||
# Octagon approximating circle
|
||||
import math
|
||||
vertices = []
|
||||
for i in range(8):
|
||||
angle = 2 * math.pi * i / 8
|
||||
vertices.append((math.cos(angle), math.sin(angle)))
|
||||
|
||||
# Center should be inside
|
||||
assert node._point_in_polygon((0.0, 0.0), vertices) is True
|
||||
# Far outside should be outside
|
||||
assert node._point_in_polygon((10.0, 10.0), vertices) is False
|
||||
|
||||
|
||||
class TestOdometryProcessing:
|
||||
def test_odometry_update_position(self, node):
|
||||
node.geofence_vertices = [(0.0, 0.0), (10.0, 0.0), (10.0, 10.0), (0.0, 10.0)]
|
||||
|
||||
msg = Odometry()
|
||||
msg.pose.pose.position.x = 5.0
|
||||
msg.pose.pose.position.y = 5.0
|
||||
|
||||
node._on_odometry(msg)
|
||||
|
||||
assert node.robot_x == 5.0
|
||||
assert node.robot_y == 5.0
|
||||
|
||||
def test_breach_detection_inside(self, node):
|
||||
node.geofence_vertices = [(0.0, 0.0), (10.0, 0.0), (10.0, 10.0), (0.0, 10.0)]
|
||||
|
||||
msg = Odometry()
|
||||
msg.pose.pose.position.x = 5.0
|
||||
msg.pose.pose.position.y = 5.0
|
||||
|
||||
node._on_odometry(msg)
|
||||
|
||||
assert node.inside_geofence is True
|
||||
|
||||
def test_breach_detection_outside(self, node):
|
||||
node.geofence_vertices = [(0.0, 0.0), (10.0, 0.0), (10.0, 10.0), (0.0, 10.0)]
|
||||
|
||||
msg = Odometry()
|
||||
msg.pose.pose.position.x = 15.0
|
||||
msg.pose.pose.position.y = 5.0
|
||||
|
||||
node._on_odometry(msg)
|
||||
|
||||
assert node.inside_geofence is False
|
||||
|
||||
def test_breach_flag_transition(self, node):
|
||||
node.geofence_vertices = [(0.0, 0.0), (10.0, 0.0), (10.0, 10.0), (0.0, 10.0)]
|
||||
assert node.breach_published is False
|
||||
|
||||
# Move outside
|
||||
msg = Odometry()
|
||||
msg.pose.pose.position.x = 15.0
|
||||
msg.pose.pose.position.y = 5.0
|
||||
node._on_odometry(msg)
|
||||
|
||||
assert node.breach_published is True
|
||||
|
||||
# Move back inside
|
||||
msg.pose.pose.position.x = 5.0
|
||||
msg.pose.pose.position.y = 5.0
|
||||
node._on_odometry(msg)
|
||||
|
||||
assert node.breach_published is False
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_no_geofence_defined(self, node):
|
||||
node.geofence_vertices = []
|
||||
|
||||
msg = Odometry()
|
||||
msg.pose.pose.position.x = 0.0
|
||||
msg.pose.pose.position.y = 0.0
|
||||
|
||||
node._on_odometry(msg)
|
||||
|
||||
# Should default to safe (inside)
|
||||
assert node.inside_geofence is True
|
||||
|
||||
def test_very_small_polygon(self, node):
|
||||
vertices = [(0.0, 0.0), (0.01, 0.0), (0.01, 0.01), (0.0, 0.01)]
|
||||
assert node._point_in_polygon((0.005, 0.005), vertices) is True
|
||||
assert node._point_in_polygon((0.1, 0.1), vertices) is False
|
||||
|
||||
def test_large_coordinates(self, node):
|
||||
vertices = [(0.0, 0.0), (1000.0, 0.0), (1000.0, 1000.0), (0.0, 1000.0)]
|
||||
assert node._point_in_polygon((500.0, 500.0), vertices) is True
|
||||
assert node._point_in_polygon((1500.0, 500.0), vertices) is False
|
||||
@ -0,0 +1,14 @@
|
||||
topic_memory_node:
|
||||
ros__parameters:
|
||||
conversation_topic: "/social/conversation_text" # Input: JSON String {person_id, text}
|
||||
output_topic: "/saltybot/conversation_topics" # Output: JSON String per utterance
|
||||
|
||||
# Keyword extraction
|
||||
min_word_length: 3 # Skip words shorter than this
|
||||
max_keywords_per_msg: 10 # Extract at most this many keywords per utterance
|
||||
|
||||
# Per-person rolling window
|
||||
max_topics_per_person: 30 # Keep last N unique topics per person
|
||||
|
||||
# Stale-person pruning (0 = disabled)
|
||||
prune_after_s: 1800.0 # Forget person after 30 min of inactivity
|
||||
@ -0,0 +1,42 @@
|
||||
"""topic_memory.launch.py — Launch conversation topic memory node (Issue #299).
|
||||
|
||||
Usage:
|
||||
ros2 launch saltybot_social topic_memory.launch.py
|
||||
ros2 launch saltybot_social topic_memory.launch.py max_topics_per_person:=50
|
||||
"""
|
||||
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg = get_package_share_directory("saltybot_social")
|
||||
cfg = os.path.join(pkg, "config", "topic_memory_params.yaml")
|
||||
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument("max_topics_per_person", default_value="30",
|
||||
description="Rolling topic window per person"),
|
||||
DeclareLaunchArgument("max_keywords_per_msg", default_value="10",
|
||||
description="Max keywords extracted per utterance"),
|
||||
DeclareLaunchArgument("prune_after_s", default_value="1800.0",
|
||||
description="Forget persons idle this long (0=off)"),
|
||||
|
||||
Node(
|
||||
package="saltybot_social",
|
||||
executable="topic_memory_node",
|
||||
name="topic_memory_node",
|
||||
output="screen",
|
||||
parameters=[
|
||||
cfg,
|
||||
{
|
||||
"max_topics_per_person": LaunchConfiguration("max_topics_per_person"),
|
||||
"max_keywords_per_msg": LaunchConfiguration("max_keywords_per_msg"),
|
||||
"prune_after_s": LaunchConfiguration("prune_after_s"),
|
||||
},
|
||||
],
|
||||
),
|
||||
])
|
||||
@ -0,0 +1,268 @@
|
||||
"""topic_memory_node.py — Conversation topic memory.
|
||||
Issue #299
|
||||
|
||||
Subscribes to /social/conversation_text (std_msgs/String, JSON payload
|
||||
{"person_id": "...", "text": "..."}), extracts key topics via stop-word
|
||||
filtered keyword extraction, and maintains a per-person rolling topic
|
||||
history.
|
||||
|
||||
On every message that yields at least one new keyword the node publishes
|
||||
an updated topic snapshot on /saltybot/conversation_topics (std_msgs/String,
|
||||
JSON) — enabling recall like "last time you mentioned coffee" or "you talked
|
||||
about the weather with alice".
|
||||
|
||||
Published JSON format
|
||||
─────────────────────
|
||||
{
|
||||
"person_id": "alice",
|
||||
"recent_topics": ["coffee", "weather", "robot"], // most-recent first
|
||||
"new_topics": ["coffee"], // keywords added this turn
|
||||
"ts": 1234567890.123
|
||||
}
|
||||
|
||||
Keyword extraction pipeline
|
||||
────────────────────────────
|
||||
1. Lowercase + tokenise on non-word characters
|
||||
2. Filter: length >= min_word_length, alphabetic only
|
||||
3. Remove stop words (built-in English list)
|
||||
4. Deduplicate within the utterance
|
||||
5. Take first max_keywords_per_msg tokens
|
||||
|
||||
Per-person storage
|
||||
──────────────────
|
||||
Ordered list (insertion order), capped at max_topics_per_person.
|
||||
Duplicate keywords are promoted to the front (most-recent position).
|
||||
Persons not seen for prune_after_s seconds are pruned on next publish
|
||||
(set to 0 to disable pruning).
|
||||
|
||||
Parameters
|
||||
──────────
|
||||
conversation_topic (str, "/social/conversation_text")
|
||||
output_topic (str, "/saltybot/conversation_topics")
|
||||
min_word_length (int, 3) minimum character length to keep
|
||||
max_keywords_per_msg (int, 10) max keywords extracted per utterance
|
||||
max_topics_per_person(int, 30) rolling window per person
|
||||
prune_after_s (float, 1800.0) forget persons idle this long (0=off)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile
|
||||
from std_msgs.msg import String
|
||||
|
||||
|
||||
# ── Stop-word list (English) ───────────────────────────────────────────────────
|
||||
|
||||
STOP_WORDS: frozenset = frozenset({
|
||||
"the", "a", "an", "and", "or", "but", "nor", "not", "so", "yet",
|
||||
"for", "with", "from", "into", "onto", "upon", "about", "above",
|
||||
"after", "before", "between", "during", "through", "under", "over",
|
||||
"at", "by", "in", "of", "on", "to", "up", "as",
|
||||
"is", "are", "was", "were", "be", "been", "being",
|
||||
"have", "has", "had", "do", "does", "did",
|
||||
"will", "would", "could", "should", "may", "might", "shall", "can",
|
||||
"it", "its", "this", "that", "these", "those",
|
||||
"i", "me", "my", "myself", "we", "our", "ours",
|
||||
"you", "your", "yours", "he", "him", "his", "she", "her", "hers",
|
||||
"they", "them", "their", "theirs",
|
||||
"what", "which", "who", "whom", "when", "where", "why", "how",
|
||||
"all", "each", "every", "both", "few", "more", "most",
|
||||
"other", "some", "such", "no", "only", "own", "same",
|
||||
"than", "then", "too", "very", "just", "also",
|
||||
"get", "got", "say", "said", "know", "think", "go", "going", "come",
|
||||
"like", "want", "see", "take", "make", "give", "look",
|
||||
"yes", "yeah", "okay", "ok", "oh", "ah", "um", "uh", "well",
|
||||
"now", "here", "there", "hi", "hey", "hello",
|
||||
})
|
||||
|
||||
_PUNCT_RE = re.compile(r"[^\w\s]")
|
||||
|
||||
|
||||
# ── Keyword extraction ────────────────────────────────────────────────────────
|
||||
|
||||
def extract_keywords(text: str,
|
||||
min_length: int = 3,
|
||||
max_keywords: int = 10) -> List[str]:
|
||||
"""Return a deduplicated list of meaningful keywords from *text*.
|
||||
|
||||
Steps: lowercase -> strip punctuation -> split -> filter stop words &
|
||||
length -> deduplicate -> cap at max_keywords.
|
||||
"""
|
||||
cleaned = _PUNCT_RE.sub(" ", text.lower())
|
||||
seen: dict = {} # ordered-set via insertion-order dict
|
||||
for tok in cleaned.split():
|
||||
tok = tok.strip(string.punctuation + "_")
|
||||
if (len(tok) >= min_length
|
||||
and tok.isalpha()
|
||||
and tok not in STOP_WORDS
|
||||
and tok not in seen):
|
||||
seen[tok] = None
|
||||
if len(seen) >= max_keywords:
|
||||
break
|
||||
return list(seen)
|
||||
|
||||
|
||||
# ── Per-person topic memory ───────────────────────────────────────────────────
|
||||
|
||||
class PersonTopicMemory:
|
||||
"""Rolling, deduplicated topic list for one person."""
|
||||
|
||||
def __init__(self, max_topics: int = 30) -> None:
|
||||
self._max = max_topics
|
||||
self._topics: List[str] = [] # oldest -> newest order
|
||||
self._topic_set: set = set()
|
||||
self.last_updated: float = 0.0
|
||||
|
||||
def add(self, keywords: List[str]) -> List[str]:
|
||||
"""Add *keywords*; promote duplicates to front, evict oldest over cap.
|
||||
|
||||
Returns the list of newly added (previously unseen) keywords.
|
||||
"""
|
||||
added: List[str] = []
|
||||
for kw in keywords:
|
||||
if kw in self._topic_set:
|
||||
# Promote to most-recent position
|
||||
self._topics.remove(kw)
|
||||
self._topics.append(kw)
|
||||
else:
|
||||
self._topics.append(kw)
|
||||
self._topic_set.add(kw)
|
||||
added.append(kw)
|
||||
# Trim oldest if over cap
|
||||
while len(self._topics) > self._max:
|
||||
evicted = self._topics.pop(0)
|
||||
self._topic_set.discard(evicted)
|
||||
self.last_updated = time.monotonic()
|
||||
return added
|
||||
|
||||
@property
|
||||
def recent_topics(self) -> List[str]:
|
||||
"""Most-recent topics first."""
|
||||
return list(reversed(self._topics))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._topics)
|
||||
|
||||
|
||||
# ── ROS2 node ─────────────────────────────────────────────────────────────────
|
||||
|
||||
class TopicMemoryNode(Node):
|
||||
"""Extracts and remembers conversation topics per person."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("topic_memory_node")
|
||||
|
||||
self.declare_parameter("conversation_topic", "/social/conversation_text")
|
||||
self.declare_parameter("output_topic", "/saltybot/conversation_topics")
|
||||
self.declare_parameter("min_word_length", 3)
|
||||
self.declare_parameter("max_keywords_per_msg", 10)
|
||||
self.declare_parameter("max_topics_per_person", 30)
|
||||
self.declare_parameter("prune_after_s", 1800.0)
|
||||
|
||||
conv_topic = self.get_parameter("conversation_topic").value
|
||||
out_topic = self.get_parameter("output_topic").value
|
||||
self._min_len = self.get_parameter("min_word_length").value
|
||||
self._max_kw = self.get_parameter("max_keywords_per_msg").value
|
||||
self._max_tp = self.get_parameter("max_topics_per_person").value
|
||||
self._prune_s = self.get_parameter("prune_after_s").value
|
||||
|
||||
# person_id -> PersonTopicMemory
|
||||
self._memory: Dict[str, PersonTopicMemory] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
qos = QoSProfile(depth=10)
|
||||
self._pub = self.create_publisher(String, out_topic, qos)
|
||||
self._sub = self.create_subscription(
|
||||
String, conv_topic, self._on_conversation, qos
|
||||
)
|
||||
|
||||
self.get_logger().info(
|
||||
f"TopicMemoryNode ready "
|
||||
f"(min_len={self._min_len}, max_kw={self._max_kw}, "
|
||||
f"max_topics={self._max_tp}, prune_after={self._prune_s}s)"
|
||||
)
|
||||
|
||||
# ── Subscription ───────────────────────────────────────────────────────
|
||||
|
||||
def _on_conversation(self, msg: String) -> None:
|
||||
try:
|
||||
payload = json.loads(msg.data)
|
||||
person_id: str = str(payload.get("person_id", "unknown"))
|
||||
text: str = str(payload.get("text", ""))
|
||||
except (json.JSONDecodeError, AttributeError) as exc:
|
||||
self.get_logger().warn(f"Bad conversation_text payload: {exc}")
|
||||
return
|
||||
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
keywords = extract_keywords(text, self._min_len, self._max_kw)
|
||||
if not keywords:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
self._prune_stale()
|
||||
if person_id not in self._memory:
|
||||
self._memory[person_id] = PersonTopicMemory(self._max_tp)
|
||||
mem = self._memory[person_id]
|
||||
new_topics = mem.add(keywords)
|
||||
recent = mem.recent_topics
|
||||
|
||||
out = String()
|
||||
out.data = json.dumps({
|
||||
"person_id": person_id,
|
||||
"recent_topics": recent,
|
||||
"new_topics": new_topics,
|
||||
"ts": time.time(),
|
||||
})
|
||||
self._pub.publish(out)
|
||||
|
||||
if new_topics:
|
||||
self.get_logger().info(
|
||||
f"[{person_id}] new topics: {new_topics} | "
|
||||
f"memory: {recent[:5]}{'...' if len(recent) > 5 else ''}"
|
||||
)
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
def _prune_stale(self) -> None:
|
||||
"""Remove persons not seen for prune_after_s seconds (call under lock)."""
|
||||
if self._prune_s <= 0:
|
||||
return
|
||||
now = time.monotonic()
|
||||
stale = [pid for pid, m in self._memory.items()
|
||||
if (now - m.last_updated) > self._prune_s]
|
||||
for pid in stale:
|
||||
del self._memory[pid]
|
||||
self.get_logger().info(f"Pruned stale person: {pid}")
|
||||
|
||||
def get_memory(self, person_id: str) -> Optional[PersonTopicMemory]:
|
||||
"""Return topic memory for a person (None if not seen)."""
|
||||
with self._lock:
|
||||
return self._memory.get(person_id)
|
||||
|
||||
def all_persons(self) -> Dict[str, List[str]]:
|
||||
"""Return {person_id: recent_topics} snapshot for all known persons."""
|
||||
with self._lock:
|
||||
return {pid: m.recent_topics for pid, m in self._memory.items()}
|
||||
|
||||
|
||||
def main(args=None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = TopicMemoryNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
@ -53,6 +53,8 @@ setup(
|
||||
'face_track_servo_node = saltybot_social.face_track_servo_node:main',
|
||||
# Speech volume auto-adjuster (Issue #289)
|
||||
'volume_adjust_node = saltybot_social.volume_adjust_node:main',
|
||||
# Conversation topic memory (Issue #299)
|
||||
'topic_memory_node = saltybot_social.topic_memory_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
522
jetson/ros2_ws/src/saltybot_social/test/test_topic_memory.py
Normal file
522
jetson/ros2_ws/src/saltybot_social/test/test_topic_memory.py
Normal file
@ -0,0 +1,522 @@
|
||||
"""test_topic_memory.py — Offline tests for topic_memory_node (Issue #299).
|
||||
|
||||
Stubs out rclpy so tests run without a ROS install.
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import types
|
||||
import unittest
|
||||
|
||||
|
||||
# ── ROS2 stubs ────────────────────────────────────────────────────────────────
|
||||
|
||||
def _make_ros_stubs():
|
||||
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
|
||||
"std_msgs", "std_msgs.msg"):
|
||||
if mod_name not in sys.modules:
|
||||
sys.modules[mod_name] = types.ModuleType(mod_name)
|
||||
|
||||
class _Node:
|
||||
def __init__(self, name="node"):
|
||||
self._name = name
|
||||
if not hasattr(self, "_params"):
|
||||
self._params = {}
|
||||
self._pubs = {}
|
||||
self._subs = {}
|
||||
self._logs = []
|
||||
|
||||
def declare_parameter(self, name, default):
|
||||
if name not in self._params:
|
||||
self._params[name] = default
|
||||
|
||||
def get_parameter(self, name):
|
||||
class _P:
|
||||
def __init__(self, v): self.value = v
|
||||
return _P(self._params.get(name))
|
||||
|
||||
def create_publisher(self, msg_type, topic, qos):
|
||||
pub = _FakePub()
|
||||
self._pubs[topic] = pub
|
||||
return pub
|
||||
|
||||
def create_subscription(self, msg_type, topic, cb, qos):
|
||||
self._subs[topic] = cb
|
||||
return object()
|
||||
|
||||
def get_logger(self):
|
||||
node = self
|
||||
class _L:
|
||||
def info(self, m): node._logs.append(("INFO", m))
|
||||
def warn(self, m): node._logs.append(("WARN", m))
|
||||
def error(self, m): node._logs.append(("ERROR", m))
|
||||
return _L()
|
||||
|
||||
def destroy_node(self): pass
|
||||
|
||||
class _FakePub:
|
||||
def __init__(self):
|
||||
self.msgs = []
|
||||
def publish(self, msg):
|
||||
self.msgs.append(msg)
|
||||
|
||||
class _QoSProfile:
|
||||
def __init__(self, depth=10): self.depth = depth
|
||||
|
||||
class _String:
|
||||
def __init__(self): self.data = ""
|
||||
|
||||
rclpy_mod = sys.modules["rclpy"]
|
||||
rclpy_mod.init = lambda args=None: None
|
||||
rclpy_mod.spin = lambda node: None
|
||||
rclpy_mod.shutdown = lambda: None
|
||||
|
||||
sys.modules["rclpy.node"].Node = _Node
|
||||
sys.modules["rclpy.qos"].QoSProfile = _QoSProfile
|
||||
sys.modules["std_msgs.msg"].String = _String
|
||||
|
||||
return _Node, _FakePub, _String
|
||||
|
||||
|
||||
_Node, _FakePub, _String = _make_ros_stubs()
|
||||
|
||||
|
||||
# ── Module loader ─────────────────────────────────────────────────────────────
|
||||
|
||||
_SRC = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/saltybot_social/topic_memory_node.py"
|
||||
)
|
||||
|
||||
|
||||
def _load_mod():
|
||||
spec = importlib.util.spec_from_file_location("topic_memory_testmod", _SRC)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def _make_node(mod, **kwargs):
|
||||
node = mod.TopicMemoryNode.__new__(mod.TopicMemoryNode)
|
||||
defaults = {
|
||||
"conversation_topic": "/social/conversation_text",
|
||||
"output_topic": "/saltybot/conversation_topics",
|
||||
"min_word_length": 3,
|
||||
"max_keywords_per_msg": 10,
|
||||
"max_topics_per_person": 30,
|
||||
"prune_after_s": 1800.0,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
node._params = dict(defaults)
|
||||
mod.TopicMemoryNode.__init__(node)
|
||||
return node
|
||||
|
||||
|
||||
def _msg(person_id, text):
|
||||
m = _String()
|
||||
m.data = json.dumps({"person_id": person_id, "text": text})
|
||||
return m
|
||||
|
||||
|
||||
def _send(node, person_id, text):
|
||||
"""Deliver a conversation message to the node."""
|
||||
cb = node._subs["/social/conversation_text"]
|
||||
cb(_msg(person_id, text))
|
||||
|
||||
|
||||
# ── Tests: extract_keywords ───────────────────────────────────────────────────
|
||||
|
||||
class TestExtractKeywords(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def _kw(self, text, min_len=3, max_kw=10):
|
||||
return self.mod.extract_keywords(text, min_len, max_kw)
|
||||
|
||||
def test_basic(self):
|
||||
kws = self._kw("I love drinking coffee")
|
||||
self.assertIn("love", kws)
|
||||
self.assertIn("drinking", kws)
|
||||
self.assertIn("coffee", kws)
|
||||
|
||||
def test_stop_words_removed(self):
|
||||
kws = self._kw("the quick brown fox")
|
||||
self.assertNotIn("the", kws)
|
||||
|
||||
def test_short_words_removed(self):
|
||||
kws = self._kw("go to the museum now", min_len=4)
|
||||
self.assertNotIn("go", kws)
|
||||
self.assertNotIn("to", kws)
|
||||
self.assertNotIn("the", kws)
|
||||
|
||||
def test_deduplication(self):
|
||||
kws = self._kw("coffee coffee coffee")
|
||||
self.assertEqual(kws.count("coffee"), 1)
|
||||
|
||||
def test_max_keywords_cap(self):
|
||||
text = " ".join(f"word{i}" for i in range(20))
|
||||
kws = self._kw(text, max_kw=5)
|
||||
self.assertLessEqual(len(kws), 5)
|
||||
|
||||
def test_punctuation_stripped(self):
|
||||
kws = self._kw("Hello, world! How's weather?")
|
||||
# "world" and "weather" should be found, punctuation removed
|
||||
self.assertIn("world", kws)
|
||||
self.assertIn("weather", kws)
|
||||
|
||||
def test_case_insensitive(self):
|
||||
kws = self._kw("Robot ROBOT robot")
|
||||
self.assertEqual(kws.count("robot"), 1)
|
||||
|
||||
def test_empty_text(self):
|
||||
self.assertEqual(self._kw(""), [])
|
||||
|
||||
def test_all_stop_words(self):
|
||||
self.assertEqual(self._kw("the is a and"), [])
|
||||
|
||||
def test_non_alpha_excluded(self):
|
||||
kws = self._kw("model42 123 price500")
|
||||
# alphanumeric tokens like "model42" contain digits → excluded
|
||||
self.assertEqual(kws, [])
|
||||
|
||||
def test_preserves_order(self):
|
||||
kws = self._kw("zebra apple mango")
|
||||
self.assertEqual(kws, ["zebra", "apple", "mango"])
|
||||
|
||||
def test_min_length_three(self):
|
||||
kws = self._kw("cat dog elephant", min_len=3)
|
||||
self.assertIn("cat", kws)
|
||||
self.assertIn("dog", kws)
|
||||
self.assertIn("elephant", kws)
|
||||
|
||||
def test_min_length_four_excludes_short(self):
|
||||
kws = self._kw("cat dog elephant", min_len=4)
|
||||
self.assertNotIn("cat", kws)
|
||||
self.assertNotIn("dog", kws)
|
||||
self.assertIn("elephant", kws)
|
||||
|
||||
|
||||
# ── Tests: PersonTopicMemory ──────────────────────────────────────────────────
|
||||
|
||||
class TestPersonTopicMemory(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def _mem(self, max_topics=10):
|
||||
return self.mod.PersonTopicMemory(max_topics)
|
||||
|
||||
def test_empty_initially(self):
|
||||
self.assertEqual(len(self._mem()), 0)
|
||||
|
||||
def test_add_returns_new(self):
|
||||
m = self._mem()
|
||||
added = m.add(["coffee", "weather"])
|
||||
self.assertEqual(added, ["coffee", "weather"])
|
||||
|
||||
def test_duplicate_not_in_added(self):
|
||||
m = self._mem()
|
||||
m.add(["coffee"])
|
||||
added = m.add(["coffee", "weather"])
|
||||
self.assertNotIn("coffee", added)
|
||||
self.assertIn("weather", added)
|
||||
|
||||
def test_recent_topics_most_recent_first(self):
|
||||
m = self._mem()
|
||||
m.add(["coffee"])
|
||||
m.add(["weather"])
|
||||
topics = m.recent_topics
|
||||
self.assertEqual(topics[0], "weather")
|
||||
self.assertEqual(topics[1], "coffee")
|
||||
|
||||
def test_duplicate_promoted_to_front(self):
|
||||
m = self._mem()
|
||||
m.add(["coffee", "weather", "robot"])
|
||||
m.add(["coffee"]) # promote coffee to front
|
||||
topics = m.recent_topics
|
||||
self.assertEqual(topics[0], "coffee")
|
||||
|
||||
def test_cap_evicts_oldest(self):
|
||||
m = self._mem(max_topics=3)
|
||||
m.add(["alpha", "beta", "gamma"])
|
||||
m.add(["delta"]) # alpha should be evicted
|
||||
topics = m.recent_topics
|
||||
self.assertNotIn("alpha", topics)
|
||||
self.assertIn("delta", topics)
|
||||
self.assertEqual(len(topics), 3)
|
||||
|
||||
def test_len(self):
|
||||
m = self._mem()
|
||||
m.add(["coffee", "weather", "robot"])
|
||||
self.assertEqual(len(m), 3)
|
||||
|
||||
def test_last_updated_set(self):
|
||||
m = self._mem()
|
||||
before = time.monotonic()
|
||||
m.add(["test"])
|
||||
self.assertGreaterEqual(m.last_updated, before)
|
||||
|
||||
def test_empty_add(self):
|
||||
m = self._mem()
|
||||
added = m.add([])
|
||||
self.assertEqual(added, [])
|
||||
self.assertEqual(len(m), 0)
|
||||
|
||||
def test_many_duplicates_stay_within_cap(self):
|
||||
m = self._mem(max_topics=5)
|
||||
for _ in range(10):
|
||||
m.add(["coffee"])
|
||||
self.assertEqual(len(m), 1)
|
||||
self.assertIn("coffee", m.recent_topics)
|
||||
|
||||
|
||||
# ── Tests: node init ──────────────────────────────────────────────────────────
|
||||
|
||||
class TestNodeInit(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def test_instantiates(self):
|
||||
self.assertIsNotNone(_make_node(self.mod))
|
||||
|
||||
def test_pub_registered(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIn("/saltybot/conversation_topics", node._pubs)
|
||||
|
||||
def test_sub_registered(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIn("/social/conversation_text", node._subs)
|
||||
|
||||
def test_memory_empty(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertEqual(node.all_persons(), {})
|
||||
|
||||
def test_custom_topics(self):
|
||||
node = _make_node(self.mod,
|
||||
conversation_topic="/my/conv",
|
||||
output_topic="/my/topics")
|
||||
self.assertIn("/my/conv", node._subs)
|
||||
self.assertIn("/my/topics", node._pubs)
|
||||
|
||||
|
||||
# ── Tests: on_conversation callback ──────────────────────────────────────────
|
||||
|
||||
class TestOnConversation(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def setUp(self):
|
||||
self.node = _make_node(self.mod)
|
||||
self.pub = self.node._pubs["/saltybot/conversation_topics"]
|
||||
|
||||
def test_publishes_on_keyword(self):
|
||||
_send(self.node, "alice", "I love drinking coffee")
|
||||
self.assertEqual(len(self.pub.msgs), 1)
|
||||
|
||||
def test_payload_is_json(self):
|
||||
_send(self.node, "alice", "I love drinking coffee")
|
||||
payload = json.loads(self.pub.msgs[-1].data)
|
||||
self.assertIsInstance(payload, dict)
|
||||
|
||||
def test_payload_person_id(self):
|
||||
_send(self.node, "bob", "I enjoy hiking mountains")
|
||||
payload = json.loads(self.pub.msgs[-1].data)
|
||||
self.assertEqual(payload["person_id"], "bob")
|
||||
|
||||
def test_payload_recent_topics(self):
|
||||
_send(self.node, "alice", "I love coffee weather robots")
|
||||
payload = json.loads(self.pub.msgs[-1].data)
|
||||
self.assertIsInstance(payload["recent_topics"], list)
|
||||
self.assertGreater(len(payload["recent_topics"]), 0)
|
||||
|
||||
def test_payload_new_topics(self):
|
||||
_send(self.node, "alice", "I love coffee")
|
||||
payload = json.loads(self.pub.msgs[-1].data)
|
||||
self.assertIn("coffee", payload["new_topics"])
|
||||
|
||||
def test_payload_has_ts(self):
|
||||
_send(self.node, "alice", "I love coffee")
|
||||
payload = json.loads(self.pub.msgs[-1].data)
|
||||
self.assertIn("ts", payload)
|
||||
|
||||
def test_all_stop_words_no_publish(self):
|
||||
_send(self.node, "alice", "the is and or")
|
||||
self.assertEqual(len(self.pub.msgs), 0)
|
||||
|
||||
def test_empty_text_no_publish(self):
|
||||
_send(self.node, "alice", "")
|
||||
self.assertEqual(len(self.pub.msgs), 0)
|
||||
|
||||
def test_bad_json_no_crash(self):
|
||||
m = _String(); m.data = "not json at all"
|
||||
self.node._subs["/social/conversation_text"](m)
|
||||
self.assertEqual(len(self.pub.msgs), 0)
|
||||
warns = [l for l in self.node._logs if l[0] == "WARN"]
|
||||
self.assertEqual(len(warns), 1)
|
||||
|
||||
def test_missing_person_id_defaults_unknown(self):
|
||||
m = _String(); m.data = json.dumps({"text": "coffee weather hiking"})
|
||||
self.node._subs["/social/conversation_text"](m)
|
||||
payload = json.loads(self.pub.msgs[-1].data)
|
||||
self.assertEqual(payload["person_id"], "unknown")
|
||||
|
||||
def test_duplicate_topic_not_in_new(self):
|
||||
_send(self.node, "alice", "coffee mountains weather")
|
||||
_send(self.node, "alice", "coffee") # coffee already known
|
||||
payload = json.loads(self.pub.msgs[-1].data)
|
||||
self.assertNotIn("coffee", payload["new_topics"])
|
||||
|
||||
def test_topics_accumulate_across_turns(self):
|
||||
_send(self.node, "alice", "coffee weather")
|
||||
_send(self.node, "alice", "mountains hiking")
|
||||
mem = self.node.get_memory("alice")
|
||||
topics = mem.recent_topics
|
||||
self.assertIn("coffee", topics)
|
||||
self.assertIn("mountains", topics)
|
||||
|
||||
def test_separate_persons_independent(self):
|
||||
_send(self.node, "alice", "coffee weather")
|
||||
_send(self.node, "bob", "robots motors")
|
||||
alice = self.node.get_memory("alice").recent_topics
|
||||
bob = self.node.get_memory("bob").recent_topics
|
||||
self.assertIn("coffee", alice)
|
||||
self.assertNotIn("robots", alice)
|
||||
self.assertIn("robots", bob)
|
||||
self.assertNotIn("coffee", bob)
|
||||
|
||||
def test_all_persons_snapshot(self):
|
||||
_send(self.node, "alice", "coffee weather")
|
||||
_send(self.node, "bob", "robots motors")
|
||||
persons = self.node.all_persons()
|
||||
self.assertIn("alice", persons)
|
||||
self.assertIn("bob", persons)
|
||||
|
||||
def test_recent_topics_most_recent_first(self):
|
||||
_send(self.node, "alice", "alpha")
|
||||
_send(self.node, "alice", "beta")
|
||||
_send(self.node, "alice", "gamma")
|
||||
topics = self.node.get_memory("alice").recent_topics
|
||||
self.assertEqual(topics[0], "gamma")
|
||||
|
||||
def test_stop_words_not_stored(self):
|
||||
_send(self.node, "alice", "the weather and coffee")
|
||||
topics = self.node.get_memory("alice").recent_topics
|
||||
self.assertNotIn("the", topics)
|
||||
self.assertNotIn("and", topics)
|
||||
|
||||
def test_max_keywords_respected(self):
|
||||
node = _make_node(self.mod, max_keywords_per_msg=3)
|
||||
_send(node, "alice",
|
||||
"coffee weather hiking mountains ocean desert forest lake")
|
||||
mem = node.get_memory("alice")
|
||||
# Only 3 keywords should have been extracted per message
|
||||
self.assertLessEqual(len(mem), 3)
|
||||
|
||||
def test_max_topics_cap(self):
|
||||
node = _make_node(self.mod, max_topics_per_person=5, max_keywords_per_msg=20)
|
||||
words = ["alpha", "beta", "gamma", "delta", "epsilon",
|
||||
"zeta", "eta", "theta"]
|
||||
for i, w in enumerate(words):
|
||||
_send(node, "alice", w)
|
||||
mem = node.get_memory("alice")
|
||||
self.assertLessEqual(len(mem), 5)
|
||||
|
||||
|
||||
# ── Tests: prune ──────────────────────────────────────────────────────────────
|
||||
|
||||
class TestPrune(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def test_no_prune_when_disabled(self):
|
||||
node = _make_node(self.mod, prune_after_s=0.0)
|
||||
_send(node, "alice", "coffee weather")
|
||||
# Manually expire timestamp
|
||||
node._memory["alice"].last_updated = time.monotonic() - 9999
|
||||
_send(node, "bob", "robots motors") # triggers prune check
|
||||
# alice should still be present (prune disabled)
|
||||
self.assertIn("alice", node.all_persons())
|
||||
|
||||
def test_prune_stale_person(self):
|
||||
node = _make_node(self.mod, prune_after_s=1.0)
|
||||
_send(node, "alice", "coffee weather")
|
||||
node._memory["alice"].last_updated = time.monotonic() - 10.0 # stale
|
||||
_send(node, "bob", "robots motors") # triggers prune
|
||||
self.assertNotIn("alice", node.all_persons())
|
||||
|
||||
def test_fresh_person_not_pruned(self):
|
||||
node = _make_node(self.mod, prune_after_s=1.0)
|
||||
_send(node, "alice", "coffee weather")
|
||||
# alice is fresh (just added)
|
||||
_send(node, "bob", "robots motors")
|
||||
self.assertIn("alice", node.all_persons())
|
||||
|
||||
|
||||
# ── Tests: source and config ──────────────────────────────────────────────────
|
||||
|
||||
class TestNodeSrc(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
with open(_SRC) as f: cls.src = f.read()
|
||||
|
||||
def test_issue_tag(self): self.assertIn("#299", self.src)
|
||||
def test_input_topic(self): self.assertIn("/social/conversation_text", self.src)
|
||||
def test_output_topic(self): self.assertIn("/saltybot/conversation_topics", self.src)
|
||||
def test_extract_keywords(self): self.assertIn("extract_keywords", self.src)
|
||||
def test_person_topic_memory(self):self.assertIn("PersonTopicMemory", self.src)
|
||||
def test_stop_words(self): self.assertIn("STOP_WORDS", self.src)
|
||||
def test_json_output(self): self.assertIn("json.dumps", self.src)
|
||||
def test_person_id_in_output(self):self.assertIn("person_id", self.src)
|
||||
def test_recent_topics_key(self): self.assertIn("recent_topics", self.src)
|
||||
def test_new_topics_key(self): self.assertIn("new_topics", self.src)
|
||||
def test_threading_lock(self): self.assertIn("threading.Lock", self.src)
|
||||
def test_prune_method(self): self.assertIn("_prune_stale", self.src)
|
||||
def test_main_defined(self): self.assertIn("def main", self.src)
|
||||
def test_min_word_length_param(self):self.assertIn("min_word_length", self.src)
|
||||
def test_max_topics_param(self): self.assertIn("max_topics_per_person", self.src)
|
||||
|
||||
|
||||
class TestConfig(unittest.TestCase):
|
||||
_CONFIG = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/config/topic_memory_params.yaml"
|
||||
)
|
||||
_LAUNCH = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/launch/topic_memory.launch.py"
|
||||
)
|
||||
_SETUP = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/setup.py"
|
||||
)
|
||||
|
||||
def test_config_exists(self):
|
||||
import os; self.assertTrue(os.path.exists(self._CONFIG))
|
||||
|
||||
def test_config_min_word_length(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("min_word_length", c)
|
||||
|
||||
def test_config_max_topics(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("max_topics_per_person", c)
|
||||
|
||||
def test_config_prune(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("prune_after_s", c)
|
||||
|
||||
def test_launch_exists(self):
|
||||
import os; self.assertTrue(os.path.exists(self._LAUNCH))
|
||||
|
||||
def test_launch_max_topics_arg(self):
|
||||
with open(self._LAUNCH) as f: c = f.read()
|
||||
self.assertIn("max_topics_per_person", c)
|
||||
|
||||
def test_entry_point(self):
|
||||
with open(self._SETUP) as f: c = f.read()
|
||||
self.assertIn("topic_memory_node", c)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
x
Reference in New Issue
Block a user