From f61a03b3c526f8dba5a6ff99c2321d3abde2a140 Mon Sep 17 00:00:00 2001 From: sl-perception Date: Sun, 1 Mar 2026 23:11:20 -0500 Subject: [PATCH] feat(social): face detection + recognition (SCRFD + ArcFace TRT FP16, Issue #80) Add two new ROS2 packages for the social sprint: saltybot_social_msgs (ament_cmake): - FaceDetection, FaceDetectionArray, FaceEmbedding, FaceEmbeddingArray - PersonState, PersonStateArray - EnrollPerson, ListPersons, DeletePerson, UpdatePerson services saltybot_social_face (ament_python): - SCRFDDetector: SCRFD face detection with TRT FP16 + ONNX fallback - 640x640 input, 3-stride anchor decoding, NMS - ArcFaceRecognizer: 512-dim embedding extraction with gallery matching - 5-point landmark alignment to 112x112, cosine similarity - FaceGallery: thread-safe persistent gallery (npz + JSON sidecar) - FaceRecognitionNode: ROS2 node subscribing /camera/color/image_raw, publishing /social/faces/detections, /social/faces/embeddings - Enrollment via /social/enroll service (N-sample face averaging) - Launch file, config YAML, TRT engine builder script Co-Authored-By: Claude Sonnet 4.6 --- .../config/face_recognition_params.yaml | 11 + .../launch/face_recognition.launch.py | 80 ++++ .../src/saltybot_social_face/package.xml | 27 ++ .../resource/saltybot_social_face | 0 .../saltybot_social_face/__init__.py | 1 + .../arcface_recognizer.py | 316 +++++++++++++ .../saltybot_social_face/enrollment_cli.py | 78 ++++ .../saltybot_social_face/face_gallery.py | 206 +++++++++ .../face_recognition_node.py | 431 ++++++++++++++++++ .../saltybot_social_face/scrfd_detector.py | 350 ++++++++++++++ .../scripts/build_face_trt_engines.py | 112 +++++ .../src/saltybot_social_face/setup.cfg | 4 + .../ros2_ws/src/saltybot_social_face/setup.py | 30 ++ 13 files changed, 1646 insertions(+) create mode 100644 jetson/ros2_ws/src/saltybot_social_face/config/face_recognition_params.yaml create mode 100644 jetson/ros2_ws/src/saltybot_social_face/launch/face_recognition.launch.py create mode 100644 jetson/ros2_ws/src/saltybot_social_face/package.xml create mode 100644 jetson/ros2_ws/src/saltybot_social_face/resource/saltybot_social_face create mode 100644 jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/__init__.py create mode 100644 jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/arcface_recognizer.py create mode 100644 jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/enrollment_cli.py create mode 100644 jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/face_gallery.py create mode 100644 jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/face_recognition_node.py create mode 100644 jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/scrfd_detector.py create mode 100644 jetson/ros2_ws/src/saltybot_social_face/scripts/build_face_trt_engines.py create mode 100644 jetson/ros2_ws/src/saltybot_social_face/setup.cfg create mode 100644 jetson/ros2_ws/src/saltybot_social_face/setup.py diff --git a/jetson/ros2_ws/src/saltybot_social_face/config/face_recognition_params.yaml b/jetson/ros2_ws/src/saltybot_social_face/config/face_recognition_params.yaml new file mode 100644 index 0000000..f5a6517 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_face/config/face_recognition_params.yaml @@ -0,0 +1,11 @@ +face_recognizer: + ros__parameters: + scrfd_engine_path: '/mnt/nvme/saltybot/models/scrfd_2.5g.engine' + scrfd_onnx_path: '/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx' + arcface_engine_path: '/mnt/nvme/saltybot/models/arcface_r50.engine' + arcface_onnx_path: '/mnt/nvme/saltybot/models/arcface_r50.onnx' + gallery_dir: '/mnt/nvme/saltybot/gallery' + recognition_threshold: 0.35 + publish_debug_image: false + max_faces: 10 + scrfd_conf_thresh: 0.5 diff --git a/jetson/ros2_ws/src/saltybot_social_face/launch/face_recognition.launch.py b/jetson/ros2_ws/src/saltybot_social_face/launch/face_recognition.launch.py new file mode 100644 index 0000000..f177b7e --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_face/launch/face_recognition.launch.py @@ -0,0 +1,80 @@ +""" +face_recognition.launch.py -- Launch file for the SCRFD + ArcFace face recognition node. + +Launches the face_recognizer node with configurable model paths and parameters. +The RealSense camera must be running separately (e.g., via realsense.launch.py). +""" + +from launch import LaunchDescription +from launch.actions import DeclareLaunchArgument +from launch.substitutions import LaunchConfiguration +from launch_ros.actions import Node + + +def generate_launch_description(): + """Generate launch description for face recognition pipeline.""" + return LaunchDescription([ + DeclareLaunchArgument( + 'scrfd_engine_path', + default_value='/mnt/nvme/saltybot/models/scrfd_2.5g.engine', + description='Path to SCRFD TensorRT engine file', + ), + DeclareLaunchArgument( + 'scrfd_onnx_path', + default_value='/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx', + description='Path to SCRFD ONNX model file (fallback)', + ), + DeclareLaunchArgument( + 'arcface_engine_path', + default_value='/mnt/nvme/saltybot/models/arcface_r50.engine', + description='Path to ArcFace TensorRT engine file', + ), + DeclareLaunchArgument( + 'arcface_onnx_path', + default_value='/mnt/nvme/saltybot/models/arcface_r50.onnx', + description='Path to ArcFace ONNX model file (fallback)', + ), + DeclareLaunchArgument( + 'gallery_dir', + default_value='/mnt/nvme/saltybot/gallery', + description='Directory for persistent face gallery storage', + ), + DeclareLaunchArgument( + 'recognition_threshold', + default_value='0.35', + description='Cosine similarity threshold for face recognition', + ), + DeclareLaunchArgument( + 'publish_debug_image', + default_value='false', + description='Publish annotated debug image to /social/faces/debug_image', + ), + DeclareLaunchArgument( + 'max_faces', + default_value='10', + description='Maximum faces to process per frame', + ), + DeclareLaunchArgument( + 'scrfd_conf_thresh', + default_value='0.5', + description='SCRFD detection confidence threshold', + ), + + Node( + package='saltybot_social_face', + executable='face_recognition', + name='face_recognizer', + output='screen', + parameters=[{ + 'scrfd_engine_path': LaunchConfiguration('scrfd_engine_path'), + 'scrfd_onnx_path': LaunchConfiguration('scrfd_onnx_path'), + 'arcface_engine_path': LaunchConfiguration('arcface_engine_path'), + 'arcface_onnx_path': LaunchConfiguration('arcface_onnx_path'), + 'gallery_dir': LaunchConfiguration('gallery_dir'), + 'recognition_threshold': LaunchConfiguration('recognition_threshold'), + 'publish_debug_image': LaunchConfiguration('publish_debug_image'), + 'max_faces': LaunchConfiguration('max_faces'), + 'scrfd_conf_thresh': LaunchConfiguration('scrfd_conf_thresh'), + }], + ), + ]) diff --git a/jetson/ros2_ws/src/saltybot_social_face/package.xml b/jetson/ros2_ws/src/saltybot_social_face/package.xml new file mode 100644 index 0000000..7a21282 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_face/package.xml @@ -0,0 +1,27 @@ + + + + saltybot_social_face + 0.1.0 + SCRFD face detection and ArcFace recognition for SaltyBot social interactions + seb + MIT + + rclpy + sensor_msgs + cv_bridge + image_transport + saltybot_social_msgs + + python3-numpy + python3-opencv + + ament_copyright + ament_flake8 + ament_pep257 + python3-pytest + + + ament_python + + diff --git a/jetson/ros2_ws/src/saltybot_social_face/resource/saltybot_social_face b/jetson/ros2_ws/src/saltybot_social_face/resource/saltybot_social_face new file mode 100644 index 0000000..e69de29 diff --git a/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/__init__.py b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/__init__.py new file mode 100644 index 0000000..e7872ab --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/__init__.py @@ -0,0 +1 @@ +"""SaltyBot social face detection and recognition package.""" diff --git a/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/arcface_recognizer.py b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/arcface_recognizer.py new file mode 100644 index 0000000..9d04dc4 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/arcface_recognizer.py @@ -0,0 +1,316 @@ +""" +arcface_recognizer.py -- ArcFace face embedding extraction and gallery matching. + +Performs face alignment using 5-point landmarks (insightface standard reference), +extracts 512-dimensional embeddings via ArcFace (TRT FP16 or ONNX fallback), +and matches against a persistent gallery using cosine similarity. +""" + +import os +import logging +from typing import Optional + +import numpy as np +import cv2 + +logger = logging.getLogger(__name__) + +# InsightFace standard reference landmarks for 112x112 alignment +ARCFACE_SRC = np.array([ + [38.2946, 51.6963], # left eye + [73.5318, 51.5014], # right eye + [56.0252, 71.7366], # nose + [41.5493, 92.3655], # left mouth + [70.7299, 92.2041], # right mouth +], dtype=np.float32) + + +# -- Inference backends -------------------------------------------------------- + +class _TRTBackend: + """TensorRT inference engine for ArcFace.""" + + def __init__(self, engine_path: str): + import tensorrt as trt + import pycuda.driver as cuda + import pycuda.autoinit # noqa: F401 + + self._cuda = cuda + trt_logger = trt.Logger(trt.Logger.WARNING) + with open(engine_path, 'rb') as f, trt.Runtime(trt_logger) as runtime: + self._engine = runtime.deserialize_cuda_engine(f.read()) + self._context = self._engine.create_execution_context() + + self._inputs = [] + self._outputs = [] + self._bindings = [] + for i in range(self._engine.num_io_tensors): + name = self._engine.get_tensor_name(i) + dtype = trt.nptype(self._engine.get_tensor_dtype(name)) + shape = tuple(self._engine.get_tensor_shape(name)) + nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize + host_mem = cuda.pagelocked_empty(shape, dtype) + device_mem = cuda.mem_alloc(nbytes) + self._bindings.append(int(device_mem)) + if self._engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: + self._inputs.append({'host': host_mem, 'device': device_mem}) + else: + self._outputs.append({'host': host_mem, 'device': device_mem, + 'shape': shape}) + self._stream = cuda.Stream() + + def infer(self, input_data: np.ndarray) -> np.ndarray: + """Run inference and return the embedding vector.""" + np.copyto(self._inputs[0]['host'], input_data.ravel()) + self._cuda.memcpy_htod_async( + self._inputs[0]['device'], self._inputs[0]['host'], self._stream) + self._context.execute_async_v2(self._bindings, self._stream.handle) + for out in self._outputs: + self._cuda.memcpy_dtoh_async(out['host'], out['device'], self._stream) + self._stream.synchronize() + return self._outputs[0]['host'].reshape(self._outputs[0]['shape']).copy() + + +class _ONNXBackend: + """ONNX Runtime inference (CUDA EP with CPU fallback).""" + + def __init__(self, onnx_path: str): + import onnxruntime as ort + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + self._session = ort.InferenceSession(onnx_path, providers=providers) + self._input_name = self._session.get_inputs()[0].name + + def infer(self, input_data: np.ndarray) -> np.ndarray: + """Run inference and return the embedding vector.""" + results = self._session.run(None, {self._input_name: input_data}) + return results[0] + + +# -- Face alignment ------------------------------------------------------------ + +def align_face(bgr: np.ndarray, landmarks_10: list[float]) -> np.ndarray: + """Align a face to 112x112 using 5-point landmarks. + + Args: + bgr: Source BGR image. + landmarks_10: Flat list of 10 floats [x0,y0, x1,y1, ..., x4,y4]. + + Returns: + Aligned BGR face crop of shape (112, 112, 3). + """ + src_pts = np.array(landmarks_10, dtype=np.float32).reshape(5, 2) + M, _ = cv2.estimateAffinePartial2D(src_pts, ARCFACE_SRC) + if M is None: + # Fallback: simple crop and resize from bbox-like region + cx = np.mean(src_pts[:, 0]) + cy = np.mean(src_pts[:, 1]) + spread = max(np.ptp(src_pts[:, 0]), np.ptp(src_pts[:, 1])) * 1.5 + half = spread / 2 + x1 = max(0, int(cx - half)) + y1 = max(0, int(cy - half)) + x2 = min(bgr.shape[1], int(cx + half)) + y2 = min(bgr.shape[0], int(cy + half)) + crop = bgr[y1:y2, x1:x2] + return cv2.resize(crop, (112, 112), interpolation=cv2.INTER_LINEAR) + + aligned = cv2.warpAffine(bgr, M, (112, 112), borderMode=cv2.BORDER_REPLICATE) + return aligned + + +# -- Main recognizer class ----------------------------------------------------- + +class ArcFaceRecognizer: + """ArcFace face embedding extractor and gallery matcher. + + Args: + engine_path: Path to TensorRT engine file. + onnx_path: Path to ONNX model file (used if engine not available). + """ + + def __init__(self, engine_path: str = '', onnx_path: str = ''): + self._backend: Optional[_TRTBackend | _ONNXBackend] = None + self.gallery: dict[int, dict] = {} + + # Try TRT first, then ONNX + if engine_path and os.path.isfile(engine_path): + try: + self._backend = _TRTBackend(engine_path) + logger.info('ArcFace TensorRT backend loaded: %s', engine_path) + return + except Exception as e: + logger.warning('ArcFace TRT load failed (%s), trying ONNX', e) + + if onnx_path and os.path.isfile(onnx_path): + try: + self._backend = _ONNXBackend(onnx_path) + logger.info('ArcFace ONNX backend loaded: %s', onnx_path) + return + except Exception as e: + logger.error('ArcFace ONNX load failed: %s', e) + + logger.error('No ArcFace model loaded. Recognition will be unavailable.') + + @property + def is_loaded(self) -> bool: + """Return True if a backend is loaded and ready.""" + return self._backend is not None + + def embed(self, bgr_face_112x112: np.ndarray) -> np.ndarray: + """Extract 512-dim L2-normalized embedding from a 112x112 aligned face. + + Args: + bgr_face_112x112: Aligned face crop, BGR, shape (112, 112, 3). + + Returns: + L2-normalized embedding of shape (512,). + """ + if self._backend is None: + return np.zeros(512, dtype=np.float32) + + # Preprocess: BGR->RGB, /255, subtract 0.5, divide 0.5 -> [1,3,112,112] + rgb = cv2.cvtColor(bgr_face_112x112, cv2.COLOR_BGR2RGB).astype(np.float32) + rgb = rgb / 255.0 + rgb = (rgb - 0.5) / 0.5 + blob = rgb.transpose(2, 0, 1)[np.newaxis] # [1, 3, 112, 112] + blob = np.ascontiguousarray(blob) + + output = self._backend.infer(blob) + embedding = output.flatten()[:512].astype(np.float32) + + # L2 normalize + norm = np.linalg.norm(embedding) + if norm > 0: + embedding = embedding / norm + return embedding + + def align_and_embed(self, bgr_image: np.ndarray, landmarks_10: list[float]) -> np.ndarray: + """Align face using landmarks and extract embedding. + + Args: + bgr_image: Full BGR image. + landmarks_10: Flat list of 10 floats from SCRFD detection. + + Returns: + L2-normalized embedding of shape (512,). + """ + aligned = align_face(bgr_image, landmarks_10) + return self.embed(aligned) + + def load_gallery(self, gallery_path: str) -> None: + """Load gallery from .npz file with JSON metadata sidecar. + + Args: + gallery_path: Path to the .npz gallery file. + """ + import json + + if not os.path.isfile(gallery_path): + logger.info('No gallery file at %s, starting empty.', gallery_path) + self.gallery = {} + return + + data = np.load(gallery_path, allow_pickle=False) + meta_path = gallery_path.replace('.npz', '_meta.json') + + if os.path.isfile(meta_path): + with open(meta_path, 'r') as f: + meta = json.load(f) + else: + meta = {} + + self.gallery = {} + for key in data.files: + pid = int(key) + embedding = data[key].astype(np.float32) + norm = np.linalg.norm(embedding) + if norm > 0: + embedding = embedding / norm + info = meta.get(str(pid), {}) + self.gallery[pid] = { + 'name': info.get('name', f'person_{pid}'), + 'embedding': embedding, + 'samples': info.get('samples', 1), + 'enrolled_at': info.get('enrolled_at', 0.0), + } + logger.info('Gallery loaded: %d persons from %s', len(self.gallery), gallery_path) + + def save_gallery(self, gallery_path: str) -> None: + """Save gallery to .npz file with JSON metadata sidecar. + + Args: + gallery_path: Path to the .npz gallery file. + """ + import json + + arrays = {} + meta = {} + for pid, info in self.gallery.items(): + arrays[str(pid)] = info['embedding'] + meta[str(pid)] = { + 'name': info['name'], + 'samples': info['samples'], + 'enrolled_at': info['enrolled_at'], + } + + os.makedirs(os.path.dirname(gallery_path) or '.', exist_ok=True) + np.savez(gallery_path, **arrays) + + meta_path = gallery_path.replace('.npz', '_meta.json') + with open(meta_path, 'w') as f: + json.dump(meta, f, indent=2) + logger.info('Gallery saved: %d persons to %s', len(self.gallery), gallery_path) + + def match(self, embedding: np.ndarray, threshold: float = 0.35) -> tuple[int, str, float]: + """Match an embedding against the gallery. + + Args: + embedding: L2-normalized query embedding of shape (512,). + threshold: Minimum cosine similarity for a match. + + Returns: + (person_id, person_name, score) or (-1, '', 0.0) if no match. + """ + if not self.gallery: + return (-1, '', 0.0) + + best_pid = -1 + best_name = '' + best_score = 0.0 + + for pid, info in self.gallery.items(): + score = float(np.dot(embedding, info['embedding'])) + if score > best_score: + best_score = score + best_pid = pid + best_name = info['name'] + + if best_score >= threshold: + return (best_pid, best_name, best_score) + return (-1, '', 0.0) + + def enroll(self, person_id: int, person_name: str, embeddings_list: list[np.ndarray]) -> None: + """Enroll a person by averaging multiple embeddings. + + Args: + person_id: Unique integer ID for this person. + person_name: Human-readable name. + embeddings_list: List of L2-normalized embeddings to average. + """ + import time as _time + + if not embeddings_list: + return + + mean_emb = np.mean(embeddings_list, axis=0).astype(np.float32) + norm = np.linalg.norm(mean_emb) + if norm > 0: + mean_emb = mean_emb / norm + + self.gallery[person_id] = { + 'name': person_name, + 'embedding': mean_emb, + 'samples': len(embeddings_list), + 'enrolled_at': _time.time(), + } + logger.info('Enrolled person %d (%s) with %d samples.', + person_id, person_name, len(embeddings_list)) diff --git a/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/enrollment_cli.py b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/enrollment_cli.py new file mode 100644 index 0000000..037bf8e --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/enrollment_cli.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +""" +enrollment_cli.py -- CLI tool for enrolling persons via the /social/enroll service. + +Usage: + ros2 run saltybot_social_face enrollment_cli -- --name Alice --mode face --samples 10 +""" + +import argparse +import sys + +import rclpy +from rclpy.node import Node + +from saltybot_social_msgs.srv import EnrollPerson + + +class EnrollmentCLI(Node): + """Simple CLI node that calls the EnrollPerson service.""" + + def __init__(self, name: str, mode: str, n_samples: int): + super().__init__('enrollment_cli') + self._client = self.create_client(EnrollPerson, '/social/enroll') + + self.get_logger().info('Waiting for /social/enroll service...') + if not self._client.wait_for_service(timeout_sec=10.0): + self.get_logger().error('Service /social/enroll not available.') + return + + request = EnrollPerson.Request() + request.name = name + request.mode = mode + request.n_samples = n_samples + + self.get_logger().info( + 'Enrolling "%s" (mode=%s, samples=%d)...', name, mode, n_samples) + + future = self._client.call_async(request) + rclpy.spin_until_future_complete(self, future, timeout_sec=120.0) + + if future.result() is not None: + result = future.result() + if result.success: + self.get_logger().info( + 'Enrollment successful: person_id=%d, %s', + result.person_id, result.message) + else: + self.get_logger().error( + 'Enrollment failed: %s', result.message) + else: + self.get_logger().error('Enrollment service call timed out or failed.') + + +def main(args=None): + """Entry point for enrollment CLI.""" + parser = argparse.ArgumentParser(description='Enroll a person for face recognition.') + parser.add_argument('--name', type=str, required=True, + help='Name of the person to enroll.') + parser.add_argument('--mode', type=str, default='face', + choices=['face', 'voice', 'both'], + help='Enrollment mode (default: face).') + parser.add_argument('--samples', type=int, default=10, + help='Number of face samples to collect (default: 10).') + + # Parse only known args so ROS2 remapping args pass through + parsed, remaining = parser.parse_known_args(args=sys.argv[1:]) + + rclpy.init(args=remaining) + node = EnrollmentCLI(parsed.name, parsed.mode, parsed.samples) + try: + pass # Node does all work in __init__ + finally: + node.destroy_node() + rclpy.shutdown() + + +if __name__ == '__main__': + main() diff --git a/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/face_gallery.py b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/face_gallery.py new file mode 100644 index 0000000..50de106 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/face_gallery.py @@ -0,0 +1,206 @@ +""" +face_gallery.py -- Persistent face embedding gallery backed by numpy .npz + JSON. + +Thread-safe gallery storage for face recognition. Embeddings are stored in a +.npz file, with a sidecar metadata.json containing names, sample counts, and +enrollment timestamps. Auto-increment IDs start at 1. +""" + +import json +import logging +import os +import threading +import time +from typing import Optional + +import numpy as np + +logger = logging.getLogger(__name__) + + +class FaceGallery: + """Persistent, thread-safe face embedding gallery. + + Args: + gallery_dir: Directory for gallery.npz and metadata.json files. + """ + + def __init__(self, gallery_dir: str): + self._gallery_dir = gallery_dir + self._npz_path = os.path.join(gallery_dir, 'gallery.npz') + self._meta_path = os.path.join(gallery_dir, 'metadata.json') + self._gallery: dict[int, dict] = {} + self._next_id = 1 + self._lock = threading.Lock() + + def load(self) -> None: + """Load gallery from disk. Populates internal gallery dict.""" + with self._lock: + self._gallery = {} + self._next_id = 1 + + if not os.path.isfile(self._npz_path): + logger.info('No gallery file at %s, starting empty.', self._npz_path) + return + + data = np.load(self._npz_path, allow_pickle=False) + meta: dict = {} + if os.path.isfile(self._meta_path): + with open(self._meta_path, 'r') as f: + meta = json.load(f) + + for key in data.files: + pid = int(key) + embedding = data[key].astype(np.float32) + norm = np.linalg.norm(embedding) + if norm > 0: + embedding = embedding / norm + info = meta.get(str(pid), {}) + self._gallery[pid] = { + 'name': info.get('name', f'person_{pid}'), + 'embedding': embedding, + 'samples': info.get('samples', 1), + 'enrolled_at': info.get('enrolled_at', 0.0), + } + if pid >= self._next_id: + self._next_id = pid + 1 + + logger.info('Gallery loaded: %d persons from %s', + len(self._gallery), self._npz_path) + + def save(self) -> None: + """Save gallery to disk (npz + JSON sidecar).""" + with self._lock: + os.makedirs(self._gallery_dir, exist_ok=True) + + arrays = {} + meta = {} + for pid, info in self._gallery.items(): + arrays[str(pid)] = info['embedding'] + meta[str(pid)] = { + 'name': info['name'], + 'samples': info['samples'], + 'enrolled_at': info['enrolled_at'], + } + + np.savez(self._npz_path, **arrays) + with open(self._meta_path, 'w') as f: + json.dump(meta, f, indent=2) + + logger.info('Gallery saved: %d persons to %s', + len(self._gallery), self._npz_path) + + def add_person(self, name: str, embedding: np.ndarray, samples: int = 1) -> int: + """Add a new person to the gallery. + + Args: + name: Person's name. + embedding: L2-normalized 512-dim embedding. + samples: Number of samples used to compute the embedding. + + Returns: + Assigned person_id (auto-increment integer). + """ + with self._lock: + pid = self._next_id + self._next_id += 1 + emb = embedding.astype(np.float32) + norm = np.linalg.norm(emb) + if norm > 0: + emb = emb / norm + self._gallery[pid] = { + 'name': name, + 'embedding': emb, + 'samples': samples, + 'enrolled_at': time.time(), + } + logger.info('Added person %d (%s), %d samples.', pid, name, samples) + return pid + + def update_name(self, person_id: int, new_name: str) -> bool: + """Update a person's name. + + Args: + person_id: The ID of the person to update. + new_name: New name string. + + Returns: + True if the person was found and updated. + """ + with self._lock: + if person_id not in self._gallery: + return False + self._gallery[person_id]['name'] = new_name + return True + + def delete_person(self, person_id: int) -> bool: + """Remove a person from the gallery. + + Args: + person_id: The ID of the person to delete. + + Returns: + True if the person was found and removed. + """ + with self._lock: + if person_id not in self._gallery: + return False + del self._gallery[person_id] + logger.info('Deleted person %d.', person_id) + return True + + def get_all(self) -> list[dict]: + """Get all gallery entries. + + Returns: + List of dicts with keys: person_id, name, embedding, samples, enrolled_at. + """ + with self._lock: + result = [] + for pid, info in self._gallery.items(): + result.append({ + 'person_id': pid, + 'name': info['name'], + 'embedding': info['embedding'].copy(), + 'samples': info['samples'], + 'enrolled_at': info['enrolled_at'], + }) + return result + + def match(self, query_embedding: np.ndarray, threshold: float = 0.35) -> tuple[int, str, float]: + """Match a query embedding against the gallery using cosine similarity. + + Args: + query_embedding: L2-normalized 512-dim embedding. + threshold: Minimum cosine similarity for a match. + + Returns: + (person_id, name, score) or (-1, '', 0.0) if no match. + """ + with self._lock: + if not self._gallery: + return (-1, '', 0.0) + + best_pid = -1 + best_name = '' + best_score = 0.0 + + query = query_embedding.astype(np.float32) + norm = np.linalg.norm(query) + if norm > 0: + query = query / norm + + for pid, info in self._gallery.items(): + score = float(np.dot(query, info['embedding'])) + if score > best_score: + best_score = score + best_pid = pid + best_name = info['name'] + + if best_score >= threshold: + return (best_pid, best_name, best_score) + return (-1, '', 0.0) + + def __len__(self) -> int: + with self._lock: + return len(self._gallery) diff --git a/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/face_recognition_node.py b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/face_recognition_node.py new file mode 100644 index 0000000..d73fc98 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/face_recognition_node.py @@ -0,0 +1,431 @@ +""" +face_recognition_node.py -- ROS2 node for SCRFD face detection + ArcFace recognition. + +Pipeline: + 1. Subscribe to /camera/color/image_raw (RealSense D435i color stream). + 2. Run SCRFD face detection (TensorRT FP16 or ONNX fallback). + 3. For each detected face, align and extract ArcFace embedding. + 4. Match embedding against persistent gallery. + 5. Publish FaceDetectionArray with identified faces. + +Services: + /social/enroll -- Enroll a new person (collects N face samples). + /social/persons/list -- List all enrolled persons. + /social/persons/delete -- Delete a person from the gallery. + /social/persons/update -- Update a person's name. +""" + +import time + +import numpy as np + +import rclpy +from rclpy.node import Node +from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy, DurabilityPolicy + +import cv2 +from cv_bridge import CvBridge + +from sensor_msgs.msg import Image +from builtin_interfaces.msg import Time + +from saltybot_social_msgs.msg import ( + FaceDetection, + FaceDetectionArray, + FaceEmbedding, + FaceEmbeddingArray, +) +from saltybot_social_msgs.srv import ( + EnrollPerson, + ListPersons, + DeletePerson, + UpdatePerson, +) + +from .scrfd_detector import SCRFDDetector +from .arcface_recognizer import ArcFaceRecognizer +from .face_gallery import FaceGallery + + +class FaceRecognitionNode(Node): + """ROS2 node: SCRFD face detection + ArcFace gallery matching.""" + + def __init__(self): + super().__init__('face_recognizer') + self._bridge = CvBridge() + self._frame_count = 0 + self._fps_t0 = time.monotonic() + + # -- Parameters -------------------------------------------------------- + self.declare_parameter('scrfd_engine_path', + '/mnt/nvme/saltybot/models/scrfd_2.5g.engine') + self.declare_parameter('scrfd_onnx_path', + '/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx') + self.declare_parameter('arcface_engine_path', + '/mnt/nvme/saltybot/models/arcface_r50.engine') + self.declare_parameter('arcface_onnx_path', + '/mnt/nvme/saltybot/models/arcface_r50.onnx') + self.declare_parameter('gallery_dir', '/mnt/nvme/saltybot/gallery') + self.declare_parameter('recognition_threshold', 0.35) + self.declare_parameter('publish_debug_image', False) + self.declare_parameter('max_faces', 10) + self.declare_parameter('scrfd_conf_thresh', 0.5) + + self._recognition_threshold = self.get_parameter('recognition_threshold').value + self._pub_debug_flag = self.get_parameter('publish_debug_image').value + self._max_faces = self.get_parameter('max_faces').value + + # -- Models ------------------------------------------------------------ + self._detector = SCRFDDetector( + engine_path=self.get_parameter('scrfd_engine_path').value, + onnx_path=self.get_parameter('scrfd_onnx_path').value, + conf_thresh=self.get_parameter('scrfd_conf_thresh').value, + ) + self._recognizer = ArcFaceRecognizer( + engine_path=self.get_parameter('arcface_engine_path').value, + onnx_path=self.get_parameter('arcface_onnx_path').value, + ) + + # -- Gallery ----------------------------------------------------------- + gallery_dir = self.get_parameter('gallery_dir').value + self._gallery = FaceGallery(gallery_dir) + self._gallery.load() + self.get_logger().info('Gallery loaded: %d persons.', len(self._gallery)) + + # -- Enrollment state -------------------------------------------------- + self._enrolling = None # {name, samples_needed, collected: [embeddings]} + + # -- QoS profiles ------------------------------------------------------ + best_effort_qos = QoSProfile( + reliability=ReliabilityPolicy.BEST_EFFORT, + history=HistoryPolicy.KEEP_LAST, + depth=1, + ) + reliable_qos = QoSProfile( + reliability=ReliabilityPolicy.RELIABLE, + durability=DurabilityPolicy.TRANSIENT_LOCAL, + history=HistoryPolicy.KEEP_LAST, + depth=1, + ) + + # -- Subscribers ------------------------------------------------------- + self.create_subscription( + Image, + '/camera/color/image_raw', + self._on_image, + best_effort_qos, + ) + + # -- Publishers -------------------------------------------------------- + self._pub_detections = self.create_publisher( + FaceDetectionArray, '/social/faces/detections', best_effort_qos) + self._pub_embeddings = self.create_publisher( + FaceEmbeddingArray, '/social/faces/embeddings', reliable_qos) + if self._pub_debug_flag: + self._pub_debug_img = self.create_publisher( + Image, '/social/faces/debug_image', best_effort_qos) + + # -- Services ---------------------------------------------------------- + self.create_service(EnrollPerson, '/social/enroll', self._handle_enroll) + self.create_service(ListPersons, '/social/persons/list', self._handle_list) + self.create_service(DeletePerson, '/social/persons/delete', self._handle_delete) + self.create_service(UpdatePerson, '/social/persons/update', self._handle_update) + + # Publish initial gallery state + self._publish_gallery_embeddings() + + self.get_logger().info('FaceRecognitionNode ready.') + + # -- Image callback -------------------------------------------------------- + + def _on_image(self, msg: Image): + """Process incoming camera frame: detect, recognize, publish.""" + try: + bgr = self._bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8') + except Exception as e: + self.get_logger().error('Image decode error: %s', str(e), + throttle_duration_sec=5.0) + return + + # Detect faces + detections = self._detector.detect(bgr) + + # Limit face count + if len(detections) > self._max_faces: + detections = sorted(detections, key=lambda d: d['score'], reverse=True) + detections = detections[:self._max_faces] + + # Build output message + det_array = FaceDetectionArray() + det_array.header = msg.header + + for det in detections: + # Extract embedding and match gallery + embedding = self._recognizer.align_and_embed(bgr, det['kps']) + pid, pname, score = self._gallery.match( + embedding, self._recognition_threshold) + + # Handle enrollment: collect embedding from largest face + if self._enrolling is not None: + self._enrollment_collect(det, embedding) + + # Build FaceDetection message + face_msg = FaceDetection() + face_msg.header = msg.header + face_msg.face_id = pid + face_msg.person_name = pname + face_msg.confidence = det['score'] + face_msg.recognition_score = score + + bbox = det['bbox'] + face_msg.bbox_x = bbox[0] + face_msg.bbox_y = bbox[1] + face_msg.bbox_w = bbox[2] - bbox[0] + face_msg.bbox_h = bbox[3] - bbox[1] + + kps = det['kps'] + for i in range(10): + face_msg.landmarks[i] = kps[i] + + det_array.faces.append(face_msg) + + self._pub_detections.publish(det_array) + + # Debug image + if self._pub_debug_flag and hasattr(self, '_pub_debug_img'): + debug_img = self._draw_debug(bgr, detections, det_array.faces) + self._pub_debug_img.publish( + self._bridge.cv2_to_imgmsg(debug_img, encoding='bgr8')) + + # FPS logging + self._frame_count += 1 + if self._frame_count % 30 == 0: + elapsed = time.monotonic() - self._fps_t0 + fps = 30.0 / elapsed if elapsed > 0 else 0.0 + self._fps_t0 = time.monotonic() + self.get_logger().info( + 'FPS: %.1f | faces: %d', fps, len(detections)) + + # -- Enrollment logic ------------------------------------------------------ + + def _enrollment_collect(self, det: dict, embedding: np.ndarray): + """Collect an embedding sample during enrollment (largest face only).""" + if self._enrolling is None: + return + + # Only collect from the largest face (by bbox area) + bbox = det['bbox'] + area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + + if not hasattr(self, '_enroll_best_area'): + self._enroll_best_area = 0.0 + self._enroll_best_embedding = None + + if area > self._enroll_best_area: + self._enroll_best_area = area + self._enroll_best_embedding = embedding + + def _enrollment_frame_end(self): + """Called at end of each frame to finalize enrollment sample collection.""" + if self._enrolling is None or self._enroll_best_embedding is None: + return + + self._enrolling['collected'].append(self._enroll_best_embedding) + self._enroll_best_area = 0.0 + self._enroll_best_embedding = None + + collected = len(self._enrolling['collected']) + needed = self._enrolling['samples_needed'] + self.get_logger().info('Enrollment: %d/%d samples for "%s".', + collected, needed, self._enrolling['name']) + + if collected >= needed: + # Finalize enrollment + name = self._enrolling['name'] + embeddings = self._enrolling['collected'] + mean_emb = np.mean(embeddings, axis=0).astype(np.float32) + norm = np.linalg.norm(mean_emb) + if norm > 0: + mean_emb = mean_emb / norm + + pid = self._gallery.add_person(name, mean_emb, samples=len(embeddings)) + self._gallery.save() + self._publish_gallery_embeddings() + + self.get_logger().info('Enrollment complete: person %d (%s).', pid, name) + + # Store result for the service callback + self._enrolling['result_pid'] = pid + self._enrolling['done'] = True + self._enrolling = None + + # -- Service handlers ------------------------------------------------------ + + def _handle_enroll(self, request, response): + """Handle EnrollPerson service: start collecting face samples.""" + name = request.name.strip() + if not name: + response.success = False + response.message = 'Name cannot be empty.' + response.person_id = -1 + return response + + n_samples = request.n_samples if request.n_samples > 0 else 10 + + self.get_logger().info('Starting enrollment for "%s" (%d samples).', + name, n_samples) + + # Set enrollment state — frames will collect embeddings + self._enrolling = { + 'name': name, + 'samples_needed': n_samples, + 'collected': [], + 'done': False, + 'result_pid': -1, + } + self._enroll_best_area = 0.0 + self._enroll_best_embedding = None + + # Spin until enrollment is done (blocking service) + rate = self.create_rate(10) # 10 Hz check + timeout_sec = n_samples * 2.0 + 10.0 # generous timeout + t0 = time.monotonic() + + while not self._enrolling.get('done', False): + # Finalize any pending frame collection + self._enrollment_frame_end() + + if time.monotonic() - t0 > timeout_sec: + self._enrolling = None + response.success = False + response.message = f'Enrollment timed out after {timeout_sec:.0f}s.' + response.person_id = -1 + return response + + rclpy.spin_once(self, timeout_sec=0.1) + + response.success = True + response.message = f'Enrolled "{name}" with {n_samples} samples.' + response.person_id = self._enrolling.get('result_pid', -1) if self._enrolling else -1 + + # Clean up (already set to None in _enrollment_frame_end on success) + return response + + def _handle_list(self, request, response): + """Handle ListPersons service: return all gallery entries.""" + entries = self._gallery.get_all() + for entry in entries: + emb_msg = FaceEmbedding() + emb_msg.person_id = entry['person_id'] + emb_msg.person_name = entry['name'] + emb_msg.embedding = entry['embedding'].tolist() + emb_msg.sample_count = entry['samples'] + + secs = int(entry['enrolled_at']) + nsecs = int((entry['enrolled_at'] - secs) * 1e9) + emb_msg.enrolled_at = Time(sec=secs, nanosec=nsecs) + + response.persons.append(emb_msg) + return response + + def _handle_delete(self, request, response): + """Handle DeletePerson service: remove a person from the gallery.""" + if self._gallery.delete_person(request.person_id): + self._gallery.save() + self._publish_gallery_embeddings() + response.success = True + response.message = f'Deleted person {request.person_id}.' + else: + response.success = False + response.message = f'Person {request.person_id} not found.' + return response + + def _handle_update(self, request, response): + """Handle UpdatePerson service: rename a person.""" + new_name = request.new_name.strip() + if not new_name: + response.success = False + response.message = 'New name cannot be empty.' + return response + + if self._gallery.update_name(request.person_id, new_name): + self._gallery.save() + self._publish_gallery_embeddings() + response.success = True + response.message = f'Updated person {request.person_id} to "{new_name}".' + else: + response.success = False + response.message = f'Person {request.person_id} not found.' + return response + + # -- Gallery publishing ---------------------------------------------------- + + def _publish_gallery_embeddings(self): + """Publish current gallery as FaceEmbeddingArray (latched-like).""" + entries = self._gallery.get_all() + msg = FaceEmbeddingArray() + msg.header.stamp = self.get_clock().now().to_msg() + + for entry in entries: + emb_msg = FaceEmbedding() + emb_msg.person_id = entry['person_id'] + emb_msg.person_name = entry['name'] + emb_msg.embedding = entry['embedding'].tolist() + emb_msg.sample_count = entry['samples'] + + secs = int(entry['enrolled_at']) + nsecs = int((entry['enrolled_at'] - secs) * 1e9) + emb_msg.enrolled_at = Time(sec=secs, nanosec=nsecs) + + msg.embeddings.append(emb_msg) + + self._pub_embeddings.publish(msg) + + # -- Debug image ----------------------------------------------------------- + + def _draw_debug(self, bgr: np.ndarray, detections: list[dict], + face_msgs: list) -> np.ndarray: + """Draw bounding boxes, landmarks, and names on the image.""" + vis = bgr.copy() + for det, face_msg in zip(detections, face_msgs): + bbox = det['bbox'] + x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) + + # Color: green if recognized, yellow if unknown + if face_msg.face_id >= 0: + color = (0, 255, 0) + label = f'{face_msg.person_name} ({face_msg.recognition_score:.2f})' + else: + color = (0, 255, 255) + label = f'unknown ({face_msg.confidence:.2f})' + + cv2.rectangle(vis, (x1, y1), (x2, y2), color, 2) + cv2.putText(vis, label, (x1, y1 - 8), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) + + # Draw landmarks + kps = det['kps'] + for k in range(5): + px, py = int(kps[k * 2]), int(kps[k * 2 + 1]) + cv2.circle(vis, (px, py), 2, (0, 0, 255), -1) + + return vis + + +# -- Entry point --------------------------------------------------------------- + +def main(args=None): + """ROS2 entry point for face_recognition node.""" + rclpy.init(args=args) + node = FaceRecognitionNode() + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.shutdown() + + +if __name__ == '__main__': + main() diff --git a/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/scrfd_detector.py b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/scrfd_detector.py new file mode 100644 index 0000000..70625ae --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/scrfd_detector.py @@ -0,0 +1,350 @@ +""" +scrfd_detector.py -- SCRFD face detection with TensorRT FP16 + ONNX fallback. + +SCRFD (Sample and Computation Redistribution for Face Detection) produces +9 output tensors across 3 strides (8, 16, 32), each with score, bbox, and +keypoint branches. This module handles anchor generation, bbox/keypoint +decoding, and NMS to produce final face detections. +""" + +import os +import logging +from typing import Optional + +import numpy as np +import cv2 + +logger = logging.getLogger(__name__) + +_STRIDES = [8, 16, 32] +_NUM_ANCHORS = 2 # anchors per cell per stride + + +# -- Inference backends -------------------------------------------------------- + +class _TRTBackend: + """TensorRT inference engine for SCRFD.""" + + def __init__(self, engine_path: str): + import tensorrt as trt + import pycuda.driver as cuda + import pycuda.autoinit # noqa: F401 + + self._cuda = cuda + trt_logger = trt.Logger(trt.Logger.WARNING) + with open(engine_path, 'rb') as f, trt.Runtime(trt_logger) as runtime: + self._engine = runtime.deserialize_cuda_engine(f.read()) + self._context = self._engine.create_execution_context() + + self._inputs = [] + self._outputs = [] + self._output_names = [] + self._bindings = [] + for i in range(self._engine.num_io_tensors): + name = self._engine.get_tensor_name(i) + dtype = trt.nptype(self._engine.get_tensor_dtype(name)) + shape = tuple(self._engine.get_tensor_shape(name)) + nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize + host_mem = cuda.pagelocked_empty(shape, dtype) + device_mem = cuda.mem_alloc(nbytes) + self._bindings.append(int(device_mem)) + if self._engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: + self._inputs.append({'host': host_mem, 'device': device_mem}) + else: + self._outputs.append({'host': host_mem, 'device': device_mem, + 'shape': shape}) + self._output_names.append(name) + self._stream = cuda.Stream() + + def infer(self, input_data: np.ndarray) -> list[np.ndarray]: + """Run inference and return output tensors.""" + np.copyto(self._inputs[0]['host'], input_data.ravel()) + self._cuda.memcpy_htod_async( + self._inputs[0]['device'], self._inputs[0]['host'], self._stream) + self._context.execute_async_v2(self._bindings, self._stream.handle) + results = [] + for out in self._outputs: + self._cuda.memcpy_dtoh_async(out['host'], out['device'], self._stream) + self._stream.synchronize() + for out in self._outputs: + results.append(out['host'].reshape(out['shape']).copy()) + return results + + +class _ONNXBackend: + """ONNX Runtime inference (CUDA EP with CPU fallback).""" + + def __init__(self, onnx_path: str): + import onnxruntime as ort + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + self._session = ort.InferenceSession(onnx_path, providers=providers) + self._input_name = self._session.get_inputs()[0].name + self._output_names = [o.name for o in self._session.get_outputs()] + + def infer(self, input_data: np.ndarray) -> list[np.ndarray]: + """Run inference and return output tensors.""" + return self._session.run(None, {self._input_name: input_data}) + + +# -- NMS ---------------------------------------------------------------------- + +def _nms(boxes: np.ndarray, scores: np.ndarray, iou_thresh: float) -> list[int]: + """Non-maximum suppression. boxes: [N, 4] as x1,y1,x2,y2.""" + if len(boxes) == 0: + return [] + + x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3] + areas = (x2 - x1) * (y2 - y1) + order = scores.argsort()[::-1] + keep = [] + + while order.size > 0: + i = order[0] + keep.append(int(i)) + if order.size == 1: + break + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + inter = np.maximum(0.0, xx2 - xx1) * np.maximum(0.0, yy2 - yy1) + iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-6) + remaining = np.where(iou <= iou_thresh)[0] + order = order[remaining + 1] + + return keep + + +# -- Anchor generation --------------------------------------------------------- + +def _generate_anchors(input_h: int, input_w: int) -> dict[int, np.ndarray]: + """Generate anchor centers for each stride. + + Returns dict mapping stride -> array of shape [H*W*num_anchors, 2], + where each row is (cx, cy) in input pixel coordinates. + """ + anchors = {} + for stride in _STRIDES: + feat_h = input_h // stride + feat_w = input_w // stride + grid_y, grid_x = np.mgrid[:feat_h, :feat_w] + centers = np.stack([grid_x.ravel(), grid_y.ravel()], axis=1).astype(np.float32) + centers = (centers + 0.5) * stride + # Repeat for num_anchors per cell + centers = np.tile(centers, (_NUM_ANCHORS, 1)) # [H*W*2, 2] + # Interleave properly: [anchor0_cell0, anchor1_cell0, anchor0_cell1, ...] + centers = np.repeat( + np.stack([grid_x.ravel(), grid_y.ravel()], axis=1).astype(np.float32), + _NUM_ANCHORS, axis=0 + ) + centers = (centers + 0.5) * stride + anchors[stride] = centers + return anchors + + +# -- Main detector class ------------------------------------------------------- + +class SCRFDDetector: + """SCRFD face detector with TensorRT FP16 and ONNX fallback. + + Args: + engine_path: Path to TensorRT engine file. + onnx_path: Path to ONNX model file (used if engine not available). + conf_thresh: Minimum confidence for detections. + nms_iou: IoU threshold for NMS. + input_size: Model input resolution (square). + """ + + def __init__( + self, + engine_path: str = '', + onnx_path: str = '', + conf_thresh: float = 0.5, + nms_iou: float = 0.4, + input_size: int = 640, + ): + self._conf_thresh = conf_thresh + self._nms_iou = nms_iou + self._input_size = input_size + self._backend: Optional[_TRTBackend | _ONNXBackend] = None + self._anchors = _generate_anchors(input_size, input_size) + + # Try TRT first, then ONNX + if engine_path and os.path.isfile(engine_path): + try: + self._backend = _TRTBackend(engine_path) + logger.info('SCRFD TensorRT backend loaded: %s', engine_path) + return + except Exception as e: + logger.warning('SCRFD TRT load failed (%s), trying ONNX', e) + + if onnx_path and os.path.isfile(onnx_path): + try: + self._backend = _ONNXBackend(onnx_path) + logger.info('SCRFD ONNX backend loaded: %s', onnx_path) + return + except Exception as e: + logger.error('SCRFD ONNX load failed: %s', e) + + logger.error('No SCRFD model loaded. Detection will be unavailable.') + + @property + def is_loaded(self) -> bool: + """Return True if a backend is loaded and ready.""" + return self._backend is not None + + def detect(self, bgr: np.ndarray) -> list[dict]: + """Detect faces in a BGR image. + + Args: + bgr: Input image in BGR format, shape (H, W, 3). + + Returns: + List of dicts with keys: + bbox: [x1, y1, x2, y2] in original image coordinates + kps: [x0,y0, x1,y1, ..., x4,y4] — 10 floats, 5 landmarks + score: detection confidence + """ + if self._backend is None: + return [] + + orig_h, orig_w = bgr.shape[:2] + tensor, scale, pad_w, pad_h = self._preprocess(bgr) + outputs = self._backend.infer(tensor) + detections = self._decode_outputs(outputs) + detections = self._rescale(detections, scale, pad_w, pad_h, orig_w, orig_h) + return detections + + def _preprocess(self, bgr: np.ndarray) -> tuple[np.ndarray, float, int, int]: + """Resize to input_size x input_size with letterbox padding, normalize.""" + h, w = bgr.shape[:2] + size = self._input_size + scale = min(size / h, size / w) + new_w, new_h = int(w * scale), int(h * scale) + pad_w = (size - new_w) // 2 + pad_h = (size - new_h) // 2 + + resized = cv2.resize(bgr, (new_w, new_h), interpolation=cv2.INTER_LINEAR) + canvas = np.full((size, size, 3), 0, dtype=np.uint8) + canvas[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = resized + + # Normalize: subtract 127.5, divide 128.0 + blob = canvas.astype(np.float32) + blob = (blob - 127.5) / 128.0 + # HWC -> NCHW + blob = blob.transpose(2, 0, 1)[np.newaxis] + blob = np.ascontiguousarray(blob) + return blob, scale, pad_w, pad_h + + def _decode_outputs(self, outputs: list[np.ndarray]) -> list[dict]: + """Decode SCRFD 9-output format into face detections. + + SCRFD outputs 9 tensors, 3 per stride (score, bbox, kps): + score_8, bbox_8, kps_8, score_16, bbox_16, kps_16, score_32, bbox_32, kps_32 + """ + all_scores = [] + all_bboxes = [] + all_kps = [] + + for idx, stride in enumerate(_STRIDES): + score_out = outputs[idx * 3].squeeze() # [H*W*num_anchors] + bbox_out = outputs[idx * 3 + 1].squeeze() # [H*W*num_anchors, 4] + kps_out = outputs[idx * 3 + 2].squeeze() # [H*W*num_anchors, 10] + + if score_out.ndim == 0: + continue + + # Ensure proper shapes + if score_out.ndim == 1: + n = score_out.shape[0] + else: + n = score_out.shape[0] + score_out = score_out.ravel() + + if bbox_out.ndim == 1: + bbox_out = bbox_out.reshape(-1, 4) + if kps_out.ndim == 1: + kps_out = kps_out.reshape(-1, 10) + + # Filter by confidence + mask = score_out > self._conf_thresh + if not mask.any(): + continue + + scores = score_out[mask] + bboxes = bbox_out[mask] + kps = kps_out[mask] + + anchors = self._anchors[stride] + # Trim or pad anchors to match output count + if anchors.shape[0] > n: + anchors = anchors[:n] + elif anchors.shape[0] < n: + continue + + anchors = anchors[mask] + + # Decode bboxes: center = anchor + pred[:2]*stride, size = exp(pred[2:])*stride + cx = anchors[:, 0] + bboxes[:, 0] * stride + cy = anchors[:, 1] + bboxes[:, 1] * stride + w = np.exp(bboxes[:, 2]) * stride + h = np.exp(bboxes[:, 3]) * stride + x1 = cx - w / 2.0 + y1 = cy - h / 2.0 + x2 = cx + w / 2.0 + y2 = cy + h / 2.0 + decoded_bboxes = np.stack([x1, y1, x2, y2], axis=1) + + # Decode keypoints: kp = anchor + pred * stride + decoded_kps = np.zeros_like(kps) + for k in range(5): + decoded_kps[:, k * 2] = anchors[:, 0] + kps[:, k * 2] * stride + decoded_kps[:, k * 2 + 1] = anchors[:, 1] + kps[:, k * 2 + 1] * stride + + all_scores.append(scores) + all_bboxes.append(decoded_bboxes) + all_kps.append(decoded_kps) + + if not all_scores: + return [] + + scores = np.concatenate(all_scores) + bboxes = np.concatenate(all_bboxes) + kps = np.concatenate(all_kps) + + # NMS + keep = _nms(bboxes, scores, self._nms_iou) + + results = [] + for i in keep: + results.append({ + 'bbox': bboxes[i].tolist(), + 'kps': kps[i].tolist(), + 'score': float(scores[i]), + }) + return results + + def _rescale( + self, + detections: list[dict], + scale: float, + pad_w: int, + pad_h: int, + orig_w: int, + orig_h: int, + ) -> list[dict]: + """Rescale detections from model input space to original image space.""" + for det in detections: + bbox = det['bbox'] + bbox[0] = max(0.0, (bbox[0] - pad_w) / scale) + bbox[1] = max(0.0, (bbox[1] - pad_h) / scale) + bbox[2] = min(float(orig_w), (bbox[2] - pad_w) / scale) + bbox[3] = min(float(orig_h), (bbox[3] - pad_h) / scale) + det['bbox'] = bbox + + kps = det['kps'] + for k in range(5): + kps[k * 2] = (kps[k * 2] - pad_w) / scale + kps[k * 2 + 1] = (kps[k * 2 + 1] - pad_h) / scale + det['kps'] = kps + return detections diff --git a/jetson/ros2_ws/src/saltybot_social_face/scripts/build_face_trt_engines.py b/jetson/ros2_ws/src/saltybot_social_face/scripts/build_face_trt_engines.py new file mode 100644 index 0000000..c2a1a83 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_face/scripts/build_face_trt_engines.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +""" +build_face_trt_engines.py -- Build TensorRT FP16 engines for SCRFD and ArcFace. + +Converts ONNX model files to optimized TensorRT engines with FP16 precision +for fast inference on Jetson Orin Nano Super. + +Usage: + python3 build_face_trt_engines.py \ + --scrfd-onnx /path/to/scrfd_2.5g_bnkps.onnx \ + --arcface-onnx /path/to/arcface_r50.onnx \ + --output-dir /mnt/nvme/saltybot/models \ + --fp16 --workspace-mb 1024 +""" + +import argparse +import os +import time + + +def build_engine(onnx_path: str, engine_path: str, fp16: bool, workspace_mb: int): + """Build a TensorRT engine from an ONNX model. + + Args: + onnx_path: Path to the source ONNX model file. + engine_path: Output path for the serialized TensorRT engine. + fp16: Enable FP16 precision. + workspace_mb: Maximum workspace size in megabytes. + """ + import tensorrt as trt + + logger = trt.Logger(trt.Logger.INFO) + builder = trt.Builder(logger) + network = builder.create_network( + 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + parser = trt.OnnxParser(network, logger) + + print(f'Parsing ONNX model: {onnx_path}') + t0 = time.monotonic() + + with open(onnx_path, 'rb') as f: + if not parser.parse(f.read()): + for i in range(parser.num_errors): + print(f' ONNX parse error: {parser.get_error(i)}') + raise RuntimeError(f'Failed to parse {onnx_path}') + + parse_time = time.monotonic() - t0 + print(f' Parsed in {parse_time:.1f}s') + + config = builder.create_builder_config() + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, + workspace_mb * (1 << 20)) + if fp16: + if builder.platform_has_fast_fp16: + config.set_flag(trt.BuilderFlag.FP16) + print(' FP16 enabled.') + else: + print(' Warning: FP16 not supported on this platform, using FP32.') + + print(f'Building engine (this may take several minutes)...') + t0 = time.monotonic() + serialized = builder.build_serialized_network(network, config) + build_time = time.monotonic() - t0 + + if serialized is None: + raise RuntimeError('Engine build failed.') + + os.makedirs(os.path.dirname(engine_path) or '.', exist_ok=True) + with open(engine_path, 'wb') as f: + f.write(serialized) + + size_mb = os.path.getsize(engine_path) / (1 << 20) + print(f' Engine saved: {engine_path} ({size_mb:.1f} MB, built in {build_time:.1f}s)') + + +def main(): + """Main entry point for TRT engine building.""" + parser = argparse.ArgumentParser( + description='Build TensorRT FP16 engines for SCRFD and ArcFace.') + parser.add_argument('--scrfd-onnx', type=str, default='', + help='Path to SCRFD ONNX model.') + parser.add_argument('--arcface-onnx', type=str, default='', + help='Path to ArcFace ONNX model.') + parser.add_argument('--output-dir', type=str, + default='/mnt/nvme/saltybot/models', + help='Output directory for engine files.') + parser.add_argument('--fp16', action='store_true', default=True, + help='Enable FP16 precision (default: True).') + parser.add_argument('--no-fp16', action='store_false', dest='fp16', + help='Disable FP16 (use FP32 only).') + parser.add_argument('--workspace-mb', type=int, default=1024, + help='TRT workspace size in MB (default: 1024).') + args = parser.parse_args() + + if not args.scrfd_onnx and not args.arcface_onnx: + parser.error('At least one of --scrfd-onnx or --arcface-onnx is required.') + + if args.scrfd_onnx: + engine_path = os.path.join(args.output_dir, 'scrfd_2.5g.engine') + print(f'\n=== Building SCRFD engine ===') + build_engine(args.scrfd_onnx, engine_path, args.fp16, args.workspace_mb) + + if args.arcface_onnx: + engine_path = os.path.join(args.output_dir, 'arcface_r50.engine') + print(f'\n=== Building ArcFace engine ===') + build_engine(args.arcface_onnx, engine_path, args.fp16, args.workspace_mb) + + print('\nDone.') + + +if __name__ == '__main__': + main() diff --git a/jetson/ros2_ws/src/saltybot_social_face/setup.cfg b/jetson/ros2_ws/src/saltybot_social_face/setup.cfg new file mode 100644 index 0000000..fb7d235 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_face/setup.cfg @@ -0,0 +1,4 @@ +[develop] +script_dir=$base/lib/saltybot_social_face +[install] +install_scripts=$base/lib/saltybot_social_face diff --git a/jetson/ros2_ws/src/saltybot_social_face/setup.py b/jetson/ros2_ws/src/saltybot_social_face/setup.py new file mode 100644 index 0000000..fa85fe9 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_face/setup.py @@ -0,0 +1,30 @@ +"""Setup for saltybot_social_face package.""" + +from setuptools import find_packages, setup + +package_name = 'saltybot_social_face' + +setup( + name=package_name, + version='0.1.0', + packages=find_packages(exclude=['test']), + data_files=[ + ('share/ament_index/resource_index/packages', ['resource/' + package_name]), + ('share/' + package_name, ['package.xml']), + ('share/' + package_name + '/launch', ['launch/face_recognition.launch.py']), + ('share/' + package_name + '/config', ['config/face_recognition_params.yaml']), + ], + install_requires=['setuptools'], + zip_safe=True, + maintainer='seb', + maintainer_email='seb@vayrette.com', + description='SCRFD face detection and ArcFace recognition for SaltyBot social interactions', + license='MIT', + tests_require=['pytest'], + entry_points={ + 'console_scripts': [ + 'face_recognition = saltybot_social_face.face_recognition_node:main', + 'enrollment_cli = saltybot_social_face.enrollment_cli:main', + ], + }, +) -- 2.47.2