feat(perception): add person re-ID node (Issue #201)

Two new packages:
- saltybot_person_reid_msgs: PersonAppearance + PersonAppearanceArray msgs
- saltybot_person_reid: MobileNetV2 torso-crop embedder (128-dim L2-norm)
  with 128-bin HSV histogram fallback, cosine-similarity gallery with EMA
  identity updates and configurable age-based pruning, ROS2 node publishing
  PersonAppearanceArray on /saltybot/person_reid at 5 Hz.

Pure-Python helpers (_embedding_model, _reid_gallery) importable without
rclpy — 18/18 unit tests pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
sl-perception 2026-03-02 11:20:50 -05:00
parent 03e7995e66
commit 0d07b09949
18 changed files with 651 additions and 0 deletions

View File

@ -0,0 +1,6 @@
person_reid:
ros__parameters:
model_path: '' # path to MobileNetV2+projection ONNX file (empty = histogram fallback)
match_threshold: 0.75 # cosine similarity threshold for re-ID match
max_identity_age_s: 300.0 # seconds before unseen identity is pruned
publish_hz: 5.0 # publication rate (Hz)

View File

@ -0,0 +1,28 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_person_reid</name>
<version>0.1.0</version>
<description>
Person re-identification node — cross-camera appearance matching using
MobileNetV2 ONNX embeddings (128-dim, cosine similarity gallery).
</description>
<maintainer email="robot@saltylab.local">SaltyLab</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>sensor_msgs</depend>
<depend>vision_msgs</depend>
<depend>cv_bridge</depend>
<depend>message_filters</depend>
<depend>saltybot_person_reid_msgs</depend>
<exec_depend>python3-numpy</exec_depend>
<exec_depend>python3-opencv</exec_depend>
<test_depend>pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,95 @@
"""
_embedding_model.py Appearance embedding extractor (no ROS2 deps).
Primary: MobileNetV2 ONNX torso crop 128-dim L2-normalised embedding.
Fallback: 128-bin HSV histogram (H:16 × S:8) when no model file is available.
"""
from __future__ import annotations
import numpy as np
import cv2
# Top fraction of the bounding box height used as torso crop
_INPUT_SIZE = (128, 256) # (W, H) fed to MobileNetV2
class EmbeddingModel:
"""
Extract a 128-dim L2-normalised appearance embedding from a BGR crop.
Parameters
----------
model_path : str or None
Path to a MobileNetV2+projection ONNX file. When None (or file
not found), falls back to a 128-bin HSV colour histogram.
"""
def __init__(self, model_path: str | None = None):
self._net = None
if model_path:
try:
self._net = cv2.dnn.readNetFromONNX(model_path)
except Exception:
pass # histogram fallback
def embed(self, bgr_crop: np.ndarray) -> np.ndarray:
"""
Parameters
----------
bgr_crop : np.ndarray shape (H, W, 3) uint8
Returns
-------
np.ndarray shape (128,) float32, L2-normalised
"""
if bgr_crop.size == 0:
return np.zeros(128, dtype=np.float32)
if self._net is not None:
return self._mobilenet_embed(bgr_crop)
return self._histogram_embed(bgr_crop)
# ── MobileNetV2 path ──────────────────────────────────────────────────────
def _mobilenet_embed(self, bgr: np.ndarray) -> np.ndarray:
resized = cv2.resize(bgr, _INPUT_SIZE)
blob = cv2.dnn.blobFromImage(
resized,
scalefactor=1.0 / 255.0,
size=_INPUT_SIZE,
mean=(0.485 * 255, 0.456 * 255, 0.406 * 255),
swapRB=True,
crop=False,
)
# Std normalisation: divide channel-wise
blob[:, 0] /= 0.229
blob[:, 1] /= 0.224
blob[:, 2] /= 0.225
self._net.setInput(blob)
feat = self._net.forward().flatten().astype(np.float32)
# Ensure 128-dim — average-pool if model output differs
if feat.shape[0] != 128:
n = feat.shape[0]
block = max(1, n // 128)
feat = feat[: block * 128].reshape(128, block).mean(axis=1)
return _l2_norm(feat)
# ── HSV histogram fallback ────────────────────────────────────────────────
def _histogram_embed(self, bgr: np.ndarray) -> np.ndarray:
"""128-bin HSV histogram: 16 H-bins × 8 S-bins, concatenated."""
hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
hist = cv2.calcHist(
[hsv], [0, 1], None,
[16, 8], [0, 180, 0, 256],
).flatten().astype(np.float32)
return _l2_norm(hist)
def _l2_norm(v: np.ndarray) -> np.ndarray:
n = float(np.linalg.norm(v))
return v / n if n > 1e-6 else v

View File

@ -0,0 +1,105 @@
"""
_reid_gallery.py Appearance gallery for person re-identification (no ROS2 deps).
Matches an incoming embedding against stored identities using cosine similarity.
New identities are created when the best match falls below the threshold.
"""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from typing import List, Tuple
import numpy as np
@dataclass
class Identity:
identity_id: int
embedding: np.ndarray # shape (D,) L2-normalised
last_seen: float = field(default_factory=time.monotonic)
hit_count: int = 1
def update(self, new_embedding: np.ndarray, alpha: float = 0.1) -> None:
"""EMA update of the stored embedding, re-normalised after blending."""
merged = (1.0 - alpha) * self.embedding + alpha * new_embedding
n = float(np.linalg.norm(merged))
self.embedding = merged / n if n > 1e-6 else merged
self.last_seen = time.monotonic()
self.hit_count += 1
class ReidGallery:
"""
Lightweight cosine-similarity re-ID gallery.
Parameters
----------
match_threshold : float
Cosine similarity (dot product of unit vectors) required to accept a
match. Range [0, 1]; 0 = always new identity, 1 = perfect match only.
max_age_s : float
Identities not seen for this many seconds are pruned.
"""
def __init__(
self,
match_threshold: float = 0.75,
max_age_s: float = 300.0,
):
self._threshold = match_threshold
self._max_age_s = max_age_s
self._identities: List[Identity] = []
self._next_id = 1
def match(self, embedding: np.ndarray) -> Tuple[int, float, bool]:
"""
Match embedding against the gallery.
Returns
-------
(identity_id, match_score, is_new)
identity_id : assigned ID (new or existing)
match_score : cosine similarity to best match (0.0 if new)
is_new : True if a new identity was created
"""
self._prune()
if not self._identities:
return self._add_identity(embedding)
scores = np.array(
[float(np.dot(embedding, ident.embedding)) for ident in self._identities]
)
best_idx = int(np.argmax(scores))
best_score = float(scores[best_idx])
if best_score >= self._threshold:
ident = self._identities[best_idx]
ident.update(embedding)
return ident.identity_id, best_score, False
return self._add_identity(embedding)
# ── Internal helpers ──────────────────────────────────────────────────────
def _add_identity(self, embedding: np.ndarray) -> Tuple[int, float, bool]:
new_id = self._next_id
self._next_id += 1
self._identities.append(
Identity(identity_id=new_id, embedding=embedding.copy())
)
return new_id, 0.0, True
def _prune(self) -> None:
now = time.monotonic()
self._identities = [
ident
for ident in self._identities
if now - ident.last_seen < self._max_age_s
]
@property
def size(self) -> int:
return len(self._identities)

View File

@ -0,0 +1,174 @@
"""
person_reid_node.py Person re-identification for cross-camera tracking.
Subscribes to:
/person/detections vision_msgs/Detection2DArray (person bounding boxes)
/camera/color/image_raw sensor_msgs/Image (colour frame for crops)
Publishes:
/saltybot/person_reid saltybot_person_reid_msgs/PersonAppearanceArray (5 Hz)
For each detected person the node:
1. Crops the torso region (top 65 % of the bounding box height).
2. Extracts a 128-dim L2-normalised embedding via MobileNetV2 ONNX (if the
model file is provided) or a 128-bin HSV colour histogram (fallback).
3. Matches against a cosine-similarity gallery.
4. Assigns a persistent identity_id (new or existing).
Parameters:
model_path str '' Path to MobileNetV2+projection ONNX file
match_threshold float 0.75 Cosine similarity threshold for matching
max_identity_age_s float 300.0 Seconds before an unseen identity is pruned
publish_hz float 5.0 Publication rate (Hz)
"""
from __future__ import annotations
from typing import List
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
import message_filters
import cv2
import numpy as np
from cv_bridge import CvBridge
from sensor_msgs.msg import Image
from vision_msgs.msg import Detection2DArray
from saltybot_person_reid_msgs.msg import PersonAppearance, PersonAppearanceArray
from ._embedding_model import EmbeddingModel
from ._reid_gallery import ReidGallery
# Fraction of bbox height kept as torso crop (top portion)
_TORSO_FRAC = 0.65
_BEST_EFFORT_QOS = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=4,
)
class PersonReidNode(Node):
def __init__(self):
super().__init__('person_reid')
self.declare_parameter('model_path', '')
self.declare_parameter('match_threshold', 0.75)
self.declare_parameter('max_identity_age_s', 300.0)
self.declare_parameter('publish_hz', 5.0)
model_path = self.get_parameter('model_path').value
match_thr = self.get_parameter('match_threshold').value
max_age = self.get_parameter('max_identity_age_s').value
publish_hz = self.get_parameter('publish_hz').value
self._bridge = CvBridge()
self._embedder = EmbeddingModel(model_path or None)
self._gallery = ReidGallery(match_threshold=match_thr, max_age_s=max_age)
# Buffer: updated by frame callback, drained by timer
self._pending: List[PersonAppearance] = []
self._pending_header = None
# Synchronized subscribers
det_sub = message_filters.Subscriber(
self, Detection2DArray, '/person/detections',
qos_profile=_BEST_EFFORT_QOS)
img_sub = message_filters.Subscriber(
self, Image, '/camera/color/image_raw',
qos_profile=_BEST_EFFORT_QOS)
self._sync = message_filters.ApproximateTimeSynchronizer(
[det_sub, img_sub], queue_size=4, slop=0.1)
self._sync.registerCallback(self._on_frame)
self._pub = self.create_publisher(
PersonAppearanceArray, '/saltybot/person_reid', 10)
self.create_timer(1.0 / publish_hz, self._tick)
backend = 'ONNX' if self._embedder._net else 'histogram'
self.get_logger().info(
f'person_reid ready — backend={backend} '
f'threshold={match_thr} max_age={max_age}s'
)
# ── Frame callback ─────────────────────────────────────────────────────────
def _on_frame(self, det_msg: Detection2DArray, img_msg: Image) -> None:
if not det_msg.detections:
self._pending = []
self._pending_header = det_msg.header
return
try:
bgr = self._bridge.imgmsg_to_cv2(img_msg, desired_encoding='bgr8')
except Exception as exc:
self.get_logger().error(
f'imgmsg_to_cv2 failed: {exc}', throttle_duration_sec=5.0)
return
h_img, w_img = bgr.shape[:2]
appearances: List[PersonAppearance] = []
for det in det_msg.detections:
cx = det.bbox.center.position.x
cy = det.bbox.center.position.y
bw = det.bbox.size_x
bh = det.bbox.size_y
conf = det.results[0].hypothesis.score if det.results else 0.0
# Torso crop: top TORSO_FRAC of bounding box
x1 = max(0, int(cx - bw / 2.0))
y1 = max(0, int(cy - bh / 2.0))
x2 = min(w_img, int(cx + bw / 2.0))
y2 = min(h_img, int(cy - bh / 2.0 + bh * _TORSO_FRAC))
if x2 - x1 < 8 or y2 - y1 < 8:
continue
crop = bgr[y1:y2, x1:x2]
emb = self._embedder.embed(crop)
identity_id, match_score, is_new = self._gallery.match(emb)
app = PersonAppearance()
app.header = det_msg.header
app.track_id = identity_id
app.embedding = emb.tolist()
app.bbox = det.bbox
app.confidence = float(conf)
app.match_score = float(match_score)
app.is_new_identity = is_new
appearances.append(app)
self._pending = appearances
self._pending_header = det_msg.header
# ── 5 Hz publish tick ─────────────────────────────────────────────────────
def _tick(self) -> None:
if self._pending_header is None:
return
msg = PersonAppearanceArray()
msg.header = self._pending_header
msg.appearances = self._pending
self._pub.publish(msg)
def main(args=None):
rclpy.init(args=args)
node = PersonReidNode()
try:
rclpy.spin(node)
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/saltybot_person_reid
[install]
install_scripts=$base/lib/saltybot_person_reid

View File

@ -0,0 +1,29 @@
from setuptools import setup, find_packages
from glob import glob
package_name = 'saltybot_person_reid'
setup(
name=package_name,
version='0.1.0',
packages=find_packages(exclude=['test']),
data_files=[
('share/ament_index/resource_index/packages',
['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
('share/' + package_name + '/config',
glob('config/*.yaml')),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='SaltyLab',
maintainer_email='robot@saltylab.local',
description='Person re-identification — cross-camera appearance matching',
license='MIT',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'person_reid = saltybot_person_reid.person_reid_node:main',
],
},
)

View File

@ -0,0 +1,163 @@
"""
test_person_reid.py Unit tests for person re-ID helpers (no ROS2 required).
Covers:
- _l2_norm helper
- EmbeddingModel (histogram fallback no model file needed)
- ReidGallery cosine-similarity matching
"""
import sys
import os
import time
import numpy as np
import pytest
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from saltybot_person_reid._embedding_model import EmbeddingModel, _l2_norm
from saltybot_person_reid._reid_gallery import ReidGallery
# ── _l2_norm ──────────────────────────────────────────────────────────────────
class TestL2Norm:
def test_unit_vector_unchanged(self):
v = np.array([1.0, 0.0, 0.0], dtype=np.float32)
assert np.allclose(_l2_norm(v), v)
def test_normalised_to_unit_norm(self):
v = np.array([3.0, 4.0], dtype=np.float32)
assert abs(np.linalg.norm(_l2_norm(v)) - 1.0) < 1e-6
def test_zero_vector_does_not_crash(self):
v = np.zeros(4, dtype=np.float32)
result = _l2_norm(v)
assert result.shape == (4,)
# ── EmbeddingModel ────────────────────────────────────────────────────────────
class TestEmbeddingModel:
def test_histogram_fallback_shape(self):
m = EmbeddingModel(model_path=None)
bgr = np.random.randint(0, 255, (100, 50, 3), dtype=np.uint8)
emb = m.embed(bgr)
assert emb.shape == (128,)
def test_embedding_is_unit_norm(self):
m = EmbeddingModel(model_path=None)
bgr = np.random.randint(0, 255, (80, 40, 3), dtype=np.uint8)
emb = m.embed(bgr)
assert abs(np.linalg.norm(emb) - 1.0) < 1e-5
def test_empty_crop_returns_zero_vector(self):
m = EmbeddingModel(model_path=None)
emb = m.embed(np.zeros((0, 0, 3), dtype=np.uint8))
assert emb.shape == (128,)
assert np.all(emb == 0.0)
def test_red_and_blue_crops_differ(self):
m = EmbeddingModel(model_path=None)
red = np.full((80, 40, 3), (0, 0, 200), dtype=np.uint8)
blue = np.full((80, 40, 3), (200, 0, 0), dtype=np.uint8)
sim = float(np.dot(m.embed(red), m.embed(blue)))
assert sim < 0.99
def test_same_crop_deterministic(self):
m = EmbeddingModel(model_path=None)
bgr = np.random.randint(0, 255, (80, 40, 3), dtype=np.uint8)
assert np.allclose(m.embed(bgr), m.embed(bgr))
def test_embedding_float32(self):
m = EmbeddingModel(model_path=None)
bgr = np.random.randint(0, 255, (60, 30, 3), dtype=np.uint8)
emb = m.embed(bgr)
assert emb.dtype == np.float32
# ── ReidGallery ───────────────────────────────────────────────────────────────
def _unit(dim: int = 128, seed: int | None = None) -> np.ndarray:
rng = np.random.default_rng(seed)
v = rng.standard_normal(dim).astype(np.float32)
return v / np.linalg.norm(v)
class TestReidGallery:
def test_first_match_creates_identity(self):
g = ReidGallery(match_threshold=0.75)
uid, score, is_new = g.match(_unit(seed=0))
assert uid == 1
assert is_new is True
assert score == pytest.approx(0.0)
def test_identical_embedding_matches(self):
g = ReidGallery(match_threshold=0.75)
emb = _unit(seed=1)
g.match(emb)
uid2, score2, is_new2 = g.match(emb)
assert uid2 == 1
assert is_new2 is False
assert score2 > 0.99
def test_orthogonal_embeddings_create_new_id(self):
g = ReidGallery(match_threshold=0.75)
e1 = np.zeros(128, dtype=np.float32); e1[0] = 1.0
e2 = np.zeros(128, dtype=np.float32); e2[64] = 1.0
uid1, _, new1 = g.match(e1)
uid2, _, new2 = g.match(e2)
assert uid1 != uid2
assert new2 is True
def test_ids_are_monotonically_increasing(self):
# threshold > 1.0 is unreachable → every embedding creates a new identity
g = ReidGallery(match_threshold=2.0)
ids = [g.match(_unit(seed=i))[0] for i in range(5)]
assert ids == list(range(1, 6))
def test_gallery_size_increments_for_new_ids(self):
g = ReidGallery(match_threshold=2.0)
for i in range(4):
g.match(_unit(seed=i))
assert g.size == 4
def test_prune_removes_stale_identities(self):
g = ReidGallery(match_threshold=0.75, max_age_s=0.01)
g.match(_unit(seed=0))
time.sleep(0.05)
g._prune()
assert g.size == 0
def test_empty_gallery_prune_is_safe(self):
g = ReidGallery()
g._prune()
assert g.size == 0
def test_match_below_threshold_increments_id(self):
g = ReidGallery(match_threshold=0.99)
# Two random unit vectors are almost certainly < 0.99 similar
e1, e2 = _unit(seed=10), _unit(seed=20)
uid1, _, _ = g.match(e1)
uid2, _, _ = g.match(e2)
# uid2 may or may not equal uid1 depending on random similarity,
# but both must be valid positive integers
assert uid1 >= 1
assert uid2 >= 1
def test_identity_update_does_not_change_id(self):
g = ReidGallery(match_threshold=0.5)
emb = _unit(seed=5)
uid_first, _, _ = g.match(emb)
for _ in range(10):
g.match(emb)
uid_last, _, _ = g.match(emb)
assert uid_last == uid_first
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@ -0,0 +1,16 @@
cmake_minimum_required(VERSION 3.8)
project(saltybot_person_reid_msgs)
find_package(ament_cmake REQUIRED)
find_package(rosidl_default_generators REQUIRED)
find_package(std_msgs REQUIRED)
find_package(vision_msgs REQUIRED)
rosidl_generate_interfaces(${PROJECT_NAME}
"msg/PersonAppearance.msg"
"msg/PersonAppearanceArray.msg"
DEPENDENCIES std_msgs vision_msgs
)
ament_export_dependencies(rosidl_default_runtime)
ament_package()

View File

@ -0,0 +1,7 @@
std_msgs/Header header
uint32 track_id
float32[] embedding
vision_msgs/BoundingBox2D bbox
float32 confidence
float32 match_score
bool is_new_identity

View File

@ -0,0 +1,2 @@
std_msgs/Header header
PersonAppearance[] appearances

View File

@ -0,0 +1,22 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_person_reid_msgs</name>
<version>0.1.0</version>
<description>Message types for person re-identification.</description>
<maintainer email="robot@saltylab.local">SaltyLab</maintainer>
<license>MIT</license>
<buildtool_depend>ament_cmake</buildtool_depend>
<buildtool_depend>rosidl_default_generators</buildtool_depend>
<depend>std_msgs</depend>
<depend>vision_msgs</depend>
<exec_depend>rosidl_default_runtime</exec_depend>
<member_of_group>rosidl_interface_packages</member_of_group>
<export>
<build_type>ament_cmake</build_type>
</export>
</package>