Compare commits

...

1 Commits

Author SHA1 Message Date
0612eedbcd feat(social): face detection + recognition (SCRFD + ArcFace TRT FP16, Issue #80)
Add two new ROS2 packages for the social sprint:

saltybot_social_msgs (ament_cmake):
- FaceDetection, FaceDetectionArray, FaceEmbedding, FaceEmbeddingArray
- PersonState, PersonStateArray
- EnrollPerson, ListPersons, DeletePerson, UpdatePerson services

saltybot_social_face (ament_python):
- SCRFDDetector: SCRFD face detection with TRT FP16 + ONNX fallback
  - 640x640 input, 3-stride anchor decoding, NMS
- ArcFaceRecognizer: 512-dim embedding extraction with gallery matching
  - 5-point landmark alignment to 112x112, cosine similarity
- FaceGallery: thread-safe persistent gallery (npz + JSON sidecar)
- FaceRecognitionNode: ROS2 node subscribing /camera/color/image_raw,
  publishing /social/faces/detections, /social/faces/embeddings
- Enrollment via /social/enroll service (N-sample face averaging)
- Launch file, config YAML, TRT engine builder script

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-01 23:11:20 -05:00
25 changed files with 1759 additions and 0 deletions

View File

@ -0,0 +1,11 @@
face_recognizer:
ros__parameters:
scrfd_engine_path: '/mnt/nvme/saltybot/models/scrfd_2.5g.engine'
scrfd_onnx_path: '/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx'
arcface_engine_path: '/mnt/nvme/saltybot/models/arcface_r50.engine'
arcface_onnx_path: '/mnt/nvme/saltybot/models/arcface_r50.onnx'
gallery_dir: '/mnt/nvme/saltybot/gallery'
recognition_threshold: 0.35
publish_debug_image: false
max_faces: 10
scrfd_conf_thresh: 0.5

View File

@ -0,0 +1,80 @@
"""
face_recognition.launch.py -- Launch file for the SCRFD + ArcFace face recognition node.
Launches the face_recognizer node with configurable model paths and parameters.
The RealSense camera must be running separately (e.g., via realsense.launch.py).
"""
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
"""Generate launch description for face recognition pipeline."""
return LaunchDescription([
DeclareLaunchArgument(
'scrfd_engine_path',
default_value='/mnt/nvme/saltybot/models/scrfd_2.5g.engine',
description='Path to SCRFD TensorRT engine file',
),
DeclareLaunchArgument(
'scrfd_onnx_path',
default_value='/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx',
description='Path to SCRFD ONNX model file (fallback)',
),
DeclareLaunchArgument(
'arcface_engine_path',
default_value='/mnt/nvme/saltybot/models/arcface_r50.engine',
description='Path to ArcFace TensorRT engine file',
),
DeclareLaunchArgument(
'arcface_onnx_path',
default_value='/mnt/nvme/saltybot/models/arcface_r50.onnx',
description='Path to ArcFace ONNX model file (fallback)',
),
DeclareLaunchArgument(
'gallery_dir',
default_value='/mnt/nvme/saltybot/gallery',
description='Directory for persistent face gallery storage',
),
DeclareLaunchArgument(
'recognition_threshold',
default_value='0.35',
description='Cosine similarity threshold for face recognition',
),
DeclareLaunchArgument(
'publish_debug_image',
default_value='false',
description='Publish annotated debug image to /social/faces/debug_image',
),
DeclareLaunchArgument(
'max_faces',
default_value='10',
description='Maximum faces to process per frame',
),
DeclareLaunchArgument(
'scrfd_conf_thresh',
default_value='0.5',
description='SCRFD detection confidence threshold',
),
Node(
package='saltybot_social_face',
executable='face_recognition',
name='face_recognizer',
output='screen',
parameters=[{
'scrfd_engine_path': LaunchConfiguration('scrfd_engine_path'),
'scrfd_onnx_path': LaunchConfiguration('scrfd_onnx_path'),
'arcface_engine_path': LaunchConfiguration('arcface_engine_path'),
'arcface_onnx_path': LaunchConfiguration('arcface_onnx_path'),
'gallery_dir': LaunchConfiguration('gallery_dir'),
'recognition_threshold': LaunchConfiguration('recognition_threshold'),
'publish_debug_image': LaunchConfiguration('publish_debug_image'),
'max_faces': LaunchConfiguration('max_faces'),
'scrfd_conf_thresh': LaunchConfiguration('scrfd_conf_thresh'),
}],
),
])

View File

@ -0,0 +1,27 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social_face</name>
<version>0.1.0</version>
<description>SCRFD face detection and ArcFace recognition for SaltyBot social interactions</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>sensor_msgs</depend>
<depend>cv_bridge</depend>
<depend>image_transport</depend>
<depend>saltybot_social_msgs</depend>
<exec_depend>python3-numpy</exec_depend>
<exec_depend>python3-opencv</exec_depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1 @@
"""SaltyBot social face detection and recognition package."""

View File

@ -0,0 +1,316 @@
"""
arcface_recognizer.py -- ArcFace face embedding extraction and gallery matching.
Performs face alignment using 5-point landmarks (insightface standard reference),
extracts 512-dimensional embeddings via ArcFace (TRT FP16 or ONNX fallback),
and matches against a persistent gallery using cosine similarity.
"""
import os
import logging
from typing import Optional
import numpy as np
import cv2
logger = logging.getLogger(__name__)
# InsightFace standard reference landmarks for 112x112 alignment
ARCFACE_SRC = np.array([
[38.2946, 51.6963], # left eye
[73.5318, 51.5014], # right eye
[56.0252, 71.7366], # nose
[41.5493, 92.3655], # left mouth
[70.7299, 92.2041], # right mouth
], dtype=np.float32)
# -- Inference backends --------------------------------------------------------
class _TRTBackend:
"""TensorRT inference engine for ArcFace."""
def __init__(self, engine_path: str):
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit # noqa: F401
self._cuda = cuda
trt_logger = trt.Logger(trt.Logger.WARNING)
with open(engine_path, 'rb') as f, trt.Runtime(trt_logger) as runtime:
self._engine = runtime.deserialize_cuda_engine(f.read())
self._context = self._engine.create_execution_context()
self._inputs = []
self._outputs = []
self._bindings = []
for i in range(self._engine.num_io_tensors):
name = self._engine.get_tensor_name(i)
dtype = trt.nptype(self._engine.get_tensor_dtype(name))
shape = tuple(self._engine.get_tensor_shape(name))
nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize
host_mem = cuda.pagelocked_empty(shape, dtype)
device_mem = cuda.mem_alloc(nbytes)
self._bindings.append(int(device_mem))
if self._engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
self._inputs.append({'host': host_mem, 'device': device_mem})
else:
self._outputs.append({'host': host_mem, 'device': device_mem,
'shape': shape})
self._stream = cuda.Stream()
def infer(self, input_data: np.ndarray) -> np.ndarray:
"""Run inference and return the embedding vector."""
np.copyto(self._inputs[0]['host'], input_data.ravel())
self._cuda.memcpy_htod_async(
self._inputs[0]['device'], self._inputs[0]['host'], self._stream)
self._context.execute_async_v2(self._bindings, self._stream.handle)
for out in self._outputs:
self._cuda.memcpy_dtoh_async(out['host'], out['device'], self._stream)
self._stream.synchronize()
return self._outputs[0]['host'].reshape(self._outputs[0]['shape']).copy()
class _ONNXBackend:
"""ONNX Runtime inference (CUDA EP with CPU fallback)."""
def __init__(self, onnx_path: str):
import onnxruntime as ort
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
self._session = ort.InferenceSession(onnx_path, providers=providers)
self._input_name = self._session.get_inputs()[0].name
def infer(self, input_data: np.ndarray) -> np.ndarray:
"""Run inference and return the embedding vector."""
results = self._session.run(None, {self._input_name: input_data})
return results[0]
# -- Face alignment ------------------------------------------------------------
def align_face(bgr: np.ndarray, landmarks_10: list[float]) -> np.ndarray:
"""Align a face to 112x112 using 5-point landmarks.
Args:
bgr: Source BGR image.
landmarks_10: Flat list of 10 floats [x0,y0, x1,y1, ..., x4,y4].
Returns:
Aligned BGR face crop of shape (112, 112, 3).
"""
src_pts = np.array(landmarks_10, dtype=np.float32).reshape(5, 2)
M, _ = cv2.estimateAffinePartial2D(src_pts, ARCFACE_SRC)
if M is None:
# Fallback: simple crop and resize from bbox-like region
cx = np.mean(src_pts[:, 0])
cy = np.mean(src_pts[:, 1])
spread = max(np.ptp(src_pts[:, 0]), np.ptp(src_pts[:, 1])) * 1.5
half = spread / 2
x1 = max(0, int(cx - half))
y1 = max(0, int(cy - half))
x2 = min(bgr.shape[1], int(cx + half))
y2 = min(bgr.shape[0], int(cy + half))
crop = bgr[y1:y2, x1:x2]
return cv2.resize(crop, (112, 112), interpolation=cv2.INTER_LINEAR)
aligned = cv2.warpAffine(bgr, M, (112, 112), borderMode=cv2.BORDER_REPLICATE)
return aligned
# -- Main recognizer class -----------------------------------------------------
class ArcFaceRecognizer:
"""ArcFace face embedding extractor and gallery matcher.
Args:
engine_path: Path to TensorRT engine file.
onnx_path: Path to ONNX model file (used if engine not available).
"""
def __init__(self, engine_path: str = '', onnx_path: str = ''):
self._backend: Optional[_TRTBackend | _ONNXBackend] = None
self.gallery: dict[int, dict] = {}
# Try TRT first, then ONNX
if engine_path and os.path.isfile(engine_path):
try:
self._backend = _TRTBackend(engine_path)
logger.info('ArcFace TensorRT backend loaded: %s', engine_path)
return
except Exception as e:
logger.warning('ArcFace TRT load failed (%s), trying ONNX', e)
if onnx_path and os.path.isfile(onnx_path):
try:
self._backend = _ONNXBackend(onnx_path)
logger.info('ArcFace ONNX backend loaded: %s', onnx_path)
return
except Exception as e:
logger.error('ArcFace ONNX load failed: %s', e)
logger.error('No ArcFace model loaded. Recognition will be unavailable.')
@property
def is_loaded(self) -> bool:
"""Return True if a backend is loaded and ready."""
return self._backend is not None
def embed(self, bgr_face_112x112: np.ndarray) -> np.ndarray:
"""Extract 512-dim L2-normalized embedding from a 112x112 aligned face.
Args:
bgr_face_112x112: Aligned face crop, BGR, shape (112, 112, 3).
Returns:
L2-normalized embedding of shape (512,).
"""
if self._backend is None:
return np.zeros(512, dtype=np.float32)
# Preprocess: BGR->RGB, /255, subtract 0.5, divide 0.5 -> [1,3,112,112]
rgb = cv2.cvtColor(bgr_face_112x112, cv2.COLOR_BGR2RGB).astype(np.float32)
rgb = rgb / 255.0
rgb = (rgb - 0.5) / 0.5
blob = rgb.transpose(2, 0, 1)[np.newaxis] # [1, 3, 112, 112]
blob = np.ascontiguousarray(blob)
output = self._backend.infer(blob)
embedding = output.flatten()[:512].astype(np.float32)
# L2 normalize
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
return embedding
def align_and_embed(self, bgr_image: np.ndarray, landmarks_10: list[float]) -> np.ndarray:
"""Align face using landmarks and extract embedding.
Args:
bgr_image: Full BGR image.
landmarks_10: Flat list of 10 floats from SCRFD detection.
Returns:
L2-normalized embedding of shape (512,).
"""
aligned = align_face(bgr_image, landmarks_10)
return self.embed(aligned)
def load_gallery(self, gallery_path: str) -> None:
"""Load gallery from .npz file with JSON metadata sidecar.
Args:
gallery_path: Path to the .npz gallery file.
"""
import json
if not os.path.isfile(gallery_path):
logger.info('No gallery file at %s, starting empty.', gallery_path)
self.gallery = {}
return
data = np.load(gallery_path, allow_pickle=False)
meta_path = gallery_path.replace('.npz', '_meta.json')
if os.path.isfile(meta_path):
with open(meta_path, 'r') as f:
meta = json.load(f)
else:
meta = {}
self.gallery = {}
for key in data.files:
pid = int(key)
embedding = data[key].astype(np.float32)
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
info = meta.get(str(pid), {})
self.gallery[pid] = {
'name': info.get('name', f'person_{pid}'),
'embedding': embedding,
'samples': info.get('samples', 1),
'enrolled_at': info.get('enrolled_at', 0.0),
}
logger.info('Gallery loaded: %d persons from %s', len(self.gallery), gallery_path)
def save_gallery(self, gallery_path: str) -> None:
"""Save gallery to .npz file with JSON metadata sidecar.
Args:
gallery_path: Path to the .npz gallery file.
"""
import json
arrays = {}
meta = {}
for pid, info in self.gallery.items():
arrays[str(pid)] = info['embedding']
meta[str(pid)] = {
'name': info['name'],
'samples': info['samples'],
'enrolled_at': info['enrolled_at'],
}
os.makedirs(os.path.dirname(gallery_path) or '.', exist_ok=True)
np.savez(gallery_path, **arrays)
meta_path = gallery_path.replace('.npz', '_meta.json')
with open(meta_path, 'w') as f:
json.dump(meta, f, indent=2)
logger.info('Gallery saved: %d persons to %s', len(self.gallery), gallery_path)
def match(self, embedding: np.ndarray, threshold: float = 0.35) -> tuple[int, str, float]:
"""Match an embedding against the gallery.
Args:
embedding: L2-normalized query embedding of shape (512,).
threshold: Minimum cosine similarity for a match.
Returns:
(person_id, person_name, score) or (-1, '', 0.0) if no match.
"""
if not self.gallery:
return (-1, '', 0.0)
best_pid = -1
best_name = ''
best_score = 0.0
for pid, info in self.gallery.items():
score = float(np.dot(embedding, info['embedding']))
if score > best_score:
best_score = score
best_pid = pid
best_name = info['name']
if best_score >= threshold:
return (best_pid, best_name, best_score)
return (-1, '', 0.0)
def enroll(self, person_id: int, person_name: str, embeddings_list: list[np.ndarray]) -> None:
"""Enroll a person by averaging multiple embeddings.
Args:
person_id: Unique integer ID for this person.
person_name: Human-readable name.
embeddings_list: List of L2-normalized embeddings to average.
"""
import time as _time
if not embeddings_list:
return
mean_emb = np.mean(embeddings_list, axis=0).astype(np.float32)
norm = np.linalg.norm(mean_emb)
if norm > 0:
mean_emb = mean_emb / norm
self.gallery[person_id] = {
'name': person_name,
'embedding': mean_emb,
'samples': len(embeddings_list),
'enrolled_at': _time.time(),
}
logger.info('Enrolled person %d (%s) with %d samples.',
person_id, person_name, len(embeddings_list))

View File

@ -0,0 +1,78 @@
#!/usr/bin/env python3
"""
enrollment_cli.py -- CLI tool for enrolling persons via the /social/enroll service.
Usage:
ros2 run saltybot_social_face enrollment_cli -- --name Alice --mode face --samples 10
"""
import argparse
import sys
import rclpy
from rclpy.node import Node
from saltybot_social_msgs.srv import EnrollPerson
class EnrollmentCLI(Node):
"""Simple CLI node that calls the EnrollPerson service."""
def __init__(self, name: str, mode: str, n_samples: int):
super().__init__('enrollment_cli')
self._client = self.create_client(EnrollPerson, '/social/enroll')
self.get_logger().info('Waiting for /social/enroll service...')
if not self._client.wait_for_service(timeout_sec=10.0):
self.get_logger().error('Service /social/enroll not available.')
return
request = EnrollPerson.Request()
request.name = name
request.mode = mode
request.n_samples = n_samples
self.get_logger().info(
'Enrolling "%s" (mode=%s, samples=%d)...', name, mode, n_samples)
future = self._client.call_async(request)
rclpy.spin_until_future_complete(self, future, timeout_sec=120.0)
if future.result() is not None:
result = future.result()
if result.success:
self.get_logger().info(
'Enrollment successful: person_id=%d, %s',
result.person_id, result.message)
else:
self.get_logger().error(
'Enrollment failed: %s', result.message)
else:
self.get_logger().error('Enrollment service call timed out or failed.')
def main(args=None):
"""Entry point for enrollment CLI."""
parser = argparse.ArgumentParser(description='Enroll a person for face recognition.')
parser.add_argument('--name', type=str, required=True,
help='Name of the person to enroll.')
parser.add_argument('--mode', type=str, default='face',
choices=['face', 'voice', 'both'],
help='Enrollment mode (default: face).')
parser.add_argument('--samples', type=int, default=10,
help='Number of face samples to collect (default: 10).')
# Parse only known args so ROS2 remapping args pass through
parsed, remaining = parser.parse_known_args(args=sys.argv[1:])
rclpy.init(args=remaining)
node = EnrollmentCLI(parsed.name, parsed.mode, parsed.samples)
try:
pass # Node does all work in __init__
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,206 @@
"""
face_gallery.py -- Persistent face embedding gallery backed by numpy .npz + JSON.
Thread-safe gallery storage for face recognition. Embeddings are stored in a
.npz file, with a sidecar metadata.json containing names, sample counts, and
enrollment timestamps. Auto-increment IDs start at 1.
"""
import json
import logging
import os
import threading
import time
from typing import Optional
import numpy as np
logger = logging.getLogger(__name__)
class FaceGallery:
"""Persistent, thread-safe face embedding gallery.
Args:
gallery_dir: Directory for gallery.npz and metadata.json files.
"""
def __init__(self, gallery_dir: str):
self._gallery_dir = gallery_dir
self._npz_path = os.path.join(gallery_dir, 'gallery.npz')
self._meta_path = os.path.join(gallery_dir, 'metadata.json')
self._gallery: dict[int, dict] = {}
self._next_id = 1
self._lock = threading.Lock()
def load(self) -> None:
"""Load gallery from disk. Populates internal gallery dict."""
with self._lock:
self._gallery = {}
self._next_id = 1
if not os.path.isfile(self._npz_path):
logger.info('No gallery file at %s, starting empty.', self._npz_path)
return
data = np.load(self._npz_path, allow_pickle=False)
meta: dict = {}
if os.path.isfile(self._meta_path):
with open(self._meta_path, 'r') as f:
meta = json.load(f)
for key in data.files:
pid = int(key)
embedding = data[key].astype(np.float32)
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
info = meta.get(str(pid), {})
self._gallery[pid] = {
'name': info.get('name', f'person_{pid}'),
'embedding': embedding,
'samples': info.get('samples', 1),
'enrolled_at': info.get('enrolled_at', 0.0),
}
if pid >= self._next_id:
self._next_id = pid + 1
logger.info('Gallery loaded: %d persons from %s',
len(self._gallery), self._npz_path)
def save(self) -> None:
"""Save gallery to disk (npz + JSON sidecar)."""
with self._lock:
os.makedirs(self._gallery_dir, exist_ok=True)
arrays = {}
meta = {}
for pid, info in self._gallery.items():
arrays[str(pid)] = info['embedding']
meta[str(pid)] = {
'name': info['name'],
'samples': info['samples'],
'enrolled_at': info['enrolled_at'],
}
np.savez(self._npz_path, **arrays)
with open(self._meta_path, 'w') as f:
json.dump(meta, f, indent=2)
logger.info('Gallery saved: %d persons to %s',
len(self._gallery), self._npz_path)
def add_person(self, name: str, embedding: np.ndarray, samples: int = 1) -> int:
"""Add a new person to the gallery.
Args:
name: Person's name.
embedding: L2-normalized 512-dim embedding.
samples: Number of samples used to compute the embedding.
Returns:
Assigned person_id (auto-increment integer).
"""
with self._lock:
pid = self._next_id
self._next_id += 1
emb = embedding.astype(np.float32)
norm = np.linalg.norm(emb)
if norm > 0:
emb = emb / norm
self._gallery[pid] = {
'name': name,
'embedding': emb,
'samples': samples,
'enrolled_at': time.time(),
}
logger.info('Added person %d (%s), %d samples.', pid, name, samples)
return pid
def update_name(self, person_id: int, new_name: str) -> bool:
"""Update a person's name.
Args:
person_id: The ID of the person to update.
new_name: New name string.
Returns:
True if the person was found and updated.
"""
with self._lock:
if person_id not in self._gallery:
return False
self._gallery[person_id]['name'] = new_name
return True
def delete_person(self, person_id: int) -> bool:
"""Remove a person from the gallery.
Args:
person_id: The ID of the person to delete.
Returns:
True if the person was found and removed.
"""
with self._lock:
if person_id not in self._gallery:
return False
del self._gallery[person_id]
logger.info('Deleted person %d.', person_id)
return True
def get_all(self) -> list[dict]:
"""Get all gallery entries.
Returns:
List of dicts with keys: person_id, name, embedding, samples, enrolled_at.
"""
with self._lock:
result = []
for pid, info in self._gallery.items():
result.append({
'person_id': pid,
'name': info['name'],
'embedding': info['embedding'].copy(),
'samples': info['samples'],
'enrolled_at': info['enrolled_at'],
})
return result
def match(self, query_embedding: np.ndarray, threshold: float = 0.35) -> tuple[int, str, float]:
"""Match a query embedding against the gallery using cosine similarity.
Args:
query_embedding: L2-normalized 512-dim embedding.
threshold: Minimum cosine similarity for a match.
Returns:
(person_id, name, score) or (-1, '', 0.0) if no match.
"""
with self._lock:
if not self._gallery:
return (-1, '', 0.0)
best_pid = -1
best_name = ''
best_score = 0.0
query = query_embedding.astype(np.float32)
norm = np.linalg.norm(query)
if norm > 0:
query = query / norm
for pid, info in self._gallery.items():
score = float(np.dot(query, info['embedding']))
if score > best_score:
best_score = score
best_pid = pid
best_name = info['name']
if best_score >= threshold:
return (best_pid, best_name, best_score)
return (-1, '', 0.0)
def __len__(self) -> int:
with self._lock:
return len(self._gallery)

View File

@ -0,0 +1,431 @@
"""
face_recognition_node.py -- ROS2 node for SCRFD face detection + ArcFace recognition.
Pipeline:
1. Subscribe to /camera/color/image_raw (RealSense D435i color stream).
2. Run SCRFD face detection (TensorRT FP16 or ONNX fallback).
3. For each detected face, align and extract ArcFace embedding.
4. Match embedding against persistent gallery.
5. Publish FaceDetectionArray with identified faces.
Services:
/social/enroll -- Enroll a new person (collects N face samples).
/social/persons/list -- List all enrolled persons.
/social/persons/delete -- Delete a person from the gallery.
/social/persons/update -- Update a person's name.
"""
import time
import numpy as np
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy, DurabilityPolicy
import cv2
from cv_bridge import CvBridge
from sensor_msgs.msg import Image
from builtin_interfaces.msg import Time
from saltybot_social_msgs.msg import (
FaceDetection,
FaceDetectionArray,
FaceEmbedding,
FaceEmbeddingArray,
)
from saltybot_social_msgs.srv import (
EnrollPerson,
ListPersons,
DeletePerson,
UpdatePerson,
)
from .scrfd_detector import SCRFDDetector
from .arcface_recognizer import ArcFaceRecognizer
from .face_gallery import FaceGallery
class FaceRecognitionNode(Node):
"""ROS2 node: SCRFD face detection + ArcFace gallery matching."""
def __init__(self):
super().__init__('face_recognizer')
self._bridge = CvBridge()
self._frame_count = 0
self._fps_t0 = time.monotonic()
# -- Parameters --------------------------------------------------------
self.declare_parameter('scrfd_engine_path',
'/mnt/nvme/saltybot/models/scrfd_2.5g.engine')
self.declare_parameter('scrfd_onnx_path',
'/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx')
self.declare_parameter('arcface_engine_path',
'/mnt/nvme/saltybot/models/arcface_r50.engine')
self.declare_parameter('arcface_onnx_path',
'/mnt/nvme/saltybot/models/arcface_r50.onnx')
self.declare_parameter('gallery_dir', '/mnt/nvme/saltybot/gallery')
self.declare_parameter('recognition_threshold', 0.35)
self.declare_parameter('publish_debug_image', False)
self.declare_parameter('max_faces', 10)
self.declare_parameter('scrfd_conf_thresh', 0.5)
self._recognition_threshold = self.get_parameter('recognition_threshold').value
self._pub_debug_flag = self.get_parameter('publish_debug_image').value
self._max_faces = self.get_parameter('max_faces').value
# -- Models ------------------------------------------------------------
self._detector = SCRFDDetector(
engine_path=self.get_parameter('scrfd_engine_path').value,
onnx_path=self.get_parameter('scrfd_onnx_path').value,
conf_thresh=self.get_parameter('scrfd_conf_thresh').value,
)
self._recognizer = ArcFaceRecognizer(
engine_path=self.get_parameter('arcface_engine_path').value,
onnx_path=self.get_parameter('arcface_onnx_path').value,
)
# -- Gallery -----------------------------------------------------------
gallery_dir = self.get_parameter('gallery_dir').value
self._gallery = FaceGallery(gallery_dir)
self._gallery.load()
self.get_logger().info('Gallery loaded: %d persons.', len(self._gallery))
# -- Enrollment state --------------------------------------------------
self._enrolling = None # {name, samples_needed, collected: [embeddings]}
# -- QoS profiles ------------------------------------------------------
best_effort_qos = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=1,
)
reliable_qos = QoSProfile(
reliability=ReliabilityPolicy.RELIABLE,
durability=DurabilityPolicy.TRANSIENT_LOCAL,
history=HistoryPolicy.KEEP_LAST,
depth=1,
)
# -- Subscribers -------------------------------------------------------
self.create_subscription(
Image,
'/camera/color/image_raw',
self._on_image,
best_effort_qos,
)
# -- Publishers --------------------------------------------------------
self._pub_detections = self.create_publisher(
FaceDetectionArray, '/social/faces/detections', best_effort_qos)
self._pub_embeddings = self.create_publisher(
FaceEmbeddingArray, '/social/faces/embeddings', reliable_qos)
if self._pub_debug_flag:
self._pub_debug_img = self.create_publisher(
Image, '/social/faces/debug_image', best_effort_qos)
# -- Services ----------------------------------------------------------
self.create_service(EnrollPerson, '/social/enroll', self._handle_enroll)
self.create_service(ListPersons, '/social/persons/list', self._handle_list)
self.create_service(DeletePerson, '/social/persons/delete', self._handle_delete)
self.create_service(UpdatePerson, '/social/persons/update', self._handle_update)
# Publish initial gallery state
self._publish_gallery_embeddings()
self.get_logger().info('FaceRecognitionNode ready.')
# -- Image callback --------------------------------------------------------
def _on_image(self, msg: Image):
"""Process incoming camera frame: detect, recognize, publish."""
try:
bgr = self._bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')
except Exception as e:
self.get_logger().error('Image decode error: %s', str(e),
throttle_duration_sec=5.0)
return
# Detect faces
detections = self._detector.detect(bgr)
# Limit face count
if len(detections) > self._max_faces:
detections = sorted(detections, key=lambda d: d['score'], reverse=True)
detections = detections[:self._max_faces]
# Build output message
det_array = FaceDetectionArray()
det_array.header = msg.header
for det in detections:
# Extract embedding and match gallery
embedding = self._recognizer.align_and_embed(bgr, det['kps'])
pid, pname, score = self._gallery.match(
embedding, self._recognition_threshold)
# Handle enrollment: collect embedding from largest face
if self._enrolling is not None:
self._enrollment_collect(det, embedding)
# Build FaceDetection message
face_msg = FaceDetection()
face_msg.header = msg.header
face_msg.face_id = pid
face_msg.person_name = pname
face_msg.confidence = det['score']
face_msg.recognition_score = score
bbox = det['bbox']
face_msg.bbox_x = bbox[0]
face_msg.bbox_y = bbox[1]
face_msg.bbox_w = bbox[2] - bbox[0]
face_msg.bbox_h = bbox[3] - bbox[1]
kps = det['kps']
for i in range(10):
face_msg.landmarks[i] = kps[i]
det_array.faces.append(face_msg)
self._pub_detections.publish(det_array)
# Debug image
if self._pub_debug_flag and hasattr(self, '_pub_debug_img'):
debug_img = self._draw_debug(bgr, detections, det_array.faces)
self._pub_debug_img.publish(
self._bridge.cv2_to_imgmsg(debug_img, encoding='bgr8'))
# FPS logging
self._frame_count += 1
if self._frame_count % 30 == 0:
elapsed = time.monotonic() - self._fps_t0
fps = 30.0 / elapsed if elapsed > 0 else 0.0
self._fps_t0 = time.monotonic()
self.get_logger().info(
'FPS: %.1f | faces: %d', fps, len(detections))
# -- Enrollment logic ------------------------------------------------------
def _enrollment_collect(self, det: dict, embedding: np.ndarray):
"""Collect an embedding sample during enrollment (largest face only)."""
if self._enrolling is None:
return
# Only collect from the largest face (by bbox area)
bbox = det['bbox']
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
if not hasattr(self, '_enroll_best_area'):
self._enroll_best_area = 0.0
self._enroll_best_embedding = None
if area > self._enroll_best_area:
self._enroll_best_area = area
self._enroll_best_embedding = embedding
def _enrollment_frame_end(self):
"""Called at end of each frame to finalize enrollment sample collection."""
if self._enrolling is None or self._enroll_best_embedding is None:
return
self._enrolling['collected'].append(self._enroll_best_embedding)
self._enroll_best_area = 0.0
self._enroll_best_embedding = None
collected = len(self._enrolling['collected'])
needed = self._enrolling['samples_needed']
self.get_logger().info('Enrollment: %d/%d samples for "%s".',
collected, needed, self._enrolling['name'])
if collected >= needed:
# Finalize enrollment
name = self._enrolling['name']
embeddings = self._enrolling['collected']
mean_emb = np.mean(embeddings, axis=0).astype(np.float32)
norm = np.linalg.norm(mean_emb)
if norm > 0:
mean_emb = mean_emb / norm
pid = self._gallery.add_person(name, mean_emb, samples=len(embeddings))
self._gallery.save()
self._publish_gallery_embeddings()
self.get_logger().info('Enrollment complete: person %d (%s).', pid, name)
# Store result for the service callback
self._enrolling['result_pid'] = pid
self._enrolling['done'] = True
self._enrolling = None
# -- Service handlers ------------------------------------------------------
def _handle_enroll(self, request, response):
"""Handle EnrollPerson service: start collecting face samples."""
name = request.name.strip()
if not name:
response.success = False
response.message = 'Name cannot be empty.'
response.person_id = -1
return response
n_samples = request.n_samples if request.n_samples > 0 else 10
self.get_logger().info('Starting enrollment for "%s" (%d samples).',
name, n_samples)
# Set enrollment state — frames will collect embeddings
self._enrolling = {
'name': name,
'samples_needed': n_samples,
'collected': [],
'done': False,
'result_pid': -1,
}
self._enroll_best_area = 0.0
self._enroll_best_embedding = None
# Spin until enrollment is done (blocking service)
rate = self.create_rate(10) # 10 Hz check
timeout_sec = n_samples * 2.0 + 10.0 # generous timeout
t0 = time.monotonic()
while not self._enrolling.get('done', False):
# Finalize any pending frame collection
self._enrollment_frame_end()
if time.monotonic() - t0 > timeout_sec:
self._enrolling = None
response.success = False
response.message = f'Enrollment timed out after {timeout_sec:.0f}s.'
response.person_id = -1
return response
rclpy.spin_once(self, timeout_sec=0.1)
response.success = True
response.message = f'Enrolled "{name}" with {n_samples} samples.'
response.person_id = self._enrolling.get('result_pid', -1) if self._enrolling else -1
# Clean up (already set to None in _enrollment_frame_end on success)
return response
def _handle_list(self, request, response):
"""Handle ListPersons service: return all gallery entries."""
entries = self._gallery.get_all()
for entry in entries:
emb_msg = FaceEmbedding()
emb_msg.person_id = entry['person_id']
emb_msg.person_name = entry['name']
emb_msg.embedding = entry['embedding'].tolist()
emb_msg.sample_count = entry['samples']
secs = int(entry['enrolled_at'])
nsecs = int((entry['enrolled_at'] - secs) * 1e9)
emb_msg.enrolled_at = Time(sec=secs, nanosec=nsecs)
response.persons.append(emb_msg)
return response
def _handle_delete(self, request, response):
"""Handle DeletePerson service: remove a person from the gallery."""
if self._gallery.delete_person(request.person_id):
self._gallery.save()
self._publish_gallery_embeddings()
response.success = True
response.message = f'Deleted person {request.person_id}.'
else:
response.success = False
response.message = f'Person {request.person_id} not found.'
return response
def _handle_update(self, request, response):
"""Handle UpdatePerson service: rename a person."""
new_name = request.new_name.strip()
if not new_name:
response.success = False
response.message = 'New name cannot be empty.'
return response
if self._gallery.update_name(request.person_id, new_name):
self._gallery.save()
self._publish_gallery_embeddings()
response.success = True
response.message = f'Updated person {request.person_id} to "{new_name}".'
else:
response.success = False
response.message = f'Person {request.person_id} not found.'
return response
# -- Gallery publishing ----------------------------------------------------
def _publish_gallery_embeddings(self):
"""Publish current gallery as FaceEmbeddingArray (latched-like)."""
entries = self._gallery.get_all()
msg = FaceEmbeddingArray()
msg.header.stamp = self.get_clock().now().to_msg()
for entry in entries:
emb_msg = FaceEmbedding()
emb_msg.person_id = entry['person_id']
emb_msg.person_name = entry['name']
emb_msg.embedding = entry['embedding'].tolist()
emb_msg.sample_count = entry['samples']
secs = int(entry['enrolled_at'])
nsecs = int((entry['enrolled_at'] - secs) * 1e9)
emb_msg.enrolled_at = Time(sec=secs, nanosec=nsecs)
msg.embeddings.append(emb_msg)
self._pub_embeddings.publish(msg)
# -- Debug image -----------------------------------------------------------
def _draw_debug(self, bgr: np.ndarray, detections: list[dict],
face_msgs: list) -> np.ndarray:
"""Draw bounding boxes, landmarks, and names on the image."""
vis = bgr.copy()
for det, face_msg in zip(detections, face_msgs):
bbox = det['bbox']
x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
# Color: green if recognized, yellow if unknown
if face_msg.face_id >= 0:
color = (0, 255, 0)
label = f'{face_msg.person_name} ({face_msg.recognition_score:.2f})'
else:
color = (0, 255, 255)
label = f'unknown ({face_msg.confidence:.2f})'
cv2.rectangle(vis, (x1, y1), (x2, y2), color, 2)
cv2.putText(vis, label, (x1, y1 - 8),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
# Draw landmarks
kps = det['kps']
for k in range(5):
px, py = int(kps[k * 2]), int(kps[k * 2 + 1])
cv2.circle(vis, (px, py), 2, (0, 0, 255), -1)
return vis
# -- Entry point ---------------------------------------------------------------
def main(args=None):
"""ROS2 entry point for face_recognition node."""
rclpy.init(args=args)
node = FaceRecognitionNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,350 @@
"""
scrfd_detector.py -- SCRFD face detection with TensorRT FP16 + ONNX fallback.
SCRFD (Sample and Computation Redistribution for Face Detection) produces
9 output tensors across 3 strides (8, 16, 32), each with score, bbox, and
keypoint branches. This module handles anchor generation, bbox/keypoint
decoding, and NMS to produce final face detections.
"""
import os
import logging
from typing import Optional
import numpy as np
import cv2
logger = logging.getLogger(__name__)
_STRIDES = [8, 16, 32]
_NUM_ANCHORS = 2 # anchors per cell per stride
# -- Inference backends --------------------------------------------------------
class _TRTBackend:
"""TensorRT inference engine for SCRFD."""
def __init__(self, engine_path: str):
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit # noqa: F401
self._cuda = cuda
trt_logger = trt.Logger(trt.Logger.WARNING)
with open(engine_path, 'rb') as f, trt.Runtime(trt_logger) as runtime:
self._engine = runtime.deserialize_cuda_engine(f.read())
self._context = self._engine.create_execution_context()
self._inputs = []
self._outputs = []
self._output_names = []
self._bindings = []
for i in range(self._engine.num_io_tensors):
name = self._engine.get_tensor_name(i)
dtype = trt.nptype(self._engine.get_tensor_dtype(name))
shape = tuple(self._engine.get_tensor_shape(name))
nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize
host_mem = cuda.pagelocked_empty(shape, dtype)
device_mem = cuda.mem_alloc(nbytes)
self._bindings.append(int(device_mem))
if self._engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
self._inputs.append({'host': host_mem, 'device': device_mem})
else:
self._outputs.append({'host': host_mem, 'device': device_mem,
'shape': shape})
self._output_names.append(name)
self._stream = cuda.Stream()
def infer(self, input_data: np.ndarray) -> list[np.ndarray]:
"""Run inference and return output tensors."""
np.copyto(self._inputs[0]['host'], input_data.ravel())
self._cuda.memcpy_htod_async(
self._inputs[0]['device'], self._inputs[0]['host'], self._stream)
self._context.execute_async_v2(self._bindings, self._stream.handle)
results = []
for out in self._outputs:
self._cuda.memcpy_dtoh_async(out['host'], out['device'], self._stream)
self._stream.synchronize()
for out in self._outputs:
results.append(out['host'].reshape(out['shape']).copy())
return results
class _ONNXBackend:
"""ONNX Runtime inference (CUDA EP with CPU fallback)."""
def __init__(self, onnx_path: str):
import onnxruntime as ort
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
self._session = ort.InferenceSession(onnx_path, providers=providers)
self._input_name = self._session.get_inputs()[0].name
self._output_names = [o.name for o in self._session.get_outputs()]
def infer(self, input_data: np.ndarray) -> list[np.ndarray]:
"""Run inference and return output tensors."""
return self._session.run(None, {self._input_name: input_data})
# -- NMS ----------------------------------------------------------------------
def _nms(boxes: np.ndarray, scores: np.ndarray, iou_thresh: float) -> list[int]:
"""Non-maximum suppression. boxes: [N, 4] as x1,y1,x2,y2."""
if len(boxes) == 0:
return []
x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
areas = (x2 - x1) * (y2 - y1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(int(i))
if order.size == 1:
break
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
inter = np.maximum(0.0, xx2 - xx1) * np.maximum(0.0, yy2 - yy1)
iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-6)
remaining = np.where(iou <= iou_thresh)[0]
order = order[remaining + 1]
return keep
# -- Anchor generation ---------------------------------------------------------
def _generate_anchors(input_h: int, input_w: int) -> dict[int, np.ndarray]:
"""Generate anchor centers for each stride.
Returns dict mapping stride -> array of shape [H*W*num_anchors, 2],
where each row is (cx, cy) in input pixel coordinates.
"""
anchors = {}
for stride in _STRIDES:
feat_h = input_h // stride
feat_w = input_w // stride
grid_y, grid_x = np.mgrid[:feat_h, :feat_w]
centers = np.stack([grid_x.ravel(), grid_y.ravel()], axis=1).astype(np.float32)
centers = (centers + 0.5) * stride
# Repeat for num_anchors per cell
centers = np.tile(centers, (_NUM_ANCHORS, 1)) # [H*W*2, 2]
# Interleave properly: [anchor0_cell0, anchor1_cell0, anchor0_cell1, ...]
centers = np.repeat(
np.stack([grid_x.ravel(), grid_y.ravel()], axis=1).astype(np.float32),
_NUM_ANCHORS, axis=0
)
centers = (centers + 0.5) * stride
anchors[stride] = centers
return anchors
# -- Main detector class -------------------------------------------------------
class SCRFDDetector:
"""SCRFD face detector with TensorRT FP16 and ONNX fallback.
Args:
engine_path: Path to TensorRT engine file.
onnx_path: Path to ONNX model file (used if engine not available).
conf_thresh: Minimum confidence for detections.
nms_iou: IoU threshold for NMS.
input_size: Model input resolution (square).
"""
def __init__(
self,
engine_path: str = '',
onnx_path: str = '',
conf_thresh: float = 0.5,
nms_iou: float = 0.4,
input_size: int = 640,
):
self._conf_thresh = conf_thresh
self._nms_iou = nms_iou
self._input_size = input_size
self._backend: Optional[_TRTBackend | _ONNXBackend] = None
self._anchors = _generate_anchors(input_size, input_size)
# Try TRT first, then ONNX
if engine_path and os.path.isfile(engine_path):
try:
self._backend = _TRTBackend(engine_path)
logger.info('SCRFD TensorRT backend loaded: %s', engine_path)
return
except Exception as e:
logger.warning('SCRFD TRT load failed (%s), trying ONNX', e)
if onnx_path and os.path.isfile(onnx_path):
try:
self._backend = _ONNXBackend(onnx_path)
logger.info('SCRFD ONNX backend loaded: %s', onnx_path)
return
except Exception as e:
logger.error('SCRFD ONNX load failed: %s', e)
logger.error('No SCRFD model loaded. Detection will be unavailable.')
@property
def is_loaded(self) -> bool:
"""Return True if a backend is loaded and ready."""
return self._backend is not None
def detect(self, bgr: np.ndarray) -> list[dict]:
"""Detect faces in a BGR image.
Args:
bgr: Input image in BGR format, shape (H, W, 3).
Returns:
List of dicts with keys:
bbox: [x1, y1, x2, y2] in original image coordinates
kps: [x0,y0, x1,y1, ..., x4,y4] 10 floats, 5 landmarks
score: detection confidence
"""
if self._backend is None:
return []
orig_h, orig_w = bgr.shape[:2]
tensor, scale, pad_w, pad_h = self._preprocess(bgr)
outputs = self._backend.infer(tensor)
detections = self._decode_outputs(outputs)
detections = self._rescale(detections, scale, pad_w, pad_h, orig_w, orig_h)
return detections
def _preprocess(self, bgr: np.ndarray) -> tuple[np.ndarray, float, int, int]:
"""Resize to input_size x input_size with letterbox padding, normalize."""
h, w = bgr.shape[:2]
size = self._input_size
scale = min(size / h, size / w)
new_w, new_h = int(w * scale), int(h * scale)
pad_w = (size - new_w) // 2
pad_h = (size - new_h) // 2
resized = cv2.resize(bgr, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
canvas = np.full((size, size, 3), 0, dtype=np.uint8)
canvas[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = resized
# Normalize: subtract 127.5, divide 128.0
blob = canvas.astype(np.float32)
blob = (blob - 127.5) / 128.0
# HWC -> NCHW
blob = blob.transpose(2, 0, 1)[np.newaxis]
blob = np.ascontiguousarray(blob)
return blob, scale, pad_w, pad_h
def _decode_outputs(self, outputs: list[np.ndarray]) -> list[dict]:
"""Decode SCRFD 9-output format into face detections.
SCRFD outputs 9 tensors, 3 per stride (score, bbox, kps):
score_8, bbox_8, kps_8, score_16, bbox_16, kps_16, score_32, bbox_32, kps_32
"""
all_scores = []
all_bboxes = []
all_kps = []
for idx, stride in enumerate(_STRIDES):
score_out = outputs[idx * 3].squeeze() # [H*W*num_anchors]
bbox_out = outputs[idx * 3 + 1].squeeze() # [H*W*num_anchors, 4]
kps_out = outputs[idx * 3 + 2].squeeze() # [H*W*num_anchors, 10]
if score_out.ndim == 0:
continue
# Ensure proper shapes
if score_out.ndim == 1:
n = score_out.shape[0]
else:
n = score_out.shape[0]
score_out = score_out.ravel()
if bbox_out.ndim == 1:
bbox_out = bbox_out.reshape(-1, 4)
if kps_out.ndim == 1:
kps_out = kps_out.reshape(-1, 10)
# Filter by confidence
mask = score_out > self._conf_thresh
if not mask.any():
continue
scores = score_out[mask]
bboxes = bbox_out[mask]
kps = kps_out[mask]
anchors = self._anchors[stride]
# Trim or pad anchors to match output count
if anchors.shape[0] > n:
anchors = anchors[:n]
elif anchors.shape[0] < n:
continue
anchors = anchors[mask]
# Decode bboxes: center = anchor + pred[:2]*stride, size = exp(pred[2:])*stride
cx = anchors[:, 0] + bboxes[:, 0] * stride
cy = anchors[:, 1] + bboxes[:, 1] * stride
w = np.exp(bboxes[:, 2]) * stride
h = np.exp(bboxes[:, 3]) * stride
x1 = cx - w / 2.0
y1 = cy - h / 2.0
x2 = cx + w / 2.0
y2 = cy + h / 2.0
decoded_bboxes = np.stack([x1, y1, x2, y2], axis=1)
# Decode keypoints: kp = anchor + pred * stride
decoded_kps = np.zeros_like(kps)
for k in range(5):
decoded_kps[:, k * 2] = anchors[:, 0] + kps[:, k * 2] * stride
decoded_kps[:, k * 2 + 1] = anchors[:, 1] + kps[:, k * 2 + 1] * stride
all_scores.append(scores)
all_bboxes.append(decoded_bboxes)
all_kps.append(decoded_kps)
if not all_scores:
return []
scores = np.concatenate(all_scores)
bboxes = np.concatenate(all_bboxes)
kps = np.concatenate(all_kps)
# NMS
keep = _nms(bboxes, scores, self._nms_iou)
results = []
for i in keep:
results.append({
'bbox': bboxes[i].tolist(),
'kps': kps[i].tolist(),
'score': float(scores[i]),
})
return results
def _rescale(
self,
detections: list[dict],
scale: float,
pad_w: int,
pad_h: int,
orig_w: int,
orig_h: int,
) -> list[dict]:
"""Rescale detections from model input space to original image space."""
for det in detections:
bbox = det['bbox']
bbox[0] = max(0.0, (bbox[0] - pad_w) / scale)
bbox[1] = max(0.0, (bbox[1] - pad_h) / scale)
bbox[2] = min(float(orig_w), (bbox[2] - pad_w) / scale)
bbox[3] = min(float(orig_h), (bbox[3] - pad_h) / scale)
det['bbox'] = bbox
kps = det['kps']
for k in range(5):
kps[k * 2] = (kps[k * 2] - pad_w) / scale
kps[k * 2 + 1] = (kps[k * 2 + 1] - pad_h) / scale
det['kps'] = kps
return detections

View File

@ -0,0 +1,112 @@
#!/usr/bin/env python3
"""
build_face_trt_engines.py -- Build TensorRT FP16 engines for SCRFD and ArcFace.
Converts ONNX model files to optimized TensorRT engines with FP16 precision
for fast inference on Jetson Orin Nano Super.
Usage:
python3 build_face_trt_engines.py \
--scrfd-onnx /path/to/scrfd_2.5g_bnkps.onnx \
--arcface-onnx /path/to/arcface_r50.onnx \
--output-dir /mnt/nvme/saltybot/models \
--fp16 --workspace-mb 1024
"""
import argparse
import os
import time
def build_engine(onnx_path: str, engine_path: str, fp16: bool, workspace_mb: int):
"""Build a TensorRT engine from an ONNX model.
Args:
onnx_path: Path to the source ONNX model file.
engine_path: Output path for the serialized TensorRT engine.
fp16: Enable FP16 precision.
workspace_mb: Maximum workspace size in megabytes.
"""
import tensorrt as trt
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
print(f'Parsing ONNX model: {onnx_path}')
t0 = time.monotonic()
with open(onnx_path, 'rb') as f:
if not parser.parse(f.read()):
for i in range(parser.num_errors):
print(f' ONNX parse error: {parser.get_error(i)}')
raise RuntimeError(f'Failed to parse {onnx_path}')
parse_time = time.monotonic() - t0
print(f' Parsed in {parse_time:.1f}s')
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE,
workspace_mb * (1 << 20))
if fp16:
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
print(' FP16 enabled.')
else:
print(' Warning: FP16 not supported on this platform, using FP32.')
print(f'Building engine (this may take several minutes)...')
t0 = time.monotonic()
serialized = builder.build_serialized_network(network, config)
build_time = time.monotonic() - t0
if serialized is None:
raise RuntimeError('Engine build failed.')
os.makedirs(os.path.dirname(engine_path) or '.', exist_ok=True)
with open(engine_path, 'wb') as f:
f.write(serialized)
size_mb = os.path.getsize(engine_path) / (1 << 20)
print(f' Engine saved: {engine_path} ({size_mb:.1f} MB, built in {build_time:.1f}s)')
def main():
"""Main entry point for TRT engine building."""
parser = argparse.ArgumentParser(
description='Build TensorRT FP16 engines for SCRFD and ArcFace.')
parser.add_argument('--scrfd-onnx', type=str, default='',
help='Path to SCRFD ONNX model.')
parser.add_argument('--arcface-onnx', type=str, default='',
help='Path to ArcFace ONNX model.')
parser.add_argument('--output-dir', type=str,
default='/mnt/nvme/saltybot/models',
help='Output directory for engine files.')
parser.add_argument('--fp16', action='store_true', default=True,
help='Enable FP16 precision (default: True).')
parser.add_argument('--no-fp16', action='store_false', dest='fp16',
help='Disable FP16 (use FP32 only).')
parser.add_argument('--workspace-mb', type=int, default=1024,
help='TRT workspace size in MB (default: 1024).')
args = parser.parse_args()
if not args.scrfd_onnx and not args.arcface_onnx:
parser.error('At least one of --scrfd-onnx or --arcface-onnx is required.')
if args.scrfd_onnx:
engine_path = os.path.join(args.output_dir, 'scrfd_2.5g.engine')
print(f'\n=== Building SCRFD engine ===')
build_engine(args.scrfd_onnx, engine_path, args.fp16, args.workspace_mb)
if args.arcface_onnx:
engine_path = os.path.join(args.output_dir, 'arcface_r50.engine')
print(f'\n=== Building ArcFace engine ===')
build_engine(args.arcface_onnx, engine_path, args.fp16, args.workspace_mb)
print('\nDone.')
if __name__ == '__main__':
main()

View File

@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/saltybot_social_face
[install]
install_scripts=$base/lib/saltybot_social_face

View File

@ -0,0 +1,30 @@
"""Setup for saltybot_social_face package."""
from setuptools import find_packages, setup
package_name = 'saltybot_social_face'
setup(
name=package_name,
version='0.1.0',
packages=find_packages(exclude=['test']),
data_files=[
('share/ament_index/resource_index/packages', ['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
('share/' + package_name + '/launch', ['launch/face_recognition.launch.py']),
('share/' + package_name + '/config', ['config/face_recognition_params.yaml']),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='seb',
maintainer_email='seb@vayrette.com',
description='SCRFD face detection and ArcFace recognition for SaltyBot social interactions',
license='MIT',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'face_recognition = saltybot_social_face.face_recognition_node:main',
'enrollment_cli = saltybot_social_face.enrollment_cli:main',
],
},
)

View File

@ -0,0 +1,28 @@
cmake_minimum_required(VERSION 3.8)
project(saltybot_social_msgs)
if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
add_compile_options(-Wall -Wextra -Wpedantic)
endif()
find_package(ament_cmake REQUIRED)
find_package(rosidl_default_generators REQUIRED)
find_package(std_msgs REQUIRED)
find_package(geometry_msgs REQUIRED)
find_package(builtin_interfaces REQUIRED)
rosidl_generate_interfaces(${PROJECT_NAME}
"msg/FaceDetection.msg"
"msg/FaceDetectionArray.msg"
"msg/FaceEmbedding.msg"
"msg/FaceEmbeddingArray.msg"
"msg/PersonState.msg"
"msg/PersonStateArray.msg"
"srv/EnrollPerson.srv"
"srv/ListPersons.srv"
"srv/DeletePerson.srv"
"srv/UpdatePerson.srv"
DEPENDENCIES std_msgs geometry_msgs builtin_interfaces
)
ament_package()

View File

@ -0,0 +1,12 @@
std_msgs/Header header
int32 face_id # -1 if unknown
string person_name # "" if unknown
float32 confidence # detection confidence 0-1
float32 recognition_score # cosine similarity 0-1 (0 if unknown)
# Bounding box in pixels
float32 bbox_x
float32 bbox_y
float32 bbox_w
float32 bbox_h
# 5-point landmarks [x0,y0, x1,y1, x2,y2, x3,y3, x4,y4] = left_eye, right_eye, nose, left_mouth, right_mouth
float32[10] landmarks

View File

@ -0,0 +1,2 @@
std_msgs/Header header
saltybot_social_msgs/FaceDetection[] faces

View File

@ -0,0 +1,5 @@
int32 person_id
string person_name
float32[] embedding # 512-dim ArcFace embedding
builtin_interfaces/Time enrolled_at
int32 sample_count

View File

@ -0,0 +1,2 @@
std_msgs/Header header
saltybot_social_msgs/FaceEmbedding[] embeddings

View File

@ -0,0 +1,19 @@
std_msgs/Header header
int32 person_id
string person_name
int32 face_id
string speaker_id # from audio, "" if unknown
string uwb_anchor_id # "" if no UWB
geometry_msgs/Point position # in base_link frame, zeros if unknown
float32 distance # metres
float32 bearing_deg # degrees, 0=forward
uint8 state # see STATE_* constants
uint8 STATE_UNKNOWN=0
uint8 STATE_APPROACHING=1
uint8 STATE_ENGAGED=2
uint8 STATE_TALKING=3
uint8 STATE_LEAVING=4
uint8 STATE_ABSENT=5
float32 engagement_score # 0-1, attention model output
builtin_interfaces/Time last_seen
int32 camera_id # which CSI camera last saw them (-1=depth cam)

View File

@ -0,0 +1,3 @@
std_msgs/Header header
saltybot_social_msgs/PersonState[] persons
int32 primary_attention_id # person_id of focus target (-1 if none)

View File

@ -0,0 +1,24 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social_msgs</name>
<version>0.1.0</version>
<description>Custom message and service definitions for the SaltyBot social sprint</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<buildtool_depend>ament_cmake</buildtool_depend>
<buildtool_depend>rosidl_default_generators</buildtool_depend>
<depend>std_msgs</depend>
<depend>geometry_msgs</depend>
<depend>builtin_interfaces</depend>
<exec_depend>rosidl_default_runtime</exec_depend>
<member_of_group>rosidl_interface_packages</member_of_group>
<export>
<build_type>ament_cmake</build_type>
</export>
</package>

View File

@ -0,0 +1,4 @@
int32 person_id
---
bool success
string message

View File

@ -0,0 +1,7 @@
string name
string mode # "face", "voice", "both"
int32 n_samples # number of face crops to average (default 10)
---
bool success
string message
int32 person_id

View File

@ -0,0 +1,2 @@
---
saltybot_social_msgs/FaceEmbedding[] persons

View File

@ -0,0 +1,5 @@
int32 person_id
string new_name
---
bool success
string message