Compare commits

..

7 Commits

Author SHA1 Message Date
44771751e2 feat(social): personality system — SOUL.md persona, mood engine, relationship DB (Issue #84)
New packages:
- saltybot_social_msgs: PersonalityState.msg + QueryMood.srv custom interfaces
- saltybot_social_personality: full personality node

Features:
- SOUL.md YAML/Markdown persona file: name, humor_level (0-10), sass_level (0-10),
  base_mood, per-tier greeting templates, mood prefix strings
- Hot-reload: SoulWatcher polls SOUL.md every reload_interval seconds, applies
  changes live without restarting the node
- Per-person relationship memory in SQLite: score, interaction_count,
  first/last_seen, learned preferences (JSON), full interaction log
- Mood engine (pure functions): happy | curious | annoyed | playful
  driven by relationship score, interaction count, recent event window (120s)
- Greeting personalisation: stranger | regular | favorite tiers
  keyed on interaction count thresholds from SOUL.md
- Publishes /social/personality/state (PersonalityState) at publish_rate Hz
- /social/personality/query_mood (QueryMood) service for on-demand mood query
- Full ROS2 dynamic reconfigure: soul_file, db_path, reload_interval, publish_rate
- 52 unit tests, no ROS2 runtime required

ROS2 interfaces:
  Sub: /social/person_detected  (std_msgs/String JSON)
  Pub: /social/personality/state (saltybot_social_msgs/PersonalityState)
  Srv: /social/personality/query_mood (saltybot_social_msgs/QueryMood)
2026-03-01 23:56:05 -05:00
dc746ccedc Merge pull request 'feat(social): face detection + recognition #80' (#96) from sl-perception/social-face-detection into main 2026-03-01 23:55:18 -05:00
d6a6965af6 Merge pull request 'feat(social): person enrollment system #87' (#95) from sl-perception/social-enrollment into main 2026-03-01 23:55:16 -05:00
35b940e1f5 Merge pull request 'feat(social): Issue #86 — physical expression + motor attention' (#94) from sl-firmware/social-expression into main 2026-03-01 23:55:14 -05:00
5143e5bfac feat(social): Issue #86 — physical expression + motor attention
ESP32-C3 NeoPixel sketch (esp32/social_expression/social_expression.ino):
  - Adafruit NeoPixel + ArduinoJson, serial JSON protocol 115200 8N1
  - Mood→colour: happy=green, curious=blue, annoyed=red, playful=rainbow
  - Idle breathing animation (sine-modulated warm white)
  - Auto-falls to idle after IDLE_TIMEOUT_MS (3 s) with no command

ROS2 saltybot_social_msgs (new package):
  - Mood.msg — {mood, intensity}
  - Person.msg — {track_id, bearing_rad, distance_m, confidence, is_speaking, source}
  - PersonArray.msg — {persons[], active_id}

ROS2 saltybot_social (new package):
  - expression_node: subscribes /social/mood → JSON serial to ESP32-C3
    reconnects on port error; sends idle frame after idle_timeout_s
  - attention_node: subscribes /social/persons → /cmd_vel rotation-only
    proportional control with dead zone; prefers active speaker, falls
    back to highest-confidence person; lost-target idle after 2 s
  - launch/social.launch.py — combined launch
  - config YAML for both nodes with documented parameters
  - test/test_attention.py — 15 pytest-only unit tests

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-01 23:35:59 -05:00
5c4f18e46c feat(social): person enrollment system — SQLite gallery + voice trigger (Issue #87)
- saltybot_social_msgs: 6 msg + 5 srv definitions for social interaction
- saltybot_social_enrollment: enrollment_node + enrollment_cli
- PersonDB: thread-safe SQLite-backed gallery (embeddings, voice samples)
- Voice-triggered enrollment via "remember me my name is X" phrase
- CLI: enroll/list/delete/rename via ros2 run
- Services: /social/enroll, /social/persons/list|delete|update
- Gallery sync from /social/faces/embeddings topic
2026-03-01 23:32:26 -05:00
f61a03b3c5 feat(social): face detection + recognition (SCRFD + ArcFace TRT FP16, Issue #80)
Add two new ROS2 packages for the social sprint:

saltybot_social_msgs (ament_cmake):
- FaceDetection, FaceDetectionArray, FaceEmbedding, FaceEmbeddingArray
- PersonState, PersonStateArray
- EnrollPerson, ListPersons, DeletePerson, UpdatePerson services

saltybot_social_face (ament_python):
- SCRFDDetector: SCRFD face detection with TRT FP16 + ONNX fallback
  - 640x640 input, 3-stride anchor decoding, NMS
- ArcFaceRecognizer: 512-dim embedding extraction with gallery matching
  - 5-point landmark alignment to 112x112, cosine similarity
- FaceGallery: thread-safe persistent gallery (npz + JSON sidecar)
- FaceRecognitionNode: ROS2 node subscribing /camera/color/image_raw,
  publishing /social/faces/detections, /social/faces/embeddings
- Enrollment via /social/enroll service (N-sample face averaging)
- Launch file, config YAML, TRT engine builder script

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-01 23:31:48 -05:00
37 changed files with 3231 additions and 4 deletions

View File

@ -0,0 +1,206 @@
/*
* social_expression.ino LED mood display for saltybot social layer.
*
* Hardware
*
* MCU : ESP32-C3 (e.g. Seeed XIAO ESP32-C3 or equivalent)
* LEDs : NeoPixel ring or strip, DATA_PIN GPIO5, power via 5V rail
*
* Serial protocol (from Orin Nano, 115200 8N1)
*
* One JSON line per command, newline-terminated:
* {"mood":"happy","intensity":0.8}\n
*
* mood values : "happy" | "curious" | "annoyed" | "playful" | "idle"
* intensity : 0.0 (off) .. 1.0 (full brightness)
*
* If no command arrives for IDLE_TIMEOUT_MS, the node enters the
* breathing idle animation automatically.
*
* Mood colour mapping
*
* happy green #00FF00
* curious blue #0040FF
* annoyed red #FF1000
* playful rainbow (cycling hue across all pixels)
* idle soft white breathing (sine modulated)
*
* Dependencies (Arduino Library Manager)
*
* Adafruit NeoPixel 1.11.0
* ArduinoJson 7.0.0
*
* Build: Arduino IDE 2 / arduino-cli with board "ESP32C3 Dev Module"
* Board manager URL: https://raw.githubusercontent.com/espressif/arduino-esp32/gh-pages/package_esp32_index.json
*/
#include <Adafruit_NeoPixel.h>
#include <ArduinoJson.h>
// ── Hardware config ──────────────────────────────────────────────────────────
#define LED_PIN 5 // GPIO connected to NeoPixel data line
#define LED_COUNT 16 // pixels in ring / strip — adjust as needed
#define LED_BRIGHTNESS 120 // global cap 0-255 (≈47%), protects PSU
// ── Timing constants ─────────────────────────────────────────────────────────
#define IDLE_TIMEOUT_MS 3000 // fall back to breathing after 3 s silence
#define LOOP_INTERVAL_MS 20 // animation tick ≈ 50 Hz
// ── Colours (R, G, B) ────────────────────────────────────────────────────────
static const uint8_t COL_HAPPY[3] = { 0, 220, 0 }; // green
static const uint8_t COL_CURIOUS[3] = { 0, 64, 255 }; // blue
static const uint8_t COL_ANNOYED[3] = {255, 16, 0 }; // red
// ── State ────────────────────────────────────────────────────────────────────
enum Mood { MOOD_IDLE, MOOD_HAPPY, MOOD_CURIOUS, MOOD_ANNOYED, MOOD_PLAYFUL };
static Mood g_mood = MOOD_IDLE;
static float g_intensity = 1.0f;
static uint32_t g_last_cmd_ms = 0; // millis() of last received command
// Animation counters
static uint16_t g_rainbow_hue = 0; // 0..65535, cycles for playful
static uint32_t g_last_tick = 0;
// Serial receive buffer
static char g_serial_buf[128];
static uint8_t g_buf_pos = 0;
// ── NeoPixel object ──────────────────────────────────────────────────────────
Adafruit_NeoPixel strip(LED_COUNT, LED_PIN, NEO_GRB + NEO_KHZ800);
// ── Helpers ──────────────────────────────────────────────────────────────────
static uint8_t scale(uint8_t v, float intensity) {
return (uint8_t)(v * intensity);
}
/*
* Breathing: sine envelope over time.
* Returns brightness factor 0.0..1.0, period ~4 s.
*/
static float breath_factor(uint32_t now_ms) {
float phase = (float)(now_ms % 4000) / 4000.0f; // 0..1 per period
return 0.08f + 0.35f * (0.5f + 0.5f * sinf(2.0f * M_PI * phase));
}
static void set_all(uint8_t r, uint8_t g, uint8_t b) {
for (int i = 0; i < LED_COUNT; i++) {
strip.setPixelColor(i, strip.Color(r, g, b));
}
}
// ── Animation drivers ────────────────────────────────────────────────────────
static void animate_solid(const uint8_t col[3], float intensity) {
set_all(scale(col[0], intensity),
scale(col[1], intensity),
scale(col[2], intensity));
}
static void animate_breathing(uint32_t now_ms, float intensity) {
float bf = breath_factor(now_ms) * intensity;
// Warm white: R=255 G=200 B=120
set_all(scale(255, bf), scale(200, bf), scale(120, bf));
}
static void animate_rainbow(float intensity) {
// Spread full wheel across all pixels
for (int i = 0; i < LED_COUNT; i++) {
uint16_t hue = g_rainbow_hue + (uint16_t)((float)i / LED_COUNT * 65536.0f);
uint32_t rgb = strip.ColorHSV(hue, 255,
(uint8_t)(255.0f * intensity));
strip.setPixelColor(i, rgb);
}
// Advance hue each tick (full cycle in ~6 s at 50 Hz)
g_rainbow_hue += 218;
}
// ── Serial parser ────────────────────────────────────────────────────────────
static void parse_command(const char *line) {
StaticJsonDocument<128> doc;
DeserializationError err = deserializeJson(doc, line);
if (err) return;
const char *mood_str = doc["mood"] | "";
float intensity = doc["intensity"] | 1.0f;
if (intensity < 0.0f) intensity = 0.0f;
if (intensity > 1.0f) intensity = 1.0f;
if (strcmp(mood_str, "happy") == 0) g_mood = MOOD_HAPPY;
else if (strcmp(mood_str, "curious") == 0) g_mood = MOOD_CURIOUS;
else if (strcmp(mood_str, "annoyed") == 0) g_mood = MOOD_ANNOYED;
else if (strcmp(mood_str, "playful") == 0) g_mood = MOOD_PLAYFUL;
else g_mood = MOOD_IDLE;
g_intensity = intensity;
g_last_cmd_ms = millis();
}
static void read_serial(void) {
while (Serial.available()) {
char c = (char)Serial.read();
if (c == '\n' || c == '\r') {
if (g_buf_pos > 0) {
g_serial_buf[g_buf_pos] = '\0';
parse_command(g_serial_buf);
g_buf_pos = 0;
}
} else if (g_buf_pos < (sizeof(g_serial_buf) - 1)) {
g_serial_buf[g_buf_pos++] = c;
} else {
// Buffer overflow — discard line
g_buf_pos = 0;
}
}
}
// ── Arduino entry points ─────────────────────────────────────────────────────
void setup(void) {
Serial.begin(115200);
strip.begin();
strip.setBrightness(LED_BRIGHTNESS);
strip.clear();
strip.show();
g_last_tick = millis();
}
void loop(void) {
uint32_t now = millis();
read_serial();
// Fall back to idle if no command for IDLE_TIMEOUT_MS
if ((now - g_last_cmd_ms) > IDLE_TIMEOUT_MS) {
g_mood = MOOD_IDLE;
}
// Throttle animation ticks
if ((now - g_last_tick) < LOOP_INTERVAL_MS) return;
g_last_tick = now;
switch (g_mood) {
case MOOD_HAPPY:
animate_solid(COL_HAPPY, g_intensity);
break;
case MOOD_CURIOUS:
animate_solid(COL_CURIOUS, g_intensity);
break;
case MOOD_ANNOYED:
animate_solid(COL_ANNOYED, g_intensity);
break;
case MOOD_PLAYFUL:
animate_rainbow(g_intensity);
break;
case MOOD_IDLE:
default:
animate_breathing(now, g_intensity > 0.0f ? g_intensity : 1.0f);
break;
}
strip.show();
}

View File

@ -0,0 +1,30 @@
# attention_params.yaml — Social attention controller configuration.
#
# ── Proportional gain ─────────────────────────────────────────────────────────
# angular_z = clamp(kp_angular * bearing_rad, ±max_angular_vel)
#
# kp_angular : 1.0 → 1.0 rad/s per 1 rad error (57° → full speed).
# Increase for snappier tracking; decrease if robot oscillates.
kp_angular: 1.0 # rad/s per rad
# ── Velocity limits ──────────────────────────────────────────────────────────
# Hard cap on rotation output. In social mode the balance loop must not be
# disturbed by large sudden angular commands; keep this ≤ 1.0 rad/s.
max_angular_vel: 0.8 # rad/s
# ── Dead zone ────────────────────────────────────────────────────────────────
# If |bearing| ≤ dead_zone_rad, rotation is suppressed.
# 0.15 rad ≈ 8.6° — prevents chattering when person is roughly centred.
dead_zone_rad: 0.15 # radians
# ── Control loop ─────────────────────────────────────────────────────────────
control_rate: 20.0 # Hz
# ── Lost-target timeout ───────────────────────────────────────────────────────
# If no /social/persons message arrives for this many seconds, publish zero
# and enter IDLE state. The cmd_vel bridge deadman will also kick in.
lost_timeout_s: 2.0 # seconds
# ── Master enable ─────────────────────────────────────────────────────────────
# Toggle at runtime: ros2 param set /attention_node attention_enabled false
attention_enabled: true

View File

@ -0,0 +1,22 @@
# expression_params.yaml — LED mood display bridge configuration.
#
# serial_port : udev symlink or device path for the ESP32-C3 USB-CDC port.
# Recommended: create a udev rule:
# SUBSYSTEM=="tty", ATTRS{idVendor}=="303a", ATTRS{idProduct}=="1001",
# SYMLINK+="esp32-social"
#
# baud_rate : must match social_expression.ino (default 115200)
#
# idle_timeout_s : if no /social/mood message arrives for this many seconds,
# the node sends {"mood":"idle","intensity":1.0} so the ESP32 breathing
# animation is synchronised with node awareness.
#
# control_rate : how often to check the idle timeout (Hz).
# Does NOT gate the mood messages — those are forwarded immediately.
expression_node:
ros__parameters:
serial_port: /dev/esp32-social
baud_rate: 115200
idle_timeout_s: 3.0
control_rate: 10.0

View File

@ -1,3 +1,18 @@
"""
social.launch.py Launch the full saltybot social stack.
Includes:
person_state_tracker multi-modal person identity fusion (Issue #82)
expression_node /social/mood ESP32-C3 NeoPixel serial (Issue #86)
attention_node /social/persons /cmd_vel rotation (Issue #86)
Usage:
ros2 launch saltybot_social social.launch.py
ros2 launch saltybot_social social.launch.py serial_port:=/dev/ttyUSB1
"""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
@ -5,7 +20,12 @@ from launch_ros.actions import Node
def generate_launch_description():
pkg = get_package_share_directory("saltybot_social")
exp_cfg = os.path.join(pkg, "config", "expression_params.yaml")
att_cfg = os.path.join(pkg, "config", "attention_params.yaml")
return LaunchDescription([
# person_state_tracker args (Issue #82)
DeclareLaunchArgument(
'engagement_distance',
default_value='2.0',
@ -21,6 +41,19 @@ def generate_launch_description():
default_value='false',
description='Whether UWB anchor data is available'
),
# expression_node args (Issue #86)
DeclareLaunchArgument("serial_port", default_value="/dev/esp32-social"),
DeclareLaunchArgument("baud_rate", default_value="115200"),
DeclareLaunchArgument("idle_timeout_s", default_value="3.0"),
# attention_node args (Issue #86)
DeclareLaunchArgument("kp_angular", default_value="1.0"),
DeclareLaunchArgument("max_angular_vel", default_value="0.8"),
DeclareLaunchArgument("dead_zone_rad", default_value="0.15"),
DeclareLaunchArgument("lost_timeout_s", default_value="2.0"),
DeclareLaunchArgument("attention_enabled", default_value="true"),
Node(
package='saltybot_social',
executable='person_state_tracker',
@ -32,4 +65,36 @@ def generate_launch_description():
'uwb_enabled': LaunchConfiguration('uwb_enabled'),
}],
),
Node(
package="saltybot_social",
executable="expression_node",
name="expression_node",
output="screen",
parameters=[
exp_cfg,
{
"serial_port": LaunchConfiguration("serial_port"),
"baud_rate": LaunchConfiguration("baud_rate"),
"idle_timeout_s": LaunchConfiguration("idle_timeout_s"),
},
],
),
Node(
package="saltybot_social",
executable="attention_node",
name="attention_node",
output="screen",
parameters=[
att_cfg,
{
"kp_angular": LaunchConfiguration("kp_angular"),
"max_angular_vel": LaunchConfiguration("max_angular_vel"),
"dead_zone_rad": LaunchConfiguration("dead_zone_rad"),
"lost_timeout_s": LaunchConfiguration("lost_timeout_s"),
"attention_enabled": LaunchConfiguration("attention_enabled"),
},
],
),
])

View File

@ -3,7 +3,12 @@
<package format="3">
<name>saltybot_social</name>
<version>0.1.0</version>
<description>Multi-modal person identity fusion and state tracking for saltybot</description>
<description>
Social interaction layer for saltybot.
person_state_tracker: multi-modal person identity fusion (Issue #82).
expression_node: bridges /social/mood to ESP32-C3 NeoPixel ring over serial (Issue #86).
attention_node: rotates robot toward active speaker via /social/persons bearing (Issue #86).
</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<depend>rclpy</depend>

View File

@ -0,0 +1,205 @@
"""
attention_node.py Social attention controller for saltybot.
Rotates the robot to face the active speaker or most confident detected person.
Publishes rotation-only cmd_vel (linear.x = 0) so balance is not disturbed.
Subscribes
/social/persons (saltybot_social_msgs/PersonArray)
Contains all tracked persons; active_id marks the current speaker.
bearing_rad: signed bearing in base_link (+ve = left/CCW).
Publishes
/cmd_vel (geometry_msgs/Twist)
angular.z = clamp(kp_angular * bearing, ±max_angular_vel)
linear.x = 0 always (social mode, not follow mode)
Control law
1. Pick target: person with active_id; fall back to highest-confidence person.
2. Dead zone: if |bearing| dead_zone_rad angular_z = 0.
3. Proportional: angular_z = kp_angular * bearing.
4. Clamp to ±max_angular_vel.
5. If no fresh person data for lost_timeout_s publish zero and stay idle.
State machine
ATTENDING fresh target; publishing proportional rotation
IDLE no target; publishing zero
Parameters
kp_angular (float) 1.0 proportional gain (rad/s per rad error)
max_angular_vel (float) 0.8 hard cap (rad/s)
dead_zone_rad (float) 0.15 ~8.6° dead zone around robot heading
control_rate (float) 20.0 Hz
lost_timeout_s (float) 2.0 seconds before going idle
attention_enabled (bool) True runtime kill switch
"""
import math
import time
import rclpy
from rclpy.node import Node
from geometry_msgs.msg import Twist
from saltybot_social_msgs.msg import PersonArray
# ── Pure helpers (used by tests) ──────────────────────────────────────────────
def clamp(v: float, lo: float, hi: float) -> float:
return max(lo, min(hi, v))
def select_target(persons, active_id: int):
"""
Pick the attention target from a PersonArray's persons list.
Priority:
1. Person whose track_id == active_id (speaking / explicitly active).
2. Person with highest confidence (fallback when no speaker declared).
Returns the selected Person message, or None if persons is empty.
"""
if not persons:
return None
# Prefer the declared active speaker
for p in persons:
if p.track_id == active_id:
return p
# Fall back to most confident detection
return max(persons, key=lambda p: p.confidence)
def compute_attention_cmd(
bearing_rad: float,
dead_zone_rad: float,
kp_angular: float,
max_angular_vel: float,
) -> float:
"""
Proportional rotation command toward bearing.
Returns angular_z (rad/s).
Zero inside dead zone; proportional + clamped outside.
"""
if abs(bearing_rad) <= dead_zone_rad:
return 0.0
return clamp(kp_angular * bearing_rad, -max_angular_vel, max_angular_vel)
# ── ROS2 node ─────────────────────────────────────────────────────────────────
class AttentionNode(Node):
def __init__(self):
super().__init__("attention_node")
self.declare_parameter("kp_angular", 1.0)
self.declare_parameter("max_angular_vel", 0.8)
self.declare_parameter("dead_zone_rad", 0.15)
self.declare_parameter("control_rate", 20.0)
self.declare_parameter("lost_timeout_s", 2.0)
self.declare_parameter("attention_enabled", True)
self._reload_params()
# State
self._last_persons_t = 0.0 # monotonic; 0 = never received
self._persons = []
self._active_id = -1
self._state = "idle"
self.create_subscription(
PersonArray, "/social/persons", self._persons_cb, 10
)
self._cmd_pub = self.create_publisher(Twist, "/cmd_vel", 10)
rate = self._p["control_rate"]
self._timer = self.create_timer(1.0 / rate, self._control_cb)
self.get_logger().info(
f"AttentionNode ready kp={self._p['kp_angular']} "
f"dead_zone={math.degrees(self._p['dead_zone_rad']):.1f}° "
f"max_ω={self._p['max_angular_vel']} rad/s"
)
def _reload_params(self):
self._p = {
"kp_angular": self.get_parameter("kp_angular").value,
"max_angular_vel": self.get_parameter("max_angular_vel").value,
"dead_zone_rad": self.get_parameter("dead_zone_rad").value,
"control_rate": self.get_parameter("control_rate").value,
"lost_timeout_s": self.get_parameter("lost_timeout_s").value,
"attention_enabled": self.get_parameter("attention_enabled").value,
}
# ── Callbacks ─────────────────────────────────────────────────────────────
def _persons_cb(self, msg: PersonArray):
self._persons = msg.persons
self._active_id = msg.active_id
self._last_persons_t = time.monotonic()
if self._state == "idle" and msg.persons:
self._state = "attending"
self.get_logger().info("Attention: person detected — attending")
def _control_cb(self):
self._reload_params()
twist = Twist()
if not self._p["attention_enabled"]:
self._cmd_pub.publish(twist)
return
now = time.monotonic()
fresh = (
self._last_persons_t > 0.0
and (now - self._last_persons_t) < self._p["lost_timeout_s"]
)
if not fresh:
if self._state == "attending":
self._state = "idle"
self.get_logger().info("Attention: no person — idle")
self._cmd_pub.publish(twist)
return
target = select_target(self._persons, self._active_id)
if target is None:
self._cmd_pub.publish(twist)
return
self._state = "attending"
twist.angular.z = compute_attention_cmd(
bearing_rad=target.bearing_rad,
dead_zone_rad=self._p["dead_zone_rad"],
kp_angular=self._p["kp_angular"],
max_angular_vel=self._p["max_angular_vel"],
)
self._cmd_pub.publish(twist)
# ── Entry point ───────────────────────────────────────────────────────────────
def main(args=None):
rclpy.init(args=args)
node = AttentionNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.try_shutdown()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,157 @@
"""
expression_node.py LED mood display bridge for saltybot social layer.
Subscribes
/social/mood (saltybot_social_msgs/Mood)
mood : "happy" | "curious" | "annoyed" | "playful" | "idle"
intensity : 0.0 .. 1.0
Behaviour
* On every /social/mood message, serialises the command as a single-line
JSON frame and writes it over the ESP32-C3 serial link:
{"mood":"happy","intensity":0.80}\n
* If no mood message arrives for idle_timeout_s, sends an explicit
idle command so the ESP32 breathing animation is synchronised with the
node's own awareness of whether the robot is active.
* Re-opens the serial port automatically if the device disconnects.
Parameters
serial_port (str) /dev/ttyUSB0 ESP32-C3 USB-CDC device
baud_rate (int) 115200
idle_timeout_s (float) 3.0 seconds before sending idle
control_rate (float) 10.0 Hz; how often to check idle
"""
import json
import threading
import time
import rclpy
from rclpy.node import Node
from saltybot_social_msgs.msg import Mood
try:
import serial
_SERIAL_AVAILABLE = True
except ImportError:
_SERIAL_AVAILABLE = False
class ExpressionNode(Node):
def __init__(self):
super().__init__("expression_node")
self.declare_parameter("serial_port", "/dev/ttyUSB0")
self.declare_parameter("baud_rate", 115200)
self.declare_parameter("idle_timeout_s", 3.0)
self.declare_parameter("control_rate", 10.0)
self._port = self.get_parameter("serial_port").value
self._baud = self.get_parameter("baud_rate").value
self._idle_to = self.get_parameter("idle_timeout_s").value
rate = self.get_parameter("control_rate").value
self._last_cmd_t = 0.0 # monotonic, 0 = never received
self._lock = threading.Lock()
self._ser = None
self.create_subscription(Mood, "/social/mood", self._mood_cb, 10)
self._timer = self.create_timer(1.0 / rate, self._tick_cb)
if _SERIAL_AVAILABLE:
threading.Thread(target=self._open_serial, daemon=True).start()
else:
self.get_logger().warn(
"pyserial not available — serial output disabled (dry-run mode)"
)
self.get_logger().info(
f"ExpressionNode ready port={self._port} baud={self._baud} "
f"idle_timeout={self._idle_to}s"
)
# ── Serial management ─────────────────────────────────────────────────────
def _open_serial(self):
"""Background thread: keep serial port open, reconnect on error."""
while rclpy.ok():
try:
with self._lock:
self._ser = serial.Serial(
self._port, self._baud, timeout=1.0
)
self.get_logger().info(f"Opened serial port {self._port}")
# Hold serial open until error
while rclpy.ok():
with self._lock:
if self._ser is None or not self._ser.is_open:
break
time.sleep(0.5)
except Exception as exc:
self.get_logger().warn(
f"Serial {self._port} error: {exc} — retry in 3 s"
)
with self._lock:
if self._ser and self._ser.is_open:
self._ser.close()
self._ser = None
time.sleep(3.0)
def _send(self, mood: str, intensity: float):
"""Serialise and write one JSON line to ESP32-C3."""
line = json.dumps({"mood": mood, "intensity": round(intensity, 3)}) + "\n"
data = line.encode("ascii")
with self._lock:
if self._ser and self._ser.is_open:
try:
self._ser.write(data)
except Exception as exc:
self.get_logger().warn(f"Serial write error: {exc}")
self._ser.close()
self._ser = None
else:
# Dry-run: log the frame
self.get_logger().debug(f"[dry-run] → {line.strip()}")
# ── Callbacks ─────────────────────────────────────────────────────────────
def _mood_cb(self, msg: Mood):
intensity = max(0.0, min(1.0, float(msg.intensity)))
self._send(msg.mood, intensity)
self._last_cmd_t = time.monotonic()
def _tick_cb(self):
"""Periodic check — send idle if we've been quiet too long."""
if self._last_cmd_t == 0.0:
return
if (time.monotonic() - self._last_cmd_t) >= self._idle_to:
self._send("idle", 1.0)
# Reset so we don't flood the ESP32 with idle frames
self._last_cmd_t = time.monotonic()
# ── Entry point ───────────────────────────────────────────────────────────────
def main(args=None):
rclpy.init(args=args)
node = ExpressionNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.try_shutdown()
if __name__ == "__main__":
main()

View File

@ -21,12 +21,14 @@ setup(
zip_safe=True,
maintainer='seb',
maintainer_email='seb@vayrette.com',
description='Multi-modal person identity fusion and state tracking for saltybot',
description='Social interaction layer — person state tracking, LED expression + attention',
license='MIT',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'person_state_tracker = saltybot_social.person_state_tracker_node:main',
'expression_node = saltybot_social.expression_node:main',
'attention_node = saltybot_social.attention_node:main',
],
},
)

View File

@ -0,0 +1,140 @@
"""
test_attention.py Unit tests for attention_node helpers.
No ROS2 / serial / GPU dependencies runs with plain pytest.
"""
import math
import pytest
from saltybot_social.attention_node import (
clamp,
compute_attention_cmd,
select_target,
)
# ── Helpers ───────────────────────────────────────────────────────────────────
class FakePerson:
def __init__(self, track_id, bearing_rad, confidence=0.9, is_speaking=False):
self.track_id = track_id
self.bearing_rad = bearing_rad
self.confidence = confidence
self.is_speaking = is_speaking
# ── clamp ─────────────────────────────────────────────────────────────────────
class TestClamp:
def test_within_range(self):
assert clamp(0.5, 0.0, 1.0) == 0.5
def test_below_lo(self):
assert clamp(-0.5, 0.0, 1.0) == 0.0
def test_above_hi(self):
assert clamp(1.5, 0.0, 1.0) == 1.0
def test_at_boundary(self):
assert clamp(0.0, 0.0, 1.0) == 0.0
assert clamp(1.0, 0.0, 1.0) == 1.0
# ── select_target ─────────────────────────────────────────────────────────────
class TestSelectTarget:
def test_empty_returns_none(self):
assert select_target([], active_id=-1) is None
def test_picks_active_id(self):
persons = [FakePerson(1, 0.1, confidence=0.6),
FakePerson(2, 0.5, confidence=0.9)]
result = select_target(persons, active_id=1)
assert result.track_id == 1
def test_falls_back_to_highest_confidence(self):
persons = [FakePerson(1, 0.1, confidence=0.6),
FakePerson(2, 0.5, confidence=0.9),
FakePerson(3, 0.2, confidence=0.4)]
result = select_target(persons, active_id=-1)
assert result.track_id == 2
def test_single_person_no_active(self):
persons = [FakePerson(5, 0.3)]
result = select_target(persons, active_id=-1)
assert result.track_id == 5
def test_active_id_not_in_list_falls_back(self):
persons = [FakePerson(1, 0.0, confidence=0.5),
FakePerson(2, 0.2, confidence=0.95)]
result = select_target(persons, active_id=99)
assert result.track_id == 2
# ── compute_attention_cmd ─────────────────────────────────────────────────────
class TestComputeAttentionCmd:
def test_inside_dead_zone_zero(self):
cmd = compute_attention_cmd(
bearing_rad=0.05,
dead_zone_rad=0.15,
kp_angular=1.0,
max_angular_vel=0.8,
)
assert cmd == 0.0
def test_exactly_at_dead_zone_boundary_zero(self):
cmd = compute_attention_cmd(
bearing_rad=0.15,
dead_zone_rad=0.15,
kp_angular=1.0,
max_angular_vel=0.8,
)
assert cmd == 0.0
def test_positive_bearing_gives_positive_angular_z(self):
cmd = compute_attention_cmd(
bearing_rad=0.4,
dead_zone_rad=0.15,
kp_angular=1.0,
max_angular_vel=0.8,
)
assert cmd > 0.0
def test_negative_bearing_gives_negative_angular_z(self):
cmd = compute_attention_cmd(
bearing_rad=-0.4,
dead_zone_rad=0.15,
kp_angular=1.0,
max_angular_vel=0.8,
)
assert cmd < 0.0
def test_clamps_to_max_angular_vel(self):
cmd = compute_attention_cmd(
bearing_rad=math.pi, # 180° off
dead_zone_rad=0.15,
kp_angular=2.0,
max_angular_vel=0.8,
)
assert abs(cmd) <= 0.8
def test_proportional_scaling(self):
"""Double bearing → double cmd (within linear region)."""
kp = 1.0
dz = 0.05
cmd1 = compute_attention_cmd(0.3, dz, kp, 10.0)
cmd2 = compute_attention_cmd(0.6, dz, kp, 10.0)
assert abs(cmd2 - 2.0 * cmd1) < 0.01
def test_zero_bearing_zero_cmd(self):
cmd = compute_attention_cmd(
bearing_rad=0.0,
dead_zone_rad=0.1,
kp_angular=1.0,
max_angular_vel=0.8,
)
assert cmd == 0.0

View File

@ -0,0 +1,6 @@
enrollment_node:
ros__parameters:
db_path: '/mnt/nvme/saltybot/gallery/persons.db'
voice_samples_dir: '/mnt/nvme/saltybot/gallery/voice'
auto_enroll_phrase: 'remember me my name is'
n_samples_default: 10

View File

@ -0,0 +1,22 @@
"""Launch file for saltybot social enrollment node."""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch_ros.actions import Node
def generate_launch_description():
pkg_share = get_package_share_directory('saltybot_social_enrollment')
config_file = os.path.join(pkg_share, 'config', 'enrollment_params.yaml')
return LaunchDescription([
Node(
package='saltybot_social_enrollment',
executable='enrollment_node',
name='enrollment_node',
parameters=[config_file],
output='screen',
),
])

View File

@ -0,0 +1,23 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social_enrollment</name>
<version>0.1.0</version>
<description>Person enrollment system for saltybot social interaction</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<buildtool_depend>ament_python</buildtool_depend>
<depend>rclpy</depend>
<depend>sensor_msgs</depend>
<depend>cv_bridge</depend>
<depend>std_msgs</depend>
<depend>saltybot_social_msgs</depend>
<exec_depend>python3-numpy</exec_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,184 @@
#!/usr/bin/env python3
"""enrollment_cli.py -- Gallery management CLI for saltybot social.
Usage:
ros2 run saltybot_social_enrollment enrollment_cli enroll --name "Alice" [--samples 15] [--mode face]
ros2 run saltybot_social_enrollment enrollment_cli list
ros2 run saltybot_social_enrollment enrollment_cli delete --id 3
ros2 run saltybot_social_enrollment enrollment_cli rename --id 2 --name "Bob"
"""
import argparse
import sys
import rclpy
from rclpy.node import Node
from saltybot_social_msgs.srv import (
EnrollPerson,
ListPersons,
DeletePerson,
UpdatePerson,
)
class EnrollmentCLI(Node):
def __init__(self):
super().__init__('enrollment_cli')
self._enroll_client = self.create_client(
EnrollPerson, '/social/enroll'
)
self._list_client = self.create_client(
ListPersons, '/social/persons/list'
)
self._delete_client = self.create_client(
DeletePerson, '/social/persons/delete'
)
self._update_client = self.create_client(
UpdatePerson, '/social/persons/update'
)
def enroll(self, name: str, n_samples: int = 10, mode: str = 'face'):
if not self._enroll_client.wait_for_service(timeout_sec=5.0):
print('ERROR: /social/enroll service not available')
return False
req = EnrollPerson.Request()
req.name = name
req.mode = mode
req.n_samples = n_samples
print(f'Enrolling "{name}" ({mode}, {n_samples} samples)...')
future = self._enroll_client.call_async(req)
rclpy.spin_until_future_complete(self, future, timeout_sec=60.0)
if future.result() is None:
print('ERROR: Enrollment timed out')
return False
resp = future.result()
if resp.success:
print(f'OK: Enrolled "{name}" as person_id={resp.person_id}')
else:
print(f'FAILED: {resp.message}')
return resp.success
def list_persons(self):
if not self._list_client.wait_for_service(timeout_sec=5.0):
print('ERROR: /social/persons/list service not available')
return
req = ListPersons.Request()
future = self._list_client.call_async(req)
rclpy.spin_until_future_complete(self, future, timeout_sec=10.0)
if future.result() is None:
print('ERROR: List request timed out')
return
resp = future.result()
if not resp.persons:
print('Gallery is empty.')
return
print(f'{"ID":>4} {"Name":<20} {"Samples":>7} {"Embedding Dim":>13}')
print('-' * 50)
for p in resp.persons:
dim = len(p.embedding) if p.embedding else 0
print(f'{p.person_id:>4} {p.person_name:<20} {p.sample_count:>7} {dim:>13}')
def delete(self, person_id: int):
if not self._delete_client.wait_for_service(timeout_sec=5.0):
print('ERROR: /social/persons/delete service not available')
return False
req = DeletePerson.Request()
req.person_id = person_id
future = self._delete_client.call_async(req)
rclpy.spin_until_future_complete(self, future, timeout_sec=10.0)
if future.result() is None:
print('ERROR: Delete request timed out')
return False
resp = future.result()
if resp.success:
print(f'OK: {resp.message}')
else:
print(f'FAILED: {resp.message}')
return resp.success
def rename(self, person_id: int, new_name: str):
if not self._update_client.wait_for_service(timeout_sec=5.0):
print('ERROR: /social/persons/update service not available')
return False
req = UpdatePerson.Request()
req.person_id = person_id
req.new_name = new_name
future = self._update_client.call_async(req)
rclpy.spin_until_future_complete(self, future, timeout_sec=10.0)
if future.result() is None:
print('ERROR: Rename request timed out')
return False
resp = future.result()
if resp.success:
print(f'OK: {resp.message}')
else:
print(f'FAILED: {resp.message}')
return resp.success
def main(args=None):
parser = argparse.ArgumentParser(
description='saltybot person enrollment CLI'
)
subparsers = parser.add_subparsers(dest='command', required=True)
# enroll
enroll_p = subparsers.add_parser('enroll', help='Enroll a new person')
enroll_p.add_argument('--name', required=True, help='Person name')
enroll_p.add_argument('--samples', type=int, default=10,
help='Number of face samples (default: 10)')
enroll_p.add_argument('--mode', default='face',
help='Enrollment mode (default: face)')
# list
subparsers.add_parser('list', help='List enrolled persons')
# delete
delete_p = subparsers.add_parser('delete', help='Delete a person')
delete_p.add_argument('--id', type=int, required=True,
help='Person ID to delete')
# rename
rename_p = subparsers.add_parser('rename', help='Rename a person')
rename_p.add_argument('--id', type=int, required=True,
help='Person ID to rename')
rename_p.add_argument('--name', required=True, help='New name')
parsed = parser.parse_args(sys.argv[1:])
rclpy.init()
cli = EnrollmentCLI()
try:
if parsed.command == 'enroll':
cli.enroll(parsed.name, parsed.samples, parsed.mode)
elif parsed.command == 'list':
cli.list_persons()
elif parsed.command == 'delete':
cli.delete(parsed.id)
elif parsed.command == 'rename':
cli.rename(parsed.id, parsed.name)
finally:
cli.destroy_node()
rclpy.try_shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,302 @@
"""enrollment_node.py -- ROS2 person enrollment node for saltybot social.
Coordinates person enrollment:
- Forwards /social/enroll to face_recognizer's service
- Owns persistent SQLite gallery (PersonDB)
- Voice-triggered enrollment via "remember me my name is X"
- Gallery management services (list/delete/update)
- Syncs DB from /social/faces/embeddings topic
"""
import threading
import numpy as np
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, DurabilityPolicy
from std_msgs.msg import String
from saltybot_social_msgs.msg import (
FaceDetectionArray,
FaceEmbedding,
FaceEmbeddingArray,
)
from saltybot_social_msgs.srv import (
EnrollPerson,
ListPersons,
DeletePerson,
UpdatePerson,
)
from saltybot_social_enrollment.person_db import PersonDB
class EnrollmentNode(Node):
def __init__(self):
super().__init__('enrollment_node')
# Parameters
self.declare_parameter('db_path', '/mnt/nvme/saltybot/gallery/persons.db')
self.declare_parameter('voice_samples_dir', '/mnt/nvme/saltybot/gallery/voice')
self.declare_parameter('auto_enroll_phrase', 'remember me my name is')
self.declare_parameter('n_samples_default', 10)
db_path = self.get_parameter('db_path').value
self._voice_dir = self.get_parameter('voice_samples_dir').value
self._phrase = self.get_parameter('auto_enroll_phrase').value
self._n_samples = self.get_parameter('n_samples_default').value
self._db = PersonDB(db_path)
self.get_logger().info(f'PersonDB initialized at {db_path}')
# Client to face_recognizer's enroll service
self._enroll_client = self.create_client(
EnrollPerson, '/social/face_recognizer/enroll'
)
# QoS profiles
best_effort_qos = QoSProfile(
depth=10,
reliability=ReliabilityPolicy.BEST_EFFORT,
durability=DurabilityPolicy.VOLATILE,
)
reliable_qos = QoSProfile(
depth=1,
reliability=ReliabilityPolicy.RELIABLE,
durability=DurabilityPolicy.VOLATILE,
)
status_qos = QoSProfile(
depth=1,
reliability=ReliabilityPolicy.BEST_EFFORT,
durability=DurabilityPolicy.VOLATILE,
)
# Subscriptions
self.create_subscription(
FaceDetectionArray, '/social/faces/detections',
self._on_detections, best_effort_qos
)
self.create_subscription(
FaceEmbeddingArray, '/social/faces/embeddings',
self._on_embeddings, reliable_qos
)
self.create_subscription(
String, '/social/speech/transcript',
self._on_transcript, best_effort_qos
)
self.create_subscription(
String, '/social/speech/command',
self._on_command, best_effort_qos
)
# Services
self.create_service(EnrollPerson, '/social/enroll', self._handle_enroll)
self.create_service(ListPersons, '/social/persons/list', self._handle_list)
self.create_service(DeletePerson, '/social/persons/delete', self._handle_delete)
self.create_service(UpdatePerson, '/social/persons/update', self._handle_update)
# Publishers
self._pub_embeddings = self.create_publisher(
FaceEmbeddingArray, '/social/faces/embeddings', reliable_qos
)
self._pub_status = self.create_publisher(
String, '/social/enrollment/status', status_qos
)
self.get_logger().info('EnrollmentNode ready')
# ---- Voice-triggered enrollment ----
def _on_transcript(self, msg: String):
text = msg.data.lower()
phrase = self._phrase.lower()
if phrase in text:
idx = text.index(phrase) + len(phrase)
name = text[idx:].strip()
# Clean up: take first 3 words max as the name
words = name.split()
if words:
name = ' '.join(words[:3]).title()
self.get_logger().info(f'Voice enrollment triggered: "{name}"')
self._trigger_voice_enroll(name)
def _on_command(self, msg: String):
# Reserved for explicit voice commands (e.g., "enroll Alice")
pass
def _trigger_voice_enroll(self, name: str):
if not self._enroll_client.wait_for_service(timeout_sec=1.0):
self.get_logger().warn(
'face_recognizer enroll service not available'
)
self._publish_status(f'Enrollment failed: face_recognizer unavailable')
return
req = EnrollPerson.Request()
req.name = name
req.mode = 'face'
req.n_samples = self._n_samples
future = self._enroll_client.call_async(req)
future.add_done_callback(
lambda f: self._on_enroll_done(f, name)
)
self._publish_status(f'Enrolling "{name}"... look at the camera')
def _on_enroll_done(self, future, name: str):
try:
resp = future.result()
if resp.success:
status = f'Enrolled "{name}" (id={resp.person_id})'
self.get_logger().info(status)
else:
status = f'Enrollment failed for "{name}": {resp.message}'
self.get_logger().warn(status)
self._publish_status(status)
except Exception as e:
self.get_logger().error(f'Enroll call failed: {e}')
self._publish_status(f'Enrollment error: {e}')
# ---- Face detection callback (during enrollment) ----
def _on_detections(self, msg: FaceDetectionArray):
# Reserved for future direct enrollment (without face_recognizer)
pass
# ---- Embeddings sync from face_recognizer ----
def _on_embeddings(self, msg: FaceEmbeddingArray):
for emb in msg.embeddings:
existing = self._db.get_person(emb.person_id)
if existing is None:
arr = np.array(emb.embedding, dtype=np.float32)
self._db.add_person(
emb.person_name, arr, emb.sample_count
)
self.get_logger().info(
f'Synced new person from face_recognizer: '
f'{emb.person_name} (id={emb.person_id})'
)
# ---- Service handlers ----
def _handle_enroll(self, request, response):
"""Forward enrollment to face_recognizer service."""
if not self._enroll_client.wait_for_service(timeout_sec=2.0):
response.success = False
response.message = 'face_recognizer service unavailable'
return response
# Use threading.Event to bridge async call in service callback
event = threading.Event()
result_holder = {}
req = EnrollPerson.Request()
req.name = request.name
req.mode = request.mode
req.n_samples = request.n_samples
future = self._enroll_client.call_async(req)
def _done(f):
try:
result_holder['resp'] = f.result()
except Exception as e:
result_holder['err'] = str(e)
event.set()
future.add_done_callback(_done)
success = event.wait(timeout=35.0)
if not success:
response.success = False
response.message = 'Enrollment timed out'
elif 'resp' in result_holder:
resp = result_holder['resp']
response.success = resp.success
response.message = resp.message
response.person_id = resp.person_id
else:
response.success = False
response.message = result_holder.get('err', 'Unknown error')
return response
def _handle_list(self, request, response):
persons = self._db.get_all()
response.persons = []
for p in persons:
fe = FaceEmbedding()
fe.person_id = p['id']
fe.person_name = p['name']
fe.sample_count = p['sample_count']
fe.enrolled_at.sec = int(p['enrolled_at'])
fe.enrolled_at.nanosec = int(
(p['enrolled_at'] - int(p['enrolled_at'])) * 1e9
)
if p['embedding'] is not None:
fe.embedding = p['embedding'].tolist()
response.persons.append(fe)
return response
def _handle_delete(self, request, response):
success = self._db.delete_person(request.person_id)
response.success = success
if success:
response.message = f'Deleted person {request.person_id}'
self.get_logger().info(response.message)
self._publish_embeddings_from_db()
else:
response.message = f'Person {request.person_id} not found'
return response
def _handle_update(self, request, response):
success = self._db.update_name(request.person_id, request.new_name)
response.success = success
if success:
response.message = f'Updated person {request.person_id} name to "{request.new_name}"'
self.get_logger().info(response.message)
self._publish_embeddings_from_db()
else:
response.message = f'Person {request.person_id} not found'
return response
# ---- Helpers ----
def _publish_status(self, text: str):
msg = String()
msg.data = text
self._pub_status.publish(msg)
def _publish_embeddings_from_db(self):
persons = self._db.get_all()
arr = FaceEmbeddingArray()
arr.header.stamp = self.get_clock().now().to_msg()
for p in persons:
if p['embedding'] is not None:
fe = FaceEmbedding()
fe.person_id = p['id']
fe.person_name = p['name']
fe.sample_count = p['sample_count']
fe.enrolled_at.sec = int(p['enrolled_at'])
fe.enrolled_at.nanosec = int(
(p['enrolled_at'] - int(p['enrolled_at'])) * 1e9
)
fe.embedding = p['embedding'].tolist()
arr.embeddings.append(fe)
self._pub_embeddings.publish(arr)
def main(args=None):
rclpy.init(args=args)
node = EnrollmentNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.try_shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,138 @@
"""person_db.py -- Persistent SQLite person gallery for saltybot enrollment."""
import sqlite3
import json
import time
import numpy as np
import threading
from pathlib import Path
class PersonDB:
"""Thread-safe SQLite-backed person gallery.
Schema:
persons(id INTEGER PRIMARY KEY, name TEXT, enrolled_at REAL,
sample_count INTEGER, embedding BLOB, metadata TEXT)
voice_samples(id INTEGER PRIMARY KEY, person_id INTEGER REFERENCES persons,
recorded_at REAL, sample_path TEXT)
"""
def __init__(self, db_path: str):
self._db_path = db_path
self._lock = threading.Lock()
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
self._init_db()
def _init_db(self):
with self._connect() as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS persons (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
enrolled_at REAL NOT NULL,
sample_count INTEGER DEFAULT 1,
embedding BLOB,
metadata TEXT DEFAULT '{}'
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS voice_samples (
id INTEGER PRIMARY KEY AUTOINCREMENT,
person_id INTEGER REFERENCES persons(id),
recorded_at REAL NOT NULL,
sample_path TEXT NOT NULL
)
""")
def _connect(self):
return sqlite3.connect(self._db_path)
def add_person(self, name: str, embedding: np.ndarray, sample_count: int = 1,
metadata: dict = None) -> int:
"""Add a new person. Returns new person_id."""
emb_blob = embedding.astype(np.float32).tobytes() if embedding is not None else None
now = time.time()
with self._lock:
with self._connect() as conn:
cur = conn.execute(
"INSERT INTO persons (name, enrolled_at, sample_count, embedding, metadata) "
"VALUES (?, ?, ?, ?, ?)",
(name, now, sample_count, emb_blob, json.dumps(metadata or {}))
)
return cur.lastrowid
def update_embedding(self, person_id: int, embedding: np.ndarray,
sample_count: int) -> bool:
emb_blob = embedding.astype(np.float32).tobytes()
with self._lock:
with self._connect() as conn:
conn.execute(
"UPDATE persons SET embedding=?, sample_count=? WHERE id=?",
(emb_blob, sample_count, person_id)
)
return conn.total_changes > 0
def update_name(self, person_id: int, new_name: str) -> bool:
with self._lock:
with self._connect() as conn:
conn.execute(
"UPDATE persons SET name=? WHERE id=?",
(new_name, person_id)
)
return conn.total_changes > 0
def delete_person(self, person_id: int) -> bool:
with self._lock:
with self._connect() as conn:
conn.execute(
"DELETE FROM voice_samples WHERE person_id=?",
(person_id,)
)
conn.execute(
"DELETE FROM persons WHERE id=?",
(person_id,)
)
return conn.total_changes > 0
def get_all(self) -> list:
"""Returns list of dicts with id, name, enrolled_at, sample_count, embedding."""
with self._lock:
with self._connect() as conn:
rows = conn.execute(
"SELECT id, name, enrolled_at, sample_count, embedding FROM persons"
).fetchall()
result = []
for row in rows:
emb = np.frombuffer(row[4], dtype=np.float32).copy() if row[4] else None
result.append({
'id': row[0], 'name': row[1], 'enrolled_at': row[2],
'sample_count': row[3], 'embedding': emb
})
return result
def get_person(self, person_id: int) -> dict | None:
with self._lock:
with self._connect() as conn:
row = conn.execute(
"SELECT id, name, enrolled_at, sample_count, embedding "
"FROM persons WHERE id=?",
(person_id,)
).fetchone()
if row is None:
return None
emb = np.frombuffer(row[4], dtype=np.float32).copy() if row[4] else None
return {
'id': row[0], 'name': row[1], 'enrolled_at': row[2],
'sample_count': row[3], 'embedding': emb
}
def add_voice_sample(self, person_id: int, sample_path: str) -> int:
with self._lock:
with self._connect() as conn:
cur = conn.execute(
"INSERT INTO voice_samples (person_id, recorded_at, sample_path) "
"VALUES (?, ?, ?)",
(person_id, time.time(), sample_path)
)
return cur.lastrowid

View File

@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/saltybot_social_enrollment
[install]
install_scripts=$base/lib/saltybot_social_enrollment

View File

@ -0,0 +1,29 @@
from setuptools import setup
import os
from glob import glob
package_name = 'saltybot_social_enrollment'
setup(
name=package_name,
version='0.1.0',
packages=[package_name],
data_files=[
('share/ament_index/resource_index/packages', ['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
(os.path.join('share', package_name, 'launch'), glob('launch/*.py')),
(os.path.join('share', package_name, 'config'), glob('config/*.yaml')),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='seb',
maintainer_email='seb@vayrette.com',
description='Person enrollment system for saltybot social interaction',
license='MIT',
entry_points={
'console_scripts': [
'enrollment_node = saltybot_social_enrollment.enrollment_node:main',
'enrollment_cli = saltybot_social_enrollment.enrollment_cli:main',
],
},
)

View File

@ -0,0 +1,11 @@
face_recognizer:
ros__parameters:
scrfd_engine_path: '/mnt/nvme/saltybot/models/scrfd_2.5g.engine'
scrfd_onnx_path: '/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx'
arcface_engine_path: '/mnt/nvme/saltybot/models/arcface_r50.engine'
arcface_onnx_path: '/mnt/nvme/saltybot/models/arcface_r50.onnx'
gallery_dir: '/mnt/nvme/saltybot/gallery'
recognition_threshold: 0.35
publish_debug_image: false
max_faces: 10
scrfd_conf_thresh: 0.5

View File

@ -0,0 +1,80 @@
"""
face_recognition.launch.py -- Launch file for the SCRFD + ArcFace face recognition node.
Launches the face_recognizer node with configurable model paths and parameters.
The RealSense camera must be running separately (e.g., via realsense.launch.py).
"""
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
"""Generate launch description for face recognition pipeline."""
return LaunchDescription([
DeclareLaunchArgument(
'scrfd_engine_path',
default_value='/mnt/nvme/saltybot/models/scrfd_2.5g.engine',
description='Path to SCRFD TensorRT engine file',
),
DeclareLaunchArgument(
'scrfd_onnx_path',
default_value='/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx',
description='Path to SCRFD ONNX model file (fallback)',
),
DeclareLaunchArgument(
'arcface_engine_path',
default_value='/mnt/nvme/saltybot/models/arcface_r50.engine',
description='Path to ArcFace TensorRT engine file',
),
DeclareLaunchArgument(
'arcface_onnx_path',
default_value='/mnt/nvme/saltybot/models/arcface_r50.onnx',
description='Path to ArcFace ONNX model file (fallback)',
),
DeclareLaunchArgument(
'gallery_dir',
default_value='/mnt/nvme/saltybot/gallery',
description='Directory for persistent face gallery storage',
),
DeclareLaunchArgument(
'recognition_threshold',
default_value='0.35',
description='Cosine similarity threshold for face recognition',
),
DeclareLaunchArgument(
'publish_debug_image',
default_value='false',
description='Publish annotated debug image to /social/faces/debug_image',
),
DeclareLaunchArgument(
'max_faces',
default_value='10',
description='Maximum faces to process per frame',
),
DeclareLaunchArgument(
'scrfd_conf_thresh',
default_value='0.5',
description='SCRFD detection confidence threshold',
),
Node(
package='saltybot_social_face',
executable='face_recognition',
name='face_recognizer',
output='screen',
parameters=[{
'scrfd_engine_path': LaunchConfiguration('scrfd_engine_path'),
'scrfd_onnx_path': LaunchConfiguration('scrfd_onnx_path'),
'arcface_engine_path': LaunchConfiguration('arcface_engine_path'),
'arcface_onnx_path': LaunchConfiguration('arcface_onnx_path'),
'gallery_dir': LaunchConfiguration('gallery_dir'),
'recognition_threshold': LaunchConfiguration('recognition_threshold'),
'publish_debug_image': LaunchConfiguration('publish_debug_image'),
'max_faces': LaunchConfiguration('max_faces'),
'scrfd_conf_thresh': LaunchConfiguration('scrfd_conf_thresh'),
}],
),
])

View File

@ -0,0 +1,27 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social_face</name>
<version>0.1.0</version>
<description>SCRFD face detection and ArcFace recognition for SaltyBot social interactions</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>sensor_msgs</depend>
<depend>cv_bridge</depend>
<depend>image_transport</depend>
<depend>saltybot_social_msgs</depend>
<exec_depend>python3-numpy</exec_depend>
<exec_depend>python3-opencv</exec_depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1 @@
"""SaltyBot social face detection and recognition package."""

View File

@ -0,0 +1,316 @@
"""
arcface_recognizer.py -- ArcFace face embedding extraction and gallery matching.
Performs face alignment using 5-point landmarks (insightface standard reference),
extracts 512-dimensional embeddings via ArcFace (TRT FP16 or ONNX fallback),
and matches against a persistent gallery using cosine similarity.
"""
import os
import logging
from typing import Optional
import numpy as np
import cv2
logger = logging.getLogger(__name__)
# InsightFace standard reference landmarks for 112x112 alignment
ARCFACE_SRC = np.array([
[38.2946, 51.6963], # left eye
[73.5318, 51.5014], # right eye
[56.0252, 71.7366], # nose
[41.5493, 92.3655], # left mouth
[70.7299, 92.2041], # right mouth
], dtype=np.float32)
# -- Inference backends --------------------------------------------------------
class _TRTBackend:
"""TensorRT inference engine for ArcFace."""
def __init__(self, engine_path: str):
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit # noqa: F401
self._cuda = cuda
trt_logger = trt.Logger(trt.Logger.WARNING)
with open(engine_path, 'rb') as f, trt.Runtime(trt_logger) as runtime:
self._engine = runtime.deserialize_cuda_engine(f.read())
self._context = self._engine.create_execution_context()
self._inputs = []
self._outputs = []
self._bindings = []
for i in range(self._engine.num_io_tensors):
name = self._engine.get_tensor_name(i)
dtype = trt.nptype(self._engine.get_tensor_dtype(name))
shape = tuple(self._engine.get_tensor_shape(name))
nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize
host_mem = cuda.pagelocked_empty(shape, dtype)
device_mem = cuda.mem_alloc(nbytes)
self._bindings.append(int(device_mem))
if self._engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
self._inputs.append({'host': host_mem, 'device': device_mem})
else:
self._outputs.append({'host': host_mem, 'device': device_mem,
'shape': shape})
self._stream = cuda.Stream()
def infer(self, input_data: np.ndarray) -> np.ndarray:
"""Run inference and return the embedding vector."""
np.copyto(self._inputs[0]['host'], input_data.ravel())
self._cuda.memcpy_htod_async(
self._inputs[0]['device'], self._inputs[0]['host'], self._stream)
self._context.execute_async_v2(self._bindings, self._stream.handle)
for out in self._outputs:
self._cuda.memcpy_dtoh_async(out['host'], out['device'], self._stream)
self._stream.synchronize()
return self._outputs[0]['host'].reshape(self._outputs[0]['shape']).copy()
class _ONNXBackend:
"""ONNX Runtime inference (CUDA EP with CPU fallback)."""
def __init__(self, onnx_path: str):
import onnxruntime as ort
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
self._session = ort.InferenceSession(onnx_path, providers=providers)
self._input_name = self._session.get_inputs()[0].name
def infer(self, input_data: np.ndarray) -> np.ndarray:
"""Run inference and return the embedding vector."""
results = self._session.run(None, {self._input_name: input_data})
return results[0]
# -- Face alignment ------------------------------------------------------------
def align_face(bgr: np.ndarray, landmarks_10: list[float]) -> np.ndarray:
"""Align a face to 112x112 using 5-point landmarks.
Args:
bgr: Source BGR image.
landmarks_10: Flat list of 10 floats [x0,y0, x1,y1, ..., x4,y4].
Returns:
Aligned BGR face crop of shape (112, 112, 3).
"""
src_pts = np.array(landmarks_10, dtype=np.float32).reshape(5, 2)
M, _ = cv2.estimateAffinePartial2D(src_pts, ARCFACE_SRC)
if M is None:
# Fallback: simple crop and resize from bbox-like region
cx = np.mean(src_pts[:, 0])
cy = np.mean(src_pts[:, 1])
spread = max(np.ptp(src_pts[:, 0]), np.ptp(src_pts[:, 1])) * 1.5
half = spread / 2
x1 = max(0, int(cx - half))
y1 = max(0, int(cy - half))
x2 = min(bgr.shape[1], int(cx + half))
y2 = min(bgr.shape[0], int(cy + half))
crop = bgr[y1:y2, x1:x2]
return cv2.resize(crop, (112, 112), interpolation=cv2.INTER_LINEAR)
aligned = cv2.warpAffine(bgr, M, (112, 112), borderMode=cv2.BORDER_REPLICATE)
return aligned
# -- Main recognizer class -----------------------------------------------------
class ArcFaceRecognizer:
"""ArcFace face embedding extractor and gallery matcher.
Args:
engine_path: Path to TensorRT engine file.
onnx_path: Path to ONNX model file (used if engine not available).
"""
def __init__(self, engine_path: str = '', onnx_path: str = ''):
self._backend: Optional[_TRTBackend | _ONNXBackend] = None
self.gallery: dict[int, dict] = {}
# Try TRT first, then ONNX
if engine_path and os.path.isfile(engine_path):
try:
self._backend = _TRTBackend(engine_path)
logger.info('ArcFace TensorRT backend loaded: %s', engine_path)
return
except Exception as e:
logger.warning('ArcFace TRT load failed (%s), trying ONNX', e)
if onnx_path and os.path.isfile(onnx_path):
try:
self._backend = _ONNXBackend(onnx_path)
logger.info('ArcFace ONNX backend loaded: %s', onnx_path)
return
except Exception as e:
logger.error('ArcFace ONNX load failed: %s', e)
logger.error('No ArcFace model loaded. Recognition will be unavailable.')
@property
def is_loaded(self) -> bool:
"""Return True if a backend is loaded and ready."""
return self._backend is not None
def embed(self, bgr_face_112x112: np.ndarray) -> np.ndarray:
"""Extract 512-dim L2-normalized embedding from a 112x112 aligned face.
Args:
bgr_face_112x112: Aligned face crop, BGR, shape (112, 112, 3).
Returns:
L2-normalized embedding of shape (512,).
"""
if self._backend is None:
return np.zeros(512, dtype=np.float32)
# Preprocess: BGR->RGB, /255, subtract 0.5, divide 0.5 -> [1,3,112,112]
rgb = cv2.cvtColor(bgr_face_112x112, cv2.COLOR_BGR2RGB).astype(np.float32)
rgb = rgb / 255.0
rgb = (rgb - 0.5) / 0.5
blob = rgb.transpose(2, 0, 1)[np.newaxis] # [1, 3, 112, 112]
blob = np.ascontiguousarray(blob)
output = self._backend.infer(blob)
embedding = output.flatten()[:512].astype(np.float32)
# L2 normalize
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
return embedding
def align_and_embed(self, bgr_image: np.ndarray, landmarks_10: list[float]) -> np.ndarray:
"""Align face using landmarks and extract embedding.
Args:
bgr_image: Full BGR image.
landmarks_10: Flat list of 10 floats from SCRFD detection.
Returns:
L2-normalized embedding of shape (512,).
"""
aligned = align_face(bgr_image, landmarks_10)
return self.embed(aligned)
def load_gallery(self, gallery_path: str) -> None:
"""Load gallery from .npz file with JSON metadata sidecar.
Args:
gallery_path: Path to the .npz gallery file.
"""
import json
if not os.path.isfile(gallery_path):
logger.info('No gallery file at %s, starting empty.', gallery_path)
self.gallery = {}
return
data = np.load(gallery_path, allow_pickle=False)
meta_path = gallery_path.replace('.npz', '_meta.json')
if os.path.isfile(meta_path):
with open(meta_path, 'r') as f:
meta = json.load(f)
else:
meta = {}
self.gallery = {}
for key in data.files:
pid = int(key)
embedding = data[key].astype(np.float32)
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
info = meta.get(str(pid), {})
self.gallery[pid] = {
'name': info.get('name', f'person_{pid}'),
'embedding': embedding,
'samples': info.get('samples', 1),
'enrolled_at': info.get('enrolled_at', 0.0),
}
logger.info('Gallery loaded: %d persons from %s', len(self.gallery), gallery_path)
def save_gallery(self, gallery_path: str) -> None:
"""Save gallery to .npz file with JSON metadata sidecar.
Args:
gallery_path: Path to the .npz gallery file.
"""
import json
arrays = {}
meta = {}
for pid, info in self.gallery.items():
arrays[str(pid)] = info['embedding']
meta[str(pid)] = {
'name': info['name'],
'samples': info['samples'],
'enrolled_at': info['enrolled_at'],
}
os.makedirs(os.path.dirname(gallery_path) or '.', exist_ok=True)
np.savez(gallery_path, **arrays)
meta_path = gallery_path.replace('.npz', '_meta.json')
with open(meta_path, 'w') as f:
json.dump(meta, f, indent=2)
logger.info('Gallery saved: %d persons to %s', len(self.gallery), gallery_path)
def match(self, embedding: np.ndarray, threshold: float = 0.35) -> tuple[int, str, float]:
"""Match an embedding against the gallery.
Args:
embedding: L2-normalized query embedding of shape (512,).
threshold: Minimum cosine similarity for a match.
Returns:
(person_id, person_name, score) or (-1, '', 0.0) if no match.
"""
if not self.gallery:
return (-1, '', 0.0)
best_pid = -1
best_name = ''
best_score = 0.0
for pid, info in self.gallery.items():
score = float(np.dot(embedding, info['embedding']))
if score > best_score:
best_score = score
best_pid = pid
best_name = info['name']
if best_score >= threshold:
return (best_pid, best_name, best_score)
return (-1, '', 0.0)
def enroll(self, person_id: int, person_name: str, embeddings_list: list[np.ndarray]) -> None:
"""Enroll a person by averaging multiple embeddings.
Args:
person_id: Unique integer ID for this person.
person_name: Human-readable name.
embeddings_list: List of L2-normalized embeddings to average.
"""
import time as _time
if not embeddings_list:
return
mean_emb = np.mean(embeddings_list, axis=0).astype(np.float32)
norm = np.linalg.norm(mean_emb)
if norm > 0:
mean_emb = mean_emb / norm
self.gallery[person_id] = {
'name': person_name,
'embedding': mean_emb,
'samples': len(embeddings_list),
'enrolled_at': _time.time(),
}
logger.info('Enrolled person %d (%s) with %d samples.',
person_id, person_name, len(embeddings_list))

View File

@ -0,0 +1,78 @@
#!/usr/bin/env python3
"""
enrollment_cli.py -- CLI tool for enrolling persons via the /social/enroll service.
Usage:
ros2 run saltybot_social_face enrollment_cli -- --name Alice --mode face --samples 10
"""
import argparse
import sys
import rclpy
from rclpy.node import Node
from saltybot_social_msgs.srv import EnrollPerson
class EnrollmentCLI(Node):
"""Simple CLI node that calls the EnrollPerson service."""
def __init__(self, name: str, mode: str, n_samples: int):
super().__init__('enrollment_cli')
self._client = self.create_client(EnrollPerson, '/social/enroll')
self.get_logger().info('Waiting for /social/enroll service...')
if not self._client.wait_for_service(timeout_sec=10.0):
self.get_logger().error('Service /social/enroll not available.')
return
request = EnrollPerson.Request()
request.name = name
request.mode = mode
request.n_samples = n_samples
self.get_logger().info(
'Enrolling "%s" (mode=%s, samples=%d)...', name, mode, n_samples)
future = self._client.call_async(request)
rclpy.spin_until_future_complete(self, future, timeout_sec=120.0)
if future.result() is not None:
result = future.result()
if result.success:
self.get_logger().info(
'Enrollment successful: person_id=%d, %s',
result.person_id, result.message)
else:
self.get_logger().error(
'Enrollment failed: %s', result.message)
else:
self.get_logger().error('Enrollment service call timed out or failed.')
def main(args=None):
"""Entry point for enrollment CLI."""
parser = argparse.ArgumentParser(description='Enroll a person for face recognition.')
parser.add_argument('--name', type=str, required=True,
help='Name of the person to enroll.')
parser.add_argument('--mode', type=str, default='face',
choices=['face', 'voice', 'both'],
help='Enrollment mode (default: face).')
parser.add_argument('--samples', type=int, default=10,
help='Number of face samples to collect (default: 10).')
# Parse only known args so ROS2 remapping args pass through
parsed, remaining = parser.parse_known_args(args=sys.argv[1:])
rclpy.init(args=remaining)
node = EnrollmentCLI(parsed.name, parsed.mode, parsed.samples)
try:
pass # Node does all work in __init__
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,206 @@
"""
face_gallery.py -- Persistent face embedding gallery backed by numpy .npz + JSON.
Thread-safe gallery storage for face recognition. Embeddings are stored in a
.npz file, with a sidecar metadata.json containing names, sample counts, and
enrollment timestamps. Auto-increment IDs start at 1.
"""
import json
import logging
import os
import threading
import time
from typing import Optional
import numpy as np
logger = logging.getLogger(__name__)
class FaceGallery:
"""Persistent, thread-safe face embedding gallery.
Args:
gallery_dir: Directory for gallery.npz and metadata.json files.
"""
def __init__(self, gallery_dir: str):
self._gallery_dir = gallery_dir
self._npz_path = os.path.join(gallery_dir, 'gallery.npz')
self._meta_path = os.path.join(gallery_dir, 'metadata.json')
self._gallery: dict[int, dict] = {}
self._next_id = 1
self._lock = threading.Lock()
def load(self) -> None:
"""Load gallery from disk. Populates internal gallery dict."""
with self._lock:
self._gallery = {}
self._next_id = 1
if not os.path.isfile(self._npz_path):
logger.info('No gallery file at %s, starting empty.', self._npz_path)
return
data = np.load(self._npz_path, allow_pickle=False)
meta: dict = {}
if os.path.isfile(self._meta_path):
with open(self._meta_path, 'r') as f:
meta = json.load(f)
for key in data.files:
pid = int(key)
embedding = data[key].astype(np.float32)
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
info = meta.get(str(pid), {})
self._gallery[pid] = {
'name': info.get('name', f'person_{pid}'),
'embedding': embedding,
'samples': info.get('samples', 1),
'enrolled_at': info.get('enrolled_at', 0.0),
}
if pid >= self._next_id:
self._next_id = pid + 1
logger.info('Gallery loaded: %d persons from %s',
len(self._gallery), self._npz_path)
def save(self) -> None:
"""Save gallery to disk (npz + JSON sidecar)."""
with self._lock:
os.makedirs(self._gallery_dir, exist_ok=True)
arrays = {}
meta = {}
for pid, info in self._gallery.items():
arrays[str(pid)] = info['embedding']
meta[str(pid)] = {
'name': info['name'],
'samples': info['samples'],
'enrolled_at': info['enrolled_at'],
}
np.savez(self._npz_path, **arrays)
with open(self._meta_path, 'w') as f:
json.dump(meta, f, indent=2)
logger.info('Gallery saved: %d persons to %s',
len(self._gallery), self._npz_path)
def add_person(self, name: str, embedding: np.ndarray, samples: int = 1) -> int:
"""Add a new person to the gallery.
Args:
name: Person's name.
embedding: L2-normalized 512-dim embedding.
samples: Number of samples used to compute the embedding.
Returns:
Assigned person_id (auto-increment integer).
"""
with self._lock:
pid = self._next_id
self._next_id += 1
emb = embedding.astype(np.float32)
norm = np.linalg.norm(emb)
if norm > 0:
emb = emb / norm
self._gallery[pid] = {
'name': name,
'embedding': emb,
'samples': samples,
'enrolled_at': time.time(),
}
logger.info('Added person %d (%s), %d samples.', pid, name, samples)
return pid
def update_name(self, person_id: int, new_name: str) -> bool:
"""Update a person's name.
Args:
person_id: The ID of the person to update.
new_name: New name string.
Returns:
True if the person was found and updated.
"""
with self._lock:
if person_id not in self._gallery:
return False
self._gallery[person_id]['name'] = new_name
return True
def delete_person(self, person_id: int) -> bool:
"""Remove a person from the gallery.
Args:
person_id: The ID of the person to delete.
Returns:
True if the person was found and removed.
"""
with self._lock:
if person_id not in self._gallery:
return False
del self._gallery[person_id]
logger.info('Deleted person %d.', person_id)
return True
def get_all(self) -> list[dict]:
"""Get all gallery entries.
Returns:
List of dicts with keys: person_id, name, embedding, samples, enrolled_at.
"""
with self._lock:
result = []
for pid, info in self._gallery.items():
result.append({
'person_id': pid,
'name': info['name'],
'embedding': info['embedding'].copy(),
'samples': info['samples'],
'enrolled_at': info['enrolled_at'],
})
return result
def match(self, query_embedding: np.ndarray, threshold: float = 0.35) -> tuple[int, str, float]:
"""Match a query embedding against the gallery using cosine similarity.
Args:
query_embedding: L2-normalized 512-dim embedding.
threshold: Minimum cosine similarity for a match.
Returns:
(person_id, name, score) or (-1, '', 0.0) if no match.
"""
with self._lock:
if not self._gallery:
return (-1, '', 0.0)
best_pid = -1
best_name = ''
best_score = 0.0
query = query_embedding.astype(np.float32)
norm = np.linalg.norm(query)
if norm > 0:
query = query / norm
for pid, info in self._gallery.items():
score = float(np.dot(query, info['embedding']))
if score > best_score:
best_score = score
best_pid = pid
best_name = info['name']
if best_score >= threshold:
return (best_pid, best_name, best_score)
return (-1, '', 0.0)
def __len__(self) -> int:
with self._lock:
return len(self._gallery)

View File

@ -0,0 +1,431 @@
"""
face_recognition_node.py -- ROS2 node for SCRFD face detection + ArcFace recognition.
Pipeline:
1. Subscribe to /camera/color/image_raw (RealSense D435i color stream).
2. Run SCRFD face detection (TensorRT FP16 or ONNX fallback).
3. For each detected face, align and extract ArcFace embedding.
4. Match embedding against persistent gallery.
5. Publish FaceDetectionArray with identified faces.
Services:
/social/enroll -- Enroll a new person (collects N face samples).
/social/persons/list -- List all enrolled persons.
/social/persons/delete -- Delete a person from the gallery.
/social/persons/update -- Update a person's name.
"""
import time
import numpy as np
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy, DurabilityPolicy
import cv2
from cv_bridge import CvBridge
from sensor_msgs.msg import Image
from builtin_interfaces.msg import Time
from saltybot_social_msgs.msg import (
FaceDetection,
FaceDetectionArray,
FaceEmbedding,
FaceEmbeddingArray,
)
from saltybot_social_msgs.srv import (
EnrollPerson,
ListPersons,
DeletePerson,
UpdatePerson,
)
from .scrfd_detector import SCRFDDetector
from .arcface_recognizer import ArcFaceRecognizer
from .face_gallery import FaceGallery
class FaceRecognitionNode(Node):
"""ROS2 node: SCRFD face detection + ArcFace gallery matching."""
def __init__(self):
super().__init__('face_recognizer')
self._bridge = CvBridge()
self._frame_count = 0
self._fps_t0 = time.monotonic()
# -- Parameters --------------------------------------------------------
self.declare_parameter('scrfd_engine_path',
'/mnt/nvme/saltybot/models/scrfd_2.5g.engine')
self.declare_parameter('scrfd_onnx_path',
'/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx')
self.declare_parameter('arcface_engine_path',
'/mnt/nvme/saltybot/models/arcface_r50.engine')
self.declare_parameter('arcface_onnx_path',
'/mnt/nvme/saltybot/models/arcface_r50.onnx')
self.declare_parameter('gallery_dir', '/mnt/nvme/saltybot/gallery')
self.declare_parameter('recognition_threshold', 0.35)
self.declare_parameter('publish_debug_image', False)
self.declare_parameter('max_faces', 10)
self.declare_parameter('scrfd_conf_thresh', 0.5)
self._recognition_threshold = self.get_parameter('recognition_threshold').value
self._pub_debug_flag = self.get_parameter('publish_debug_image').value
self._max_faces = self.get_parameter('max_faces').value
# -- Models ------------------------------------------------------------
self._detector = SCRFDDetector(
engine_path=self.get_parameter('scrfd_engine_path').value,
onnx_path=self.get_parameter('scrfd_onnx_path').value,
conf_thresh=self.get_parameter('scrfd_conf_thresh').value,
)
self._recognizer = ArcFaceRecognizer(
engine_path=self.get_parameter('arcface_engine_path').value,
onnx_path=self.get_parameter('arcface_onnx_path').value,
)
# -- Gallery -----------------------------------------------------------
gallery_dir = self.get_parameter('gallery_dir').value
self._gallery = FaceGallery(gallery_dir)
self._gallery.load()
self.get_logger().info('Gallery loaded: %d persons.', len(self._gallery))
# -- Enrollment state --------------------------------------------------
self._enrolling = None # {name, samples_needed, collected: [embeddings]}
# -- QoS profiles ------------------------------------------------------
best_effort_qos = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=1,
)
reliable_qos = QoSProfile(
reliability=ReliabilityPolicy.RELIABLE,
durability=DurabilityPolicy.TRANSIENT_LOCAL,
history=HistoryPolicy.KEEP_LAST,
depth=1,
)
# -- Subscribers -------------------------------------------------------
self.create_subscription(
Image,
'/camera/color/image_raw',
self._on_image,
best_effort_qos,
)
# -- Publishers --------------------------------------------------------
self._pub_detections = self.create_publisher(
FaceDetectionArray, '/social/faces/detections', best_effort_qos)
self._pub_embeddings = self.create_publisher(
FaceEmbeddingArray, '/social/faces/embeddings', reliable_qos)
if self._pub_debug_flag:
self._pub_debug_img = self.create_publisher(
Image, '/social/faces/debug_image', best_effort_qos)
# -- Services ----------------------------------------------------------
self.create_service(EnrollPerson, '/social/enroll', self._handle_enroll)
self.create_service(ListPersons, '/social/persons/list', self._handle_list)
self.create_service(DeletePerson, '/social/persons/delete', self._handle_delete)
self.create_service(UpdatePerson, '/social/persons/update', self._handle_update)
# Publish initial gallery state
self._publish_gallery_embeddings()
self.get_logger().info('FaceRecognitionNode ready.')
# -- Image callback --------------------------------------------------------
def _on_image(self, msg: Image):
"""Process incoming camera frame: detect, recognize, publish."""
try:
bgr = self._bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')
except Exception as e:
self.get_logger().error('Image decode error: %s', str(e),
throttle_duration_sec=5.0)
return
# Detect faces
detections = self._detector.detect(bgr)
# Limit face count
if len(detections) > self._max_faces:
detections = sorted(detections, key=lambda d: d['score'], reverse=True)
detections = detections[:self._max_faces]
# Build output message
det_array = FaceDetectionArray()
det_array.header = msg.header
for det in detections:
# Extract embedding and match gallery
embedding = self._recognizer.align_and_embed(bgr, det['kps'])
pid, pname, score = self._gallery.match(
embedding, self._recognition_threshold)
# Handle enrollment: collect embedding from largest face
if self._enrolling is not None:
self._enrollment_collect(det, embedding)
# Build FaceDetection message
face_msg = FaceDetection()
face_msg.header = msg.header
face_msg.face_id = pid
face_msg.person_name = pname
face_msg.confidence = det['score']
face_msg.recognition_score = score
bbox = det['bbox']
face_msg.bbox_x = bbox[0]
face_msg.bbox_y = bbox[1]
face_msg.bbox_w = bbox[2] - bbox[0]
face_msg.bbox_h = bbox[3] - bbox[1]
kps = det['kps']
for i in range(10):
face_msg.landmarks[i] = kps[i]
det_array.faces.append(face_msg)
self._pub_detections.publish(det_array)
# Debug image
if self._pub_debug_flag and hasattr(self, '_pub_debug_img'):
debug_img = self._draw_debug(bgr, detections, det_array.faces)
self._pub_debug_img.publish(
self._bridge.cv2_to_imgmsg(debug_img, encoding='bgr8'))
# FPS logging
self._frame_count += 1
if self._frame_count % 30 == 0:
elapsed = time.monotonic() - self._fps_t0
fps = 30.0 / elapsed if elapsed > 0 else 0.0
self._fps_t0 = time.monotonic()
self.get_logger().info(
'FPS: %.1f | faces: %d', fps, len(detections))
# -- Enrollment logic ------------------------------------------------------
def _enrollment_collect(self, det: dict, embedding: np.ndarray):
"""Collect an embedding sample during enrollment (largest face only)."""
if self._enrolling is None:
return
# Only collect from the largest face (by bbox area)
bbox = det['bbox']
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
if not hasattr(self, '_enroll_best_area'):
self._enroll_best_area = 0.0
self._enroll_best_embedding = None
if area > self._enroll_best_area:
self._enroll_best_area = area
self._enroll_best_embedding = embedding
def _enrollment_frame_end(self):
"""Called at end of each frame to finalize enrollment sample collection."""
if self._enrolling is None or self._enroll_best_embedding is None:
return
self._enrolling['collected'].append(self._enroll_best_embedding)
self._enroll_best_area = 0.0
self._enroll_best_embedding = None
collected = len(self._enrolling['collected'])
needed = self._enrolling['samples_needed']
self.get_logger().info('Enrollment: %d/%d samples for "%s".',
collected, needed, self._enrolling['name'])
if collected >= needed:
# Finalize enrollment
name = self._enrolling['name']
embeddings = self._enrolling['collected']
mean_emb = np.mean(embeddings, axis=0).astype(np.float32)
norm = np.linalg.norm(mean_emb)
if norm > 0:
mean_emb = mean_emb / norm
pid = self._gallery.add_person(name, mean_emb, samples=len(embeddings))
self._gallery.save()
self._publish_gallery_embeddings()
self.get_logger().info('Enrollment complete: person %d (%s).', pid, name)
# Store result for the service callback
self._enrolling['result_pid'] = pid
self._enrolling['done'] = True
self._enrolling = None
# -- Service handlers ------------------------------------------------------
def _handle_enroll(self, request, response):
"""Handle EnrollPerson service: start collecting face samples."""
name = request.name.strip()
if not name:
response.success = False
response.message = 'Name cannot be empty.'
response.person_id = -1
return response
n_samples = request.n_samples if request.n_samples > 0 else 10
self.get_logger().info('Starting enrollment for "%s" (%d samples).',
name, n_samples)
# Set enrollment state — frames will collect embeddings
self._enrolling = {
'name': name,
'samples_needed': n_samples,
'collected': [],
'done': False,
'result_pid': -1,
}
self._enroll_best_area = 0.0
self._enroll_best_embedding = None
# Spin until enrollment is done (blocking service)
rate = self.create_rate(10) # 10 Hz check
timeout_sec = n_samples * 2.0 + 10.0 # generous timeout
t0 = time.monotonic()
while not self._enrolling.get('done', False):
# Finalize any pending frame collection
self._enrollment_frame_end()
if time.monotonic() - t0 > timeout_sec:
self._enrolling = None
response.success = False
response.message = f'Enrollment timed out after {timeout_sec:.0f}s.'
response.person_id = -1
return response
rclpy.spin_once(self, timeout_sec=0.1)
response.success = True
response.message = f'Enrolled "{name}" with {n_samples} samples.'
response.person_id = self._enrolling.get('result_pid', -1) if self._enrolling else -1
# Clean up (already set to None in _enrollment_frame_end on success)
return response
def _handle_list(self, request, response):
"""Handle ListPersons service: return all gallery entries."""
entries = self._gallery.get_all()
for entry in entries:
emb_msg = FaceEmbedding()
emb_msg.person_id = entry['person_id']
emb_msg.person_name = entry['name']
emb_msg.embedding = entry['embedding'].tolist()
emb_msg.sample_count = entry['samples']
secs = int(entry['enrolled_at'])
nsecs = int((entry['enrolled_at'] - secs) * 1e9)
emb_msg.enrolled_at = Time(sec=secs, nanosec=nsecs)
response.persons.append(emb_msg)
return response
def _handle_delete(self, request, response):
"""Handle DeletePerson service: remove a person from the gallery."""
if self._gallery.delete_person(request.person_id):
self._gallery.save()
self._publish_gallery_embeddings()
response.success = True
response.message = f'Deleted person {request.person_id}.'
else:
response.success = False
response.message = f'Person {request.person_id} not found.'
return response
def _handle_update(self, request, response):
"""Handle UpdatePerson service: rename a person."""
new_name = request.new_name.strip()
if not new_name:
response.success = False
response.message = 'New name cannot be empty.'
return response
if self._gallery.update_name(request.person_id, new_name):
self._gallery.save()
self._publish_gallery_embeddings()
response.success = True
response.message = f'Updated person {request.person_id} to "{new_name}".'
else:
response.success = False
response.message = f'Person {request.person_id} not found.'
return response
# -- Gallery publishing ----------------------------------------------------
def _publish_gallery_embeddings(self):
"""Publish current gallery as FaceEmbeddingArray (latched-like)."""
entries = self._gallery.get_all()
msg = FaceEmbeddingArray()
msg.header.stamp = self.get_clock().now().to_msg()
for entry in entries:
emb_msg = FaceEmbedding()
emb_msg.person_id = entry['person_id']
emb_msg.person_name = entry['name']
emb_msg.embedding = entry['embedding'].tolist()
emb_msg.sample_count = entry['samples']
secs = int(entry['enrolled_at'])
nsecs = int((entry['enrolled_at'] - secs) * 1e9)
emb_msg.enrolled_at = Time(sec=secs, nanosec=nsecs)
msg.embeddings.append(emb_msg)
self._pub_embeddings.publish(msg)
# -- Debug image -----------------------------------------------------------
def _draw_debug(self, bgr: np.ndarray, detections: list[dict],
face_msgs: list) -> np.ndarray:
"""Draw bounding boxes, landmarks, and names on the image."""
vis = bgr.copy()
for det, face_msg in zip(detections, face_msgs):
bbox = det['bbox']
x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
# Color: green if recognized, yellow if unknown
if face_msg.face_id >= 0:
color = (0, 255, 0)
label = f'{face_msg.person_name} ({face_msg.recognition_score:.2f})'
else:
color = (0, 255, 255)
label = f'unknown ({face_msg.confidence:.2f})'
cv2.rectangle(vis, (x1, y1), (x2, y2), color, 2)
cv2.putText(vis, label, (x1, y1 - 8),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
# Draw landmarks
kps = det['kps']
for k in range(5):
px, py = int(kps[k * 2]), int(kps[k * 2 + 1])
cv2.circle(vis, (px, py), 2, (0, 0, 255), -1)
return vis
# -- Entry point ---------------------------------------------------------------
def main(args=None):
"""ROS2 entry point for face_recognition node."""
rclpy.init(args=args)
node = FaceRecognitionNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,350 @@
"""
scrfd_detector.py -- SCRFD face detection with TensorRT FP16 + ONNX fallback.
SCRFD (Sample and Computation Redistribution for Face Detection) produces
9 output tensors across 3 strides (8, 16, 32), each with score, bbox, and
keypoint branches. This module handles anchor generation, bbox/keypoint
decoding, and NMS to produce final face detections.
"""
import os
import logging
from typing import Optional
import numpy as np
import cv2
logger = logging.getLogger(__name__)
_STRIDES = [8, 16, 32]
_NUM_ANCHORS = 2 # anchors per cell per stride
# -- Inference backends --------------------------------------------------------
class _TRTBackend:
"""TensorRT inference engine for SCRFD."""
def __init__(self, engine_path: str):
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit # noqa: F401
self._cuda = cuda
trt_logger = trt.Logger(trt.Logger.WARNING)
with open(engine_path, 'rb') as f, trt.Runtime(trt_logger) as runtime:
self._engine = runtime.deserialize_cuda_engine(f.read())
self._context = self._engine.create_execution_context()
self._inputs = []
self._outputs = []
self._output_names = []
self._bindings = []
for i in range(self._engine.num_io_tensors):
name = self._engine.get_tensor_name(i)
dtype = trt.nptype(self._engine.get_tensor_dtype(name))
shape = tuple(self._engine.get_tensor_shape(name))
nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize
host_mem = cuda.pagelocked_empty(shape, dtype)
device_mem = cuda.mem_alloc(nbytes)
self._bindings.append(int(device_mem))
if self._engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
self._inputs.append({'host': host_mem, 'device': device_mem})
else:
self._outputs.append({'host': host_mem, 'device': device_mem,
'shape': shape})
self._output_names.append(name)
self._stream = cuda.Stream()
def infer(self, input_data: np.ndarray) -> list[np.ndarray]:
"""Run inference and return output tensors."""
np.copyto(self._inputs[0]['host'], input_data.ravel())
self._cuda.memcpy_htod_async(
self._inputs[0]['device'], self._inputs[0]['host'], self._stream)
self._context.execute_async_v2(self._bindings, self._stream.handle)
results = []
for out in self._outputs:
self._cuda.memcpy_dtoh_async(out['host'], out['device'], self._stream)
self._stream.synchronize()
for out in self._outputs:
results.append(out['host'].reshape(out['shape']).copy())
return results
class _ONNXBackend:
"""ONNX Runtime inference (CUDA EP with CPU fallback)."""
def __init__(self, onnx_path: str):
import onnxruntime as ort
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
self._session = ort.InferenceSession(onnx_path, providers=providers)
self._input_name = self._session.get_inputs()[0].name
self._output_names = [o.name for o in self._session.get_outputs()]
def infer(self, input_data: np.ndarray) -> list[np.ndarray]:
"""Run inference and return output tensors."""
return self._session.run(None, {self._input_name: input_data})
# -- NMS ----------------------------------------------------------------------
def _nms(boxes: np.ndarray, scores: np.ndarray, iou_thresh: float) -> list[int]:
"""Non-maximum suppression. boxes: [N, 4] as x1,y1,x2,y2."""
if len(boxes) == 0:
return []
x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
areas = (x2 - x1) * (y2 - y1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(int(i))
if order.size == 1:
break
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
inter = np.maximum(0.0, xx2 - xx1) * np.maximum(0.0, yy2 - yy1)
iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-6)
remaining = np.where(iou <= iou_thresh)[0]
order = order[remaining + 1]
return keep
# -- Anchor generation ---------------------------------------------------------
def _generate_anchors(input_h: int, input_w: int) -> dict[int, np.ndarray]:
"""Generate anchor centers for each stride.
Returns dict mapping stride -> array of shape [H*W*num_anchors, 2],
where each row is (cx, cy) in input pixel coordinates.
"""
anchors = {}
for stride in _STRIDES:
feat_h = input_h // stride
feat_w = input_w // stride
grid_y, grid_x = np.mgrid[:feat_h, :feat_w]
centers = np.stack([grid_x.ravel(), grid_y.ravel()], axis=1).astype(np.float32)
centers = (centers + 0.5) * stride
# Repeat for num_anchors per cell
centers = np.tile(centers, (_NUM_ANCHORS, 1)) # [H*W*2, 2]
# Interleave properly: [anchor0_cell0, anchor1_cell0, anchor0_cell1, ...]
centers = np.repeat(
np.stack([grid_x.ravel(), grid_y.ravel()], axis=1).astype(np.float32),
_NUM_ANCHORS, axis=0
)
centers = (centers + 0.5) * stride
anchors[stride] = centers
return anchors
# -- Main detector class -------------------------------------------------------
class SCRFDDetector:
"""SCRFD face detector with TensorRT FP16 and ONNX fallback.
Args:
engine_path: Path to TensorRT engine file.
onnx_path: Path to ONNX model file (used if engine not available).
conf_thresh: Minimum confidence for detections.
nms_iou: IoU threshold for NMS.
input_size: Model input resolution (square).
"""
def __init__(
self,
engine_path: str = '',
onnx_path: str = '',
conf_thresh: float = 0.5,
nms_iou: float = 0.4,
input_size: int = 640,
):
self._conf_thresh = conf_thresh
self._nms_iou = nms_iou
self._input_size = input_size
self._backend: Optional[_TRTBackend | _ONNXBackend] = None
self._anchors = _generate_anchors(input_size, input_size)
# Try TRT first, then ONNX
if engine_path and os.path.isfile(engine_path):
try:
self._backend = _TRTBackend(engine_path)
logger.info('SCRFD TensorRT backend loaded: %s', engine_path)
return
except Exception as e:
logger.warning('SCRFD TRT load failed (%s), trying ONNX', e)
if onnx_path and os.path.isfile(onnx_path):
try:
self._backend = _ONNXBackend(onnx_path)
logger.info('SCRFD ONNX backend loaded: %s', onnx_path)
return
except Exception as e:
logger.error('SCRFD ONNX load failed: %s', e)
logger.error('No SCRFD model loaded. Detection will be unavailable.')
@property
def is_loaded(self) -> bool:
"""Return True if a backend is loaded and ready."""
return self._backend is not None
def detect(self, bgr: np.ndarray) -> list[dict]:
"""Detect faces in a BGR image.
Args:
bgr: Input image in BGR format, shape (H, W, 3).
Returns:
List of dicts with keys:
bbox: [x1, y1, x2, y2] in original image coordinates
kps: [x0,y0, x1,y1, ..., x4,y4] 10 floats, 5 landmarks
score: detection confidence
"""
if self._backend is None:
return []
orig_h, orig_w = bgr.shape[:2]
tensor, scale, pad_w, pad_h = self._preprocess(bgr)
outputs = self._backend.infer(tensor)
detections = self._decode_outputs(outputs)
detections = self._rescale(detections, scale, pad_w, pad_h, orig_w, orig_h)
return detections
def _preprocess(self, bgr: np.ndarray) -> tuple[np.ndarray, float, int, int]:
"""Resize to input_size x input_size with letterbox padding, normalize."""
h, w = bgr.shape[:2]
size = self._input_size
scale = min(size / h, size / w)
new_w, new_h = int(w * scale), int(h * scale)
pad_w = (size - new_w) // 2
pad_h = (size - new_h) // 2
resized = cv2.resize(bgr, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
canvas = np.full((size, size, 3), 0, dtype=np.uint8)
canvas[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = resized
# Normalize: subtract 127.5, divide 128.0
blob = canvas.astype(np.float32)
blob = (blob - 127.5) / 128.0
# HWC -> NCHW
blob = blob.transpose(2, 0, 1)[np.newaxis]
blob = np.ascontiguousarray(blob)
return blob, scale, pad_w, pad_h
def _decode_outputs(self, outputs: list[np.ndarray]) -> list[dict]:
"""Decode SCRFD 9-output format into face detections.
SCRFD outputs 9 tensors, 3 per stride (score, bbox, kps):
score_8, bbox_8, kps_8, score_16, bbox_16, kps_16, score_32, bbox_32, kps_32
"""
all_scores = []
all_bboxes = []
all_kps = []
for idx, stride in enumerate(_STRIDES):
score_out = outputs[idx * 3].squeeze() # [H*W*num_anchors]
bbox_out = outputs[idx * 3 + 1].squeeze() # [H*W*num_anchors, 4]
kps_out = outputs[idx * 3 + 2].squeeze() # [H*W*num_anchors, 10]
if score_out.ndim == 0:
continue
# Ensure proper shapes
if score_out.ndim == 1:
n = score_out.shape[0]
else:
n = score_out.shape[0]
score_out = score_out.ravel()
if bbox_out.ndim == 1:
bbox_out = bbox_out.reshape(-1, 4)
if kps_out.ndim == 1:
kps_out = kps_out.reshape(-1, 10)
# Filter by confidence
mask = score_out > self._conf_thresh
if not mask.any():
continue
scores = score_out[mask]
bboxes = bbox_out[mask]
kps = kps_out[mask]
anchors = self._anchors[stride]
# Trim or pad anchors to match output count
if anchors.shape[0] > n:
anchors = anchors[:n]
elif anchors.shape[0] < n:
continue
anchors = anchors[mask]
# Decode bboxes: center = anchor + pred[:2]*stride, size = exp(pred[2:])*stride
cx = anchors[:, 0] + bboxes[:, 0] * stride
cy = anchors[:, 1] + bboxes[:, 1] * stride
w = np.exp(bboxes[:, 2]) * stride
h = np.exp(bboxes[:, 3]) * stride
x1 = cx - w / 2.0
y1 = cy - h / 2.0
x2 = cx + w / 2.0
y2 = cy + h / 2.0
decoded_bboxes = np.stack([x1, y1, x2, y2], axis=1)
# Decode keypoints: kp = anchor + pred * stride
decoded_kps = np.zeros_like(kps)
for k in range(5):
decoded_kps[:, k * 2] = anchors[:, 0] + kps[:, k * 2] * stride
decoded_kps[:, k * 2 + 1] = anchors[:, 1] + kps[:, k * 2 + 1] * stride
all_scores.append(scores)
all_bboxes.append(decoded_bboxes)
all_kps.append(decoded_kps)
if not all_scores:
return []
scores = np.concatenate(all_scores)
bboxes = np.concatenate(all_bboxes)
kps = np.concatenate(all_kps)
# NMS
keep = _nms(bboxes, scores, self._nms_iou)
results = []
for i in keep:
results.append({
'bbox': bboxes[i].tolist(),
'kps': kps[i].tolist(),
'score': float(scores[i]),
})
return results
def _rescale(
self,
detections: list[dict],
scale: float,
pad_w: int,
pad_h: int,
orig_w: int,
orig_h: int,
) -> list[dict]:
"""Rescale detections from model input space to original image space."""
for det in detections:
bbox = det['bbox']
bbox[0] = max(0.0, (bbox[0] - pad_w) / scale)
bbox[1] = max(0.0, (bbox[1] - pad_h) / scale)
bbox[2] = min(float(orig_w), (bbox[2] - pad_w) / scale)
bbox[3] = min(float(orig_h), (bbox[3] - pad_h) / scale)
det['bbox'] = bbox
kps = det['kps']
for k in range(5):
kps[k * 2] = (kps[k * 2] - pad_w) / scale
kps[k * 2 + 1] = (kps[k * 2 + 1] - pad_h) / scale
det['kps'] = kps
return detections

View File

@ -0,0 +1,112 @@
#!/usr/bin/env python3
"""
build_face_trt_engines.py -- Build TensorRT FP16 engines for SCRFD and ArcFace.
Converts ONNX model files to optimized TensorRT engines with FP16 precision
for fast inference on Jetson Orin Nano Super.
Usage:
python3 build_face_trt_engines.py \
--scrfd-onnx /path/to/scrfd_2.5g_bnkps.onnx \
--arcface-onnx /path/to/arcface_r50.onnx \
--output-dir /mnt/nvme/saltybot/models \
--fp16 --workspace-mb 1024
"""
import argparse
import os
import time
def build_engine(onnx_path: str, engine_path: str, fp16: bool, workspace_mb: int):
"""Build a TensorRT engine from an ONNX model.
Args:
onnx_path: Path to the source ONNX model file.
engine_path: Output path for the serialized TensorRT engine.
fp16: Enable FP16 precision.
workspace_mb: Maximum workspace size in megabytes.
"""
import tensorrt as trt
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
print(f'Parsing ONNX model: {onnx_path}')
t0 = time.monotonic()
with open(onnx_path, 'rb') as f:
if not parser.parse(f.read()):
for i in range(parser.num_errors):
print(f' ONNX parse error: {parser.get_error(i)}')
raise RuntimeError(f'Failed to parse {onnx_path}')
parse_time = time.monotonic() - t0
print(f' Parsed in {parse_time:.1f}s')
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE,
workspace_mb * (1 << 20))
if fp16:
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
print(' FP16 enabled.')
else:
print(' Warning: FP16 not supported on this platform, using FP32.')
print(f'Building engine (this may take several minutes)...')
t0 = time.monotonic()
serialized = builder.build_serialized_network(network, config)
build_time = time.monotonic() - t0
if serialized is None:
raise RuntimeError('Engine build failed.')
os.makedirs(os.path.dirname(engine_path) or '.', exist_ok=True)
with open(engine_path, 'wb') as f:
f.write(serialized)
size_mb = os.path.getsize(engine_path) / (1 << 20)
print(f' Engine saved: {engine_path} ({size_mb:.1f} MB, built in {build_time:.1f}s)')
def main():
"""Main entry point for TRT engine building."""
parser = argparse.ArgumentParser(
description='Build TensorRT FP16 engines for SCRFD and ArcFace.')
parser.add_argument('--scrfd-onnx', type=str, default='',
help='Path to SCRFD ONNX model.')
parser.add_argument('--arcface-onnx', type=str, default='',
help='Path to ArcFace ONNX model.')
parser.add_argument('--output-dir', type=str,
default='/mnt/nvme/saltybot/models',
help='Output directory for engine files.')
parser.add_argument('--fp16', action='store_true', default=True,
help='Enable FP16 precision (default: True).')
parser.add_argument('--no-fp16', action='store_false', dest='fp16',
help='Disable FP16 (use FP32 only).')
parser.add_argument('--workspace-mb', type=int, default=1024,
help='TRT workspace size in MB (default: 1024).')
args = parser.parse_args()
if not args.scrfd_onnx and not args.arcface_onnx:
parser.error('At least one of --scrfd-onnx or --arcface-onnx is required.')
if args.scrfd_onnx:
engine_path = os.path.join(args.output_dir, 'scrfd_2.5g.engine')
print(f'\n=== Building SCRFD engine ===')
build_engine(args.scrfd_onnx, engine_path, args.fp16, args.workspace_mb)
if args.arcface_onnx:
engine_path = os.path.join(args.output_dir, 'arcface_r50.engine')
print(f'\n=== Building ArcFace engine ===')
build_engine(args.arcface_onnx, engine_path, args.fp16, args.workspace_mb)
print('\nDone.')
if __name__ == '__main__':
main()

View File

@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/saltybot_social_face
[install]
install_scripts=$base/lib/saltybot_social_face

View File

@ -0,0 +1,30 @@
"""Setup for saltybot_social_face package."""
from setuptools import find_packages, setup
package_name = 'saltybot_social_face'
setup(
name=package_name,
version='0.1.0',
packages=find_packages(exclude=['test']),
data_files=[
('share/ament_index/resource_index/packages', ['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
('share/' + package_name + '/launch', ['launch/face_recognition.launch.py']),
('share/' + package_name + '/config', ['config/face_recognition_params.yaml']),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='seb',
maintainer_email='seb@vayrette.com',
description='SCRFD face detection and ArcFace recognition for SaltyBot social interactions',
license='MIT',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'face_recognition = saltybot_social_face.face_recognition_node:main',
'enrollment_cli = saltybot_social_face.enrollment_cli:main',
],
},
)

View File

@ -8,7 +8,7 @@ find_package(geometry_msgs REQUIRED)
find_package(builtin_interfaces REQUIRED)
rosidl_generate_interfaces(${PROJECT_NAME}
# Social perception (from sl-perception)
# Issue #80 face detection + recognition
"msg/FaceDetection.msg"
"msg/FaceDetectionArray.msg"
"msg/FaceEmbedding.msg"
@ -19,7 +19,11 @@ rosidl_generate_interfaces(${PROJECT_NAME}
"srv/ListPersons.srv"
"srv/DeletePerson.srv"
"srv/UpdatePerson.srv"
# Personality system (Issue #84)
# Issue #86 LED expression + motor attention
"msg/Mood.msg"
"msg/Person.msg"
"msg/PersonArray.msg"
# Issue #84 personality system
"msg/PersonalityState.msg"
"srv/QueryMood.srv"
DEPENDENCIES std_msgs geometry_msgs builtin_interfaces

View File

@ -0,0 +1,7 @@
# Mood.msg — social expression command sent to the LED display node.
#
# mood : one of "happy", "curious", "annoyed", "playful", "idle"
# intensity : 0.0 (off) to 1.0 (full brightness)
string mood
float32 intensity

View File

@ -0,0 +1,17 @@
# Person.msg — single tracked person for social attention.
#
# bearing_rad : signed bearing in base_link frame (rad)
# positive = person to the left (CCW), negative = right (CW)
# distance_m : estimated distance in metres; 0 if unknown
# confidence : detection confidence 0..1
# is_speaking : true if mic DOA or VAD identifies this person as speaking
# source : "camera" | "mic_doa"
std_msgs/Header header
int32 track_id
float32 bearing_rad
float32 distance_m
float32 confidence
bool is_speaking
string source

View File

@ -0,0 +1,9 @@
# PersonArray.msg — all detected persons plus the current attention target.
#
# persons : all currently tracked persons
# active_id : track_id of the active/speaking person; -1 if none
std_msgs/Header header
saltybot_social_msgs/Person[] persons
int32 active_id

View File

@ -0,0 +1,4 @@
sensor_msgs/Image crop
---
bool success
float32[512] embedding