Compare commits

...

7 Commits

Author SHA1 Message Date
44771751e2 feat(social): personality system — SOUL.md persona, mood engine, relationship DB (Issue #84)
New packages:
- saltybot_social_msgs: PersonalityState.msg + QueryMood.srv custom interfaces
- saltybot_social_personality: full personality node

Features:
- SOUL.md YAML/Markdown persona file: name, humor_level (0-10), sass_level (0-10),
  base_mood, per-tier greeting templates, mood prefix strings
- Hot-reload: SoulWatcher polls SOUL.md every reload_interval seconds, applies
  changes live without restarting the node
- Per-person relationship memory in SQLite: score, interaction_count,
  first/last_seen, learned preferences (JSON), full interaction log
- Mood engine (pure functions): happy | curious | annoyed | playful
  driven by relationship score, interaction count, recent event window (120s)
- Greeting personalisation: stranger | regular | favorite tiers
  keyed on interaction count thresholds from SOUL.md
- Publishes /social/personality/state (PersonalityState) at publish_rate Hz
- /social/personality/query_mood (QueryMood) service for on-demand mood query
- Full ROS2 dynamic reconfigure: soul_file, db_path, reload_interval, publish_rate
- 52 unit tests, no ROS2 runtime required

ROS2 interfaces:
  Sub: /social/person_detected  (std_msgs/String JSON)
  Pub: /social/personality/state (saltybot_social_msgs/PersonalityState)
  Srv: /social/personality/query_mood (saltybot_social_msgs/QueryMood)
2026-03-01 23:56:05 -05:00
dc746ccedc Merge pull request 'feat(social): face detection + recognition #80' (#96) from sl-perception/social-face-detection into main 2026-03-01 23:55:18 -05:00
d6a6965af6 Merge pull request 'feat(social): person enrollment system #87' (#95) from sl-perception/social-enrollment into main 2026-03-01 23:55:16 -05:00
35b940e1f5 Merge pull request 'feat(social): Issue #86 — physical expression + motor attention' (#94) from sl-firmware/social-expression into main 2026-03-01 23:55:14 -05:00
5143e5bfac feat(social): Issue #86 — physical expression + motor attention
ESP32-C3 NeoPixel sketch (esp32/social_expression/social_expression.ino):
  - Adafruit NeoPixel + ArduinoJson, serial JSON protocol 115200 8N1
  - Mood→colour: happy=green, curious=blue, annoyed=red, playful=rainbow
  - Idle breathing animation (sine-modulated warm white)
  - Auto-falls to idle after IDLE_TIMEOUT_MS (3 s) with no command

ROS2 saltybot_social_msgs (new package):
  - Mood.msg — {mood, intensity}
  - Person.msg — {track_id, bearing_rad, distance_m, confidence, is_speaking, source}
  - PersonArray.msg — {persons[], active_id}

ROS2 saltybot_social (new package):
  - expression_node: subscribes /social/mood → JSON serial to ESP32-C3
    reconnects on port error; sends idle frame after idle_timeout_s
  - attention_node: subscribes /social/persons → /cmd_vel rotation-only
    proportional control with dead zone; prefers active speaker, falls
    back to highest-confidence person; lost-target idle after 2 s
  - launch/social.launch.py — combined launch
  - config YAML for both nodes with documented parameters
  - test/test_attention.py — 15 pytest-only unit tests

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-01 23:35:59 -05:00
5c4f18e46c feat(social): person enrollment system — SQLite gallery + voice trigger (Issue #87)
- saltybot_social_msgs: 6 msg + 5 srv definitions for social interaction
- saltybot_social_enrollment: enrollment_node + enrollment_cli
- PersonDB: thread-safe SQLite-backed gallery (embeddings, voice samples)
- Voice-triggered enrollment via "remember me my name is X" phrase
- CLI: enroll/list/delete/rename via ros2 run
- Services: /social/enroll, /social/persons/list|delete|update
- Gallery sync from /social/faces/embeddings topic
2026-03-01 23:32:26 -05:00
f61a03b3c5 feat(social): face detection + recognition (SCRFD + ArcFace TRT FP16, Issue #80)
Add two new ROS2 packages for the social sprint:

saltybot_social_msgs (ament_cmake):
- FaceDetection, FaceDetectionArray, FaceEmbedding, FaceEmbeddingArray
- PersonState, PersonStateArray
- EnrollPerson, ListPersons, DeletePerson, UpdatePerson services

saltybot_social_face (ament_python):
- SCRFDDetector: SCRFD face detection with TRT FP16 + ONNX fallback
  - 640x640 input, 3-stride anchor decoding, NMS
- ArcFaceRecognizer: 512-dim embedding extraction with gallery matching
  - 5-point landmark alignment to 112x112, cosine similarity
- FaceGallery: thread-safe persistent gallery (npz + JSON sidecar)
- FaceRecognitionNode: ROS2 node subscribing /camera/color/image_raw,
  publishing /social/faces/detections, /social/faces/embeddings
- Enrollment via /social/enroll service (N-sample face averaging)
- Launch file, config YAML, TRT engine builder script

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-01 23:31:48 -05:00
53 changed files with 5024 additions and 4 deletions

View File

@ -0,0 +1,206 @@
/*
* social_expression.ino LED mood display for saltybot social layer.
*
* Hardware
*
* MCU : ESP32-C3 (e.g. Seeed XIAO ESP32-C3 or equivalent)
* LEDs : NeoPixel ring or strip, DATA_PIN GPIO5, power via 5V rail
*
* Serial protocol (from Orin Nano, 115200 8N1)
*
* One JSON line per command, newline-terminated:
* {"mood":"happy","intensity":0.8}\n
*
* mood values : "happy" | "curious" | "annoyed" | "playful" | "idle"
* intensity : 0.0 (off) .. 1.0 (full brightness)
*
* If no command arrives for IDLE_TIMEOUT_MS, the node enters the
* breathing idle animation automatically.
*
* Mood colour mapping
*
* happy green #00FF00
* curious blue #0040FF
* annoyed red #FF1000
* playful rainbow (cycling hue across all pixels)
* idle soft white breathing (sine modulated)
*
* Dependencies (Arduino Library Manager)
*
* Adafruit NeoPixel 1.11.0
* ArduinoJson 7.0.0
*
* Build: Arduino IDE 2 / arduino-cli with board "ESP32C3 Dev Module"
* Board manager URL: https://raw.githubusercontent.com/espressif/arduino-esp32/gh-pages/package_esp32_index.json
*/
#include <Adafruit_NeoPixel.h>
#include <ArduinoJson.h>
// ── Hardware config ──────────────────────────────────────────────────────────
#define LED_PIN 5 // GPIO connected to NeoPixel data line
#define LED_COUNT 16 // pixels in ring / strip — adjust as needed
#define LED_BRIGHTNESS 120 // global cap 0-255 (≈47%), protects PSU
// ── Timing constants ─────────────────────────────────────────────────────────
#define IDLE_TIMEOUT_MS 3000 // fall back to breathing after 3 s silence
#define LOOP_INTERVAL_MS 20 // animation tick ≈ 50 Hz
// ── Colours (R, G, B) ────────────────────────────────────────────────────────
static const uint8_t COL_HAPPY[3] = { 0, 220, 0 }; // green
static const uint8_t COL_CURIOUS[3] = { 0, 64, 255 }; // blue
static const uint8_t COL_ANNOYED[3] = {255, 16, 0 }; // red
// ── State ────────────────────────────────────────────────────────────────────
enum Mood { MOOD_IDLE, MOOD_HAPPY, MOOD_CURIOUS, MOOD_ANNOYED, MOOD_PLAYFUL };
static Mood g_mood = MOOD_IDLE;
static float g_intensity = 1.0f;
static uint32_t g_last_cmd_ms = 0; // millis() of last received command
// Animation counters
static uint16_t g_rainbow_hue = 0; // 0..65535, cycles for playful
static uint32_t g_last_tick = 0;
// Serial receive buffer
static char g_serial_buf[128];
static uint8_t g_buf_pos = 0;
// ── NeoPixel object ──────────────────────────────────────────────────────────
Adafruit_NeoPixel strip(LED_COUNT, LED_PIN, NEO_GRB + NEO_KHZ800);
// ── Helpers ──────────────────────────────────────────────────────────────────
static uint8_t scale(uint8_t v, float intensity) {
return (uint8_t)(v * intensity);
}
/*
* Breathing: sine envelope over time.
* Returns brightness factor 0.0..1.0, period ~4 s.
*/
static float breath_factor(uint32_t now_ms) {
float phase = (float)(now_ms % 4000) / 4000.0f; // 0..1 per period
return 0.08f + 0.35f * (0.5f + 0.5f * sinf(2.0f * M_PI * phase));
}
static void set_all(uint8_t r, uint8_t g, uint8_t b) {
for (int i = 0; i < LED_COUNT; i++) {
strip.setPixelColor(i, strip.Color(r, g, b));
}
}
// ── Animation drivers ────────────────────────────────────────────────────────
static void animate_solid(const uint8_t col[3], float intensity) {
set_all(scale(col[0], intensity),
scale(col[1], intensity),
scale(col[2], intensity));
}
static void animate_breathing(uint32_t now_ms, float intensity) {
float bf = breath_factor(now_ms) * intensity;
// Warm white: R=255 G=200 B=120
set_all(scale(255, bf), scale(200, bf), scale(120, bf));
}
static void animate_rainbow(float intensity) {
// Spread full wheel across all pixels
for (int i = 0; i < LED_COUNT; i++) {
uint16_t hue = g_rainbow_hue + (uint16_t)((float)i / LED_COUNT * 65536.0f);
uint32_t rgb = strip.ColorHSV(hue, 255,
(uint8_t)(255.0f * intensity));
strip.setPixelColor(i, rgb);
}
// Advance hue each tick (full cycle in ~6 s at 50 Hz)
g_rainbow_hue += 218;
}
// ── Serial parser ────────────────────────────────────────────────────────────
static void parse_command(const char *line) {
StaticJsonDocument<128> doc;
DeserializationError err = deserializeJson(doc, line);
if (err) return;
const char *mood_str = doc["mood"] | "";
float intensity = doc["intensity"] | 1.0f;
if (intensity < 0.0f) intensity = 0.0f;
if (intensity > 1.0f) intensity = 1.0f;
if (strcmp(mood_str, "happy") == 0) g_mood = MOOD_HAPPY;
else if (strcmp(mood_str, "curious") == 0) g_mood = MOOD_CURIOUS;
else if (strcmp(mood_str, "annoyed") == 0) g_mood = MOOD_ANNOYED;
else if (strcmp(mood_str, "playful") == 0) g_mood = MOOD_PLAYFUL;
else g_mood = MOOD_IDLE;
g_intensity = intensity;
g_last_cmd_ms = millis();
}
static void read_serial(void) {
while (Serial.available()) {
char c = (char)Serial.read();
if (c == '\n' || c == '\r') {
if (g_buf_pos > 0) {
g_serial_buf[g_buf_pos] = '\0';
parse_command(g_serial_buf);
g_buf_pos = 0;
}
} else if (g_buf_pos < (sizeof(g_serial_buf) - 1)) {
g_serial_buf[g_buf_pos++] = c;
} else {
// Buffer overflow — discard line
g_buf_pos = 0;
}
}
}
// ── Arduino entry points ─────────────────────────────────────────────────────
void setup(void) {
Serial.begin(115200);
strip.begin();
strip.setBrightness(LED_BRIGHTNESS);
strip.clear();
strip.show();
g_last_tick = millis();
}
void loop(void) {
uint32_t now = millis();
read_serial();
// Fall back to idle if no command for IDLE_TIMEOUT_MS
if ((now - g_last_cmd_ms) > IDLE_TIMEOUT_MS) {
g_mood = MOOD_IDLE;
}
// Throttle animation ticks
if ((now - g_last_tick) < LOOP_INTERVAL_MS) return;
g_last_tick = now;
switch (g_mood) {
case MOOD_HAPPY:
animate_solid(COL_HAPPY, g_intensity);
break;
case MOOD_CURIOUS:
animate_solid(COL_CURIOUS, g_intensity);
break;
case MOOD_ANNOYED:
animate_solid(COL_ANNOYED, g_intensity);
break;
case MOOD_PLAYFUL:
animate_rainbow(g_intensity);
break;
case MOOD_IDLE:
default:
animate_breathing(now, g_intensity > 0.0f ? g_intensity : 1.0f);
break;
}
strip.show();
}

View File

@ -0,0 +1,30 @@
# attention_params.yaml — Social attention controller configuration.
#
# ── Proportional gain ─────────────────────────────────────────────────────────
# angular_z = clamp(kp_angular * bearing_rad, ±max_angular_vel)
#
# kp_angular : 1.0 → 1.0 rad/s per 1 rad error (57° → full speed).
# Increase for snappier tracking; decrease if robot oscillates.
kp_angular: 1.0 # rad/s per rad
# ── Velocity limits ──────────────────────────────────────────────────────────
# Hard cap on rotation output. In social mode the balance loop must not be
# disturbed by large sudden angular commands; keep this ≤ 1.0 rad/s.
max_angular_vel: 0.8 # rad/s
# ── Dead zone ────────────────────────────────────────────────────────────────
# If |bearing| ≤ dead_zone_rad, rotation is suppressed.
# 0.15 rad ≈ 8.6° — prevents chattering when person is roughly centred.
dead_zone_rad: 0.15 # radians
# ── Control loop ─────────────────────────────────────────────────────────────
control_rate: 20.0 # Hz
# ── Lost-target timeout ───────────────────────────────────────────────────────
# If no /social/persons message arrives for this many seconds, publish zero
# and enter IDLE state. The cmd_vel bridge deadman will also kick in.
lost_timeout_s: 2.0 # seconds
# ── Master enable ─────────────────────────────────────────────────────────────
# Toggle at runtime: ros2 param set /attention_node attention_enabled false
attention_enabled: true

View File

@ -0,0 +1,22 @@
# expression_params.yaml — LED mood display bridge configuration.
#
# serial_port : udev symlink or device path for the ESP32-C3 USB-CDC port.
# Recommended: create a udev rule:
# SUBSYSTEM=="tty", ATTRS{idVendor}=="303a", ATTRS{idProduct}=="1001",
# SYMLINK+="esp32-social"
#
# baud_rate : must match social_expression.ino (default 115200)
#
# idle_timeout_s : if no /social/mood message arrives for this many seconds,
# the node sends {"mood":"idle","intensity":1.0} so the ESP32 breathing
# animation is synchronised with node awareness.
#
# control_rate : how often to check the idle timeout (Hz).
# Does NOT gate the mood messages — those are forwarded immediately.
expression_node:
ros__parameters:
serial_port: /dev/esp32-social
baud_rate: 115200
idle_timeout_s: 3.0
control_rate: 10.0

View File

@ -1,3 +1,18 @@
"""
social.launch.py Launch the full saltybot social stack.
Includes:
person_state_tracker multi-modal person identity fusion (Issue #82)
expression_node /social/mood ESP32-C3 NeoPixel serial (Issue #86)
attention_node /social/persons /cmd_vel rotation (Issue #86)
Usage:
ros2 launch saltybot_social social.launch.py
ros2 launch saltybot_social social.launch.py serial_port:=/dev/ttyUSB1
"""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration from launch.substitutions import LaunchConfiguration
@ -5,7 +20,12 @@ from launch_ros.actions import Node
def generate_launch_description(): def generate_launch_description():
pkg = get_package_share_directory("saltybot_social")
exp_cfg = os.path.join(pkg, "config", "expression_params.yaml")
att_cfg = os.path.join(pkg, "config", "attention_params.yaml")
return LaunchDescription([ return LaunchDescription([
# person_state_tracker args (Issue #82)
DeclareLaunchArgument( DeclareLaunchArgument(
'engagement_distance', 'engagement_distance',
default_value='2.0', default_value='2.0',
@ -21,6 +41,19 @@ def generate_launch_description():
default_value='false', default_value='false',
description='Whether UWB anchor data is available' description='Whether UWB anchor data is available'
), ),
# expression_node args (Issue #86)
DeclareLaunchArgument("serial_port", default_value="/dev/esp32-social"),
DeclareLaunchArgument("baud_rate", default_value="115200"),
DeclareLaunchArgument("idle_timeout_s", default_value="3.0"),
# attention_node args (Issue #86)
DeclareLaunchArgument("kp_angular", default_value="1.0"),
DeclareLaunchArgument("max_angular_vel", default_value="0.8"),
DeclareLaunchArgument("dead_zone_rad", default_value="0.15"),
DeclareLaunchArgument("lost_timeout_s", default_value="2.0"),
DeclareLaunchArgument("attention_enabled", default_value="true"),
Node( Node(
package='saltybot_social', package='saltybot_social',
executable='person_state_tracker', executable='person_state_tracker',
@ -32,4 +65,36 @@ def generate_launch_description():
'uwb_enabled': LaunchConfiguration('uwb_enabled'), 'uwb_enabled': LaunchConfiguration('uwb_enabled'),
}], }],
), ),
Node(
package="saltybot_social",
executable="expression_node",
name="expression_node",
output="screen",
parameters=[
exp_cfg,
{
"serial_port": LaunchConfiguration("serial_port"),
"baud_rate": LaunchConfiguration("baud_rate"),
"idle_timeout_s": LaunchConfiguration("idle_timeout_s"),
},
],
),
Node(
package="saltybot_social",
executable="attention_node",
name="attention_node",
output="screen",
parameters=[
att_cfg,
{
"kp_angular": LaunchConfiguration("kp_angular"),
"max_angular_vel": LaunchConfiguration("max_angular_vel"),
"dead_zone_rad": LaunchConfiguration("dead_zone_rad"),
"lost_timeout_s": LaunchConfiguration("lost_timeout_s"),
"attention_enabled": LaunchConfiguration("attention_enabled"),
},
],
),
]) ])

View File

@ -3,7 +3,12 @@
<package format="3"> <package format="3">
<name>saltybot_social</name> <name>saltybot_social</name>
<version>0.1.0</version> <version>0.1.0</version>
<description>Multi-modal person identity fusion and state tracking for saltybot</description> <description>
Social interaction layer for saltybot.
person_state_tracker: multi-modal person identity fusion (Issue #82).
expression_node: bridges /social/mood to ESP32-C3 NeoPixel ring over serial (Issue #86).
attention_node: rotates robot toward active speaker via /social/persons bearing (Issue #86).
</description>
<maintainer email="seb@vayrette.com">seb</maintainer> <maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license> <license>MIT</license>
<depend>rclpy</depend> <depend>rclpy</depend>

View File

@ -0,0 +1,205 @@
"""
attention_node.py Social attention controller for saltybot.
Rotates the robot to face the active speaker or most confident detected person.
Publishes rotation-only cmd_vel (linear.x = 0) so balance is not disturbed.
Subscribes
/social/persons (saltybot_social_msgs/PersonArray)
Contains all tracked persons; active_id marks the current speaker.
bearing_rad: signed bearing in base_link (+ve = left/CCW).
Publishes
/cmd_vel (geometry_msgs/Twist)
angular.z = clamp(kp_angular * bearing, ±max_angular_vel)
linear.x = 0 always (social mode, not follow mode)
Control law
1. Pick target: person with active_id; fall back to highest-confidence person.
2. Dead zone: if |bearing| dead_zone_rad angular_z = 0.
3. Proportional: angular_z = kp_angular * bearing.
4. Clamp to ±max_angular_vel.
5. If no fresh person data for lost_timeout_s publish zero and stay idle.
State machine
ATTENDING fresh target; publishing proportional rotation
IDLE no target; publishing zero
Parameters
kp_angular (float) 1.0 proportional gain (rad/s per rad error)
max_angular_vel (float) 0.8 hard cap (rad/s)
dead_zone_rad (float) 0.15 ~8.6° dead zone around robot heading
control_rate (float) 20.0 Hz
lost_timeout_s (float) 2.0 seconds before going idle
attention_enabled (bool) True runtime kill switch
"""
import math
import time
import rclpy
from rclpy.node import Node
from geometry_msgs.msg import Twist
from saltybot_social_msgs.msg import PersonArray
# ── Pure helpers (used by tests) ──────────────────────────────────────────────
def clamp(v: float, lo: float, hi: float) -> float:
return max(lo, min(hi, v))
def select_target(persons, active_id: int):
"""
Pick the attention target from a PersonArray's persons list.
Priority:
1. Person whose track_id == active_id (speaking / explicitly active).
2. Person with highest confidence (fallback when no speaker declared).
Returns the selected Person message, or None if persons is empty.
"""
if not persons:
return None
# Prefer the declared active speaker
for p in persons:
if p.track_id == active_id:
return p
# Fall back to most confident detection
return max(persons, key=lambda p: p.confidence)
def compute_attention_cmd(
bearing_rad: float,
dead_zone_rad: float,
kp_angular: float,
max_angular_vel: float,
) -> float:
"""
Proportional rotation command toward bearing.
Returns angular_z (rad/s).
Zero inside dead zone; proportional + clamped outside.
"""
if abs(bearing_rad) <= dead_zone_rad:
return 0.0
return clamp(kp_angular * bearing_rad, -max_angular_vel, max_angular_vel)
# ── ROS2 node ─────────────────────────────────────────────────────────────────
class AttentionNode(Node):
def __init__(self):
super().__init__("attention_node")
self.declare_parameter("kp_angular", 1.0)
self.declare_parameter("max_angular_vel", 0.8)
self.declare_parameter("dead_zone_rad", 0.15)
self.declare_parameter("control_rate", 20.0)
self.declare_parameter("lost_timeout_s", 2.0)
self.declare_parameter("attention_enabled", True)
self._reload_params()
# State
self._last_persons_t = 0.0 # monotonic; 0 = never received
self._persons = []
self._active_id = -1
self._state = "idle"
self.create_subscription(
PersonArray, "/social/persons", self._persons_cb, 10
)
self._cmd_pub = self.create_publisher(Twist, "/cmd_vel", 10)
rate = self._p["control_rate"]
self._timer = self.create_timer(1.0 / rate, self._control_cb)
self.get_logger().info(
f"AttentionNode ready kp={self._p['kp_angular']} "
f"dead_zone={math.degrees(self._p['dead_zone_rad']):.1f}° "
f"max_ω={self._p['max_angular_vel']} rad/s"
)
def _reload_params(self):
self._p = {
"kp_angular": self.get_parameter("kp_angular").value,
"max_angular_vel": self.get_parameter("max_angular_vel").value,
"dead_zone_rad": self.get_parameter("dead_zone_rad").value,
"control_rate": self.get_parameter("control_rate").value,
"lost_timeout_s": self.get_parameter("lost_timeout_s").value,
"attention_enabled": self.get_parameter("attention_enabled").value,
}
# ── Callbacks ─────────────────────────────────────────────────────────────
def _persons_cb(self, msg: PersonArray):
self._persons = msg.persons
self._active_id = msg.active_id
self._last_persons_t = time.monotonic()
if self._state == "idle" and msg.persons:
self._state = "attending"
self.get_logger().info("Attention: person detected — attending")
def _control_cb(self):
self._reload_params()
twist = Twist()
if not self._p["attention_enabled"]:
self._cmd_pub.publish(twist)
return
now = time.monotonic()
fresh = (
self._last_persons_t > 0.0
and (now - self._last_persons_t) < self._p["lost_timeout_s"]
)
if not fresh:
if self._state == "attending":
self._state = "idle"
self.get_logger().info("Attention: no person — idle")
self._cmd_pub.publish(twist)
return
target = select_target(self._persons, self._active_id)
if target is None:
self._cmd_pub.publish(twist)
return
self._state = "attending"
twist.angular.z = compute_attention_cmd(
bearing_rad=target.bearing_rad,
dead_zone_rad=self._p["dead_zone_rad"],
kp_angular=self._p["kp_angular"],
max_angular_vel=self._p["max_angular_vel"],
)
self._cmd_pub.publish(twist)
# ── Entry point ───────────────────────────────────────────────────────────────
def main(args=None):
rclpy.init(args=args)
node = AttentionNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.try_shutdown()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,157 @@
"""
expression_node.py LED mood display bridge for saltybot social layer.
Subscribes
/social/mood (saltybot_social_msgs/Mood)
mood : "happy" | "curious" | "annoyed" | "playful" | "idle"
intensity : 0.0 .. 1.0
Behaviour
* On every /social/mood message, serialises the command as a single-line
JSON frame and writes it over the ESP32-C3 serial link:
{"mood":"happy","intensity":0.80}\n
* If no mood message arrives for idle_timeout_s, sends an explicit
idle command so the ESP32 breathing animation is synchronised with the
node's own awareness of whether the robot is active.
* Re-opens the serial port automatically if the device disconnects.
Parameters
serial_port (str) /dev/ttyUSB0 ESP32-C3 USB-CDC device
baud_rate (int) 115200
idle_timeout_s (float) 3.0 seconds before sending idle
control_rate (float) 10.0 Hz; how often to check idle
"""
import json
import threading
import time
import rclpy
from rclpy.node import Node
from saltybot_social_msgs.msg import Mood
try:
import serial
_SERIAL_AVAILABLE = True
except ImportError:
_SERIAL_AVAILABLE = False
class ExpressionNode(Node):
def __init__(self):
super().__init__("expression_node")
self.declare_parameter("serial_port", "/dev/ttyUSB0")
self.declare_parameter("baud_rate", 115200)
self.declare_parameter("idle_timeout_s", 3.0)
self.declare_parameter("control_rate", 10.0)
self._port = self.get_parameter("serial_port").value
self._baud = self.get_parameter("baud_rate").value
self._idle_to = self.get_parameter("idle_timeout_s").value
rate = self.get_parameter("control_rate").value
self._last_cmd_t = 0.0 # monotonic, 0 = never received
self._lock = threading.Lock()
self._ser = None
self.create_subscription(Mood, "/social/mood", self._mood_cb, 10)
self._timer = self.create_timer(1.0 / rate, self._tick_cb)
if _SERIAL_AVAILABLE:
threading.Thread(target=self._open_serial, daemon=True).start()
else:
self.get_logger().warn(
"pyserial not available — serial output disabled (dry-run mode)"
)
self.get_logger().info(
f"ExpressionNode ready port={self._port} baud={self._baud} "
f"idle_timeout={self._idle_to}s"
)
# ── Serial management ─────────────────────────────────────────────────────
def _open_serial(self):
"""Background thread: keep serial port open, reconnect on error."""
while rclpy.ok():
try:
with self._lock:
self._ser = serial.Serial(
self._port, self._baud, timeout=1.0
)
self.get_logger().info(f"Opened serial port {self._port}")
# Hold serial open until error
while rclpy.ok():
with self._lock:
if self._ser is None or not self._ser.is_open:
break
time.sleep(0.5)
except Exception as exc:
self.get_logger().warn(
f"Serial {self._port} error: {exc} — retry in 3 s"
)
with self._lock:
if self._ser and self._ser.is_open:
self._ser.close()
self._ser = None
time.sleep(3.0)
def _send(self, mood: str, intensity: float):
"""Serialise and write one JSON line to ESP32-C3."""
line = json.dumps({"mood": mood, "intensity": round(intensity, 3)}) + "\n"
data = line.encode("ascii")
with self._lock:
if self._ser and self._ser.is_open:
try:
self._ser.write(data)
except Exception as exc:
self.get_logger().warn(f"Serial write error: {exc}")
self._ser.close()
self._ser = None
else:
# Dry-run: log the frame
self.get_logger().debug(f"[dry-run] → {line.strip()}")
# ── Callbacks ─────────────────────────────────────────────────────────────
def _mood_cb(self, msg: Mood):
intensity = max(0.0, min(1.0, float(msg.intensity)))
self._send(msg.mood, intensity)
self._last_cmd_t = time.monotonic()
def _tick_cb(self):
"""Periodic check — send idle if we've been quiet too long."""
if self._last_cmd_t == 0.0:
return
if (time.monotonic() - self._last_cmd_t) >= self._idle_to:
self._send("idle", 1.0)
# Reset so we don't flood the ESP32 with idle frames
self._last_cmd_t = time.monotonic()
# ── Entry point ───────────────────────────────────────────────────────────────
def main(args=None):
rclpy.init(args=args)
node = ExpressionNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.try_shutdown()
if __name__ == "__main__":
main()

View File

@ -21,12 +21,14 @@ setup(
zip_safe=True, zip_safe=True,
maintainer='seb', maintainer='seb',
maintainer_email='seb@vayrette.com', maintainer_email='seb@vayrette.com',
description='Multi-modal person identity fusion and state tracking for saltybot', description='Social interaction layer — person state tracking, LED expression + attention',
license='MIT', license='MIT',
tests_require=['pytest'], tests_require=['pytest'],
entry_points={ entry_points={
'console_scripts': [ 'console_scripts': [
'person_state_tracker = saltybot_social.person_state_tracker_node:main', 'person_state_tracker = saltybot_social.person_state_tracker_node:main',
'expression_node = saltybot_social.expression_node:main',
'attention_node = saltybot_social.attention_node:main',
], ],
}, },
) )

View File

@ -0,0 +1,140 @@
"""
test_attention.py Unit tests for attention_node helpers.
No ROS2 / serial / GPU dependencies runs with plain pytest.
"""
import math
import pytest
from saltybot_social.attention_node import (
clamp,
compute_attention_cmd,
select_target,
)
# ── Helpers ───────────────────────────────────────────────────────────────────
class FakePerson:
def __init__(self, track_id, bearing_rad, confidence=0.9, is_speaking=False):
self.track_id = track_id
self.bearing_rad = bearing_rad
self.confidence = confidence
self.is_speaking = is_speaking
# ── clamp ─────────────────────────────────────────────────────────────────────
class TestClamp:
def test_within_range(self):
assert clamp(0.5, 0.0, 1.0) == 0.5
def test_below_lo(self):
assert clamp(-0.5, 0.0, 1.0) == 0.0
def test_above_hi(self):
assert clamp(1.5, 0.0, 1.0) == 1.0
def test_at_boundary(self):
assert clamp(0.0, 0.0, 1.0) == 0.0
assert clamp(1.0, 0.0, 1.0) == 1.0
# ── select_target ─────────────────────────────────────────────────────────────
class TestSelectTarget:
def test_empty_returns_none(self):
assert select_target([], active_id=-1) is None
def test_picks_active_id(self):
persons = [FakePerson(1, 0.1, confidence=0.6),
FakePerson(2, 0.5, confidence=0.9)]
result = select_target(persons, active_id=1)
assert result.track_id == 1
def test_falls_back_to_highest_confidence(self):
persons = [FakePerson(1, 0.1, confidence=0.6),
FakePerson(2, 0.5, confidence=0.9),
FakePerson(3, 0.2, confidence=0.4)]
result = select_target(persons, active_id=-1)
assert result.track_id == 2
def test_single_person_no_active(self):
persons = [FakePerson(5, 0.3)]
result = select_target(persons, active_id=-1)
assert result.track_id == 5
def test_active_id_not_in_list_falls_back(self):
persons = [FakePerson(1, 0.0, confidence=0.5),
FakePerson(2, 0.2, confidence=0.95)]
result = select_target(persons, active_id=99)
assert result.track_id == 2
# ── compute_attention_cmd ─────────────────────────────────────────────────────
class TestComputeAttentionCmd:
def test_inside_dead_zone_zero(self):
cmd = compute_attention_cmd(
bearing_rad=0.05,
dead_zone_rad=0.15,
kp_angular=1.0,
max_angular_vel=0.8,
)
assert cmd == 0.0
def test_exactly_at_dead_zone_boundary_zero(self):
cmd = compute_attention_cmd(
bearing_rad=0.15,
dead_zone_rad=0.15,
kp_angular=1.0,
max_angular_vel=0.8,
)
assert cmd == 0.0
def test_positive_bearing_gives_positive_angular_z(self):
cmd = compute_attention_cmd(
bearing_rad=0.4,
dead_zone_rad=0.15,
kp_angular=1.0,
max_angular_vel=0.8,
)
assert cmd > 0.0
def test_negative_bearing_gives_negative_angular_z(self):
cmd = compute_attention_cmd(
bearing_rad=-0.4,
dead_zone_rad=0.15,
kp_angular=1.0,
max_angular_vel=0.8,
)
assert cmd < 0.0
def test_clamps_to_max_angular_vel(self):
cmd = compute_attention_cmd(
bearing_rad=math.pi, # 180° off
dead_zone_rad=0.15,
kp_angular=2.0,
max_angular_vel=0.8,
)
assert abs(cmd) <= 0.8
def test_proportional_scaling(self):
"""Double bearing → double cmd (within linear region)."""
kp = 1.0
dz = 0.05
cmd1 = compute_attention_cmd(0.3, dz, kp, 10.0)
cmd2 = compute_attention_cmd(0.6, dz, kp, 10.0)
assert abs(cmd2 - 2.0 * cmd1) < 0.01
def test_zero_bearing_zero_cmd(self):
cmd = compute_attention_cmd(
bearing_rad=0.0,
dead_zone_rad=0.1,
kp_angular=1.0,
max_angular_vel=0.8,
)
assert cmd == 0.0

View File

@ -0,0 +1,6 @@
enrollment_node:
ros__parameters:
db_path: '/mnt/nvme/saltybot/gallery/persons.db'
voice_samples_dir: '/mnt/nvme/saltybot/gallery/voice'
auto_enroll_phrase: 'remember me my name is'
n_samples_default: 10

View File

@ -0,0 +1,22 @@
"""Launch file for saltybot social enrollment node."""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch_ros.actions import Node
def generate_launch_description():
pkg_share = get_package_share_directory('saltybot_social_enrollment')
config_file = os.path.join(pkg_share, 'config', 'enrollment_params.yaml')
return LaunchDescription([
Node(
package='saltybot_social_enrollment',
executable='enrollment_node',
name='enrollment_node',
parameters=[config_file],
output='screen',
),
])

View File

@ -0,0 +1,23 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social_enrollment</name>
<version>0.1.0</version>
<description>Person enrollment system for saltybot social interaction</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<buildtool_depend>ament_python</buildtool_depend>
<depend>rclpy</depend>
<depend>sensor_msgs</depend>
<depend>cv_bridge</depend>
<depend>std_msgs</depend>
<depend>saltybot_social_msgs</depend>
<exec_depend>python3-numpy</exec_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,184 @@
#!/usr/bin/env python3
"""enrollment_cli.py -- Gallery management CLI for saltybot social.
Usage:
ros2 run saltybot_social_enrollment enrollment_cli enroll --name "Alice" [--samples 15] [--mode face]
ros2 run saltybot_social_enrollment enrollment_cli list
ros2 run saltybot_social_enrollment enrollment_cli delete --id 3
ros2 run saltybot_social_enrollment enrollment_cli rename --id 2 --name "Bob"
"""
import argparse
import sys
import rclpy
from rclpy.node import Node
from saltybot_social_msgs.srv import (
EnrollPerson,
ListPersons,
DeletePerson,
UpdatePerson,
)
class EnrollmentCLI(Node):
def __init__(self):
super().__init__('enrollment_cli')
self._enroll_client = self.create_client(
EnrollPerson, '/social/enroll'
)
self._list_client = self.create_client(
ListPersons, '/social/persons/list'
)
self._delete_client = self.create_client(
DeletePerson, '/social/persons/delete'
)
self._update_client = self.create_client(
UpdatePerson, '/social/persons/update'
)
def enroll(self, name: str, n_samples: int = 10, mode: str = 'face'):
if not self._enroll_client.wait_for_service(timeout_sec=5.0):
print('ERROR: /social/enroll service not available')
return False
req = EnrollPerson.Request()
req.name = name
req.mode = mode
req.n_samples = n_samples
print(f'Enrolling "{name}" ({mode}, {n_samples} samples)...')
future = self._enroll_client.call_async(req)
rclpy.spin_until_future_complete(self, future, timeout_sec=60.0)
if future.result() is None:
print('ERROR: Enrollment timed out')
return False
resp = future.result()
if resp.success:
print(f'OK: Enrolled "{name}" as person_id={resp.person_id}')
else:
print(f'FAILED: {resp.message}')
return resp.success
def list_persons(self):
if not self._list_client.wait_for_service(timeout_sec=5.0):
print('ERROR: /social/persons/list service not available')
return
req = ListPersons.Request()
future = self._list_client.call_async(req)
rclpy.spin_until_future_complete(self, future, timeout_sec=10.0)
if future.result() is None:
print('ERROR: List request timed out')
return
resp = future.result()
if not resp.persons:
print('Gallery is empty.')
return
print(f'{"ID":>4} {"Name":<20} {"Samples":>7} {"Embedding Dim":>13}')
print('-' * 50)
for p in resp.persons:
dim = len(p.embedding) if p.embedding else 0
print(f'{p.person_id:>4} {p.person_name:<20} {p.sample_count:>7} {dim:>13}')
def delete(self, person_id: int):
if not self._delete_client.wait_for_service(timeout_sec=5.0):
print('ERROR: /social/persons/delete service not available')
return False
req = DeletePerson.Request()
req.person_id = person_id
future = self._delete_client.call_async(req)
rclpy.spin_until_future_complete(self, future, timeout_sec=10.0)
if future.result() is None:
print('ERROR: Delete request timed out')
return False
resp = future.result()
if resp.success:
print(f'OK: {resp.message}')
else:
print(f'FAILED: {resp.message}')
return resp.success
def rename(self, person_id: int, new_name: str):
if not self._update_client.wait_for_service(timeout_sec=5.0):
print('ERROR: /social/persons/update service not available')
return False
req = UpdatePerson.Request()
req.person_id = person_id
req.new_name = new_name
future = self._update_client.call_async(req)
rclpy.spin_until_future_complete(self, future, timeout_sec=10.0)
if future.result() is None:
print('ERROR: Rename request timed out')
return False
resp = future.result()
if resp.success:
print(f'OK: {resp.message}')
else:
print(f'FAILED: {resp.message}')
return resp.success
def main(args=None):
parser = argparse.ArgumentParser(
description='saltybot person enrollment CLI'
)
subparsers = parser.add_subparsers(dest='command', required=True)
# enroll
enroll_p = subparsers.add_parser('enroll', help='Enroll a new person')
enroll_p.add_argument('--name', required=True, help='Person name')
enroll_p.add_argument('--samples', type=int, default=10,
help='Number of face samples (default: 10)')
enroll_p.add_argument('--mode', default='face',
help='Enrollment mode (default: face)')
# list
subparsers.add_parser('list', help='List enrolled persons')
# delete
delete_p = subparsers.add_parser('delete', help='Delete a person')
delete_p.add_argument('--id', type=int, required=True,
help='Person ID to delete')
# rename
rename_p = subparsers.add_parser('rename', help='Rename a person')
rename_p.add_argument('--id', type=int, required=True,
help='Person ID to rename')
rename_p.add_argument('--name', required=True, help='New name')
parsed = parser.parse_args(sys.argv[1:])
rclpy.init()
cli = EnrollmentCLI()
try:
if parsed.command == 'enroll':
cli.enroll(parsed.name, parsed.samples, parsed.mode)
elif parsed.command == 'list':
cli.list_persons()
elif parsed.command == 'delete':
cli.delete(parsed.id)
elif parsed.command == 'rename':
cli.rename(parsed.id, parsed.name)
finally:
cli.destroy_node()
rclpy.try_shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,302 @@
"""enrollment_node.py -- ROS2 person enrollment node for saltybot social.
Coordinates person enrollment:
- Forwards /social/enroll to face_recognizer's service
- Owns persistent SQLite gallery (PersonDB)
- Voice-triggered enrollment via "remember me my name is X"
- Gallery management services (list/delete/update)
- Syncs DB from /social/faces/embeddings topic
"""
import threading
import numpy as np
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, DurabilityPolicy
from std_msgs.msg import String
from saltybot_social_msgs.msg import (
FaceDetectionArray,
FaceEmbedding,
FaceEmbeddingArray,
)
from saltybot_social_msgs.srv import (
EnrollPerson,
ListPersons,
DeletePerson,
UpdatePerson,
)
from saltybot_social_enrollment.person_db import PersonDB
class EnrollmentNode(Node):
def __init__(self):
super().__init__('enrollment_node')
# Parameters
self.declare_parameter('db_path', '/mnt/nvme/saltybot/gallery/persons.db')
self.declare_parameter('voice_samples_dir', '/mnt/nvme/saltybot/gallery/voice')
self.declare_parameter('auto_enroll_phrase', 'remember me my name is')
self.declare_parameter('n_samples_default', 10)
db_path = self.get_parameter('db_path').value
self._voice_dir = self.get_parameter('voice_samples_dir').value
self._phrase = self.get_parameter('auto_enroll_phrase').value
self._n_samples = self.get_parameter('n_samples_default').value
self._db = PersonDB(db_path)
self.get_logger().info(f'PersonDB initialized at {db_path}')
# Client to face_recognizer's enroll service
self._enroll_client = self.create_client(
EnrollPerson, '/social/face_recognizer/enroll'
)
# QoS profiles
best_effort_qos = QoSProfile(
depth=10,
reliability=ReliabilityPolicy.BEST_EFFORT,
durability=DurabilityPolicy.VOLATILE,
)
reliable_qos = QoSProfile(
depth=1,
reliability=ReliabilityPolicy.RELIABLE,
durability=DurabilityPolicy.VOLATILE,
)
status_qos = QoSProfile(
depth=1,
reliability=ReliabilityPolicy.BEST_EFFORT,
durability=DurabilityPolicy.VOLATILE,
)
# Subscriptions
self.create_subscription(
FaceDetectionArray, '/social/faces/detections',
self._on_detections, best_effort_qos
)
self.create_subscription(
FaceEmbeddingArray, '/social/faces/embeddings',
self._on_embeddings, reliable_qos
)
self.create_subscription(
String, '/social/speech/transcript',
self._on_transcript, best_effort_qos
)
self.create_subscription(
String, '/social/speech/command',
self._on_command, best_effort_qos
)
# Services
self.create_service(EnrollPerson, '/social/enroll', self._handle_enroll)
self.create_service(ListPersons, '/social/persons/list', self._handle_list)
self.create_service(DeletePerson, '/social/persons/delete', self._handle_delete)
self.create_service(UpdatePerson, '/social/persons/update', self._handle_update)
# Publishers
self._pub_embeddings = self.create_publisher(
FaceEmbeddingArray, '/social/faces/embeddings', reliable_qos
)
self._pub_status = self.create_publisher(
String, '/social/enrollment/status', status_qos
)
self.get_logger().info('EnrollmentNode ready')
# ---- Voice-triggered enrollment ----
def _on_transcript(self, msg: String):
text = msg.data.lower()
phrase = self._phrase.lower()
if phrase in text:
idx = text.index(phrase) + len(phrase)
name = text[idx:].strip()
# Clean up: take first 3 words max as the name
words = name.split()
if words:
name = ' '.join(words[:3]).title()
self.get_logger().info(f'Voice enrollment triggered: "{name}"')
self._trigger_voice_enroll(name)
def _on_command(self, msg: String):
# Reserved for explicit voice commands (e.g., "enroll Alice")
pass
def _trigger_voice_enroll(self, name: str):
if not self._enroll_client.wait_for_service(timeout_sec=1.0):
self.get_logger().warn(
'face_recognizer enroll service not available'
)
self._publish_status(f'Enrollment failed: face_recognizer unavailable')
return
req = EnrollPerson.Request()
req.name = name
req.mode = 'face'
req.n_samples = self._n_samples
future = self._enroll_client.call_async(req)
future.add_done_callback(
lambda f: self._on_enroll_done(f, name)
)
self._publish_status(f'Enrolling "{name}"... look at the camera')
def _on_enroll_done(self, future, name: str):
try:
resp = future.result()
if resp.success:
status = f'Enrolled "{name}" (id={resp.person_id})'
self.get_logger().info(status)
else:
status = f'Enrollment failed for "{name}": {resp.message}'
self.get_logger().warn(status)
self._publish_status(status)
except Exception as e:
self.get_logger().error(f'Enroll call failed: {e}')
self._publish_status(f'Enrollment error: {e}')
# ---- Face detection callback (during enrollment) ----
def _on_detections(self, msg: FaceDetectionArray):
# Reserved for future direct enrollment (without face_recognizer)
pass
# ---- Embeddings sync from face_recognizer ----
def _on_embeddings(self, msg: FaceEmbeddingArray):
for emb in msg.embeddings:
existing = self._db.get_person(emb.person_id)
if existing is None:
arr = np.array(emb.embedding, dtype=np.float32)
self._db.add_person(
emb.person_name, arr, emb.sample_count
)
self.get_logger().info(
f'Synced new person from face_recognizer: '
f'{emb.person_name} (id={emb.person_id})'
)
# ---- Service handlers ----
def _handle_enroll(self, request, response):
"""Forward enrollment to face_recognizer service."""
if not self._enroll_client.wait_for_service(timeout_sec=2.0):
response.success = False
response.message = 'face_recognizer service unavailable'
return response
# Use threading.Event to bridge async call in service callback
event = threading.Event()
result_holder = {}
req = EnrollPerson.Request()
req.name = request.name
req.mode = request.mode
req.n_samples = request.n_samples
future = self._enroll_client.call_async(req)
def _done(f):
try:
result_holder['resp'] = f.result()
except Exception as e:
result_holder['err'] = str(e)
event.set()
future.add_done_callback(_done)
success = event.wait(timeout=35.0)
if not success:
response.success = False
response.message = 'Enrollment timed out'
elif 'resp' in result_holder:
resp = result_holder['resp']
response.success = resp.success
response.message = resp.message
response.person_id = resp.person_id
else:
response.success = False
response.message = result_holder.get('err', 'Unknown error')
return response
def _handle_list(self, request, response):
persons = self._db.get_all()
response.persons = []
for p in persons:
fe = FaceEmbedding()
fe.person_id = p['id']
fe.person_name = p['name']
fe.sample_count = p['sample_count']
fe.enrolled_at.sec = int(p['enrolled_at'])
fe.enrolled_at.nanosec = int(
(p['enrolled_at'] - int(p['enrolled_at'])) * 1e9
)
if p['embedding'] is not None:
fe.embedding = p['embedding'].tolist()
response.persons.append(fe)
return response
def _handle_delete(self, request, response):
success = self._db.delete_person(request.person_id)
response.success = success
if success:
response.message = f'Deleted person {request.person_id}'
self.get_logger().info(response.message)
self._publish_embeddings_from_db()
else:
response.message = f'Person {request.person_id} not found'
return response
def _handle_update(self, request, response):
success = self._db.update_name(request.person_id, request.new_name)
response.success = success
if success:
response.message = f'Updated person {request.person_id} name to "{request.new_name}"'
self.get_logger().info(response.message)
self._publish_embeddings_from_db()
else:
response.message = f'Person {request.person_id} not found'
return response
# ---- Helpers ----
def _publish_status(self, text: str):
msg = String()
msg.data = text
self._pub_status.publish(msg)
def _publish_embeddings_from_db(self):
persons = self._db.get_all()
arr = FaceEmbeddingArray()
arr.header.stamp = self.get_clock().now().to_msg()
for p in persons:
if p['embedding'] is not None:
fe = FaceEmbedding()
fe.person_id = p['id']
fe.person_name = p['name']
fe.sample_count = p['sample_count']
fe.enrolled_at.sec = int(p['enrolled_at'])
fe.enrolled_at.nanosec = int(
(p['enrolled_at'] - int(p['enrolled_at'])) * 1e9
)
fe.embedding = p['embedding'].tolist()
arr.embeddings.append(fe)
self._pub_embeddings.publish(arr)
def main(args=None):
rclpy.init(args=args)
node = EnrollmentNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.try_shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,138 @@
"""person_db.py -- Persistent SQLite person gallery for saltybot enrollment."""
import sqlite3
import json
import time
import numpy as np
import threading
from pathlib import Path
class PersonDB:
"""Thread-safe SQLite-backed person gallery.
Schema:
persons(id INTEGER PRIMARY KEY, name TEXT, enrolled_at REAL,
sample_count INTEGER, embedding BLOB, metadata TEXT)
voice_samples(id INTEGER PRIMARY KEY, person_id INTEGER REFERENCES persons,
recorded_at REAL, sample_path TEXT)
"""
def __init__(self, db_path: str):
self._db_path = db_path
self._lock = threading.Lock()
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
self._init_db()
def _init_db(self):
with self._connect() as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS persons (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
enrolled_at REAL NOT NULL,
sample_count INTEGER DEFAULT 1,
embedding BLOB,
metadata TEXT DEFAULT '{}'
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS voice_samples (
id INTEGER PRIMARY KEY AUTOINCREMENT,
person_id INTEGER REFERENCES persons(id),
recorded_at REAL NOT NULL,
sample_path TEXT NOT NULL
)
""")
def _connect(self):
return sqlite3.connect(self._db_path)
def add_person(self, name: str, embedding: np.ndarray, sample_count: int = 1,
metadata: dict = None) -> int:
"""Add a new person. Returns new person_id."""
emb_blob = embedding.astype(np.float32).tobytes() if embedding is not None else None
now = time.time()
with self._lock:
with self._connect() as conn:
cur = conn.execute(
"INSERT INTO persons (name, enrolled_at, sample_count, embedding, metadata) "
"VALUES (?, ?, ?, ?, ?)",
(name, now, sample_count, emb_blob, json.dumps(metadata or {}))
)
return cur.lastrowid
def update_embedding(self, person_id: int, embedding: np.ndarray,
sample_count: int) -> bool:
emb_blob = embedding.astype(np.float32).tobytes()
with self._lock:
with self._connect() as conn:
conn.execute(
"UPDATE persons SET embedding=?, sample_count=? WHERE id=?",
(emb_blob, sample_count, person_id)
)
return conn.total_changes > 0
def update_name(self, person_id: int, new_name: str) -> bool:
with self._lock:
with self._connect() as conn:
conn.execute(
"UPDATE persons SET name=? WHERE id=?",
(new_name, person_id)
)
return conn.total_changes > 0
def delete_person(self, person_id: int) -> bool:
with self._lock:
with self._connect() as conn:
conn.execute(
"DELETE FROM voice_samples WHERE person_id=?",
(person_id,)
)
conn.execute(
"DELETE FROM persons WHERE id=?",
(person_id,)
)
return conn.total_changes > 0
def get_all(self) -> list:
"""Returns list of dicts with id, name, enrolled_at, sample_count, embedding."""
with self._lock:
with self._connect() as conn:
rows = conn.execute(
"SELECT id, name, enrolled_at, sample_count, embedding FROM persons"
).fetchall()
result = []
for row in rows:
emb = np.frombuffer(row[4], dtype=np.float32).copy() if row[4] else None
result.append({
'id': row[0], 'name': row[1], 'enrolled_at': row[2],
'sample_count': row[3], 'embedding': emb
})
return result
def get_person(self, person_id: int) -> dict | None:
with self._lock:
with self._connect() as conn:
row = conn.execute(
"SELECT id, name, enrolled_at, sample_count, embedding "
"FROM persons WHERE id=?",
(person_id,)
).fetchone()
if row is None:
return None
emb = np.frombuffer(row[4], dtype=np.float32).copy() if row[4] else None
return {
'id': row[0], 'name': row[1], 'enrolled_at': row[2],
'sample_count': row[3], 'embedding': emb
}
def add_voice_sample(self, person_id: int, sample_path: str) -> int:
with self._lock:
with self._connect() as conn:
cur = conn.execute(
"INSERT INTO voice_samples (person_id, recorded_at, sample_path) "
"VALUES (?, ?, ?)",
(person_id, time.time(), sample_path)
)
return cur.lastrowid

View File

@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/saltybot_social_enrollment
[install]
install_scripts=$base/lib/saltybot_social_enrollment

View File

@ -0,0 +1,29 @@
from setuptools import setup
import os
from glob import glob
package_name = 'saltybot_social_enrollment'
setup(
name=package_name,
version='0.1.0',
packages=[package_name],
data_files=[
('share/ament_index/resource_index/packages', ['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
(os.path.join('share', package_name, 'launch'), glob('launch/*.py')),
(os.path.join('share', package_name, 'config'), glob('config/*.yaml')),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='seb',
maintainer_email='seb@vayrette.com',
description='Person enrollment system for saltybot social interaction',
license='MIT',
entry_points={
'console_scripts': [
'enrollment_node = saltybot_social_enrollment.enrollment_node:main',
'enrollment_cli = saltybot_social_enrollment.enrollment_cli:main',
],
},
)

View File

@ -0,0 +1,11 @@
face_recognizer:
ros__parameters:
scrfd_engine_path: '/mnt/nvme/saltybot/models/scrfd_2.5g.engine'
scrfd_onnx_path: '/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx'
arcface_engine_path: '/mnt/nvme/saltybot/models/arcface_r50.engine'
arcface_onnx_path: '/mnt/nvme/saltybot/models/arcface_r50.onnx'
gallery_dir: '/mnt/nvme/saltybot/gallery'
recognition_threshold: 0.35
publish_debug_image: false
max_faces: 10
scrfd_conf_thresh: 0.5

View File

@ -0,0 +1,80 @@
"""
face_recognition.launch.py -- Launch file for the SCRFD + ArcFace face recognition node.
Launches the face_recognizer node with configurable model paths and parameters.
The RealSense camera must be running separately (e.g., via realsense.launch.py).
"""
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
"""Generate launch description for face recognition pipeline."""
return LaunchDescription([
DeclareLaunchArgument(
'scrfd_engine_path',
default_value='/mnt/nvme/saltybot/models/scrfd_2.5g.engine',
description='Path to SCRFD TensorRT engine file',
),
DeclareLaunchArgument(
'scrfd_onnx_path',
default_value='/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx',
description='Path to SCRFD ONNX model file (fallback)',
),
DeclareLaunchArgument(
'arcface_engine_path',
default_value='/mnt/nvme/saltybot/models/arcface_r50.engine',
description='Path to ArcFace TensorRT engine file',
),
DeclareLaunchArgument(
'arcface_onnx_path',
default_value='/mnt/nvme/saltybot/models/arcface_r50.onnx',
description='Path to ArcFace ONNX model file (fallback)',
),
DeclareLaunchArgument(
'gallery_dir',
default_value='/mnt/nvme/saltybot/gallery',
description='Directory for persistent face gallery storage',
),
DeclareLaunchArgument(
'recognition_threshold',
default_value='0.35',
description='Cosine similarity threshold for face recognition',
),
DeclareLaunchArgument(
'publish_debug_image',
default_value='false',
description='Publish annotated debug image to /social/faces/debug_image',
),
DeclareLaunchArgument(
'max_faces',
default_value='10',
description='Maximum faces to process per frame',
),
DeclareLaunchArgument(
'scrfd_conf_thresh',
default_value='0.5',
description='SCRFD detection confidence threshold',
),
Node(
package='saltybot_social_face',
executable='face_recognition',
name='face_recognizer',
output='screen',
parameters=[{
'scrfd_engine_path': LaunchConfiguration('scrfd_engine_path'),
'scrfd_onnx_path': LaunchConfiguration('scrfd_onnx_path'),
'arcface_engine_path': LaunchConfiguration('arcface_engine_path'),
'arcface_onnx_path': LaunchConfiguration('arcface_onnx_path'),
'gallery_dir': LaunchConfiguration('gallery_dir'),
'recognition_threshold': LaunchConfiguration('recognition_threshold'),
'publish_debug_image': LaunchConfiguration('publish_debug_image'),
'max_faces': LaunchConfiguration('max_faces'),
'scrfd_conf_thresh': LaunchConfiguration('scrfd_conf_thresh'),
}],
),
])

View File

@ -0,0 +1,27 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social_face</name>
<version>0.1.0</version>
<description>SCRFD face detection and ArcFace recognition for SaltyBot social interactions</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>sensor_msgs</depend>
<depend>cv_bridge</depend>
<depend>image_transport</depend>
<depend>saltybot_social_msgs</depend>
<exec_depend>python3-numpy</exec_depend>
<exec_depend>python3-opencv</exec_depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1 @@
"""SaltyBot social face detection and recognition package."""

View File

@ -0,0 +1,316 @@
"""
arcface_recognizer.py -- ArcFace face embedding extraction and gallery matching.
Performs face alignment using 5-point landmarks (insightface standard reference),
extracts 512-dimensional embeddings via ArcFace (TRT FP16 or ONNX fallback),
and matches against a persistent gallery using cosine similarity.
"""
import os
import logging
from typing import Optional
import numpy as np
import cv2
logger = logging.getLogger(__name__)
# InsightFace standard reference landmarks for 112x112 alignment
ARCFACE_SRC = np.array([
[38.2946, 51.6963], # left eye
[73.5318, 51.5014], # right eye
[56.0252, 71.7366], # nose
[41.5493, 92.3655], # left mouth
[70.7299, 92.2041], # right mouth
], dtype=np.float32)
# -- Inference backends --------------------------------------------------------
class _TRTBackend:
"""TensorRT inference engine for ArcFace."""
def __init__(self, engine_path: str):
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit # noqa: F401
self._cuda = cuda
trt_logger = trt.Logger(trt.Logger.WARNING)
with open(engine_path, 'rb') as f, trt.Runtime(trt_logger) as runtime:
self._engine = runtime.deserialize_cuda_engine(f.read())
self._context = self._engine.create_execution_context()
self._inputs = []
self._outputs = []
self._bindings = []
for i in range(self._engine.num_io_tensors):
name = self._engine.get_tensor_name(i)
dtype = trt.nptype(self._engine.get_tensor_dtype(name))
shape = tuple(self._engine.get_tensor_shape(name))
nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize
host_mem = cuda.pagelocked_empty(shape, dtype)
device_mem = cuda.mem_alloc(nbytes)
self._bindings.append(int(device_mem))
if self._engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
self._inputs.append({'host': host_mem, 'device': device_mem})
else:
self._outputs.append({'host': host_mem, 'device': device_mem,
'shape': shape})
self._stream = cuda.Stream()
def infer(self, input_data: np.ndarray) -> np.ndarray:
"""Run inference and return the embedding vector."""
np.copyto(self._inputs[0]['host'], input_data.ravel())
self._cuda.memcpy_htod_async(
self._inputs[0]['device'], self._inputs[0]['host'], self._stream)
self._context.execute_async_v2(self._bindings, self._stream.handle)
for out in self._outputs:
self._cuda.memcpy_dtoh_async(out['host'], out['device'], self._stream)
self._stream.synchronize()
return self._outputs[0]['host'].reshape(self._outputs[0]['shape']).copy()
class _ONNXBackend:
"""ONNX Runtime inference (CUDA EP with CPU fallback)."""
def __init__(self, onnx_path: str):
import onnxruntime as ort
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
self._session = ort.InferenceSession(onnx_path, providers=providers)
self._input_name = self._session.get_inputs()[0].name
def infer(self, input_data: np.ndarray) -> np.ndarray:
"""Run inference and return the embedding vector."""
results = self._session.run(None, {self._input_name: input_data})
return results[0]
# -- Face alignment ------------------------------------------------------------
def align_face(bgr: np.ndarray, landmarks_10: list[float]) -> np.ndarray:
"""Align a face to 112x112 using 5-point landmarks.
Args:
bgr: Source BGR image.
landmarks_10: Flat list of 10 floats [x0,y0, x1,y1, ..., x4,y4].
Returns:
Aligned BGR face crop of shape (112, 112, 3).
"""
src_pts = np.array(landmarks_10, dtype=np.float32).reshape(5, 2)
M, _ = cv2.estimateAffinePartial2D(src_pts, ARCFACE_SRC)
if M is None:
# Fallback: simple crop and resize from bbox-like region
cx = np.mean(src_pts[:, 0])
cy = np.mean(src_pts[:, 1])
spread = max(np.ptp(src_pts[:, 0]), np.ptp(src_pts[:, 1])) * 1.5
half = spread / 2
x1 = max(0, int(cx - half))
y1 = max(0, int(cy - half))
x2 = min(bgr.shape[1], int(cx + half))
y2 = min(bgr.shape[0], int(cy + half))
crop = bgr[y1:y2, x1:x2]
return cv2.resize(crop, (112, 112), interpolation=cv2.INTER_LINEAR)
aligned = cv2.warpAffine(bgr, M, (112, 112), borderMode=cv2.BORDER_REPLICATE)
return aligned
# -- Main recognizer class -----------------------------------------------------
class ArcFaceRecognizer:
"""ArcFace face embedding extractor and gallery matcher.
Args:
engine_path: Path to TensorRT engine file.
onnx_path: Path to ONNX model file (used if engine not available).
"""
def __init__(self, engine_path: str = '', onnx_path: str = ''):
self._backend: Optional[_TRTBackend | _ONNXBackend] = None
self.gallery: dict[int, dict] = {}
# Try TRT first, then ONNX
if engine_path and os.path.isfile(engine_path):
try:
self._backend = _TRTBackend(engine_path)
logger.info('ArcFace TensorRT backend loaded: %s', engine_path)
return
except Exception as e:
logger.warning('ArcFace TRT load failed (%s), trying ONNX', e)
if onnx_path and os.path.isfile(onnx_path):
try:
self._backend = _ONNXBackend(onnx_path)
logger.info('ArcFace ONNX backend loaded: %s', onnx_path)
return
except Exception as e:
logger.error('ArcFace ONNX load failed: %s', e)
logger.error('No ArcFace model loaded. Recognition will be unavailable.')
@property
def is_loaded(self) -> bool:
"""Return True if a backend is loaded and ready."""
return self._backend is not None
def embed(self, bgr_face_112x112: np.ndarray) -> np.ndarray:
"""Extract 512-dim L2-normalized embedding from a 112x112 aligned face.
Args:
bgr_face_112x112: Aligned face crop, BGR, shape (112, 112, 3).
Returns:
L2-normalized embedding of shape (512,).
"""
if self._backend is None:
return np.zeros(512, dtype=np.float32)
# Preprocess: BGR->RGB, /255, subtract 0.5, divide 0.5 -> [1,3,112,112]
rgb = cv2.cvtColor(bgr_face_112x112, cv2.COLOR_BGR2RGB).astype(np.float32)
rgb = rgb / 255.0
rgb = (rgb - 0.5) / 0.5
blob = rgb.transpose(2, 0, 1)[np.newaxis] # [1, 3, 112, 112]
blob = np.ascontiguousarray(blob)
output = self._backend.infer(blob)
embedding = output.flatten()[:512].astype(np.float32)
# L2 normalize
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
return embedding
def align_and_embed(self, bgr_image: np.ndarray, landmarks_10: list[float]) -> np.ndarray:
"""Align face using landmarks and extract embedding.
Args:
bgr_image: Full BGR image.
landmarks_10: Flat list of 10 floats from SCRFD detection.
Returns:
L2-normalized embedding of shape (512,).
"""
aligned = align_face(bgr_image, landmarks_10)
return self.embed(aligned)
def load_gallery(self, gallery_path: str) -> None:
"""Load gallery from .npz file with JSON metadata sidecar.
Args:
gallery_path: Path to the .npz gallery file.
"""
import json
if not os.path.isfile(gallery_path):
logger.info('No gallery file at %s, starting empty.', gallery_path)
self.gallery = {}
return
data = np.load(gallery_path, allow_pickle=False)
meta_path = gallery_path.replace('.npz', '_meta.json')
if os.path.isfile(meta_path):
with open(meta_path, 'r') as f:
meta = json.load(f)
else:
meta = {}
self.gallery = {}
for key in data.files:
pid = int(key)
embedding = data[key].astype(np.float32)
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
info = meta.get(str(pid), {})
self.gallery[pid] = {
'name': info.get('name', f'person_{pid}'),
'embedding': embedding,
'samples': info.get('samples', 1),
'enrolled_at': info.get('enrolled_at', 0.0),
}
logger.info('Gallery loaded: %d persons from %s', len(self.gallery), gallery_path)
def save_gallery(self, gallery_path: str) -> None:
"""Save gallery to .npz file with JSON metadata sidecar.
Args:
gallery_path: Path to the .npz gallery file.
"""
import json
arrays = {}
meta = {}
for pid, info in self.gallery.items():
arrays[str(pid)] = info['embedding']
meta[str(pid)] = {
'name': info['name'],
'samples': info['samples'],
'enrolled_at': info['enrolled_at'],
}
os.makedirs(os.path.dirname(gallery_path) or '.', exist_ok=True)
np.savez(gallery_path, **arrays)
meta_path = gallery_path.replace('.npz', '_meta.json')
with open(meta_path, 'w') as f:
json.dump(meta, f, indent=2)
logger.info('Gallery saved: %d persons to %s', len(self.gallery), gallery_path)
def match(self, embedding: np.ndarray, threshold: float = 0.35) -> tuple[int, str, float]:
"""Match an embedding against the gallery.
Args:
embedding: L2-normalized query embedding of shape (512,).
threshold: Minimum cosine similarity for a match.
Returns:
(person_id, person_name, score) or (-1, '', 0.0) if no match.
"""
if not self.gallery:
return (-1, '', 0.0)
best_pid = -1
best_name = ''
best_score = 0.0
for pid, info in self.gallery.items():
score = float(np.dot(embedding, info['embedding']))
if score > best_score:
best_score = score
best_pid = pid
best_name = info['name']
if best_score >= threshold:
return (best_pid, best_name, best_score)
return (-1, '', 0.0)
def enroll(self, person_id: int, person_name: str, embeddings_list: list[np.ndarray]) -> None:
"""Enroll a person by averaging multiple embeddings.
Args:
person_id: Unique integer ID for this person.
person_name: Human-readable name.
embeddings_list: List of L2-normalized embeddings to average.
"""
import time as _time
if not embeddings_list:
return
mean_emb = np.mean(embeddings_list, axis=0).astype(np.float32)
norm = np.linalg.norm(mean_emb)
if norm > 0:
mean_emb = mean_emb / norm
self.gallery[person_id] = {
'name': person_name,
'embedding': mean_emb,
'samples': len(embeddings_list),
'enrolled_at': _time.time(),
}
logger.info('Enrolled person %d (%s) with %d samples.',
person_id, person_name, len(embeddings_list))

View File

@ -0,0 +1,78 @@
#!/usr/bin/env python3
"""
enrollment_cli.py -- CLI tool for enrolling persons via the /social/enroll service.
Usage:
ros2 run saltybot_social_face enrollment_cli -- --name Alice --mode face --samples 10
"""
import argparse
import sys
import rclpy
from rclpy.node import Node
from saltybot_social_msgs.srv import EnrollPerson
class EnrollmentCLI(Node):
"""Simple CLI node that calls the EnrollPerson service."""
def __init__(self, name: str, mode: str, n_samples: int):
super().__init__('enrollment_cli')
self._client = self.create_client(EnrollPerson, '/social/enroll')
self.get_logger().info('Waiting for /social/enroll service...')
if not self._client.wait_for_service(timeout_sec=10.0):
self.get_logger().error('Service /social/enroll not available.')
return
request = EnrollPerson.Request()
request.name = name
request.mode = mode
request.n_samples = n_samples
self.get_logger().info(
'Enrolling "%s" (mode=%s, samples=%d)...', name, mode, n_samples)
future = self._client.call_async(request)
rclpy.spin_until_future_complete(self, future, timeout_sec=120.0)
if future.result() is not None:
result = future.result()
if result.success:
self.get_logger().info(
'Enrollment successful: person_id=%d, %s',
result.person_id, result.message)
else:
self.get_logger().error(
'Enrollment failed: %s', result.message)
else:
self.get_logger().error('Enrollment service call timed out or failed.')
def main(args=None):
"""Entry point for enrollment CLI."""
parser = argparse.ArgumentParser(description='Enroll a person for face recognition.')
parser.add_argument('--name', type=str, required=True,
help='Name of the person to enroll.')
parser.add_argument('--mode', type=str, default='face',
choices=['face', 'voice', 'both'],
help='Enrollment mode (default: face).')
parser.add_argument('--samples', type=int, default=10,
help='Number of face samples to collect (default: 10).')
# Parse only known args so ROS2 remapping args pass through
parsed, remaining = parser.parse_known_args(args=sys.argv[1:])
rclpy.init(args=remaining)
node = EnrollmentCLI(parsed.name, parsed.mode, parsed.samples)
try:
pass # Node does all work in __init__
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,206 @@
"""
face_gallery.py -- Persistent face embedding gallery backed by numpy .npz + JSON.
Thread-safe gallery storage for face recognition. Embeddings are stored in a
.npz file, with a sidecar metadata.json containing names, sample counts, and
enrollment timestamps. Auto-increment IDs start at 1.
"""
import json
import logging
import os
import threading
import time
from typing import Optional
import numpy as np
logger = logging.getLogger(__name__)
class FaceGallery:
"""Persistent, thread-safe face embedding gallery.
Args:
gallery_dir: Directory for gallery.npz and metadata.json files.
"""
def __init__(self, gallery_dir: str):
self._gallery_dir = gallery_dir
self._npz_path = os.path.join(gallery_dir, 'gallery.npz')
self._meta_path = os.path.join(gallery_dir, 'metadata.json')
self._gallery: dict[int, dict] = {}
self._next_id = 1
self._lock = threading.Lock()
def load(self) -> None:
"""Load gallery from disk. Populates internal gallery dict."""
with self._lock:
self._gallery = {}
self._next_id = 1
if not os.path.isfile(self._npz_path):
logger.info('No gallery file at %s, starting empty.', self._npz_path)
return
data = np.load(self._npz_path, allow_pickle=False)
meta: dict = {}
if os.path.isfile(self._meta_path):
with open(self._meta_path, 'r') as f:
meta = json.load(f)
for key in data.files:
pid = int(key)
embedding = data[key].astype(np.float32)
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
info = meta.get(str(pid), {})
self._gallery[pid] = {
'name': info.get('name', f'person_{pid}'),
'embedding': embedding,
'samples': info.get('samples', 1),
'enrolled_at': info.get('enrolled_at', 0.0),
}
if pid >= self._next_id:
self._next_id = pid + 1
logger.info('Gallery loaded: %d persons from %s',
len(self._gallery), self._npz_path)
def save(self) -> None:
"""Save gallery to disk (npz + JSON sidecar)."""
with self._lock:
os.makedirs(self._gallery_dir, exist_ok=True)
arrays = {}
meta = {}
for pid, info in self._gallery.items():
arrays[str(pid)] = info['embedding']
meta[str(pid)] = {
'name': info['name'],
'samples': info['samples'],
'enrolled_at': info['enrolled_at'],
}
np.savez(self._npz_path, **arrays)
with open(self._meta_path, 'w') as f:
json.dump(meta, f, indent=2)
logger.info('Gallery saved: %d persons to %s',
len(self._gallery), self._npz_path)
def add_person(self, name: str, embedding: np.ndarray, samples: int = 1) -> int:
"""Add a new person to the gallery.
Args:
name: Person's name.
embedding: L2-normalized 512-dim embedding.
samples: Number of samples used to compute the embedding.
Returns:
Assigned person_id (auto-increment integer).
"""
with self._lock:
pid = self._next_id
self._next_id += 1
emb = embedding.astype(np.float32)
norm = np.linalg.norm(emb)
if norm > 0:
emb = emb / norm
self._gallery[pid] = {
'name': name,
'embedding': emb,
'samples': samples,
'enrolled_at': time.time(),
}
logger.info('Added person %d (%s), %d samples.', pid, name, samples)
return pid
def update_name(self, person_id: int, new_name: str) -> bool:
"""Update a person's name.
Args:
person_id: The ID of the person to update.
new_name: New name string.
Returns:
True if the person was found and updated.
"""
with self._lock:
if person_id not in self._gallery:
return False
self._gallery[person_id]['name'] = new_name
return True
def delete_person(self, person_id: int) -> bool:
"""Remove a person from the gallery.
Args:
person_id: The ID of the person to delete.
Returns:
True if the person was found and removed.
"""
with self._lock:
if person_id not in self._gallery:
return False
del self._gallery[person_id]
logger.info('Deleted person %d.', person_id)
return True
def get_all(self) -> list[dict]:
"""Get all gallery entries.
Returns:
List of dicts with keys: person_id, name, embedding, samples, enrolled_at.
"""
with self._lock:
result = []
for pid, info in self._gallery.items():
result.append({
'person_id': pid,
'name': info['name'],
'embedding': info['embedding'].copy(),
'samples': info['samples'],
'enrolled_at': info['enrolled_at'],
})
return result
def match(self, query_embedding: np.ndarray, threshold: float = 0.35) -> tuple[int, str, float]:
"""Match a query embedding against the gallery using cosine similarity.
Args:
query_embedding: L2-normalized 512-dim embedding.
threshold: Minimum cosine similarity for a match.
Returns:
(person_id, name, score) or (-1, '', 0.0) if no match.
"""
with self._lock:
if not self._gallery:
return (-1, '', 0.0)
best_pid = -1
best_name = ''
best_score = 0.0
query = query_embedding.astype(np.float32)
norm = np.linalg.norm(query)
if norm > 0:
query = query / norm
for pid, info in self._gallery.items():
score = float(np.dot(query, info['embedding']))
if score > best_score:
best_score = score
best_pid = pid
best_name = info['name']
if best_score >= threshold:
return (best_pid, best_name, best_score)
return (-1, '', 0.0)
def __len__(self) -> int:
with self._lock:
return len(self._gallery)

View File

@ -0,0 +1,431 @@
"""
face_recognition_node.py -- ROS2 node for SCRFD face detection + ArcFace recognition.
Pipeline:
1. Subscribe to /camera/color/image_raw (RealSense D435i color stream).
2. Run SCRFD face detection (TensorRT FP16 or ONNX fallback).
3. For each detected face, align and extract ArcFace embedding.
4. Match embedding against persistent gallery.
5. Publish FaceDetectionArray with identified faces.
Services:
/social/enroll -- Enroll a new person (collects N face samples).
/social/persons/list -- List all enrolled persons.
/social/persons/delete -- Delete a person from the gallery.
/social/persons/update -- Update a person's name.
"""
import time
import numpy as np
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy, DurabilityPolicy
import cv2
from cv_bridge import CvBridge
from sensor_msgs.msg import Image
from builtin_interfaces.msg import Time
from saltybot_social_msgs.msg import (
FaceDetection,
FaceDetectionArray,
FaceEmbedding,
FaceEmbeddingArray,
)
from saltybot_social_msgs.srv import (
EnrollPerson,
ListPersons,
DeletePerson,
UpdatePerson,
)
from .scrfd_detector import SCRFDDetector
from .arcface_recognizer import ArcFaceRecognizer
from .face_gallery import FaceGallery
class FaceRecognitionNode(Node):
"""ROS2 node: SCRFD face detection + ArcFace gallery matching."""
def __init__(self):
super().__init__('face_recognizer')
self._bridge = CvBridge()
self._frame_count = 0
self._fps_t0 = time.monotonic()
# -- Parameters --------------------------------------------------------
self.declare_parameter('scrfd_engine_path',
'/mnt/nvme/saltybot/models/scrfd_2.5g.engine')
self.declare_parameter('scrfd_onnx_path',
'/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx')
self.declare_parameter('arcface_engine_path',
'/mnt/nvme/saltybot/models/arcface_r50.engine')
self.declare_parameter('arcface_onnx_path',
'/mnt/nvme/saltybot/models/arcface_r50.onnx')
self.declare_parameter('gallery_dir', '/mnt/nvme/saltybot/gallery')
self.declare_parameter('recognition_threshold', 0.35)
self.declare_parameter('publish_debug_image', False)
self.declare_parameter('max_faces', 10)
self.declare_parameter('scrfd_conf_thresh', 0.5)
self._recognition_threshold = self.get_parameter('recognition_threshold').value
self._pub_debug_flag = self.get_parameter('publish_debug_image').value
self._max_faces = self.get_parameter('max_faces').value
# -- Models ------------------------------------------------------------
self._detector = SCRFDDetector(
engine_path=self.get_parameter('scrfd_engine_path').value,
onnx_path=self.get_parameter('scrfd_onnx_path').value,
conf_thresh=self.get_parameter('scrfd_conf_thresh').value,
)
self._recognizer = ArcFaceRecognizer(
engine_path=self.get_parameter('arcface_engine_path').value,
onnx_path=self.get_parameter('arcface_onnx_path').value,
)
# -- Gallery -----------------------------------------------------------
gallery_dir = self.get_parameter('gallery_dir').value
self._gallery = FaceGallery(gallery_dir)
self._gallery.load()
self.get_logger().info('Gallery loaded: %d persons.', len(self._gallery))
# -- Enrollment state --------------------------------------------------
self._enrolling = None # {name, samples_needed, collected: [embeddings]}
# -- QoS profiles ------------------------------------------------------
best_effort_qos = QoSProfile(
reliability=ReliabilityPolicy.BEST_EFFORT,
history=HistoryPolicy.KEEP_LAST,
depth=1,
)
reliable_qos = QoSProfile(
reliability=ReliabilityPolicy.RELIABLE,
durability=DurabilityPolicy.TRANSIENT_LOCAL,
history=HistoryPolicy.KEEP_LAST,
depth=1,
)
# -- Subscribers -------------------------------------------------------
self.create_subscription(
Image,
'/camera/color/image_raw',
self._on_image,
best_effort_qos,
)
# -- Publishers --------------------------------------------------------
self._pub_detections = self.create_publisher(
FaceDetectionArray, '/social/faces/detections', best_effort_qos)
self._pub_embeddings = self.create_publisher(
FaceEmbeddingArray, '/social/faces/embeddings', reliable_qos)
if self._pub_debug_flag:
self._pub_debug_img = self.create_publisher(
Image, '/social/faces/debug_image', best_effort_qos)
# -- Services ----------------------------------------------------------
self.create_service(EnrollPerson, '/social/enroll', self._handle_enroll)
self.create_service(ListPersons, '/social/persons/list', self._handle_list)
self.create_service(DeletePerson, '/social/persons/delete', self._handle_delete)
self.create_service(UpdatePerson, '/social/persons/update', self._handle_update)
# Publish initial gallery state
self._publish_gallery_embeddings()
self.get_logger().info('FaceRecognitionNode ready.')
# -- Image callback --------------------------------------------------------
def _on_image(self, msg: Image):
"""Process incoming camera frame: detect, recognize, publish."""
try:
bgr = self._bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')
except Exception as e:
self.get_logger().error('Image decode error: %s', str(e),
throttle_duration_sec=5.0)
return
# Detect faces
detections = self._detector.detect(bgr)
# Limit face count
if len(detections) > self._max_faces:
detections = sorted(detections, key=lambda d: d['score'], reverse=True)
detections = detections[:self._max_faces]
# Build output message
det_array = FaceDetectionArray()
det_array.header = msg.header
for det in detections:
# Extract embedding and match gallery
embedding = self._recognizer.align_and_embed(bgr, det['kps'])
pid, pname, score = self._gallery.match(
embedding, self._recognition_threshold)
# Handle enrollment: collect embedding from largest face
if self._enrolling is not None:
self._enrollment_collect(det, embedding)
# Build FaceDetection message
face_msg = FaceDetection()
face_msg.header = msg.header
face_msg.face_id = pid
face_msg.person_name = pname
face_msg.confidence = det['score']
face_msg.recognition_score = score
bbox = det['bbox']
face_msg.bbox_x = bbox[0]
face_msg.bbox_y = bbox[1]
face_msg.bbox_w = bbox[2] - bbox[0]
face_msg.bbox_h = bbox[3] - bbox[1]
kps = det['kps']
for i in range(10):
face_msg.landmarks[i] = kps[i]
det_array.faces.append(face_msg)
self._pub_detections.publish(det_array)
# Debug image
if self._pub_debug_flag and hasattr(self, '_pub_debug_img'):
debug_img = self._draw_debug(bgr, detections, det_array.faces)
self._pub_debug_img.publish(
self._bridge.cv2_to_imgmsg(debug_img, encoding='bgr8'))
# FPS logging
self._frame_count += 1
if self._frame_count % 30 == 0:
elapsed = time.monotonic() - self._fps_t0
fps = 30.0 / elapsed if elapsed > 0 else 0.0
self._fps_t0 = time.monotonic()
self.get_logger().info(
'FPS: %.1f | faces: %d', fps, len(detections))
# -- Enrollment logic ------------------------------------------------------
def _enrollment_collect(self, det: dict, embedding: np.ndarray):
"""Collect an embedding sample during enrollment (largest face only)."""
if self._enrolling is None:
return
# Only collect from the largest face (by bbox area)
bbox = det['bbox']
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
if not hasattr(self, '_enroll_best_area'):
self._enroll_best_area = 0.0
self._enroll_best_embedding = None
if area > self._enroll_best_area:
self._enroll_best_area = area
self._enroll_best_embedding = embedding
def _enrollment_frame_end(self):
"""Called at end of each frame to finalize enrollment sample collection."""
if self._enrolling is None or self._enroll_best_embedding is None:
return
self._enrolling['collected'].append(self._enroll_best_embedding)
self._enroll_best_area = 0.0
self._enroll_best_embedding = None
collected = len(self._enrolling['collected'])
needed = self._enrolling['samples_needed']
self.get_logger().info('Enrollment: %d/%d samples for "%s".',
collected, needed, self._enrolling['name'])
if collected >= needed:
# Finalize enrollment
name = self._enrolling['name']
embeddings = self._enrolling['collected']
mean_emb = np.mean(embeddings, axis=0).astype(np.float32)
norm = np.linalg.norm(mean_emb)
if norm > 0:
mean_emb = mean_emb / norm
pid = self._gallery.add_person(name, mean_emb, samples=len(embeddings))
self._gallery.save()
self._publish_gallery_embeddings()
self.get_logger().info('Enrollment complete: person %d (%s).', pid, name)
# Store result for the service callback
self._enrolling['result_pid'] = pid
self._enrolling['done'] = True
self._enrolling = None
# -- Service handlers ------------------------------------------------------
def _handle_enroll(self, request, response):
"""Handle EnrollPerson service: start collecting face samples."""
name = request.name.strip()
if not name:
response.success = False
response.message = 'Name cannot be empty.'
response.person_id = -1
return response
n_samples = request.n_samples if request.n_samples > 0 else 10
self.get_logger().info('Starting enrollment for "%s" (%d samples).',
name, n_samples)
# Set enrollment state — frames will collect embeddings
self._enrolling = {
'name': name,
'samples_needed': n_samples,
'collected': [],
'done': False,
'result_pid': -1,
}
self._enroll_best_area = 0.0
self._enroll_best_embedding = None
# Spin until enrollment is done (blocking service)
rate = self.create_rate(10) # 10 Hz check
timeout_sec = n_samples * 2.0 + 10.0 # generous timeout
t0 = time.monotonic()
while not self._enrolling.get('done', False):
# Finalize any pending frame collection
self._enrollment_frame_end()
if time.monotonic() - t0 > timeout_sec:
self._enrolling = None
response.success = False
response.message = f'Enrollment timed out after {timeout_sec:.0f}s.'
response.person_id = -1
return response
rclpy.spin_once(self, timeout_sec=0.1)
response.success = True
response.message = f'Enrolled "{name}" with {n_samples} samples.'
response.person_id = self._enrolling.get('result_pid', -1) if self._enrolling else -1
# Clean up (already set to None in _enrollment_frame_end on success)
return response
def _handle_list(self, request, response):
"""Handle ListPersons service: return all gallery entries."""
entries = self._gallery.get_all()
for entry in entries:
emb_msg = FaceEmbedding()
emb_msg.person_id = entry['person_id']
emb_msg.person_name = entry['name']
emb_msg.embedding = entry['embedding'].tolist()
emb_msg.sample_count = entry['samples']
secs = int(entry['enrolled_at'])
nsecs = int((entry['enrolled_at'] - secs) * 1e9)
emb_msg.enrolled_at = Time(sec=secs, nanosec=nsecs)
response.persons.append(emb_msg)
return response
def _handle_delete(self, request, response):
"""Handle DeletePerson service: remove a person from the gallery."""
if self._gallery.delete_person(request.person_id):
self._gallery.save()
self._publish_gallery_embeddings()
response.success = True
response.message = f'Deleted person {request.person_id}.'
else:
response.success = False
response.message = f'Person {request.person_id} not found.'
return response
def _handle_update(self, request, response):
"""Handle UpdatePerson service: rename a person."""
new_name = request.new_name.strip()
if not new_name:
response.success = False
response.message = 'New name cannot be empty.'
return response
if self._gallery.update_name(request.person_id, new_name):
self._gallery.save()
self._publish_gallery_embeddings()
response.success = True
response.message = f'Updated person {request.person_id} to "{new_name}".'
else:
response.success = False
response.message = f'Person {request.person_id} not found.'
return response
# -- Gallery publishing ----------------------------------------------------
def _publish_gallery_embeddings(self):
"""Publish current gallery as FaceEmbeddingArray (latched-like)."""
entries = self._gallery.get_all()
msg = FaceEmbeddingArray()
msg.header.stamp = self.get_clock().now().to_msg()
for entry in entries:
emb_msg = FaceEmbedding()
emb_msg.person_id = entry['person_id']
emb_msg.person_name = entry['name']
emb_msg.embedding = entry['embedding'].tolist()
emb_msg.sample_count = entry['samples']
secs = int(entry['enrolled_at'])
nsecs = int((entry['enrolled_at'] - secs) * 1e9)
emb_msg.enrolled_at = Time(sec=secs, nanosec=nsecs)
msg.embeddings.append(emb_msg)
self._pub_embeddings.publish(msg)
# -- Debug image -----------------------------------------------------------
def _draw_debug(self, bgr: np.ndarray, detections: list[dict],
face_msgs: list) -> np.ndarray:
"""Draw bounding boxes, landmarks, and names on the image."""
vis = bgr.copy()
for det, face_msg in zip(detections, face_msgs):
bbox = det['bbox']
x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
# Color: green if recognized, yellow if unknown
if face_msg.face_id >= 0:
color = (0, 255, 0)
label = f'{face_msg.person_name} ({face_msg.recognition_score:.2f})'
else:
color = (0, 255, 255)
label = f'unknown ({face_msg.confidence:.2f})'
cv2.rectangle(vis, (x1, y1), (x2, y2), color, 2)
cv2.putText(vis, label, (x1, y1 - 8),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
# Draw landmarks
kps = det['kps']
for k in range(5):
px, py = int(kps[k * 2]), int(kps[k * 2 + 1])
cv2.circle(vis, (px, py), 2, (0, 0, 255), -1)
return vis
# -- Entry point ---------------------------------------------------------------
def main(args=None):
"""ROS2 entry point for face_recognition node."""
rclpy.init(args=args)
node = FaceRecognitionNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,350 @@
"""
scrfd_detector.py -- SCRFD face detection with TensorRT FP16 + ONNX fallback.
SCRFD (Sample and Computation Redistribution for Face Detection) produces
9 output tensors across 3 strides (8, 16, 32), each with score, bbox, and
keypoint branches. This module handles anchor generation, bbox/keypoint
decoding, and NMS to produce final face detections.
"""
import os
import logging
from typing import Optional
import numpy as np
import cv2
logger = logging.getLogger(__name__)
_STRIDES = [8, 16, 32]
_NUM_ANCHORS = 2 # anchors per cell per stride
# -- Inference backends --------------------------------------------------------
class _TRTBackend:
"""TensorRT inference engine for SCRFD."""
def __init__(self, engine_path: str):
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit # noqa: F401
self._cuda = cuda
trt_logger = trt.Logger(trt.Logger.WARNING)
with open(engine_path, 'rb') as f, trt.Runtime(trt_logger) as runtime:
self._engine = runtime.deserialize_cuda_engine(f.read())
self._context = self._engine.create_execution_context()
self._inputs = []
self._outputs = []
self._output_names = []
self._bindings = []
for i in range(self._engine.num_io_tensors):
name = self._engine.get_tensor_name(i)
dtype = trt.nptype(self._engine.get_tensor_dtype(name))
shape = tuple(self._engine.get_tensor_shape(name))
nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize
host_mem = cuda.pagelocked_empty(shape, dtype)
device_mem = cuda.mem_alloc(nbytes)
self._bindings.append(int(device_mem))
if self._engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
self._inputs.append({'host': host_mem, 'device': device_mem})
else:
self._outputs.append({'host': host_mem, 'device': device_mem,
'shape': shape})
self._output_names.append(name)
self._stream = cuda.Stream()
def infer(self, input_data: np.ndarray) -> list[np.ndarray]:
"""Run inference and return output tensors."""
np.copyto(self._inputs[0]['host'], input_data.ravel())
self._cuda.memcpy_htod_async(
self._inputs[0]['device'], self._inputs[0]['host'], self._stream)
self._context.execute_async_v2(self._bindings, self._stream.handle)
results = []
for out in self._outputs:
self._cuda.memcpy_dtoh_async(out['host'], out['device'], self._stream)
self._stream.synchronize()
for out in self._outputs:
results.append(out['host'].reshape(out['shape']).copy())
return results
class _ONNXBackend:
"""ONNX Runtime inference (CUDA EP with CPU fallback)."""
def __init__(self, onnx_path: str):
import onnxruntime as ort
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
self._session = ort.InferenceSession(onnx_path, providers=providers)
self._input_name = self._session.get_inputs()[0].name
self._output_names = [o.name for o in self._session.get_outputs()]
def infer(self, input_data: np.ndarray) -> list[np.ndarray]:
"""Run inference and return output tensors."""
return self._session.run(None, {self._input_name: input_data})
# -- NMS ----------------------------------------------------------------------
def _nms(boxes: np.ndarray, scores: np.ndarray, iou_thresh: float) -> list[int]:
"""Non-maximum suppression. boxes: [N, 4] as x1,y1,x2,y2."""
if len(boxes) == 0:
return []
x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
areas = (x2 - x1) * (y2 - y1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(int(i))
if order.size == 1:
break
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
inter = np.maximum(0.0, xx2 - xx1) * np.maximum(0.0, yy2 - yy1)
iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-6)
remaining = np.where(iou <= iou_thresh)[0]
order = order[remaining + 1]
return keep
# -- Anchor generation ---------------------------------------------------------
def _generate_anchors(input_h: int, input_w: int) -> dict[int, np.ndarray]:
"""Generate anchor centers for each stride.
Returns dict mapping stride -> array of shape [H*W*num_anchors, 2],
where each row is (cx, cy) in input pixel coordinates.
"""
anchors = {}
for stride in _STRIDES:
feat_h = input_h // stride
feat_w = input_w // stride
grid_y, grid_x = np.mgrid[:feat_h, :feat_w]
centers = np.stack([grid_x.ravel(), grid_y.ravel()], axis=1).astype(np.float32)
centers = (centers + 0.5) * stride
# Repeat for num_anchors per cell
centers = np.tile(centers, (_NUM_ANCHORS, 1)) # [H*W*2, 2]
# Interleave properly: [anchor0_cell0, anchor1_cell0, anchor0_cell1, ...]
centers = np.repeat(
np.stack([grid_x.ravel(), grid_y.ravel()], axis=1).astype(np.float32),
_NUM_ANCHORS, axis=0
)
centers = (centers + 0.5) * stride
anchors[stride] = centers
return anchors
# -- Main detector class -------------------------------------------------------
class SCRFDDetector:
"""SCRFD face detector with TensorRT FP16 and ONNX fallback.
Args:
engine_path: Path to TensorRT engine file.
onnx_path: Path to ONNX model file (used if engine not available).
conf_thresh: Minimum confidence for detections.
nms_iou: IoU threshold for NMS.
input_size: Model input resolution (square).
"""
def __init__(
self,
engine_path: str = '',
onnx_path: str = '',
conf_thresh: float = 0.5,
nms_iou: float = 0.4,
input_size: int = 640,
):
self._conf_thresh = conf_thresh
self._nms_iou = nms_iou
self._input_size = input_size
self._backend: Optional[_TRTBackend | _ONNXBackend] = None
self._anchors = _generate_anchors(input_size, input_size)
# Try TRT first, then ONNX
if engine_path and os.path.isfile(engine_path):
try:
self._backend = _TRTBackend(engine_path)
logger.info('SCRFD TensorRT backend loaded: %s', engine_path)
return
except Exception as e:
logger.warning('SCRFD TRT load failed (%s), trying ONNX', e)
if onnx_path and os.path.isfile(onnx_path):
try:
self._backend = _ONNXBackend(onnx_path)
logger.info('SCRFD ONNX backend loaded: %s', onnx_path)
return
except Exception as e:
logger.error('SCRFD ONNX load failed: %s', e)
logger.error('No SCRFD model loaded. Detection will be unavailable.')
@property
def is_loaded(self) -> bool:
"""Return True if a backend is loaded and ready."""
return self._backend is not None
def detect(self, bgr: np.ndarray) -> list[dict]:
"""Detect faces in a BGR image.
Args:
bgr: Input image in BGR format, shape (H, W, 3).
Returns:
List of dicts with keys:
bbox: [x1, y1, x2, y2] in original image coordinates
kps: [x0,y0, x1,y1, ..., x4,y4] 10 floats, 5 landmarks
score: detection confidence
"""
if self._backend is None:
return []
orig_h, orig_w = bgr.shape[:2]
tensor, scale, pad_w, pad_h = self._preprocess(bgr)
outputs = self._backend.infer(tensor)
detections = self._decode_outputs(outputs)
detections = self._rescale(detections, scale, pad_w, pad_h, orig_w, orig_h)
return detections
def _preprocess(self, bgr: np.ndarray) -> tuple[np.ndarray, float, int, int]:
"""Resize to input_size x input_size with letterbox padding, normalize."""
h, w = bgr.shape[:2]
size = self._input_size
scale = min(size / h, size / w)
new_w, new_h = int(w * scale), int(h * scale)
pad_w = (size - new_w) // 2
pad_h = (size - new_h) // 2
resized = cv2.resize(bgr, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
canvas = np.full((size, size, 3), 0, dtype=np.uint8)
canvas[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = resized
# Normalize: subtract 127.5, divide 128.0
blob = canvas.astype(np.float32)
blob = (blob - 127.5) / 128.0
# HWC -> NCHW
blob = blob.transpose(2, 0, 1)[np.newaxis]
blob = np.ascontiguousarray(blob)
return blob, scale, pad_w, pad_h
def _decode_outputs(self, outputs: list[np.ndarray]) -> list[dict]:
"""Decode SCRFD 9-output format into face detections.
SCRFD outputs 9 tensors, 3 per stride (score, bbox, kps):
score_8, bbox_8, kps_8, score_16, bbox_16, kps_16, score_32, bbox_32, kps_32
"""
all_scores = []
all_bboxes = []
all_kps = []
for idx, stride in enumerate(_STRIDES):
score_out = outputs[idx * 3].squeeze() # [H*W*num_anchors]
bbox_out = outputs[idx * 3 + 1].squeeze() # [H*W*num_anchors, 4]
kps_out = outputs[idx * 3 + 2].squeeze() # [H*W*num_anchors, 10]
if score_out.ndim == 0:
continue
# Ensure proper shapes
if score_out.ndim == 1:
n = score_out.shape[0]
else:
n = score_out.shape[0]
score_out = score_out.ravel()
if bbox_out.ndim == 1:
bbox_out = bbox_out.reshape(-1, 4)
if kps_out.ndim == 1:
kps_out = kps_out.reshape(-1, 10)
# Filter by confidence
mask = score_out > self._conf_thresh
if not mask.any():
continue
scores = score_out[mask]
bboxes = bbox_out[mask]
kps = kps_out[mask]
anchors = self._anchors[stride]
# Trim or pad anchors to match output count
if anchors.shape[0] > n:
anchors = anchors[:n]
elif anchors.shape[0] < n:
continue
anchors = anchors[mask]
# Decode bboxes: center = anchor + pred[:2]*stride, size = exp(pred[2:])*stride
cx = anchors[:, 0] + bboxes[:, 0] * stride
cy = anchors[:, 1] + bboxes[:, 1] * stride
w = np.exp(bboxes[:, 2]) * stride
h = np.exp(bboxes[:, 3]) * stride
x1 = cx - w / 2.0
y1 = cy - h / 2.0
x2 = cx + w / 2.0
y2 = cy + h / 2.0
decoded_bboxes = np.stack([x1, y1, x2, y2], axis=1)
# Decode keypoints: kp = anchor + pred * stride
decoded_kps = np.zeros_like(kps)
for k in range(5):
decoded_kps[:, k * 2] = anchors[:, 0] + kps[:, k * 2] * stride
decoded_kps[:, k * 2 + 1] = anchors[:, 1] + kps[:, k * 2 + 1] * stride
all_scores.append(scores)
all_bboxes.append(decoded_bboxes)
all_kps.append(decoded_kps)
if not all_scores:
return []
scores = np.concatenate(all_scores)
bboxes = np.concatenate(all_bboxes)
kps = np.concatenate(all_kps)
# NMS
keep = _nms(bboxes, scores, self._nms_iou)
results = []
for i in keep:
results.append({
'bbox': bboxes[i].tolist(),
'kps': kps[i].tolist(),
'score': float(scores[i]),
})
return results
def _rescale(
self,
detections: list[dict],
scale: float,
pad_w: int,
pad_h: int,
orig_w: int,
orig_h: int,
) -> list[dict]:
"""Rescale detections from model input space to original image space."""
for det in detections:
bbox = det['bbox']
bbox[0] = max(0.0, (bbox[0] - pad_w) / scale)
bbox[1] = max(0.0, (bbox[1] - pad_h) / scale)
bbox[2] = min(float(orig_w), (bbox[2] - pad_w) / scale)
bbox[3] = min(float(orig_h), (bbox[3] - pad_h) / scale)
det['bbox'] = bbox
kps = det['kps']
for k in range(5):
kps[k * 2] = (kps[k * 2] - pad_w) / scale
kps[k * 2 + 1] = (kps[k * 2 + 1] - pad_h) / scale
det['kps'] = kps
return detections

View File

@ -0,0 +1,112 @@
#!/usr/bin/env python3
"""
build_face_trt_engines.py -- Build TensorRT FP16 engines for SCRFD and ArcFace.
Converts ONNX model files to optimized TensorRT engines with FP16 precision
for fast inference on Jetson Orin Nano Super.
Usage:
python3 build_face_trt_engines.py \
--scrfd-onnx /path/to/scrfd_2.5g_bnkps.onnx \
--arcface-onnx /path/to/arcface_r50.onnx \
--output-dir /mnt/nvme/saltybot/models \
--fp16 --workspace-mb 1024
"""
import argparse
import os
import time
def build_engine(onnx_path: str, engine_path: str, fp16: bool, workspace_mb: int):
"""Build a TensorRT engine from an ONNX model.
Args:
onnx_path: Path to the source ONNX model file.
engine_path: Output path for the serialized TensorRT engine.
fp16: Enable FP16 precision.
workspace_mb: Maximum workspace size in megabytes.
"""
import tensorrt as trt
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
print(f'Parsing ONNX model: {onnx_path}')
t0 = time.monotonic()
with open(onnx_path, 'rb') as f:
if not parser.parse(f.read()):
for i in range(parser.num_errors):
print(f' ONNX parse error: {parser.get_error(i)}')
raise RuntimeError(f'Failed to parse {onnx_path}')
parse_time = time.monotonic() - t0
print(f' Parsed in {parse_time:.1f}s')
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE,
workspace_mb * (1 << 20))
if fp16:
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
print(' FP16 enabled.')
else:
print(' Warning: FP16 not supported on this platform, using FP32.')
print(f'Building engine (this may take several minutes)...')
t0 = time.monotonic()
serialized = builder.build_serialized_network(network, config)
build_time = time.monotonic() - t0
if serialized is None:
raise RuntimeError('Engine build failed.')
os.makedirs(os.path.dirname(engine_path) or '.', exist_ok=True)
with open(engine_path, 'wb') as f:
f.write(serialized)
size_mb = os.path.getsize(engine_path) / (1 << 20)
print(f' Engine saved: {engine_path} ({size_mb:.1f} MB, built in {build_time:.1f}s)')
def main():
"""Main entry point for TRT engine building."""
parser = argparse.ArgumentParser(
description='Build TensorRT FP16 engines for SCRFD and ArcFace.')
parser.add_argument('--scrfd-onnx', type=str, default='',
help='Path to SCRFD ONNX model.')
parser.add_argument('--arcface-onnx', type=str, default='',
help='Path to ArcFace ONNX model.')
parser.add_argument('--output-dir', type=str,
default='/mnt/nvme/saltybot/models',
help='Output directory for engine files.')
parser.add_argument('--fp16', action='store_true', default=True,
help='Enable FP16 precision (default: True).')
parser.add_argument('--no-fp16', action='store_false', dest='fp16',
help='Disable FP16 (use FP32 only).')
parser.add_argument('--workspace-mb', type=int, default=1024,
help='TRT workspace size in MB (default: 1024).')
args = parser.parse_args()
if not args.scrfd_onnx and not args.arcface_onnx:
parser.error('At least one of --scrfd-onnx or --arcface-onnx is required.')
if args.scrfd_onnx:
engine_path = os.path.join(args.output_dir, 'scrfd_2.5g.engine')
print(f'\n=== Building SCRFD engine ===')
build_engine(args.scrfd_onnx, engine_path, args.fp16, args.workspace_mb)
if args.arcface_onnx:
engine_path = os.path.join(args.output_dir, 'arcface_r50.engine')
print(f'\n=== Building ArcFace engine ===')
build_engine(args.arcface_onnx, engine_path, args.fp16, args.workspace_mb)
print('\nDone.')
if __name__ == '__main__':
main()

View File

@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/saltybot_social_face
[install]
install_scripts=$base/lib/saltybot_social_face

View File

@ -0,0 +1,30 @@
"""Setup for saltybot_social_face package."""
from setuptools import find_packages, setup
package_name = 'saltybot_social_face'
setup(
name=package_name,
version='0.1.0',
packages=find_packages(exclude=['test']),
data_files=[
('share/ament_index/resource_index/packages', ['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
('share/' + package_name + '/launch', ['launch/face_recognition.launch.py']),
('share/' + package_name + '/config', ['config/face_recognition_params.yaml']),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='seb',
maintainer_email='seb@vayrette.com',
description='SCRFD face detection and ArcFace recognition for SaltyBot social interactions',
license='MIT',
tests_require=['pytest'],
entry_points={
'console_scripts': [
'face_recognition = saltybot_social_face.face_recognition_node:main',
'enrollment_cli = saltybot_social_face.enrollment_cli:main',
],
},
)

View File

@ -8,6 +8,7 @@ find_package(geometry_msgs REQUIRED)
find_package(builtin_interfaces REQUIRED) find_package(builtin_interfaces REQUIRED)
rosidl_generate_interfaces(${PROJECT_NAME} rosidl_generate_interfaces(${PROJECT_NAME}
# Issue #80 face detection + recognition
"msg/FaceDetection.msg" "msg/FaceDetection.msg"
"msg/FaceDetectionArray.msg" "msg/FaceDetectionArray.msg"
"msg/FaceEmbedding.msg" "msg/FaceEmbedding.msg"
@ -18,7 +19,15 @@ rosidl_generate_interfaces(${PROJECT_NAME}
"srv/ListPersons.srv" "srv/ListPersons.srv"
"srv/DeletePerson.srv" "srv/DeletePerson.srv"
"srv/UpdatePerson.srv" "srv/UpdatePerson.srv"
# Issue #86 LED expression + motor attention
"msg/Mood.msg"
"msg/Person.msg"
"msg/PersonArray.msg"
# Issue #84 personality system
"msg/PersonalityState.msg"
"srv/QueryMood.srv"
DEPENDENCIES std_msgs geometry_msgs builtin_interfaces DEPENDENCIES std_msgs geometry_msgs builtin_interfaces
) )
ament_export_dependencies(rosidl_default_runtime)
ament_package() ament_package()

View File

@ -0,0 +1,7 @@
# Mood.msg — social expression command sent to the LED display node.
#
# mood : one of "happy", "curious", "annoyed", "playful", "idle"
# intensity : 0.0 (off) to 1.0 (full brightness)
string mood
float32 intensity

View File

@ -0,0 +1,17 @@
# Person.msg — single tracked person for social attention.
#
# bearing_rad : signed bearing in base_link frame (rad)
# positive = person to the left (CCW), negative = right (CW)
# distance_m : estimated distance in metres; 0 if unknown
# confidence : detection confidence 0..1
# is_speaking : true if mic DOA or VAD identifies this person as speaking
# source : "camera" | "mic_doa"
std_msgs/Header header
int32 track_id
float32 bearing_rad
float32 distance_m
float32 confidence
bool is_speaking
string source

View File

@ -0,0 +1,9 @@
# PersonArray.msg — all detected persons plus the current attention target.
#
# persons : all currently tracked persons
# active_id : track_id of the active/speaking person; -1 if none
std_msgs/Header header
saltybot_social_msgs/Person[] persons
int32 active_id

View File

@ -0,0 +1,27 @@
# PersonalityState.msg — published on /social/personality/state
#
# Snapshot of the personality node's current state: active mood, relationship
# tier with the detected person, and a pre-generated greeting string.
std_msgs/Header header
# Active persona name (from SOUL.md)
string persona_name
# Current mood: happy | curious | annoyed | playful
string mood
# Person currently being addressed (empty if no one detected)
string person_id
# Relationship tier with person_id: stranger | regular | favorite
string relationship_tier
# Raw relationship score (higher = more familiar)
float32 relationship_score
# How many times we have seen this person
int32 interaction_count
# Ready-to-use greeting for person_id at current tier
string greeting_text

View File

@ -3,16 +3,25 @@
<package format="3"> <package format="3">
<name>saltybot_social_msgs</name> <name>saltybot_social_msgs</name>
<version>0.1.0</version> <version>0.1.0</version>
<description>Custom ROS2 messages and services for saltybot social capabilities</description> <description>
Custom ROS2 message and service definitions for saltybot social capabilities.
Includes social perception types (face detection, person state, enrollment)
and the personality system types (PersonalityState, QueryMood) from Issue #84.
</description>
<maintainer email="seb@vayrette.com">seb</maintainer> <maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license> <license>MIT</license>
<buildtool_depend>ament_cmake</buildtool_depend> <buildtool_depend>ament_cmake</buildtool_depend>
<build_depend>rosidl_default_generators</build_depend>
<depend>std_msgs</depend> <depend>std_msgs</depend>
<depend>geometry_msgs</depend> <depend>geometry_msgs</depend>
<depend>builtin_interfaces</depend> <depend>builtin_interfaces</depend>
<build_depend>rosidl_default_generators</build_depend>
<exec_depend>rosidl_default_runtime</exec_depend> <exec_depend>rosidl_default_runtime</exec_depend>
<member_of_group>rosidl_interface_packages</member_of_group> <member_of_group>rosidl_interface_packages</member_of_group>
<export> <export>
<build_type>ament_cmake</build_type> <build_type>ament_cmake</build_type>
</export> </export>

View File

@ -0,0 +1,4 @@
sensor_msgs/Image crop
---
bool success
float32[512] embedding

View File

@ -0,0 +1,15 @@
# QueryMood.srv — ask the personality node for the current mood + greeting for a person
#
# Call with empty person_id to query the mood for whoever is currently tracked.
# Request
string person_id # person to query; leave empty for current tracked person
---
# Response
string mood # happy | curious | annoyed | playful
string relationship_tier # stranger | regular | favorite
float32 relationship_score
int32 interaction_count
string greeting_text # suggested greeting string
bool success
string message # error detail if success=false

View File

@ -0,0 +1,42 @@
---
# SOUL.md — Saltybot persona definition
#
# Hot-reload: personality_node watches this file and reloads on change.
# Runtime override: ros2 param set /personality_node soul_file /path/to/SOUL.md
# ── Identity ──────────────────────────────────────────────────────────────────
name: "Salty"
speaking_style: "casual and upbeat, occasional puns"
# ── Personality dials (010) ──────────────────────────────────────────────────
humor_level: 7 # 0 = deadpan/serious, 10 = comedian
sass_level: 4 # 0 = pure politeness, 10 = maximum sass
# ── Default mood (when no person is present or score is neutral) ──────────────
# One of: happy | curious | annoyed | playful
base_mood: "playful"
# ── Relationship thresholds (interaction counts) ──────────────────────────────
threshold_regular: 5 # interactions to graduate from stranger → regular
threshold_favorite: 20 # interactions to graduate from regular → favorite
# ── Greeting templates (use {name} placeholder for person_id or display name) ─
greeting_stranger: "Hello there! I'm Salty, nice to meet you!"
greeting_regular: "Hey {name}! Good to see you again!"
greeting_favorite: "Oh hey {name}!! You're literally my favorite person right now!"
# ── Mood-specific greeting prefixes ──────────────────────────────────────────
mood_prefix_happy: "Great timing — "
mood_prefix_curious: "Oh interesting, "
mood_prefix_annoyed: "Well, "
mood_prefix_playful: "Beep boop! "
---
# Description (ignored by the YAML parser, for human reference only)
Salty is the personality of the saltybot social robot.
She is curious about the world, genuinely happy to see people she knows,
and has a good sense of humour — especially with regulars.
Edit this file to change her personality. The node hot-reloads within
`reload_interval` seconds of any change.

View File

@ -0,0 +1,28 @@
# personality_params.yaml — ROS2 parameter defaults for personality_node.
#
# Run with:
# ros2 launch saltybot_social_personality personality.launch.py
# Override inline:
# ros2 launch saltybot_social_personality personality.launch.py soul_file:=/my/SOUL.md
# Dynamic reconfigure at runtime:
# ros2 param set /personality_node soul_file /path/to/SOUL.md
# ros2 param set /personality_node publish_rate 5.0
# ── SOUL.md path ───────────────────────────────────────────────────────────────
# Path to the SOUL.md persona file. Supports hot-reload.
# Relative paths are resolved from the package share directory.
soul_file: "" # empty = use bundled default config/SOUL.md
# ── SQLite database ────────────────────────────────────────────────────────────
# Path for the per-person relationship memory database.
# Created on first run; persists across restarts.
db_path: "~/.ros/saltybot_personality.db"
# ── Hot-reload polling interval ────────────────────────────────────────────────
# How often (seconds) to check if SOUL.md has been modified.
# Lower = faster reactions to edits; higher = less disk I/O.
reload_interval: 5.0 # seconds
# ── Personality state publication rate ────────────────────────────────────────
# How often to publish /social/personality/state (PersonalityState).
publish_rate: 2.0 # Hz

View File

@ -0,0 +1,99 @@
"""
personality.launch.py Launch the saltybot personality node.
Usage
-----
# Defaults (bundled SOUL.md, ~/.ros/saltybot_personality.db):
ros2 launch saltybot_social_personality personality.launch.py
# Custom persona file:
ros2 launch saltybot_social_personality personality.launch.py \\
soul_file:=/home/robot/my_persona/SOUL.md
# Custom DB + faster reload:
ros2 launch saltybot_social_personality personality.launch.py \\
db_path:=/data/saltybot.db reload_interval:=2.0
# Use a params file:
ros2 launch saltybot_social_personality personality.launch.py \\
params_file:=/my/personality_params.yaml
Dynamic reconfigure (no restart required)
-----------------------------------------
ros2 param set /personality_node soul_file /new/SOUL.md
ros2 param set /personality_node publish_rate 5.0
"""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument, OpaqueFunction
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def _launch_personality(context, *args, **kwargs):
pkg_share = get_package_share_directory("saltybot_social_personality")
params_file = LaunchConfiguration("params_file").perform(context)
soul_file = LaunchConfiguration("soul_file").perform(context)
db_path = LaunchConfiguration("db_path").perform(context)
# Default soul_file to bundled config if not specified
if not soul_file:
soul_file = os.path.join(pkg_share, "config", "SOUL.md")
# Expand ~ in db_path
if db_path:
db_path = os.path.expanduser(db_path)
inline_params = {
"soul_file": soul_file,
"db_path": db_path or os.path.expanduser("~/.ros/saltybot_personality.db"),
"reload_interval": float(LaunchConfiguration("reload_interval").perform(context)),
"publish_rate": float(LaunchConfiguration("publish_rate").perform(context)),
}
node_params = [params_file, inline_params] if params_file else [inline_params]
return [Node(
package = "saltybot_social_personality",
executable = "personality_node",
name = "personality_node",
output = "screen",
parameters = node_params,
)]
def generate_launch_description():
pkg_share = get_package_share_directory("saltybot_social_personality")
default_params = os.path.join(pkg_share, "config", "personality_params.yaml")
return LaunchDescription([
DeclareLaunchArgument(
"params_file",
default_value=default_params,
description="Full path to personality_params.yaml (base config)"),
DeclareLaunchArgument(
"soul_file",
default_value="",
description="Path to SOUL.md persona file (empty = bundled default)"),
DeclareLaunchArgument(
"db_path",
default_value="~/.ros/saltybot_personality.db",
description="SQLite relationship memory database path"),
DeclareLaunchArgument(
"reload_interval",
default_value="5.0",
description="SOUL.md hot-reload polling interval (s)"),
DeclareLaunchArgument(
"publish_rate",
default_value="2.0",
description="Personality state publish rate (Hz)"),
OpaqueFunction(function=_launch_personality),
])

View File

@ -0,0 +1,32 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social_personality</name>
<version>0.1.0</version>
<description>
SOUL.md-driven personality system for saltybot.
Loads a YAML/Markdown persona file, maintains per-person relationship memory
in SQLite, computes mood (happy/curious/annoyed/playful), personalises
greetings by tier (stranger/regular/favorite), and publishes personality
state on /social/personality/state. Supports SOUL.md hot-reload and full
ROS2 dynamic reconfigure. Issue #84.
</description>
<maintainer email="sl-controls@saltylab.local">sl-controls</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>std_msgs</depend>
<depend>rcl_interfaces</depend>
<depend>saltybot_social_msgs</depend>
<buildtool_depend>ament_python</buildtool_depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,187 @@
"""
mood_engine.py Pure-function mood computation for the saltybot personality system.
No ROS2 imports safe to unit-test without a live ROS2 environment.
Public API
----------
compute_mood(soul, score, interaction_count, recent_events) -> str
get_relationship_tier(soul, interaction_count) -> str
build_greeting(soul, tier, mood, person_id) -> str
Mood semantics
--------------
happy : positive valence, comfortable familiarity
playful : high-energy, humorous (requires humor_level >= 7)
curious : low familiarity or novel person inquisitive
annoyed : recent negative events or very low score
Tier semantics
--------------
stranger : interaction_count < threshold_regular
regular : threshold_regular <= count < threshold_favorite
favorite : count >= threshold_favorite
"""
from __future__ import annotations
# ── Mood / tier constants ──────────────────────────────────────────────────────
MOOD_HAPPY = "happy"
MOOD_PLAYFUL = "playful"
MOOD_CURIOUS = "curious"
MOOD_ANNOYED = "annoyed"
TIER_STRANGER = "stranger"
TIER_REGULAR = "regular"
TIER_FAVORITE = "favorite"
# ── Event type constants (used by relationship_db and the node) ────────────────
EVENT_GREETING = "greeting"
EVENT_POSITIVE = "positive"
EVENT_NEGATIVE = "negative"
EVENT_DETECTION = "detection"
# How far back (seconds) to consider "recent" for mood computation
_RECENT_WINDOW_S = 120.0
# ── Mood computation ──────────────────────────────────────────────────────────
def compute_mood(
soul: dict,
score: float,
interaction_count: int,
recent_events: list,
) -> str:
"""Compute the current mood for a given person.
Parameters
----------
soul : dict
Parsed SOUL.md configuration.
score : float
Relationship score for the current person (higher = more familiar).
interaction_count : int
Total number of times we have seen this person.
recent_events : list of dict
Each dict: ``{"type": str, "dt": float}`` where ``dt`` is seconds ago.
Only events with ``dt < 120.0`` are considered "recent".
Returns
-------
str
One of: ``"happy"``, ``"playful"``, ``"curious"``, ``"annoyed"``.
"""
base_mood = soul.get("base_mood", MOOD_PLAYFUL)
humor_level = float(soul.get("humor_level", 5))
# Count recent negative/positive events
recent_neg = sum(
1 for e in recent_events
if e.get("type") == EVENT_NEGATIVE and e.get("dt", 1e9) < _RECENT_WINDOW_S
)
recent_pos = sum(
1 for e in recent_events
if e.get("type") in (EVENT_POSITIVE, EVENT_GREETING)
and e.get("dt", 1e9) < _RECENT_WINDOW_S
)
# Hard override: multiple negatives → annoyed
if recent_neg >= 2:
return MOOD_ANNOYED
# No prior interactions or brand-new person → curious
if interaction_count == 0 or score < 1.0:
return MOOD_CURIOUS
# Stranger tier (low count) → curious
threshold_regular = int(soul.get("threshold_regular", 5))
if interaction_count < threshold_regular:
return MOOD_CURIOUS
# Familiar person: check positive events and humor level
if recent_pos >= 1 or score >= 20.0:
if humor_level >= 7:
return MOOD_PLAYFUL
return MOOD_HAPPY
# High score / favorite
threshold_fav = int(soul.get("threshold_favorite", 20))
if interaction_count >= threshold_fav:
if humor_level >= 7:
return MOOD_PLAYFUL
return MOOD_HAPPY
return base_mood
# ── Tier classification ────────────────────────────────────────────────────────
def get_relationship_tier(soul: dict, interaction_count: int) -> str:
"""Return the relationship tier string for a given interaction count.
Parameters
----------
soul : dict
Parsed SOUL.md configuration.
interaction_count : int
Total number of times we have seen this person.
Returns
-------
str
One of: ``"stranger"``, ``"regular"``, ``"favorite"``.
"""
threshold_regular = int(soul.get("threshold_regular", 5))
threshold_favorite = int(soul.get("threshold_favorite", 20))
if interaction_count >= threshold_favorite:
return TIER_FAVORITE
if interaction_count >= threshold_regular:
return TIER_REGULAR
return TIER_STRANGER
# ── Greeting builder ──────────────────────────────────────────────────────────
def build_greeting(soul: dict, tier: str, mood: str, person_id: str = "") -> str:
"""Compose a greeting string for a person.
Parameters
----------
soul : dict
Parsed SOUL.md configuration.
tier : str
Relationship tier (``"stranger"``, ``"regular"``, ``"favorite"``).
mood : str
Current mood (used to prefix the greeting).
person_id : str
Person identifier / display name. Substituted for ``{name}``
in the template.
Returns
-------
str
A complete, ready-to-display greeting string.
"""
template_key = {
TIER_STRANGER: "greeting_stranger",
TIER_REGULAR: "greeting_regular",
TIER_FAVORITE: "greeting_favorite",
}.get(tier, "greeting_stranger")
template = soul.get(template_key, "Hello!")
base_greeting = template.replace("{name}", person_id or "friend")
prefix_key = f"mood_prefix_{mood}"
prefix = soul.get(prefix_key, "")
if prefix:
# Avoid double punctuation / duplicate capital letters
base_first = base_greeting[0].lower() if base_greeting else ""
greeting = f"{prefix}{base_first}{base_greeting[1:]}"
else:
greeting = base_greeting
return greeting

View File

@ -0,0 +1,349 @@
"""
personality_node.py ROS2 personality system for saltybot.
Overview
--------
Loads a SOUL.md persona file, maintains per-person relationship memory in
SQLite, computes mood, and publishes personality state. All tunable params
support ROS2 dynamic reconfigure (``ros2 param set``).
Subscriptions
-------------
/social/person_detected (std_msgs/String)
JSON payload: ``{"person_id": "alice", "event_type": "greeting",
"delta_score": 1.0}``
event_type defaults to "detection" if absent.
delta_score defaults to 0.0 if absent.
Publications
------------
/social/personality/state (saltybot_social_msgs/PersonalityState)
Published at ``publish_rate`` Hz.
Services
--------
/social/personality/query_mood (saltybot_social_msgs/QueryMood)
Query mood + greeting for any person_id.
Parameters (dynamic reconfigure via ros2 param set)
-------------------
soul_file (str) Path to SOUL.md persona file.
db_path (str) SQLite database file path.
reload_interval (float) How often to poll SOUL.md for changes (s).
publish_rate (float) State publication rate (Hz).
Usage
-----
ros2 launch saltybot_social_personality personality.launch.py
ros2 launch saltybot_social_personality personality.launch.py soul_file:=/my/SOUL.md
ros2 param set /personality_node soul_file /tmp/new_SOUL.md
"""
import json
import os
import rclpy
from rclpy.node import Node
from rcl_interfaces.msg import SetParametersResult
from std_msgs.msg import String, Header
from saltybot_social_msgs.msg import PersonalityState
from saltybot_social_msgs.srv import QueryMood
from .soul_loader import load_soul, SoulWatcher
from .mood_engine import (
compute_mood, get_relationship_tier, build_greeting,
EVENT_GREETING, EVENT_POSITIVE, EVENT_NEGATIVE, EVENT_DETECTION,
)
from .relationship_db import RelationshipDB
_DEFAULT_SOUL = os.path.join(
os.path.dirname(__file__), "..", "config", "SOUL.md"
)
_DEFAULT_DB = os.path.expanduser("~/.ros/saltybot_personality.db")
class PersonalityNode(Node):
def __init__(self):
super().__init__("personality_node")
# ── Parameters ────────────────────────────────────────────────────────
self.declare_parameter("soul_file", _DEFAULT_SOUL)
self.declare_parameter("db_path", _DEFAULT_DB)
self.declare_parameter("reload_interval", 5.0)
self.declare_parameter("publish_rate", 2.0)
self._p = {}
self._reload_ros_params()
# ── State ─────────────────────────────────────────────────────────────
self._soul = {}
self._current_person = "" # person_id currently being addressed
self._watcher = None
# ── Database ──────────────────────────────────────────────────────────
self._db = RelationshipDB(self._p["db_path"])
# ── Load initial SOUL.md ──────────────────────────────────────────────
self._load_soul_safe()
self._start_watcher()
# ── Dynamic reconfigure callback ─────────────────────────────────────
self.add_on_set_parameters_callback(self._on_params_changed)
# ── Subscriptions ─────────────────────────────────────────────────────
self.create_subscription(
String, "/social/person_detected", self._person_detected_cb, 10
)
# ── Publishers ────────────────────────────────────────────────────────
self._state_pub = self.create_publisher(
PersonalityState, "/social/personality/state", 10
)
# ── Services ──────────────────────────────────────────────────────────
self.create_service(
QueryMood,
"/social/personality/query_mood",
self._query_mood_cb,
)
# ── Timers ────────────────────────────────────────────────────────────
self._pub_timer = self.create_timer(
1.0 / self._p["publish_rate"], self._publish_state
)
self.get_logger().info(
f"PersonalityNode ready "
f"persona={self._soul.get('name', '?')!r} "
f"mood={self._current_mood()!r} "
f"db={self._p['db_path']!r}"
)
# ── Parameter helpers ──────────────────────────────────────────────────────
def _reload_ros_params(self):
self._p = {
"soul_file": self.get_parameter("soul_file").value,
"db_path": self.get_parameter("db_path").value,
"reload_interval": self.get_parameter("reload_interval").value,
"publish_rate": self.get_parameter("publish_rate").value,
}
def _on_params_changed(self, params):
"""Dynamic reconfigure — apply changed params without restarting node."""
for param in params:
if param.name == "soul_file":
# Restart watcher on new soul_file
self._stop_watcher()
self._p["soul_file"] = param.value
self._load_soul_safe()
self._start_watcher()
self.get_logger().info(f"soul_file changed → {param.value!r}")
elif param.name in self._p:
self._p[param.name] = param.value
if param.name == "publish_rate" and self._pub_timer:
self._pub_timer.cancel()
self._pub_timer = self.create_timer(
1.0 / max(0.1, param.value), self._publish_state
)
return SetParametersResult(successful=True)
# ── SOUL.md ────────────────────────────────────────────────────────────────
def _load_soul_safe(self):
try:
path = os.path.realpath(self._p["soul_file"])
self._soul = load_soul(path)
self.get_logger().info(
f"SOUL.md loaded: {self._soul.get('name', '?')!r} "
f"humor={self._soul.get('humor_level')} "
f"sass={self._soul.get('sass_level')} "
f"base_mood={self._soul.get('base_mood')!r}"
)
except Exception as exc:
self.get_logger().error(f"Failed to load SOUL.md: {exc}")
if not self._soul:
# Fall back to minimal defaults so the node stays alive
self._soul = {
"name": "Salty",
"humor_level": 5,
"sass_level": 3,
"base_mood": "curious",
"threshold_regular": 5,
"threshold_favorite": 20,
"greeting_stranger": "Hello!",
"greeting_regular": "Hi {name}!",
"greeting_favorite": "Hey {name}!!",
}
def _start_watcher(self):
if not self._soul:
return
self._watcher = SoulWatcher(
path=self._p["soul_file"],
on_reload=self._on_soul_reloaded,
interval=self._p["reload_interval"],
on_error=lambda exc: self.get_logger().warn(
f"SOUL.md hot-reload error: {exc}"
),
)
self._watcher.start()
def _stop_watcher(self):
if self._watcher:
self._watcher.stop()
self._watcher = None
def _on_soul_reloaded(self, soul: dict):
self._soul = soul
self.get_logger().info(
f"SOUL.md reloaded: persona={soul.get('name')!r} "
f"humor={soul.get('humor_level')} base_mood={soul.get('base_mood')!r}"
)
# ── Mood helpers ───────────────────────────────────────────────────────────
def _current_mood(self) -> str:
if not self._current_person or not self._soul:
return self._soul.get("base_mood", "curious") if self._soul else "curious"
person = self._db.get_person(self._current_person)
recent = self._db.get_recent_events(self._current_person, window_s=120.0)
return compute_mood(
soul = self._soul,
score = person["score"],
interaction_count = person["interaction_count"],
recent_events = recent,
)
def _state_for_person(self, person_id: str) -> dict:
"""Build a complete state dict for a given person_id."""
person = self._db.get_person(person_id) if person_id else {
"score": 0.0, "interaction_count": 0
}
recent = self._db.get_recent_events(person_id, window_s=120.0) if person_id else []
mood = compute_mood(
soul = self._soul,
score = person["score"],
interaction_count = person["interaction_count"],
recent_events = recent,
)
tier = get_relationship_tier(self._soul, person["interaction_count"])
greeting = build_greeting(self._soul, tier, mood, person_id)
return {
"person_id": person_id,
"mood": mood,
"tier": tier,
"score": person["score"],
"interaction_count": person["interaction_count"],
"greeting": greeting,
}
# ── Callbacks ──────────────────────────────────────────────────────────────
def _person_detected_cb(self, msg: String):
"""Handle incoming person detection / interaction event.
Expected JSON payload::
{
"person_id": "alice", # required
"event_type": "greeting", # optional, default "detection"
"delta_score": 1.0 # optional, default 0.0
}
"""
try:
data = json.loads(msg.data)
except json.JSONDecodeError as exc:
self.get_logger().warn(f"Bad JSON on /social/person_detected: {exc}")
return
person_id = data.get("person_id", "").strip()
if not person_id:
self.get_logger().warn("person_detected msg missing 'person_id'")
return
event_type = data.get("event_type", EVENT_DETECTION)
delta_score = float(data.get("delta_score", 0.0))
# Increment score by +1 for detection events automatically
if event_type == EVENT_DETECTION and delta_score == 0.0:
delta_score = 0.5
self._db.record_interaction(
person_id = person_id,
event_type = event_type,
details = {k: v for k, v in data.items()
if k not in ("person_id", "event_type", "delta_score")},
delta_score = delta_score,
)
self._current_person = person_id
def _query_mood_cb(self, request: QueryMood.Request, response: QueryMood.Response):
"""Service handler: return mood + greeting for a specific person."""
if not self._soul:
response.success = False
response.message = "SOUL.md not loaded"
return response
person_id = (request.person_id or self._current_person).strip()
state = self._state_for_person(person_id)
response.mood = state["mood"]
response.relationship_tier = state["tier"]
response.relationship_score = float(state["score"])
response.interaction_count = int(state["interaction_count"])
response.greeting_text = state["greeting"]
response.success = True
response.message = ""
return response
# ── Publish ────────────────────────────────────────────────────────────────
def _publish_state(self):
if not self._soul:
return
state = self._state_for_person(self._current_person)
msg = PersonalityState()
msg.header = Header()
msg.header.stamp = self.get_clock().now().to_msg()
msg.header.frame_id = "personality"
msg.persona_name = str(self._soul.get("name", "Salty"))
msg.mood = state["mood"]
msg.person_id = state["person_id"]
msg.relationship_tier = state["tier"]
msg.relationship_score = float(state["score"])
msg.interaction_count = int(state["interaction_count"])
msg.greeting_text = state["greeting"]
self._state_pub.publish(msg)
# ── Lifecycle ──────────────────────────────────────────────────────────────
def destroy_node(self):
self._stop_watcher()
self._db.close()
super().destroy_node()
# ── Entry point ────────────────────────────────────────────────────────────────
def main(args=None):
rclpy.init(args=args)
node = PersonalityNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.try_shutdown()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,297 @@
"""
relationship_db.py SQLite-backed per-person relationship memory.
No ROS2 imports safe to unit-test without a live ROS2 environment.
Schema
------
people (
person_id TEXT PRIMARY KEY,
score REAL DEFAULT 0.0,
interaction_count INTEGER DEFAULT 0,
first_seen TEXT, -- ISO-8601 UTC timestamp
last_seen TEXT, -- ISO-8601 UTC timestamp
prefs TEXT -- JSON blob for learned preferences
)
interactions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
person_id TEXT,
ts TEXT, -- ISO-8601 UTC timestamp
event_type TEXT, -- greeting | positive | negative | detection
details TEXT -- free-form JSON blob
)
Public API
----------
RelationshipDB(db_path)
.get_person(person_id) -> dict
.record_interaction(person_id, event_type, details, delta_score)
.set_pref(person_id, key, value)
.get_pref(person_id, key, default)
.get_recent_events(person_id, window_s) -> list[dict]
.all_people() -> list[dict]
.close()
"""
from __future__ import annotations
import json
import os
import sqlite3
import threading
from datetime import datetime, timezone
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
class RelationshipDB:
"""Thread-safe SQLite relationship store.
Parameters
----------
db_path : str
Path to the SQLite file. Created (with parent dirs) if absent.
"""
def __init__(self, db_path: str):
parent = os.path.dirname(db_path)
if parent:
os.makedirs(parent, exist_ok=True)
self._path = db_path
self._lock = threading.Lock()
self._conn = sqlite3.connect(db_path, check_same_thread=False)
self._conn.row_factory = sqlite3.Row
self._migrate()
# ── Schema ────────────────────────────────────────────────────────────────
def _migrate(self):
with self._conn:
self._conn.executescript("""
CREATE TABLE IF NOT EXISTS people (
person_id TEXT PRIMARY KEY,
score REAL DEFAULT 0.0,
interaction_count INTEGER DEFAULT 0,
first_seen TEXT,
last_seen TEXT,
prefs TEXT DEFAULT '{}'
);
CREATE TABLE IF NOT EXISTS interactions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
person_id TEXT NOT NULL,
ts TEXT NOT NULL,
event_type TEXT NOT NULL,
details TEXT DEFAULT '{}'
);
CREATE INDEX IF NOT EXISTS idx_interactions_person_ts
ON interactions (person_id, ts);
""")
# ── People ────────────────────────────────────────────────────────────────
def get_person(self, person_id: str) -> dict:
"""Return the person record; inserts a default row if not found.
Returns
-------
dict with keys: person_id, score, interaction_count,
first_seen, last_seen, prefs (dict)
"""
with self._lock:
row = self._conn.execute(
"SELECT * FROM people WHERE person_id = ?", (person_id,)
).fetchone()
if row is None:
now = _now_iso()
self._conn.execute(
"INSERT INTO people (person_id, first_seen, last_seen) VALUES (?,?,?)",
(person_id, now, now),
)
self._conn.commit()
return {
"person_id": person_id,
"score": 0.0,
"interaction_count": 0,
"first_seen": now,
"last_seen": now,
"prefs": {},
}
prefs = {}
try:
prefs = json.loads(row["prefs"] or "{}")
except json.JSONDecodeError:
pass
return {
"person_id": row["person_id"],
"score": float(row["score"]),
"interaction_count": int(row["interaction_count"]),
"first_seen": row["first_seen"],
"last_seen": row["last_seen"],
"prefs": prefs,
}
def all_people(self) -> list:
"""Return all person records as a list of dicts."""
with self._lock:
rows = self._conn.execute("SELECT * FROM people ORDER BY score DESC").fetchall()
result = []
for row in rows:
prefs = {}
try:
prefs = json.loads(row["prefs"] or "{}")
except json.JSONDecodeError:
pass
result.append({
"person_id": row["person_id"],
"score": float(row["score"]),
"interaction_count": int(row["interaction_count"]),
"first_seen": row["first_seen"],
"last_seen": row["last_seen"],
"prefs": prefs,
})
return result
# ── Interactions ──────────────────────────────────────────────────────────
def record_interaction(
self,
person_id: str,
event_type: str,
details: dict | None = None,
delta_score: float = 0.0,
):
"""Record an interaction event and update the person's score.
Parameters
----------
person_id : str
event_type : str
One of: ``"greeting"``, ``"positive"``, ``"negative"``,
``"detection"``.
details : dict, optional
Arbitrary key/value data stored as JSON.
delta_score : float
Amount to add to the person's score (can be negative).
Interaction count is always incremented by 1.
"""
now = _now_iso()
details_json = json.dumps(details or {})
with self._lock:
# Ensure person exists
self.get_person.__wrapped__(self, person_id) if hasattr(
self.get_person, "__wrapped__"
) else None
# Upsert person row
self._conn.execute("""
INSERT INTO people (person_id, first_seen, last_seen)
VALUES (?, ?, ?)
ON CONFLICT(person_id) DO UPDATE SET
last_seen = excluded.last_seen
""", (person_id, now, now))
# Increment count + score
self._conn.execute("""
UPDATE people
SET interaction_count = interaction_count + 1,
score = score + ?,
last_seen = ?
WHERE person_id = ?
""", (delta_score, now, person_id))
# Insert interaction log row
self._conn.execute("""
INSERT INTO interactions (person_id, ts, event_type, details)
VALUES (?, ?, ?, ?)
""", (person_id, now, event_type, details_json))
self._conn.commit()
def get_recent_events(self, person_id: str, window_s: float = 120.0) -> list:
"""Return interaction events for *person_id* within the last *window_s* seconds.
Returns
-------
list of dict
Each dict: ``{"type": str, "dt": float, "ts": str, "details": dict}``
where ``dt`` is seconds ago (positive = in the past).
"""
from datetime import timedelta
cutoff = (
datetime.now(timezone.utc) - timedelta(seconds=window_s)
).isoformat()
with self._lock:
rows = self._conn.execute("""
SELECT ts, event_type, details FROM interactions
WHERE person_id = ? AND ts >= ?
ORDER BY ts DESC
""", (person_id, cutoff)).fetchall()
now_dt = datetime.now(timezone.utc)
result = []
for row in rows:
try:
row_dt = datetime.fromisoformat(row["ts"])
# Make timezone-aware if needed
if row_dt.tzinfo is None:
row_dt = row_dt.replace(tzinfo=timezone.utc)
dt_secs = (now_dt - row_dt).total_seconds()
except (ValueError, TypeError):
dt_secs = window_s
details = {}
try:
details = json.loads(row["details"] or "{}")
except json.JSONDecodeError:
pass
result.append({
"type": row["event_type"],
"dt": dt_secs,
"ts": row["ts"],
"details": details,
})
return result
# ── Preferences ───────────────────────────────────────────────────────────
def set_pref(self, person_id: str, key: str, value):
"""Set a learned preference for a person.
Parameters
----------
person_id, key : str
value : JSON-serialisable
"""
person = self.get_person(person_id)
prefs = person["prefs"]
prefs[key] = value
with self._lock:
self._conn.execute(
"UPDATE people SET prefs = ? WHERE person_id = ?",
(json.dumps(prefs), person_id),
)
self._conn.commit()
def get_pref(self, person_id: str, key: str, default=None):
"""Return a specific learned preference for a person."""
return self.get_person(person_id)["prefs"].get(key, default)
# ── Lifecycle ─────────────────────────────────────────────────────────────
def close(self):
"""Close the database connection."""
with self._lock:
self._conn.close()

View File

@ -0,0 +1,196 @@
"""
soul_loader.py SOUL.md persona parser and hot-reload watcher.
SOUL.md format
--------------
The file uses YAML front-matter (delimited by ``---`` lines) with an optional
Markdown description body that is ignored by the parser. Example::
---
name: "Salty"
humor_level: 7
sass_level: 4
base_mood: "playful"
...
---
# Optional description text (ignored)
Public API
----------
load_soul(path) -> dict (raises on parse error)
SoulWatcher(path, cb, interval)
.start()
.stop()
.reload_now() -> dict
Pure module no ROS2 imports.
"""
import os
import re
import threading
import time
import yaml
# Keys that are required in every SOUL.md file
_REQUIRED_KEYS = {
"name",
"humor_level",
"sass_level",
"base_mood",
"threshold_regular",
"threshold_favorite",
"greeting_stranger",
"greeting_regular",
"greeting_favorite",
}
_VALID_MOODS = {"happy", "curious", "annoyed", "playful"}
def _extract_frontmatter(text: str) -> str:
"""Return the YAML block between the first pair of ``---`` delimiters.
Raises ``ValueError`` if the file does not contain valid front-matter.
"""
lines = text.splitlines()
delimiters = [i for i, l in enumerate(lines) if l.strip() == "---"]
if len(delimiters) < 2:
# No delimiter found — treat the whole file as plain YAML
return text
start = delimiters[0] + 1
end = delimiters[1]
return "\n".join(lines[start:end])
def load_soul(path: str) -> dict:
"""Parse a SOUL.md file and return the validated config dict.
Parameters
----------
path : str
Absolute path to the SOUL.md file.
Returns
-------
dict
Validated persona configuration.
Raises
------
FileNotFoundError
If the file does not exist.
ValueError
If the YAML is malformed or required keys are missing.
"""
if not os.path.isfile(path):
raise FileNotFoundError(f"SOUL.md not found: {path}")
with open(path, "r", encoding="utf-8") as fh:
raw = fh.read()
yaml_text = _extract_frontmatter(raw)
try:
data = yaml.safe_load(yaml_text)
except yaml.YAMLError as exc:
raise ValueError(f"SOUL.md YAML parse error in {path}: {exc}") from exc
if not isinstance(data, dict):
raise ValueError(f"SOUL.md top level must be a YAML mapping, got {type(data)}")
# Validate required keys
missing = _REQUIRED_KEYS - data.keys()
if missing:
raise ValueError(f"SOUL.md missing required keys: {sorted(missing)}")
# Validate ranges
for key in ("humor_level", "sass_level"):
val = data.get(key)
if not isinstance(val, (int, float)) or not (0 <= val <= 10):
raise ValueError(f"SOUL.md '{key}' must be a number 010, got {val!r}")
if data.get("base_mood") not in _VALID_MOODS:
raise ValueError(
f"SOUL.md 'base_mood' must be one of {sorted(_VALID_MOODS)}, "
f"got {data.get('base_mood')!r}"
)
for key in ("threshold_regular", "threshold_favorite"):
val = data.get(key)
if not isinstance(val, int) or val < 0:
raise ValueError(f"SOUL.md '{key}' must be a non-negative integer, got {val!r}")
if data["threshold_regular"] > data["threshold_favorite"]:
raise ValueError(
"SOUL.md 'threshold_regular' must be <= 'threshold_favorite'"
)
return data
class SoulWatcher:
"""Background thread that polls SOUL.md for changes and calls a callback.
Parameters
----------
path : str
Path to the SOUL.md file to watch.
on_reload : callable
``on_reload(soul_dict)`` called whenever a valid new SOUL.md is loaded.
interval : float
Polling interval in seconds (default 5.0).
on_error : callable, optional
``on_error(exception)`` called when a reload attempt fails.
"""
def __init__(self, path: str, on_reload, interval: float = 5.0, on_error=None):
self._path = path
self._on_reload = on_reload
self._interval = interval
self._on_error = on_error
self._thread = None
self._stop_evt = threading.Event()
self._last_mtime = 0.0
# ------------------------------------------------------------------
def start(self):
"""Start the background polling thread."""
if self._thread and self._thread.is_alive():
return
self._stop_evt.clear()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
def stop(self):
"""Signal the watcher thread to stop and block until it exits."""
self._stop_evt.set()
if self._thread:
self._thread.join(timeout=self._interval + 1.0)
def reload_now(self) -> dict:
"""Force an immediate reload and return the new soul dict."""
soul = load_soul(self._path)
self._last_mtime = os.path.getmtime(self._path)
self._on_reload(soul)
return soul
# ------------------------------------------------------------------
def _run(self):
while not self._stop_evt.wait(self._interval):
try:
mtime = os.path.getmtime(self._path)
except OSError:
continue
if mtime != self._last_mtime:
try:
soul = load_soul(self._path)
except (FileNotFoundError, ValueError) as exc:
if self._on_error:
self._on_error(exc)
continue
self._last_mtime = mtime
self._on_reload(soul)

View File

@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/saltybot_social_personality
[install]
install_scripts=$base/lib/saltybot_social_personality

View File

@ -0,0 +1,28 @@
from setuptools import setup
package_name = "saltybot_social_personality"
setup(
name=package_name,
version="0.1.0",
packages=[package_name],
data_files=[
("share/ament_index/resource_index/packages", [f"resource/{package_name}"]),
(f"share/{package_name}", ["package.xml"]),
(f"share/{package_name}/launch", ["launch/personality.launch.py"]),
(f"share/{package_name}/config", ["config/SOUL.md",
"config/personality_params.yaml"]),
],
install_requires=["setuptools", "pyyaml"],
zip_safe=True,
maintainer="sl-controls",
maintainer_email="sl-controls@saltylab.local",
description="SOUL.md-driven personality system for saltybot social interaction",
license="MIT",
tests_require=["pytest"],
entry_points={
"console_scripts": [
"personality_node = saltybot_social_personality.personality_node:main",
],
},
)

View File

@ -0,0 +1,475 @@
"""
test_personality.py Unit tests for the saltybot personality system.
No ROS2 runtime required. Tests pure functions from:
- soul_loader.py
- mood_engine.py
- relationship_db.py
Run with:
pytest jetson/ros2_ws/src/saltybot_social_personality/test/test_personality.py
"""
import os
import tempfile
import textwrap
import pytest
# ── Imports (pure modules, no ROS2) ──────────────────────────────────────────
import sys
sys.path.insert(
0,
os.path.join(os.path.dirname(__file__), "..", "saltybot_social_personality"),
)
from soul_loader import load_soul, _extract_frontmatter
from mood_engine import (
compute_mood, get_relationship_tier, build_greeting,
MOOD_HAPPY, MOOD_PLAYFUL, MOOD_CURIOUS, MOOD_ANNOYED,
TIER_STRANGER, TIER_REGULAR, TIER_FAVORITE,
EVENT_NEGATIVE, EVENT_POSITIVE, EVENT_GREETING,
)
from relationship_db import RelationshipDB
# ── Helpers ───────────────────────────────────────────────────────────────────
def _minimal_soul(**overrides) -> dict:
"""Return a valid minimal soul dict, optionally overriding keys."""
base = {
"name": "Salty",
"humor_level": 7,
"sass_level": 4,
"base_mood": "playful",
"threshold_regular": 5,
"threshold_favorite": 20,
"greeting_stranger": "Hello there!",
"greeting_regular": "Hey {name}!",
"greeting_favorite": "Oh hey {name}!!",
}
base.update(overrides)
return base
def _write_soul(content: str) -> str:
"""Write a SOUL.md string to a temp file and return the path."""
fh = tempfile.NamedTemporaryFile(
mode="w", suffix=".md", delete=False, encoding="utf-8"
)
fh.write(content)
fh.close()
return fh.name
_VALID_SOUL_CONTENT = textwrap.dedent("""\
---
name: "TestBot"
speaking_style: "casual"
humor_level: 7
sass_level: 3
base_mood: "playful"
threshold_regular: 5
threshold_favorite: 20
greeting_stranger: "Hello stranger!"
greeting_regular: "Hey {name}!"
greeting_favorite: "Oh hey {name}!!"
mood_prefix_playful: "Beep boop! "
mood_prefix_happy: "Great — "
mood_prefix_curious: "Hmm, "
mood_prefix_annoyed: "Ugh, "
---
# Description (ignored)
This is the description body.
""")
# ═══════════════════════════════════════════════════════════════════════════════
# soul_loader tests
# ═══════════════════════════════════════════════════════════════════════════════
class TestExtractFrontmatter:
def test_delimited(self):
content = "---\nkey: val\n---\n# body"
assert _extract_frontmatter(content) == "key: val"
def test_no_delimiters_returns_whole(self):
content = "key: val\nother: 123"
assert _extract_frontmatter(content) == content
def test_single_delimiter_returns_whole(self):
content = "---\nkey: val\n"
result = _extract_frontmatter(content)
assert "key: val" in result
def test_body_stripped(self):
content = "---\nname: X\n---\n# Body text\nMore body"
assert "Body text" not in _extract_frontmatter(content)
assert "name: X" in _extract_frontmatter(content)
class TestLoadSoul:
def test_valid_file_loads(self):
path = _write_soul(_VALID_SOUL_CONTENT)
try:
soul = load_soul(path)
assert soul["name"] == "TestBot"
assert soul["humor_level"] == 7
assert soul["base_mood"] == "playful"
finally:
os.unlink(path)
def test_missing_file_raises(self):
with pytest.raises(FileNotFoundError):
load_soul("/nonexistent/SOUL.md")
def test_missing_required_key_raises(self):
content = "---\nname: X\nhumor_level: 5\n---" # missing many keys
path = _write_soul(content)
try:
with pytest.raises(ValueError, match="missing required keys"):
load_soul(path)
finally:
os.unlink(path)
def test_humor_out_of_range_raises(self):
soul_str = _VALID_SOUL_CONTENT.replace("humor_level: 7", "humor_level: 11")
path = _write_soul(soul_str)
try:
with pytest.raises(ValueError, match="humor_level"):
load_soul(path)
finally:
os.unlink(path)
def test_invalid_mood_raises(self):
soul_str = _VALID_SOUL_CONTENT.replace(
'base_mood: "playful"', 'base_mood: "grumpy"'
)
path = _write_soul(soul_str)
try:
with pytest.raises(ValueError, match="base_mood"):
load_soul(path)
finally:
os.unlink(path)
def test_threshold_order_enforced(self):
soul_str = _VALID_SOUL_CONTENT.replace(
"threshold_regular: 5", "threshold_regular: 25"
)
path = _write_soul(soul_str)
try:
with pytest.raises(ValueError, match="threshold_regular"):
load_soul(path)
finally:
os.unlink(path)
def test_extra_keys_allowed(self):
content = _VALID_SOUL_CONTENT.replace(
"---\n# Description",
"custom_key: 42\n---\n# Description"
)
path = _write_soul(content)
try:
soul = load_soul(path)
assert soul.get("custom_key") == 42
finally:
os.unlink(path)
# ═══════════════════════════════════════════════════════════════════════════════
# mood_engine tests
# ═══════════════════════════════════════════════════════════════════════════════
class TestGetRelationshipTier:
def test_zero_interactions_stranger(self):
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
assert get_relationship_tier(soul, 0) == TIER_STRANGER
def test_below_regular_stranger(self):
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
assert get_relationship_tier(soul, 4) == TIER_STRANGER
def test_at_regular_threshold(self):
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
assert get_relationship_tier(soul, 5) == TIER_REGULAR
def test_above_regular_below_favorite(self):
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
assert get_relationship_tier(soul, 10) == TIER_REGULAR
def test_at_favorite_threshold(self):
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
assert get_relationship_tier(soul, 20) == TIER_FAVORITE
def test_above_favorite(self):
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
assert get_relationship_tier(soul, 100) == TIER_FAVORITE
class TestComputeMood:
def test_unknown_person_returns_curious(self):
soul = _minimal_soul(humor_level=7)
mood = compute_mood(soul, score=0.0, interaction_count=0, recent_events=[])
assert mood == MOOD_CURIOUS
def test_stranger_low_count_returns_curious(self):
soul = _minimal_soul(threshold_regular=5)
mood = compute_mood(soul, score=2.0, interaction_count=3, recent_events=[])
assert mood == MOOD_CURIOUS
def test_two_negative_events_returns_annoyed(self):
soul = _minimal_soul(threshold_regular=5)
events = [
{"type": EVENT_NEGATIVE, "dt": 30.0},
{"type": EVENT_NEGATIVE, "dt": 60.0},
]
mood = compute_mood(soul, score=10.0, interaction_count=10, recent_events=events)
assert mood == MOOD_ANNOYED
def test_one_negative_not_annoyed(self):
soul = _minimal_soul(humor_level=7, threshold_regular=5, threshold_favorite=20)
events = [{"type": EVENT_NEGATIVE, "dt": 30.0}]
# 1 negative is not enough → should still be happy/playful based on score
mood = compute_mood(soul, score=25.0, interaction_count=25, recent_events=events)
assert mood != MOOD_ANNOYED
def test_high_humor_regular_returns_playful(self):
soul = _minimal_soul(humor_level=8, threshold_regular=5, threshold_favorite=20)
events = [{"type": EVENT_POSITIVE, "dt": 10.0}]
mood = compute_mood(soul, score=10.0, interaction_count=8, recent_events=events)
assert mood == MOOD_PLAYFUL
def test_low_humor_regular_returns_happy(self):
soul = _minimal_soul(humor_level=4, threshold_regular=5, threshold_favorite=20)
events = [{"type": EVENT_POSITIVE, "dt": 10.0}]
mood = compute_mood(soul, score=10.0, interaction_count=8, recent_events=events)
assert mood == MOOD_HAPPY
def test_stale_negative_ignored(self):
soul = _minimal_soul(humor_level=8, threshold_regular=5, threshold_favorite=20)
# dt > 120s → outside the recent window → should not trigger annoyed
events = [
{"type": EVENT_NEGATIVE, "dt": 200.0},
{"type": EVENT_NEGATIVE, "dt": 300.0},
]
mood = compute_mood(soul, score=15.0, interaction_count=10, recent_events=events)
assert mood != MOOD_ANNOYED
def test_favorite_high_humor_playful(self):
soul = _minimal_soul(humor_level=9, threshold_regular=5, threshold_favorite=20)
mood = compute_mood(soul, score=50.0, interaction_count=30, recent_events=[])
assert mood == MOOD_PLAYFUL
def test_favorite_low_humor_happy(self):
soul = _minimal_soul(humor_level=3, threshold_regular=5, threshold_favorite=20)
mood = compute_mood(soul, score=50.0, interaction_count=30, recent_events=[])
assert mood == MOOD_HAPPY
class TestBuildGreeting:
def _soul(self, **kw):
return _minimal_soul(
mood_prefix_happy="Great — ",
mood_prefix_curious="Hmm, ",
mood_prefix_annoyed="Well, ",
mood_prefix_playful="Beep boop! ",
**kw,
)
def test_stranger_greeting(self):
soul = self._soul()
g = build_greeting(soul, TIER_STRANGER, MOOD_CURIOUS, "")
assert "hello" in g.lower()
def test_regular_greeting_contains_name(self):
soul = self._soul()
g = build_greeting(soul, TIER_REGULAR, MOOD_HAPPY, "alice")
assert "alice" in g
def test_favorite_greeting_contains_name(self):
soul = self._soul()
g = build_greeting(soul, TIER_FAVORITE, MOOD_PLAYFUL, "bob")
assert "bob" in g
def test_mood_prefix_applied(self):
soul = self._soul()
g = build_greeting(soul, TIER_REGULAR, MOOD_PLAYFUL, "alice")
assert g.startswith("Beep boop!")
def test_no_prefix_key_no_prefix(self):
soul = _minimal_soul() # no mood_prefix_* keys
g = build_greeting(soul, TIER_REGULAR, MOOD_HAPPY, "alice")
assert g.startswith("Hey")
def test_empty_person_id_uses_friend(self):
soul = self._soul()
g = build_greeting(soul, TIER_REGULAR, MOOD_HAPPY, "")
assert "friend" in g
def test_happy_prefix(self):
soul = self._soul()
g = build_greeting(soul, TIER_REGULAR, MOOD_HAPPY, "carol")
assert g.startswith("Great")
def test_annoyed_prefix(self):
soul = self._soul()
g = build_greeting(soul, TIER_REGULAR, MOOD_ANNOYED, "dave")
assert g.startswith("Well")
# ═══════════════════════════════════════════════════════════════════════════════
# relationship_db tests
# ═══════════════════════════════════════════════════════════════════════════════
class TestRelationshipDB:
@pytest.fixture
def db(self, tmp_path):
path = str(tmp_path / "test.db")
d = RelationshipDB(path)
yield d
d.close()
def test_get_person_creates_default(self, db):
p = db.get_person("alice")
assert p["person_id"] == "alice"
assert p["score"] == pytest.approx(0.0)
assert p["interaction_count"] == 0
def test_get_person_idempotent(self, db):
p1 = db.get_person("bob")
p2 = db.get_person("bob")
assert p1["person_id"] == p2["person_id"]
def test_record_interaction_increments_count(self, db):
db.record_interaction("alice", "detection")
db.record_interaction("alice", "detection")
p = db.get_person("alice")
assert p["interaction_count"] == 2
def test_record_interaction_updates_score(self, db):
db.record_interaction("alice", "positive", delta_score=5.0)
p = db.get_person("alice")
assert p["score"] == pytest.approx(5.0)
def test_negative_delta_reduces_score(self, db):
db.record_interaction("carol", "positive", delta_score=10.0)
db.record_interaction("carol", "negative", delta_score=-3.0)
p = db.get_person("carol")
assert p["score"] == pytest.approx(7.0)
def test_score_zero_by_default(self, db):
p = db.get_person("dave")
assert p["score"] == pytest.approx(0.0)
def test_set_and_get_pref(self, db):
db.set_pref("alice", "language", "en")
assert db.get_pref("alice", "language") == "en"
def test_get_pref_default(self, db):
assert db.get_pref("nobody", "language", "fr") == "fr"
def test_multiple_prefs_stored(self, db):
db.set_pref("alice", "lang", "en")
db.set_pref("alice", "name", "Alice")
assert db.get_pref("alice", "lang") == "en"
assert db.get_pref("alice", "name") == "Alice"
def test_all_people_returns_list(self, db):
db.record_interaction("a", "detection")
db.record_interaction("b", "detection")
people = db.all_people()
ids = {p["person_id"] for p in people}
assert {"a", "b"} <= ids
def test_get_recent_events_returns_events(self, db):
db.record_interaction("alice", "greeting", delta_score=1.0)
events = db.get_recent_events("alice", window_s=60.0)
assert len(events) == 1
assert events[0]["type"] == "greeting"
def test_get_recent_events_empty_for_new_person(self, db):
events = db.get_recent_events("nobody", window_s=60.0)
assert events == []
def test_event_dt_positive(self, db):
db.record_interaction("alice", "detection")
events = db.get_recent_events("alice", window_s=60.0)
assert events[0]["dt"] >= 0.0
def test_multiple_people_isolated(self, db):
db.record_interaction("alice", "positive", delta_score=10.0)
db.record_interaction("bob", "negative", delta_score=-5.0)
assert db.get_person("alice")["score"] == pytest.approx(10.0)
assert db.get_person("bob")["score"] == pytest.approx(-5.0)
def test_details_stored(self, db):
db.record_interaction("alice", "greeting", details={"location": "lab"})
events = db.get_recent_events("alice", window_s=60.0)
assert events[0]["details"].get("location") == "lab"
# ═══════════════════════════════════════════════════════════════════════════════
# Integration: soul → tier → mood → greeting pipeline
# ═══════════════════════════════════════════════════════════════════════════════
class TestIntegrationPipeline:
def test_stranger_pipeline(self, tmp_path):
db_path = str(tmp_path / "int.db")
db = RelationshipDB(db_path)
soul = _minimal_soul(
humor_level=7, threshold_regular=5, threshold_favorite=20,
mood_prefix_curious="Hmm, "
)
# No prior interactions
person = db.get_person("stranger_001")
events = db.get_recent_events("stranger_001")
tier = get_relationship_tier(soul, person["interaction_count"])
mood = compute_mood(soul, person["score"], person["interaction_count"], events)
greeting = build_greeting(soul, tier, mood, "stranger_001")
assert tier == TIER_STRANGER
assert mood == MOOD_CURIOUS
assert "hello" in greeting.lower()
db.close()
def test_regular_positive_pipeline(self, tmp_path):
db_path = str(tmp_path / "int2.db")
db = RelationshipDB(db_path)
soul = _minimal_soul(
humor_level=8, threshold_regular=5, threshold_favorite=20,
mood_prefix_playful="Beep! "
)
# Simulate 6 positive interactions (> threshold_regular=5)
for _ in range(6):
db.record_interaction("alice", "positive", delta_score=2.0)
person = db.get_person("alice")
events = db.get_recent_events("alice")
tier = get_relationship_tier(soul, person["interaction_count"])
mood = compute_mood(soul, person["score"], person["interaction_count"], events)
greeting = build_greeting(soul, tier, mood, "alice")
assert tier == TIER_REGULAR
assert mood == MOOD_PLAYFUL # humor_level=8, recent positive events
assert "alice" in greeting
assert greeting.startswith("Beep!")
db.close()
def test_favorite_pipeline(self, tmp_path):
db_path = str(tmp_path / "int3.db")
db = RelationshipDB(db_path)
soul = _minimal_soul(
humor_level=5, threshold_regular=5, threshold_favorite=20
)
for _ in range(25):
db.record_interaction("bob", "positive", delta_score=1.0)
person = db.get_person("bob")
tier = get_relationship_tier(soul, person["interaction_count"])
assert tier == TIER_FAVORITE
greeting = build_greeting(soul, tier, "happy", "bob")
assert "bob" in greeting
assert "Oh hey" in greeting
db.close()