diff --git a/jetson/ros2_ws/src/saltybot_social_face/config/face_recognition_params.yaml b/jetson/ros2_ws/src/saltybot_social_face/config/face_recognition_params.yaml
new file mode 100644
index 0000000..f5a6517
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_face/config/face_recognition_params.yaml
@@ -0,0 +1,11 @@
+face_recognizer:
+ ros__parameters:
+ scrfd_engine_path: '/mnt/nvme/saltybot/models/scrfd_2.5g.engine'
+ scrfd_onnx_path: '/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx'
+ arcface_engine_path: '/mnt/nvme/saltybot/models/arcface_r50.engine'
+ arcface_onnx_path: '/mnt/nvme/saltybot/models/arcface_r50.onnx'
+ gallery_dir: '/mnt/nvme/saltybot/gallery'
+ recognition_threshold: 0.35
+ publish_debug_image: false
+ max_faces: 10
+ scrfd_conf_thresh: 0.5
diff --git a/jetson/ros2_ws/src/saltybot_social_face/launch/face_recognition.launch.py b/jetson/ros2_ws/src/saltybot_social_face/launch/face_recognition.launch.py
new file mode 100644
index 0000000..f177b7e
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_face/launch/face_recognition.launch.py
@@ -0,0 +1,80 @@
+"""
+face_recognition.launch.py -- Launch file for the SCRFD + ArcFace face recognition node.
+
+Launches the face_recognizer node with configurable model paths and parameters.
+The RealSense camera must be running separately (e.g., via realsense.launch.py).
+"""
+
+from launch import LaunchDescription
+from launch.actions import DeclareLaunchArgument
+from launch.substitutions import LaunchConfiguration
+from launch_ros.actions import Node
+
+
+def generate_launch_description():
+ """Generate launch description for face recognition pipeline."""
+ return LaunchDescription([
+ DeclareLaunchArgument(
+ 'scrfd_engine_path',
+ default_value='/mnt/nvme/saltybot/models/scrfd_2.5g.engine',
+ description='Path to SCRFD TensorRT engine file',
+ ),
+ DeclareLaunchArgument(
+ 'scrfd_onnx_path',
+ default_value='/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx',
+ description='Path to SCRFD ONNX model file (fallback)',
+ ),
+ DeclareLaunchArgument(
+ 'arcface_engine_path',
+ default_value='/mnt/nvme/saltybot/models/arcface_r50.engine',
+ description='Path to ArcFace TensorRT engine file',
+ ),
+ DeclareLaunchArgument(
+ 'arcface_onnx_path',
+ default_value='/mnt/nvme/saltybot/models/arcface_r50.onnx',
+ description='Path to ArcFace ONNX model file (fallback)',
+ ),
+ DeclareLaunchArgument(
+ 'gallery_dir',
+ default_value='/mnt/nvme/saltybot/gallery',
+ description='Directory for persistent face gallery storage',
+ ),
+ DeclareLaunchArgument(
+ 'recognition_threshold',
+ default_value='0.35',
+ description='Cosine similarity threshold for face recognition',
+ ),
+ DeclareLaunchArgument(
+ 'publish_debug_image',
+ default_value='false',
+ description='Publish annotated debug image to /social/faces/debug_image',
+ ),
+ DeclareLaunchArgument(
+ 'max_faces',
+ default_value='10',
+ description='Maximum faces to process per frame',
+ ),
+ DeclareLaunchArgument(
+ 'scrfd_conf_thresh',
+ default_value='0.5',
+ description='SCRFD detection confidence threshold',
+ ),
+
+ Node(
+ package='saltybot_social_face',
+ executable='face_recognition',
+ name='face_recognizer',
+ output='screen',
+ parameters=[{
+ 'scrfd_engine_path': LaunchConfiguration('scrfd_engine_path'),
+ 'scrfd_onnx_path': LaunchConfiguration('scrfd_onnx_path'),
+ 'arcface_engine_path': LaunchConfiguration('arcface_engine_path'),
+ 'arcface_onnx_path': LaunchConfiguration('arcface_onnx_path'),
+ 'gallery_dir': LaunchConfiguration('gallery_dir'),
+ 'recognition_threshold': LaunchConfiguration('recognition_threshold'),
+ 'publish_debug_image': LaunchConfiguration('publish_debug_image'),
+ 'max_faces': LaunchConfiguration('max_faces'),
+ 'scrfd_conf_thresh': LaunchConfiguration('scrfd_conf_thresh'),
+ }],
+ ),
+ ])
diff --git a/jetson/ros2_ws/src/saltybot_social_face/package.xml b/jetson/ros2_ws/src/saltybot_social_face/package.xml
new file mode 100644
index 0000000..7a21282
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_face/package.xml
@@ -0,0 +1,27 @@
+
+
+
+ saltybot_social_face
+ 0.1.0
+ SCRFD face detection and ArcFace recognition for SaltyBot social interactions
+ seb
+ MIT
+
+ rclpy
+ sensor_msgs
+ cv_bridge
+ image_transport
+ saltybot_social_msgs
+
+ python3-numpy
+ python3-opencv
+
+ ament_copyright
+ ament_flake8
+ ament_pep257
+ python3-pytest
+
+
+ ament_python
+
+
diff --git a/jetson/ros2_ws/src/saltybot_social_face/resource/saltybot_social_face b/jetson/ros2_ws/src/saltybot_social_face/resource/saltybot_social_face
new file mode 100644
index 0000000..e69de29
diff --git a/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/__init__.py b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/__init__.py
new file mode 100644
index 0000000..e7872ab
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/__init__.py
@@ -0,0 +1 @@
+"""SaltyBot social face detection and recognition package."""
diff --git a/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/arcface_recognizer.py b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/arcface_recognizer.py
new file mode 100644
index 0000000..9d04dc4
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/arcface_recognizer.py
@@ -0,0 +1,316 @@
+"""
+arcface_recognizer.py -- ArcFace face embedding extraction and gallery matching.
+
+Performs face alignment using 5-point landmarks (insightface standard reference),
+extracts 512-dimensional embeddings via ArcFace (TRT FP16 or ONNX fallback),
+and matches against a persistent gallery using cosine similarity.
+"""
+
+import os
+import logging
+from typing import Optional
+
+import numpy as np
+import cv2
+
+logger = logging.getLogger(__name__)
+
+# InsightFace standard reference landmarks for 112x112 alignment
+ARCFACE_SRC = np.array([
+ [38.2946, 51.6963], # left eye
+ [73.5318, 51.5014], # right eye
+ [56.0252, 71.7366], # nose
+ [41.5493, 92.3655], # left mouth
+ [70.7299, 92.2041], # right mouth
+], dtype=np.float32)
+
+
+# -- Inference backends --------------------------------------------------------
+
+class _TRTBackend:
+ """TensorRT inference engine for ArcFace."""
+
+ def __init__(self, engine_path: str):
+ import tensorrt as trt
+ import pycuda.driver as cuda
+ import pycuda.autoinit # noqa: F401
+
+ self._cuda = cuda
+ trt_logger = trt.Logger(trt.Logger.WARNING)
+ with open(engine_path, 'rb') as f, trt.Runtime(trt_logger) as runtime:
+ self._engine = runtime.deserialize_cuda_engine(f.read())
+ self._context = self._engine.create_execution_context()
+
+ self._inputs = []
+ self._outputs = []
+ self._bindings = []
+ for i in range(self._engine.num_io_tensors):
+ name = self._engine.get_tensor_name(i)
+ dtype = trt.nptype(self._engine.get_tensor_dtype(name))
+ shape = tuple(self._engine.get_tensor_shape(name))
+ nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize
+ host_mem = cuda.pagelocked_empty(shape, dtype)
+ device_mem = cuda.mem_alloc(nbytes)
+ self._bindings.append(int(device_mem))
+ if self._engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
+ self._inputs.append({'host': host_mem, 'device': device_mem})
+ else:
+ self._outputs.append({'host': host_mem, 'device': device_mem,
+ 'shape': shape})
+ self._stream = cuda.Stream()
+
+ def infer(self, input_data: np.ndarray) -> np.ndarray:
+ """Run inference and return the embedding vector."""
+ np.copyto(self._inputs[0]['host'], input_data.ravel())
+ self._cuda.memcpy_htod_async(
+ self._inputs[0]['device'], self._inputs[0]['host'], self._stream)
+ self._context.execute_async_v2(self._bindings, self._stream.handle)
+ for out in self._outputs:
+ self._cuda.memcpy_dtoh_async(out['host'], out['device'], self._stream)
+ self._stream.synchronize()
+ return self._outputs[0]['host'].reshape(self._outputs[0]['shape']).copy()
+
+
+class _ONNXBackend:
+ """ONNX Runtime inference (CUDA EP with CPU fallback)."""
+
+ def __init__(self, onnx_path: str):
+ import onnxruntime as ort
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
+ self._session = ort.InferenceSession(onnx_path, providers=providers)
+ self._input_name = self._session.get_inputs()[0].name
+
+ def infer(self, input_data: np.ndarray) -> np.ndarray:
+ """Run inference and return the embedding vector."""
+ results = self._session.run(None, {self._input_name: input_data})
+ return results[0]
+
+
+# -- Face alignment ------------------------------------------------------------
+
+def align_face(bgr: np.ndarray, landmarks_10: list[float]) -> np.ndarray:
+ """Align a face to 112x112 using 5-point landmarks.
+
+ Args:
+ bgr: Source BGR image.
+ landmarks_10: Flat list of 10 floats [x0,y0, x1,y1, ..., x4,y4].
+
+ Returns:
+ Aligned BGR face crop of shape (112, 112, 3).
+ """
+ src_pts = np.array(landmarks_10, dtype=np.float32).reshape(5, 2)
+ M, _ = cv2.estimateAffinePartial2D(src_pts, ARCFACE_SRC)
+ if M is None:
+ # Fallback: simple crop and resize from bbox-like region
+ cx = np.mean(src_pts[:, 0])
+ cy = np.mean(src_pts[:, 1])
+ spread = max(np.ptp(src_pts[:, 0]), np.ptp(src_pts[:, 1])) * 1.5
+ half = spread / 2
+ x1 = max(0, int(cx - half))
+ y1 = max(0, int(cy - half))
+ x2 = min(bgr.shape[1], int(cx + half))
+ y2 = min(bgr.shape[0], int(cy + half))
+ crop = bgr[y1:y2, x1:x2]
+ return cv2.resize(crop, (112, 112), interpolation=cv2.INTER_LINEAR)
+
+ aligned = cv2.warpAffine(bgr, M, (112, 112), borderMode=cv2.BORDER_REPLICATE)
+ return aligned
+
+
+# -- Main recognizer class -----------------------------------------------------
+
+class ArcFaceRecognizer:
+ """ArcFace face embedding extractor and gallery matcher.
+
+ Args:
+ engine_path: Path to TensorRT engine file.
+ onnx_path: Path to ONNX model file (used if engine not available).
+ """
+
+ def __init__(self, engine_path: str = '', onnx_path: str = ''):
+ self._backend: Optional[_TRTBackend | _ONNXBackend] = None
+ self.gallery: dict[int, dict] = {}
+
+ # Try TRT first, then ONNX
+ if engine_path and os.path.isfile(engine_path):
+ try:
+ self._backend = _TRTBackend(engine_path)
+ logger.info('ArcFace TensorRT backend loaded: %s', engine_path)
+ return
+ except Exception as e:
+ logger.warning('ArcFace TRT load failed (%s), trying ONNX', e)
+
+ if onnx_path and os.path.isfile(onnx_path):
+ try:
+ self._backend = _ONNXBackend(onnx_path)
+ logger.info('ArcFace ONNX backend loaded: %s', onnx_path)
+ return
+ except Exception as e:
+ logger.error('ArcFace ONNX load failed: %s', e)
+
+ logger.error('No ArcFace model loaded. Recognition will be unavailable.')
+
+ @property
+ def is_loaded(self) -> bool:
+ """Return True if a backend is loaded and ready."""
+ return self._backend is not None
+
+ def embed(self, bgr_face_112x112: np.ndarray) -> np.ndarray:
+ """Extract 512-dim L2-normalized embedding from a 112x112 aligned face.
+
+ Args:
+ bgr_face_112x112: Aligned face crop, BGR, shape (112, 112, 3).
+
+ Returns:
+ L2-normalized embedding of shape (512,).
+ """
+ if self._backend is None:
+ return np.zeros(512, dtype=np.float32)
+
+ # Preprocess: BGR->RGB, /255, subtract 0.5, divide 0.5 -> [1,3,112,112]
+ rgb = cv2.cvtColor(bgr_face_112x112, cv2.COLOR_BGR2RGB).astype(np.float32)
+ rgb = rgb / 255.0
+ rgb = (rgb - 0.5) / 0.5
+ blob = rgb.transpose(2, 0, 1)[np.newaxis] # [1, 3, 112, 112]
+ blob = np.ascontiguousarray(blob)
+
+ output = self._backend.infer(blob)
+ embedding = output.flatten()[:512].astype(np.float32)
+
+ # L2 normalize
+ norm = np.linalg.norm(embedding)
+ if norm > 0:
+ embedding = embedding / norm
+ return embedding
+
+ def align_and_embed(self, bgr_image: np.ndarray, landmarks_10: list[float]) -> np.ndarray:
+ """Align face using landmarks and extract embedding.
+
+ Args:
+ bgr_image: Full BGR image.
+ landmarks_10: Flat list of 10 floats from SCRFD detection.
+
+ Returns:
+ L2-normalized embedding of shape (512,).
+ """
+ aligned = align_face(bgr_image, landmarks_10)
+ return self.embed(aligned)
+
+ def load_gallery(self, gallery_path: str) -> None:
+ """Load gallery from .npz file with JSON metadata sidecar.
+
+ Args:
+ gallery_path: Path to the .npz gallery file.
+ """
+ import json
+
+ if not os.path.isfile(gallery_path):
+ logger.info('No gallery file at %s, starting empty.', gallery_path)
+ self.gallery = {}
+ return
+
+ data = np.load(gallery_path, allow_pickle=False)
+ meta_path = gallery_path.replace('.npz', '_meta.json')
+
+ if os.path.isfile(meta_path):
+ with open(meta_path, 'r') as f:
+ meta = json.load(f)
+ else:
+ meta = {}
+
+ self.gallery = {}
+ for key in data.files:
+ pid = int(key)
+ embedding = data[key].astype(np.float32)
+ norm = np.linalg.norm(embedding)
+ if norm > 0:
+ embedding = embedding / norm
+ info = meta.get(str(pid), {})
+ self.gallery[pid] = {
+ 'name': info.get('name', f'person_{pid}'),
+ 'embedding': embedding,
+ 'samples': info.get('samples', 1),
+ 'enrolled_at': info.get('enrolled_at', 0.0),
+ }
+ logger.info('Gallery loaded: %d persons from %s', len(self.gallery), gallery_path)
+
+ def save_gallery(self, gallery_path: str) -> None:
+ """Save gallery to .npz file with JSON metadata sidecar.
+
+ Args:
+ gallery_path: Path to the .npz gallery file.
+ """
+ import json
+
+ arrays = {}
+ meta = {}
+ for pid, info in self.gallery.items():
+ arrays[str(pid)] = info['embedding']
+ meta[str(pid)] = {
+ 'name': info['name'],
+ 'samples': info['samples'],
+ 'enrolled_at': info['enrolled_at'],
+ }
+
+ os.makedirs(os.path.dirname(gallery_path) or '.', exist_ok=True)
+ np.savez(gallery_path, **arrays)
+
+ meta_path = gallery_path.replace('.npz', '_meta.json')
+ with open(meta_path, 'w') as f:
+ json.dump(meta, f, indent=2)
+ logger.info('Gallery saved: %d persons to %s', len(self.gallery), gallery_path)
+
+ def match(self, embedding: np.ndarray, threshold: float = 0.35) -> tuple[int, str, float]:
+ """Match an embedding against the gallery.
+
+ Args:
+ embedding: L2-normalized query embedding of shape (512,).
+ threshold: Minimum cosine similarity for a match.
+
+ Returns:
+ (person_id, person_name, score) or (-1, '', 0.0) if no match.
+ """
+ if not self.gallery:
+ return (-1, '', 0.0)
+
+ best_pid = -1
+ best_name = ''
+ best_score = 0.0
+
+ for pid, info in self.gallery.items():
+ score = float(np.dot(embedding, info['embedding']))
+ if score > best_score:
+ best_score = score
+ best_pid = pid
+ best_name = info['name']
+
+ if best_score >= threshold:
+ return (best_pid, best_name, best_score)
+ return (-1, '', 0.0)
+
+ def enroll(self, person_id: int, person_name: str, embeddings_list: list[np.ndarray]) -> None:
+ """Enroll a person by averaging multiple embeddings.
+
+ Args:
+ person_id: Unique integer ID for this person.
+ person_name: Human-readable name.
+ embeddings_list: List of L2-normalized embeddings to average.
+ """
+ import time as _time
+
+ if not embeddings_list:
+ return
+
+ mean_emb = np.mean(embeddings_list, axis=0).astype(np.float32)
+ norm = np.linalg.norm(mean_emb)
+ if norm > 0:
+ mean_emb = mean_emb / norm
+
+ self.gallery[person_id] = {
+ 'name': person_name,
+ 'embedding': mean_emb,
+ 'samples': len(embeddings_list),
+ 'enrolled_at': _time.time(),
+ }
+ logger.info('Enrolled person %d (%s) with %d samples.',
+ person_id, person_name, len(embeddings_list))
diff --git a/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/enrollment_cli.py b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/enrollment_cli.py
new file mode 100644
index 0000000..037bf8e
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/enrollment_cli.py
@@ -0,0 +1,78 @@
+#!/usr/bin/env python3
+"""
+enrollment_cli.py -- CLI tool for enrolling persons via the /social/enroll service.
+
+Usage:
+ ros2 run saltybot_social_face enrollment_cli -- --name Alice --mode face --samples 10
+"""
+
+import argparse
+import sys
+
+import rclpy
+from rclpy.node import Node
+
+from saltybot_social_msgs.srv import EnrollPerson
+
+
+class EnrollmentCLI(Node):
+ """Simple CLI node that calls the EnrollPerson service."""
+
+ def __init__(self, name: str, mode: str, n_samples: int):
+ super().__init__('enrollment_cli')
+ self._client = self.create_client(EnrollPerson, '/social/enroll')
+
+ self.get_logger().info('Waiting for /social/enroll service...')
+ if not self._client.wait_for_service(timeout_sec=10.0):
+ self.get_logger().error('Service /social/enroll not available.')
+ return
+
+ request = EnrollPerson.Request()
+ request.name = name
+ request.mode = mode
+ request.n_samples = n_samples
+
+ self.get_logger().info(
+ 'Enrolling "%s" (mode=%s, samples=%d)...', name, mode, n_samples)
+
+ future = self._client.call_async(request)
+ rclpy.spin_until_future_complete(self, future, timeout_sec=120.0)
+
+ if future.result() is not None:
+ result = future.result()
+ if result.success:
+ self.get_logger().info(
+ 'Enrollment successful: person_id=%d, %s',
+ result.person_id, result.message)
+ else:
+ self.get_logger().error(
+ 'Enrollment failed: %s', result.message)
+ else:
+ self.get_logger().error('Enrollment service call timed out or failed.')
+
+
+def main(args=None):
+ """Entry point for enrollment CLI."""
+ parser = argparse.ArgumentParser(description='Enroll a person for face recognition.')
+ parser.add_argument('--name', type=str, required=True,
+ help='Name of the person to enroll.')
+ parser.add_argument('--mode', type=str, default='face',
+ choices=['face', 'voice', 'both'],
+ help='Enrollment mode (default: face).')
+ parser.add_argument('--samples', type=int, default=10,
+ help='Number of face samples to collect (default: 10).')
+
+ # Parse only known args so ROS2 remapping args pass through
+ parsed, remaining = parser.parse_known_args(args=sys.argv[1:])
+
+ rclpy.init(args=remaining)
+ node = EnrollmentCLI(parsed.name, parsed.mode, parsed.samples)
+ try:
+ pass # Node does all work in __init__
+ finally:
+ node.destroy_node()
+ rclpy.shutdown()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/face_gallery.py b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/face_gallery.py
new file mode 100644
index 0000000..50de106
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/face_gallery.py
@@ -0,0 +1,206 @@
+"""
+face_gallery.py -- Persistent face embedding gallery backed by numpy .npz + JSON.
+
+Thread-safe gallery storage for face recognition. Embeddings are stored in a
+.npz file, with a sidecar metadata.json containing names, sample counts, and
+enrollment timestamps. Auto-increment IDs start at 1.
+"""
+
+import json
+import logging
+import os
+import threading
+import time
+from typing import Optional
+
+import numpy as np
+
+logger = logging.getLogger(__name__)
+
+
+class FaceGallery:
+ """Persistent, thread-safe face embedding gallery.
+
+ Args:
+ gallery_dir: Directory for gallery.npz and metadata.json files.
+ """
+
+ def __init__(self, gallery_dir: str):
+ self._gallery_dir = gallery_dir
+ self._npz_path = os.path.join(gallery_dir, 'gallery.npz')
+ self._meta_path = os.path.join(gallery_dir, 'metadata.json')
+ self._gallery: dict[int, dict] = {}
+ self._next_id = 1
+ self._lock = threading.Lock()
+
+ def load(self) -> None:
+ """Load gallery from disk. Populates internal gallery dict."""
+ with self._lock:
+ self._gallery = {}
+ self._next_id = 1
+
+ if not os.path.isfile(self._npz_path):
+ logger.info('No gallery file at %s, starting empty.', self._npz_path)
+ return
+
+ data = np.load(self._npz_path, allow_pickle=False)
+ meta: dict = {}
+ if os.path.isfile(self._meta_path):
+ with open(self._meta_path, 'r') as f:
+ meta = json.load(f)
+
+ for key in data.files:
+ pid = int(key)
+ embedding = data[key].astype(np.float32)
+ norm = np.linalg.norm(embedding)
+ if norm > 0:
+ embedding = embedding / norm
+ info = meta.get(str(pid), {})
+ self._gallery[pid] = {
+ 'name': info.get('name', f'person_{pid}'),
+ 'embedding': embedding,
+ 'samples': info.get('samples', 1),
+ 'enrolled_at': info.get('enrolled_at', 0.0),
+ }
+ if pid >= self._next_id:
+ self._next_id = pid + 1
+
+ logger.info('Gallery loaded: %d persons from %s',
+ len(self._gallery), self._npz_path)
+
+ def save(self) -> None:
+ """Save gallery to disk (npz + JSON sidecar)."""
+ with self._lock:
+ os.makedirs(self._gallery_dir, exist_ok=True)
+
+ arrays = {}
+ meta = {}
+ for pid, info in self._gallery.items():
+ arrays[str(pid)] = info['embedding']
+ meta[str(pid)] = {
+ 'name': info['name'],
+ 'samples': info['samples'],
+ 'enrolled_at': info['enrolled_at'],
+ }
+
+ np.savez(self._npz_path, **arrays)
+ with open(self._meta_path, 'w') as f:
+ json.dump(meta, f, indent=2)
+
+ logger.info('Gallery saved: %d persons to %s',
+ len(self._gallery), self._npz_path)
+
+ def add_person(self, name: str, embedding: np.ndarray, samples: int = 1) -> int:
+ """Add a new person to the gallery.
+
+ Args:
+ name: Person's name.
+ embedding: L2-normalized 512-dim embedding.
+ samples: Number of samples used to compute the embedding.
+
+ Returns:
+ Assigned person_id (auto-increment integer).
+ """
+ with self._lock:
+ pid = self._next_id
+ self._next_id += 1
+ emb = embedding.astype(np.float32)
+ norm = np.linalg.norm(emb)
+ if norm > 0:
+ emb = emb / norm
+ self._gallery[pid] = {
+ 'name': name,
+ 'embedding': emb,
+ 'samples': samples,
+ 'enrolled_at': time.time(),
+ }
+ logger.info('Added person %d (%s), %d samples.', pid, name, samples)
+ return pid
+
+ def update_name(self, person_id: int, new_name: str) -> bool:
+ """Update a person's name.
+
+ Args:
+ person_id: The ID of the person to update.
+ new_name: New name string.
+
+ Returns:
+ True if the person was found and updated.
+ """
+ with self._lock:
+ if person_id not in self._gallery:
+ return False
+ self._gallery[person_id]['name'] = new_name
+ return True
+
+ def delete_person(self, person_id: int) -> bool:
+ """Remove a person from the gallery.
+
+ Args:
+ person_id: The ID of the person to delete.
+
+ Returns:
+ True if the person was found and removed.
+ """
+ with self._lock:
+ if person_id not in self._gallery:
+ return False
+ del self._gallery[person_id]
+ logger.info('Deleted person %d.', person_id)
+ return True
+
+ def get_all(self) -> list[dict]:
+ """Get all gallery entries.
+
+ Returns:
+ List of dicts with keys: person_id, name, embedding, samples, enrolled_at.
+ """
+ with self._lock:
+ result = []
+ for pid, info in self._gallery.items():
+ result.append({
+ 'person_id': pid,
+ 'name': info['name'],
+ 'embedding': info['embedding'].copy(),
+ 'samples': info['samples'],
+ 'enrolled_at': info['enrolled_at'],
+ })
+ return result
+
+ def match(self, query_embedding: np.ndarray, threshold: float = 0.35) -> tuple[int, str, float]:
+ """Match a query embedding against the gallery using cosine similarity.
+
+ Args:
+ query_embedding: L2-normalized 512-dim embedding.
+ threshold: Minimum cosine similarity for a match.
+
+ Returns:
+ (person_id, name, score) or (-1, '', 0.0) if no match.
+ """
+ with self._lock:
+ if not self._gallery:
+ return (-1, '', 0.0)
+
+ best_pid = -1
+ best_name = ''
+ best_score = 0.0
+
+ query = query_embedding.astype(np.float32)
+ norm = np.linalg.norm(query)
+ if norm > 0:
+ query = query / norm
+
+ for pid, info in self._gallery.items():
+ score = float(np.dot(query, info['embedding']))
+ if score > best_score:
+ best_score = score
+ best_pid = pid
+ best_name = info['name']
+
+ if best_score >= threshold:
+ return (best_pid, best_name, best_score)
+ return (-1, '', 0.0)
+
+ def __len__(self) -> int:
+ with self._lock:
+ return len(self._gallery)
diff --git a/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/face_recognition_node.py b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/face_recognition_node.py
new file mode 100644
index 0000000..d73fc98
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/face_recognition_node.py
@@ -0,0 +1,431 @@
+"""
+face_recognition_node.py -- ROS2 node for SCRFD face detection + ArcFace recognition.
+
+Pipeline:
+ 1. Subscribe to /camera/color/image_raw (RealSense D435i color stream).
+ 2. Run SCRFD face detection (TensorRT FP16 or ONNX fallback).
+ 3. For each detected face, align and extract ArcFace embedding.
+ 4. Match embedding against persistent gallery.
+ 5. Publish FaceDetectionArray with identified faces.
+
+Services:
+ /social/enroll -- Enroll a new person (collects N face samples).
+ /social/persons/list -- List all enrolled persons.
+ /social/persons/delete -- Delete a person from the gallery.
+ /social/persons/update -- Update a person's name.
+"""
+
+import time
+
+import numpy as np
+
+import rclpy
+from rclpy.node import Node
+from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy, DurabilityPolicy
+
+import cv2
+from cv_bridge import CvBridge
+
+from sensor_msgs.msg import Image
+from builtin_interfaces.msg import Time
+
+from saltybot_social_msgs.msg import (
+ FaceDetection,
+ FaceDetectionArray,
+ FaceEmbedding,
+ FaceEmbeddingArray,
+)
+from saltybot_social_msgs.srv import (
+ EnrollPerson,
+ ListPersons,
+ DeletePerson,
+ UpdatePerson,
+)
+
+from .scrfd_detector import SCRFDDetector
+from .arcface_recognizer import ArcFaceRecognizer
+from .face_gallery import FaceGallery
+
+
+class FaceRecognitionNode(Node):
+ """ROS2 node: SCRFD face detection + ArcFace gallery matching."""
+
+ def __init__(self):
+ super().__init__('face_recognizer')
+ self._bridge = CvBridge()
+ self._frame_count = 0
+ self._fps_t0 = time.monotonic()
+
+ # -- Parameters --------------------------------------------------------
+ self.declare_parameter('scrfd_engine_path',
+ '/mnt/nvme/saltybot/models/scrfd_2.5g.engine')
+ self.declare_parameter('scrfd_onnx_path',
+ '/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx')
+ self.declare_parameter('arcface_engine_path',
+ '/mnt/nvme/saltybot/models/arcface_r50.engine')
+ self.declare_parameter('arcface_onnx_path',
+ '/mnt/nvme/saltybot/models/arcface_r50.onnx')
+ self.declare_parameter('gallery_dir', '/mnt/nvme/saltybot/gallery')
+ self.declare_parameter('recognition_threshold', 0.35)
+ self.declare_parameter('publish_debug_image', False)
+ self.declare_parameter('max_faces', 10)
+ self.declare_parameter('scrfd_conf_thresh', 0.5)
+
+ self._recognition_threshold = self.get_parameter('recognition_threshold').value
+ self._pub_debug_flag = self.get_parameter('publish_debug_image').value
+ self._max_faces = self.get_parameter('max_faces').value
+
+ # -- Models ------------------------------------------------------------
+ self._detector = SCRFDDetector(
+ engine_path=self.get_parameter('scrfd_engine_path').value,
+ onnx_path=self.get_parameter('scrfd_onnx_path').value,
+ conf_thresh=self.get_parameter('scrfd_conf_thresh').value,
+ )
+ self._recognizer = ArcFaceRecognizer(
+ engine_path=self.get_parameter('arcface_engine_path').value,
+ onnx_path=self.get_parameter('arcface_onnx_path').value,
+ )
+
+ # -- Gallery -----------------------------------------------------------
+ gallery_dir = self.get_parameter('gallery_dir').value
+ self._gallery = FaceGallery(gallery_dir)
+ self._gallery.load()
+ self.get_logger().info('Gallery loaded: %d persons.', len(self._gallery))
+
+ # -- Enrollment state --------------------------------------------------
+ self._enrolling = None # {name, samples_needed, collected: [embeddings]}
+
+ # -- QoS profiles ------------------------------------------------------
+ best_effort_qos = QoSProfile(
+ reliability=ReliabilityPolicy.BEST_EFFORT,
+ history=HistoryPolicy.KEEP_LAST,
+ depth=1,
+ )
+ reliable_qos = QoSProfile(
+ reliability=ReliabilityPolicy.RELIABLE,
+ durability=DurabilityPolicy.TRANSIENT_LOCAL,
+ history=HistoryPolicy.KEEP_LAST,
+ depth=1,
+ )
+
+ # -- Subscribers -------------------------------------------------------
+ self.create_subscription(
+ Image,
+ '/camera/color/image_raw',
+ self._on_image,
+ best_effort_qos,
+ )
+
+ # -- Publishers --------------------------------------------------------
+ self._pub_detections = self.create_publisher(
+ FaceDetectionArray, '/social/faces/detections', best_effort_qos)
+ self._pub_embeddings = self.create_publisher(
+ FaceEmbeddingArray, '/social/faces/embeddings', reliable_qos)
+ if self._pub_debug_flag:
+ self._pub_debug_img = self.create_publisher(
+ Image, '/social/faces/debug_image', best_effort_qos)
+
+ # -- Services ----------------------------------------------------------
+ self.create_service(EnrollPerson, '/social/enroll', self._handle_enroll)
+ self.create_service(ListPersons, '/social/persons/list', self._handle_list)
+ self.create_service(DeletePerson, '/social/persons/delete', self._handle_delete)
+ self.create_service(UpdatePerson, '/social/persons/update', self._handle_update)
+
+ # Publish initial gallery state
+ self._publish_gallery_embeddings()
+
+ self.get_logger().info('FaceRecognitionNode ready.')
+
+ # -- Image callback --------------------------------------------------------
+
+ def _on_image(self, msg: Image):
+ """Process incoming camera frame: detect, recognize, publish."""
+ try:
+ bgr = self._bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')
+ except Exception as e:
+ self.get_logger().error('Image decode error: %s', str(e),
+ throttle_duration_sec=5.0)
+ return
+
+ # Detect faces
+ detections = self._detector.detect(bgr)
+
+ # Limit face count
+ if len(detections) > self._max_faces:
+ detections = sorted(detections, key=lambda d: d['score'], reverse=True)
+ detections = detections[:self._max_faces]
+
+ # Build output message
+ det_array = FaceDetectionArray()
+ det_array.header = msg.header
+
+ for det in detections:
+ # Extract embedding and match gallery
+ embedding = self._recognizer.align_and_embed(bgr, det['kps'])
+ pid, pname, score = self._gallery.match(
+ embedding, self._recognition_threshold)
+
+ # Handle enrollment: collect embedding from largest face
+ if self._enrolling is not None:
+ self._enrollment_collect(det, embedding)
+
+ # Build FaceDetection message
+ face_msg = FaceDetection()
+ face_msg.header = msg.header
+ face_msg.face_id = pid
+ face_msg.person_name = pname
+ face_msg.confidence = det['score']
+ face_msg.recognition_score = score
+
+ bbox = det['bbox']
+ face_msg.bbox_x = bbox[0]
+ face_msg.bbox_y = bbox[1]
+ face_msg.bbox_w = bbox[2] - bbox[0]
+ face_msg.bbox_h = bbox[3] - bbox[1]
+
+ kps = det['kps']
+ for i in range(10):
+ face_msg.landmarks[i] = kps[i]
+
+ det_array.faces.append(face_msg)
+
+ self._pub_detections.publish(det_array)
+
+ # Debug image
+ if self._pub_debug_flag and hasattr(self, '_pub_debug_img'):
+ debug_img = self._draw_debug(bgr, detections, det_array.faces)
+ self._pub_debug_img.publish(
+ self._bridge.cv2_to_imgmsg(debug_img, encoding='bgr8'))
+
+ # FPS logging
+ self._frame_count += 1
+ if self._frame_count % 30 == 0:
+ elapsed = time.monotonic() - self._fps_t0
+ fps = 30.0 / elapsed if elapsed > 0 else 0.0
+ self._fps_t0 = time.monotonic()
+ self.get_logger().info(
+ 'FPS: %.1f | faces: %d', fps, len(detections))
+
+ # -- Enrollment logic ------------------------------------------------------
+
+ def _enrollment_collect(self, det: dict, embedding: np.ndarray):
+ """Collect an embedding sample during enrollment (largest face only)."""
+ if self._enrolling is None:
+ return
+
+ # Only collect from the largest face (by bbox area)
+ bbox = det['bbox']
+ area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
+
+ if not hasattr(self, '_enroll_best_area'):
+ self._enroll_best_area = 0.0
+ self._enroll_best_embedding = None
+
+ if area > self._enroll_best_area:
+ self._enroll_best_area = area
+ self._enroll_best_embedding = embedding
+
+ def _enrollment_frame_end(self):
+ """Called at end of each frame to finalize enrollment sample collection."""
+ if self._enrolling is None or self._enroll_best_embedding is None:
+ return
+
+ self._enrolling['collected'].append(self._enroll_best_embedding)
+ self._enroll_best_area = 0.0
+ self._enroll_best_embedding = None
+
+ collected = len(self._enrolling['collected'])
+ needed = self._enrolling['samples_needed']
+ self.get_logger().info('Enrollment: %d/%d samples for "%s".',
+ collected, needed, self._enrolling['name'])
+
+ if collected >= needed:
+ # Finalize enrollment
+ name = self._enrolling['name']
+ embeddings = self._enrolling['collected']
+ mean_emb = np.mean(embeddings, axis=0).astype(np.float32)
+ norm = np.linalg.norm(mean_emb)
+ if norm > 0:
+ mean_emb = mean_emb / norm
+
+ pid = self._gallery.add_person(name, mean_emb, samples=len(embeddings))
+ self._gallery.save()
+ self._publish_gallery_embeddings()
+
+ self.get_logger().info('Enrollment complete: person %d (%s).', pid, name)
+
+ # Store result for the service callback
+ self._enrolling['result_pid'] = pid
+ self._enrolling['done'] = True
+ self._enrolling = None
+
+ # -- Service handlers ------------------------------------------------------
+
+ def _handle_enroll(self, request, response):
+ """Handle EnrollPerson service: start collecting face samples."""
+ name = request.name.strip()
+ if not name:
+ response.success = False
+ response.message = 'Name cannot be empty.'
+ response.person_id = -1
+ return response
+
+ n_samples = request.n_samples if request.n_samples > 0 else 10
+
+ self.get_logger().info('Starting enrollment for "%s" (%d samples).',
+ name, n_samples)
+
+ # Set enrollment state — frames will collect embeddings
+ self._enrolling = {
+ 'name': name,
+ 'samples_needed': n_samples,
+ 'collected': [],
+ 'done': False,
+ 'result_pid': -1,
+ }
+ self._enroll_best_area = 0.0
+ self._enroll_best_embedding = None
+
+ # Spin until enrollment is done (blocking service)
+ rate = self.create_rate(10) # 10 Hz check
+ timeout_sec = n_samples * 2.0 + 10.0 # generous timeout
+ t0 = time.monotonic()
+
+ while not self._enrolling.get('done', False):
+ # Finalize any pending frame collection
+ self._enrollment_frame_end()
+
+ if time.monotonic() - t0 > timeout_sec:
+ self._enrolling = None
+ response.success = False
+ response.message = f'Enrollment timed out after {timeout_sec:.0f}s.'
+ response.person_id = -1
+ return response
+
+ rclpy.spin_once(self, timeout_sec=0.1)
+
+ response.success = True
+ response.message = f'Enrolled "{name}" with {n_samples} samples.'
+ response.person_id = self._enrolling.get('result_pid', -1) if self._enrolling else -1
+
+ # Clean up (already set to None in _enrollment_frame_end on success)
+ return response
+
+ def _handle_list(self, request, response):
+ """Handle ListPersons service: return all gallery entries."""
+ entries = self._gallery.get_all()
+ for entry in entries:
+ emb_msg = FaceEmbedding()
+ emb_msg.person_id = entry['person_id']
+ emb_msg.person_name = entry['name']
+ emb_msg.embedding = entry['embedding'].tolist()
+ emb_msg.sample_count = entry['samples']
+
+ secs = int(entry['enrolled_at'])
+ nsecs = int((entry['enrolled_at'] - secs) * 1e9)
+ emb_msg.enrolled_at = Time(sec=secs, nanosec=nsecs)
+
+ response.persons.append(emb_msg)
+ return response
+
+ def _handle_delete(self, request, response):
+ """Handle DeletePerson service: remove a person from the gallery."""
+ if self._gallery.delete_person(request.person_id):
+ self._gallery.save()
+ self._publish_gallery_embeddings()
+ response.success = True
+ response.message = f'Deleted person {request.person_id}.'
+ else:
+ response.success = False
+ response.message = f'Person {request.person_id} not found.'
+ return response
+
+ def _handle_update(self, request, response):
+ """Handle UpdatePerson service: rename a person."""
+ new_name = request.new_name.strip()
+ if not new_name:
+ response.success = False
+ response.message = 'New name cannot be empty.'
+ return response
+
+ if self._gallery.update_name(request.person_id, new_name):
+ self._gallery.save()
+ self._publish_gallery_embeddings()
+ response.success = True
+ response.message = f'Updated person {request.person_id} to "{new_name}".'
+ else:
+ response.success = False
+ response.message = f'Person {request.person_id} not found.'
+ return response
+
+ # -- Gallery publishing ----------------------------------------------------
+
+ def _publish_gallery_embeddings(self):
+ """Publish current gallery as FaceEmbeddingArray (latched-like)."""
+ entries = self._gallery.get_all()
+ msg = FaceEmbeddingArray()
+ msg.header.stamp = self.get_clock().now().to_msg()
+
+ for entry in entries:
+ emb_msg = FaceEmbedding()
+ emb_msg.person_id = entry['person_id']
+ emb_msg.person_name = entry['name']
+ emb_msg.embedding = entry['embedding'].tolist()
+ emb_msg.sample_count = entry['samples']
+
+ secs = int(entry['enrolled_at'])
+ nsecs = int((entry['enrolled_at'] - secs) * 1e9)
+ emb_msg.enrolled_at = Time(sec=secs, nanosec=nsecs)
+
+ msg.embeddings.append(emb_msg)
+
+ self._pub_embeddings.publish(msg)
+
+ # -- Debug image -----------------------------------------------------------
+
+ def _draw_debug(self, bgr: np.ndarray, detections: list[dict],
+ face_msgs: list) -> np.ndarray:
+ """Draw bounding boxes, landmarks, and names on the image."""
+ vis = bgr.copy()
+ for det, face_msg in zip(detections, face_msgs):
+ bbox = det['bbox']
+ x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
+
+ # Color: green if recognized, yellow if unknown
+ if face_msg.face_id >= 0:
+ color = (0, 255, 0)
+ label = f'{face_msg.person_name} ({face_msg.recognition_score:.2f})'
+ else:
+ color = (0, 255, 255)
+ label = f'unknown ({face_msg.confidence:.2f})'
+
+ cv2.rectangle(vis, (x1, y1), (x2, y2), color, 2)
+ cv2.putText(vis, label, (x1, y1 - 8),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
+
+ # Draw landmarks
+ kps = det['kps']
+ for k in range(5):
+ px, py = int(kps[k * 2]), int(kps[k * 2 + 1])
+ cv2.circle(vis, (px, py), 2, (0, 0, 255), -1)
+
+ return vis
+
+
+# -- Entry point ---------------------------------------------------------------
+
+def main(args=None):
+ """ROS2 entry point for face_recognition node."""
+ rclpy.init(args=args)
+ node = FaceRecognitionNode()
+ try:
+ rclpy.spin(node)
+ except KeyboardInterrupt:
+ pass
+ finally:
+ node.destroy_node()
+ rclpy.shutdown()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/scrfd_detector.py b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/scrfd_detector.py
new file mode 100644
index 0000000..70625ae
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_face/saltybot_social_face/scrfd_detector.py
@@ -0,0 +1,350 @@
+"""
+scrfd_detector.py -- SCRFD face detection with TensorRT FP16 + ONNX fallback.
+
+SCRFD (Sample and Computation Redistribution for Face Detection) produces
+9 output tensors across 3 strides (8, 16, 32), each with score, bbox, and
+keypoint branches. This module handles anchor generation, bbox/keypoint
+decoding, and NMS to produce final face detections.
+"""
+
+import os
+import logging
+from typing import Optional
+
+import numpy as np
+import cv2
+
+logger = logging.getLogger(__name__)
+
+_STRIDES = [8, 16, 32]
+_NUM_ANCHORS = 2 # anchors per cell per stride
+
+
+# -- Inference backends --------------------------------------------------------
+
+class _TRTBackend:
+ """TensorRT inference engine for SCRFD."""
+
+ def __init__(self, engine_path: str):
+ import tensorrt as trt
+ import pycuda.driver as cuda
+ import pycuda.autoinit # noqa: F401
+
+ self._cuda = cuda
+ trt_logger = trt.Logger(trt.Logger.WARNING)
+ with open(engine_path, 'rb') as f, trt.Runtime(trt_logger) as runtime:
+ self._engine = runtime.deserialize_cuda_engine(f.read())
+ self._context = self._engine.create_execution_context()
+
+ self._inputs = []
+ self._outputs = []
+ self._output_names = []
+ self._bindings = []
+ for i in range(self._engine.num_io_tensors):
+ name = self._engine.get_tensor_name(i)
+ dtype = trt.nptype(self._engine.get_tensor_dtype(name))
+ shape = tuple(self._engine.get_tensor_shape(name))
+ nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize
+ host_mem = cuda.pagelocked_empty(shape, dtype)
+ device_mem = cuda.mem_alloc(nbytes)
+ self._bindings.append(int(device_mem))
+ if self._engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
+ self._inputs.append({'host': host_mem, 'device': device_mem})
+ else:
+ self._outputs.append({'host': host_mem, 'device': device_mem,
+ 'shape': shape})
+ self._output_names.append(name)
+ self._stream = cuda.Stream()
+
+ def infer(self, input_data: np.ndarray) -> list[np.ndarray]:
+ """Run inference and return output tensors."""
+ np.copyto(self._inputs[0]['host'], input_data.ravel())
+ self._cuda.memcpy_htod_async(
+ self._inputs[0]['device'], self._inputs[0]['host'], self._stream)
+ self._context.execute_async_v2(self._bindings, self._stream.handle)
+ results = []
+ for out in self._outputs:
+ self._cuda.memcpy_dtoh_async(out['host'], out['device'], self._stream)
+ self._stream.synchronize()
+ for out in self._outputs:
+ results.append(out['host'].reshape(out['shape']).copy())
+ return results
+
+
+class _ONNXBackend:
+ """ONNX Runtime inference (CUDA EP with CPU fallback)."""
+
+ def __init__(self, onnx_path: str):
+ import onnxruntime as ort
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
+ self._session = ort.InferenceSession(onnx_path, providers=providers)
+ self._input_name = self._session.get_inputs()[0].name
+ self._output_names = [o.name for o in self._session.get_outputs()]
+
+ def infer(self, input_data: np.ndarray) -> list[np.ndarray]:
+ """Run inference and return output tensors."""
+ return self._session.run(None, {self._input_name: input_data})
+
+
+# -- NMS ----------------------------------------------------------------------
+
+def _nms(boxes: np.ndarray, scores: np.ndarray, iou_thresh: float) -> list[int]:
+ """Non-maximum suppression. boxes: [N, 4] as x1,y1,x2,y2."""
+ if len(boxes) == 0:
+ return []
+
+ x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
+ areas = (x2 - x1) * (y2 - y1)
+ order = scores.argsort()[::-1]
+ keep = []
+
+ while order.size > 0:
+ i = order[0]
+ keep.append(int(i))
+ if order.size == 1:
+ break
+ xx1 = np.maximum(x1[i], x1[order[1:]])
+ yy1 = np.maximum(y1[i], y1[order[1:]])
+ xx2 = np.minimum(x2[i], x2[order[1:]])
+ yy2 = np.minimum(y2[i], y2[order[1:]])
+ inter = np.maximum(0.0, xx2 - xx1) * np.maximum(0.0, yy2 - yy1)
+ iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-6)
+ remaining = np.where(iou <= iou_thresh)[0]
+ order = order[remaining + 1]
+
+ return keep
+
+
+# -- Anchor generation ---------------------------------------------------------
+
+def _generate_anchors(input_h: int, input_w: int) -> dict[int, np.ndarray]:
+ """Generate anchor centers for each stride.
+
+ Returns dict mapping stride -> array of shape [H*W*num_anchors, 2],
+ where each row is (cx, cy) in input pixel coordinates.
+ """
+ anchors = {}
+ for stride in _STRIDES:
+ feat_h = input_h // stride
+ feat_w = input_w // stride
+ grid_y, grid_x = np.mgrid[:feat_h, :feat_w]
+ centers = np.stack([grid_x.ravel(), grid_y.ravel()], axis=1).astype(np.float32)
+ centers = (centers + 0.5) * stride
+ # Repeat for num_anchors per cell
+ centers = np.tile(centers, (_NUM_ANCHORS, 1)) # [H*W*2, 2]
+ # Interleave properly: [anchor0_cell0, anchor1_cell0, anchor0_cell1, ...]
+ centers = np.repeat(
+ np.stack([grid_x.ravel(), grid_y.ravel()], axis=1).astype(np.float32),
+ _NUM_ANCHORS, axis=0
+ )
+ centers = (centers + 0.5) * stride
+ anchors[stride] = centers
+ return anchors
+
+
+# -- Main detector class -------------------------------------------------------
+
+class SCRFDDetector:
+ """SCRFD face detector with TensorRT FP16 and ONNX fallback.
+
+ Args:
+ engine_path: Path to TensorRT engine file.
+ onnx_path: Path to ONNX model file (used if engine not available).
+ conf_thresh: Minimum confidence for detections.
+ nms_iou: IoU threshold for NMS.
+ input_size: Model input resolution (square).
+ """
+
+ def __init__(
+ self,
+ engine_path: str = '',
+ onnx_path: str = '',
+ conf_thresh: float = 0.5,
+ nms_iou: float = 0.4,
+ input_size: int = 640,
+ ):
+ self._conf_thresh = conf_thresh
+ self._nms_iou = nms_iou
+ self._input_size = input_size
+ self._backend: Optional[_TRTBackend | _ONNXBackend] = None
+ self._anchors = _generate_anchors(input_size, input_size)
+
+ # Try TRT first, then ONNX
+ if engine_path and os.path.isfile(engine_path):
+ try:
+ self._backend = _TRTBackend(engine_path)
+ logger.info('SCRFD TensorRT backend loaded: %s', engine_path)
+ return
+ except Exception as e:
+ logger.warning('SCRFD TRT load failed (%s), trying ONNX', e)
+
+ if onnx_path and os.path.isfile(onnx_path):
+ try:
+ self._backend = _ONNXBackend(onnx_path)
+ logger.info('SCRFD ONNX backend loaded: %s', onnx_path)
+ return
+ except Exception as e:
+ logger.error('SCRFD ONNX load failed: %s', e)
+
+ logger.error('No SCRFD model loaded. Detection will be unavailable.')
+
+ @property
+ def is_loaded(self) -> bool:
+ """Return True if a backend is loaded and ready."""
+ return self._backend is not None
+
+ def detect(self, bgr: np.ndarray) -> list[dict]:
+ """Detect faces in a BGR image.
+
+ Args:
+ bgr: Input image in BGR format, shape (H, W, 3).
+
+ Returns:
+ List of dicts with keys:
+ bbox: [x1, y1, x2, y2] in original image coordinates
+ kps: [x0,y0, x1,y1, ..., x4,y4] — 10 floats, 5 landmarks
+ score: detection confidence
+ """
+ if self._backend is None:
+ return []
+
+ orig_h, orig_w = bgr.shape[:2]
+ tensor, scale, pad_w, pad_h = self._preprocess(bgr)
+ outputs = self._backend.infer(tensor)
+ detections = self._decode_outputs(outputs)
+ detections = self._rescale(detections, scale, pad_w, pad_h, orig_w, orig_h)
+ return detections
+
+ def _preprocess(self, bgr: np.ndarray) -> tuple[np.ndarray, float, int, int]:
+ """Resize to input_size x input_size with letterbox padding, normalize."""
+ h, w = bgr.shape[:2]
+ size = self._input_size
+ scale = min(size / h, size / w)
+ new_w, new_h = int(w * scale), int(h * scale)
+ pad_w = (size - new_w) // 2
+ pad_h = (size - new_h) // 2
+
+ resized = cv2.resize(bgr, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
+ canvas = np.full((size, size, 3), 0, dtype=np.uint8)
+ canvas[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = resized
+
+ # Normalize: subtract 127.5, divide 128.0
+ blob = canvas.astype(np.float32)
+ blob = (blob - 127.5) / 128.0
+ # HWC -> NCHW
+ blob = blob.transpose(2, 0, 1)[np.newaxis]
+ blob = np.ascontiguousarray(blob)
+ return blob, scale, pad_w, pad_h
+
+ def _decode_outputs(self, outputs: list[np.ndarray]) -> list[dict]:
+ """Decode SCRFD 9-output format into face detections.
+
+ SCRFD outputs 9 tensors, 3 per stride (score, bbox, kps):
+ score_8, bbox_8, kps_8, score_16, bbox_16, kps_16, score_32, bbox_32, kps_32
+ """
+ all_scores = []
+ all_bboxes = []
+ all_kps = []
+
+ for idx, stride in enumerate(_STRIDES):
+ score_out = outputs[idx * 3].squeeze() # [H*W*num_anchors]
+ bbox_out = outputs[idx * 3 + 1].squeeze() # [H*W*num_anchors, 4]
+ kps_out = outputs[idx * 3 + 2].squeeze() # [H*W*num_anchors, 10]
+
+ if score_out.ndim == 0:
+ continue
+
+ # Ensure proper shapes
+ if score_out.ndim == 1:
+ n = score_out.shape[0]
+ else:
+ n = score_out.shape[0]
+ score_out = score_out.ravel()
+
+ if bbox_out.ndim == 1:
+ bbox_out = bbox_out.reshape(-1, 4)
+ if kps_out.ndim == 1:
+ kps_out = kps_out.reshape(-1, 10)
+
+ # Filter by confidence
+ mask = score_out > self._conf_thresh
+ if not mask.any():
+ continue
+
+ scores = score_out[mask]
+ bboxes = bbox_out[mask]
+ kps = kps_out[mask]
+
+ anchors = self._anchors[stride]
+ # Trim or pad anchors to match output count
+ if anchors.shape[0] > n:
+ anchors = anchors[:n]
+ elif anchors.shape[0] < n:
+ continue
+
+ anchors = anchors[mask]
+
+ # Decode bboxes: center = anchor + pred[:2]*stride, size = exp(pred[2:])*stride
+ cx = anchors[:, 0] + bboxes[:, 0] * stride
+ cy = anchors[:, 1] + bboxes[:, 1] * stride
+ w = np.exp(bboxes[:, 2]) * stride
+ h = np.exp(bboxes[:, 3]) * stride
+ x1 = cx - w / 2.0
+ y1 = cy - h / 2.0
+ x2 = cx + w / 2.0
+ y2 = cy + h / 2.0
+ decoded_bboxes = np.stack([x1, y1, x2, y2], axis=1)
+
+ # Decode keypoints: kp = anchor + pred * stride
+ decoded_kps = np.zeros_like(kps)
+ for k in range(5):
+ decoded_kps[:, k * 2] = anchors[:, 0] + kps[:, k * 2] * stride
+ decoded_kps[:, k * 2 + 1] = anchors[:, 1] + kps[:, k * 2 + 1] * stride
+
+ all_scores.append(scores)
+ all_bboxes.append(decoded_bboxes)
+ all_kps.append(decoded_kps)
+
+ if not all_scores:
+ return []
+
+ scores = np.concatenate(all_scores)
+ bboxes = np.concatenate(all_bboxes)
+ kps = np.concatenate(all_kps)
+
+ # NMS
+ keep = _nms(bboxes, scores, self._nms_iou)
+
+ results = []
+ for i in keep:
+ results.append({
+ 'bbox': bboxes[i].tolist(),
+ 'kps': kps[i].tolist(),
+ 'score': float(scores[i]),
+ })
+ return results
+
+ def _rescale(
+ self,
+ detections: list[dict],
+ scale: float,
+ pad_w: int,
+ pad_h: int,
+ orig_w: int,
+ orig_h: int,
+ ) -> list[dict]:
+ """Rescale detections from model input space to original image space."""
+ for det in detections:
+ bbox = det['bbox']
+ bbox[0] = max(0.0, (bbox[0] - pad_w) / scale)
+ bbox[1] = max(0.0, (bbox[1] - pad_h) / scale)
+ bbox[2] = min(float(orig_w), (bbox[2] - pad_w) / scale)
+ bbox[3] = min(float(orig_h), (bbox[3] - pad_h) / scale)
+ det['bbox'] = bbox
+
+ kps = det['kps']
+ for k in range(5):
+ kps[k * 2] = (kps[k * 2] - pad_w) / scale
+ kps[k * 2 + 1] = (kps[k * 2 + 1] - pad_h) / scale
+ det['kps'] = kps
+ return detections
diff --git a/jetson/ros2_ws/src/saltybot_social_face/scripts/build_face_trt_engines.py b/jetson/ros2_ws/src/saltybot_social_face/scripts/build_face_trt_engines.py
new file mode 100644
index 0000000..c2a1a83
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_face/scripts/build_face_trt_engines.py
@@ -0,0 +1,112 @@
+#!/usr/bin/env python3
+"""
+build_face_trt_engines.py -- Build TensorRT FP16 engines for SCRFD and ArcFace.
+
+Converts ONNX model files to optimized TensorRT engines with FP16 precision
+for fast inference on Jetson Orin Nano Super.
+
+Usage:
+ python3 build_face_trt_engines.py \
+ --scrfd-onnx /path/to/scrfd_2.5g_bnkps.onnx \
+ --arcface-onnx /path/to/arcface_r50.onnx \
+ --output-dir /mnt/nvme/saltybot/models \
+ --fp16 --workspace-mb 1024
+"""
+
+import argparse
+import os
+import time
+
+
+def build_engine(onnx_path: str, engine_path: str, fp16: bool, workspace_mb: int):
+ """Build a TensorRT engine from an ONNX model.
+
+ Args:
+ onnx_path: Path to the source ONNX model file.
+ engine_path: Output path for the serialized TensorRT engine.
+ fp16: Enable FP16 precision.
+ workspace_mb: Maximum workspace size in megabytes.
+ """
+ import tensorrt as trt
+
+ logger = trt.Logger(trt.Logger.INFO)
+ builder = trt.Builder(logger)
+ network = builder.create_network(
+ 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
+ parser = trt.OnnxParser(network, logger)
+
+ print(f'Parsing ONNX model: {onnx_path}')
+ t0 = time.monotonic()
+
+ with open(onnx_path, 'rb') as f:
+ if not parser.parse(f.read()):
+ for i in range(parser.num_errors):
+ print(f' ONNX parse error: {parser.get_error(i)}')
+ raise RuntimeError(f'Failed to parse {onnx_path}')
+
+ parse_time = time.monotonic() - t0
+ print(f' Parsed in {parse_time:.1f}s')
+
+ config = builder.create_builder_config()
+ config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE,
+ workspace_mb * (1 << 20))
+ if fp16:
+ if builder.platform_has_fast_fp16:
+ config.set_flag(trt.BuilderFlag.FP16)
+ print(' FP16 enabled.')
+ else:
+ print(' Warning: FP16 not supported on this platform, using FP32.')
+
+ print(f'Building engine (this may take several minutes)...')
+ t0 = time.monotonic()
+ serialized = builder.build_serialized_network(network, config)
+ build_time = time.monotonic() - t0
+
+ if serialized is None:
+ raise RuntimeError('Engine build failed.')
+
+ os.makedirs(os.path.dirname(engine_path) or '.', exist_ok=True)
+ with open(engine_path, 'wb') as f:
+ f.write(serialized)
+
+ size_mb = os.path.getsize(engine_path) / (1 << 20)
+ print(f' Engine saved: {engine_path} ({size_mb:.1f} MB, built in {build_time:.1f}s)')
+
+
+def main():
+ """Main entry point for TRT engine building."""
+ parser = argparse.ArgumentParser(
+ description='Build TensorRT FP16 engines for SCRFD and ArcFace.')
+ parser.add_argument('--scrfd-onnx', type=str, default='',
+ help='Path to SCRFD ONNX model.')
+ parser.add_argument('--arcface-onnx', type=str, default='',
+ help='Path to ArcFace ONNX model.')
+ parser.add_argument('--output-dir', type=str,
+ default='/mnt/nvme/saltybot/models',
+ help='Output directory for engine files.')
+ parser.add_argument('--fp16', action='store_true', default=True,
+ help='Enable FP16 precision (default: True).')
+ parser.add_argument('--no-fp16', action='store_false', dest='fp16',
+ help='Disable FP16 (use FP32 only).')
+ parser.add_argument('--workspace-mb', type=int, default=1024,
+ help='TRT workspace size in MB (default: 1024).')
+ args = parser.parse_args()
+
+ if not args.scrfd_onnx and not args.arcface_onnx:
+ parser.error('At least one of --scrfd-onnx or --arcface-onnx is required.')
+
+ if args.scrfd_onnx:
+ engine_path = os.path.join(args.output_dir, 'scrfd_2.5g.engine')
+ print(f'\n=== Building SCRFD engine ===')
+ build_engine(args.scrfd_onnx, engine_path, args.fp16, args.workspace_mb)
+
+ if args.arcface_onnx:
+ engine_path = os.path.join(args.output_dir, 'arcface_r50.engine')
+ print(f'\n=== Building ArcFace engine ===')
+ build_engine(args.arcface_onnx, engine_path, args.fp16, args.workspace_mb)
+
+ print('\nDone.')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/jetson/ros2_ws/src/saltybot_social_face/setup.cfg b/jetson/ros2_ws/src/saltybot_social_face/setup.cfg
new file mode 100644
index 0000000..fb7d235
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_face/setup.cfg
@@ -0,0 +1,4 @@
+[develop]
+script_dir=$base/lib/saltybot_social_face
+[install]
+install_scripts=$base/lib/saltybot_social_face
diff --git a/jetson/ros2_ws/src/saltybot_social_face/setup.py b/jetson/ros2_ws/src/saltybot_social_face/setup.py
new file mode 100644
index 0000000..fa85fe9
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_face/setup.py
@@ -0,0 +1,30 @@
+"""Setup for saltybot_social_face package."""
+
+from setuptools import find_packages, setup
+
+package_name = 'saltybot_social_face'
+
+setup(
+ name=package_name,
+ version='0.1.0',
+ packages=find_packages(exclude=['test']),
+ data_files=[
+ ('share/ament_index/resource_index/packages', ['resource/' + package_name]),
+ ('share/' + package_name, ['package.xml']),
+ ('share/' + package_name + '/launch', ['launch/face_recognition.launch.py']),
+ ('share/' + package_name + '/config', ['config/face_recognition_params.yaml']),
+ ],
+ install_requires=['setuptools'],
+ zip_safe=True,
+ maintainer='seb',
+ maintainer_email='seb@vayrette.com',
+ description='SCRFD face detection and ArcFace recognition for SaltyBot social interactions',
+ license='MIT',
+ tests_require=['pytest'],
+ entry_points={
+ 'console_scripts': [
+ 'face_recognition = saltybot_social_face.face_recognition_node:main',
+ 'enrollment_cli = saltybot_social_face.enrollment_cli:main',
+ ],
+ },
+)
diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/CMakeLists.txt b/jetson/ros2_ws/src/saltybot_social_msgs/CMakeLists.txt
new file mode 100644
index 0000000..4dfa816
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_msgs/CMakeLists.txt
@@ -0,0 +1,28 @@
+cmake_minimum_required(VERSION 3.8)
+project(saltybot_social_msgs)
+
+if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
+ add_compile_options(-Wall -Wextra -Wpedantic)
+endif()
+
+find_package(ament_cmake REQUIRED)
+find_package(rosidl_default_generators REQUIRED)
+find_package(std_msgs REQUIRED)
+find_package(geometry_msgs REQUIRED)
+find_package(builtin_interfaces REQUIRED)
+
+rosidl_generate_interfaces(${PROJECT_NAME}
+ "msg/FaceDetection.msg"
+ "msg/FaceDetectionArray.msg"
+ "msg/FaceEmbedding.msg"
+ "msg/FaceEmbeddingArray.msg"
+ "msg/PersonState.msg"
+ "msg/PersonStateArray.msg"
+ "srv/EnrollPerson.srv"
+ "srv/ListPersons.srv"
+ "srv/DeletePerson.srv"
+ "srv/UpdatePerson.srv"
+ DEPENDENCIES std_msgs geometry_msgs builtin_interfaces
+)
+
+ament_package()
diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceDetection.msg b/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceDetection.msg
new file mode 100644
index 0000000..840dca7
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceDetection.msg
@@ -0,0 +1,12 @@
+std_msgs/Header header
+int32 face_id # -1 if unknown
+string person_name # "" if unknown
+float32 confidence # detection confidence 0-1
+float32 recognition_score # cosine similarity 0-1 (0 if unknown)
+# Bounding box in pixels
+float32 bbox_x
+float32 bbox_y
+float32 bbox_w
+float32 bbox_h
+# 5-point landmarks [x0,y0, x1,y1, x2,y2, x3,y3, x4,y4] = left_eye, right_eye, nose, left_mouth, right_mouth
+float32[10] landmarks
diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceDetectionArray.msg b/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceDetectionArray.msg
new file mode 100644
index 0000000..66550cc
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceDetectionArray.msg
@@ -0,0 +1,2 @@
+std_msgs/Header header
+saltybot_social_msgs/FaceDetection[] faces
diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceEmbedding.msg b/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceEmbedding.msg
new file mode 100644
index 0000000..74eebb6
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceEmbedding.msg
@@ -0,0 +1,5 @@
+int32 person_id
+string person_name
+float32[] embedding # 512-dim ArcFace embedding
+builtin_interfaces/Time enrolled_at
+int32 sample_count
diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceEmbeddingArray.msg b/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceEmbeddingArray.msg
new file mode 100644
index 0000000..a9c23d9
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_msgs/msg/FaceEmbeddingArray.msg
@@ -0,0 +1,2 @@
+std_msgs/Header header
+saltybot_social_msgs/FaceEmbedding[] embeddings
diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/msg/PersonState.msg b/jetson/ros2_ws/src/saltybot_social_msgs/msg/PersonState.msg
new file mode 100644
index 0000000..2856a26
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_msgs/msg/PersonState.msg
@@ -0,0 +1,19 @@
+std_msgs/Header header
+int32 person_id
+string person_name
+int32 face_id
+string speaker_id # from audio, "" if unknown
+string uwb_anchor_id # "" if no UWB
+geometry_msgs/Point position # in base_link frame, zeros if unknown
+float32 distance # metres
+float32 bearing_deg # degrees, 0=forward
+uint8 state # see STATE_* constants
+uint8 STATE_UNKNOWN=0
+uint8 STATE_APPROACHING=1
+uint8 STATE_ENGAGED=2
+uint8 STATE_TALKING=3
+uint8 STATE_LEAVING=4
+uint8 STATE_ABSENT=5
+float32 engagement_score # 0-1, attention model output
+builtin_interfaces/Time last_seen
+int32 camera_id # which CSI camera last saw them (-1=depth cam)
diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/msg/PersonStateArray.msg b/jetson/ros2_ws/src/saltybot_social_msgs/msg/PersonStateArray.msg
new file mode 100644
index 0000000..a864fc5
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_msgs/msg/PersonStateArray.msg
@@ -0,0 +1,3 @@
+std_msgs/Header header
+saltybot_social_msgs/PersonState[] persons
+int32 primary_attention_id # person_id of focus target (-1 if none)
diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/package.xml b/jetson/ros2_ws/src/saltybot_social_msgs/package.xml
new file mode 100644
index 0000000..ac80146
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_msgs/package.xml
@@ -0,0 +1,24 @@
+
+
+
+ saltybot_social_msgs
+ 0.1.0
+ Custom message and service definitions for the SaltyBot social sprint
+ seb
+ MIT
+
+ ament_cmake
+ rosidl_default_generators
+
+ std_msgs
+ geometry_msgs
+ builtin_interfaces
+
+ rosidl_default_runtime
+
+ rosidl_interface_packages
+
+
+ ament_cmake
+
+
diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/srv/DeletePerson.srv b/jetson/ros2_ws/src/saltybot_social_msgs/srv/DeletePerson.srv
new file mode 100644
index 0000000..0a77e93
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_msgs/srv/DeletePerson.srv
@@ -0,0 +1,4 @@
+int32 person_id
+---
+bool success
+string message
diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/srv/EnrollPerson.srv b/jetson/ros2_ws/src/saltybot_social_msgs/srv/EnrollPerson.srv
new file mode 100644
index 0000000..4edca74
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_msgs/srv/EnrollPerson.srv
@@ -0,0 +1,7 @@
+string name
+string mode # "face", "voice", "both"
+int32 n_samples # number of face crops to average (default 10)
+---
+bool success
+string message
+int32 person_id
diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/srv/ListPersons.srv b/jetson/ros2_ws/src/saltybot_social_msgs/srv/ListPersons.srv
new file mode 100644
index 0000000..bed755a
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_msgs/srv/ListPersons.srv
@@ -0,0 +1,2 @@
+---
+saltybot_social_msgs/FaceEmbedding[] persons
diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/srv/UpdatePerson.srv b/jetson/ros2_ws/src/saltybot_social_msgs/srv/UpdatePerson.srv
new file mode 100644
index 0000000..8fc0abf
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_msgs/srv/UpdatePerson.srv
@@ -0,0 +1,5 @@
+int32 person_id
+string new_name
+---
+bool success
+string message