Merge pull request 'feat(social): person enrollment system #87' (#95) from sl-perception/social-enrollment into main

This commit is contained in:
sl-jetson 2026-03-01 23:55:16 -05:00
commit d6a6965af6
11 changed files with 712 additions and 0 deletions

View File

@ -0,0 +1,6 @@
enrollment_node:
ros__parameters:
db_path: '/mnt/nvme/saltybot/gallery/persons.db'
voice_samples_dir: '/mnt/nvme/saltybot/gallery/voice'
auto_enroll_phrase: 'remember me my name is'
n_samples_default: 10

View File

@ -0,0 +1,22 @@
"""Launch file for saltybot social enrollment node."""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch_ros.actions import Node
def generate_launch_description():
pkg_share = get_package_share_directory('saltybot_social_enrollment')
config_file = os.path.join(pkg_share, 'config', 'enrollment_params.yaml')
return LaunchDescription([
Node(
package='saltybot_social_enrollment',
executable='enrollment_node',
name='enrollment_node',
parameters=[config_file],
output='screen',
),
])

View File

@ -0,0 +1,23 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social_enrollment</name>
<version>0.1.0</version>
<description>Person enrollment system for saltybot social interaction</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<buildtool_depend>ament_python</buildtool_depend>
<depend>rclpy</depend>
<depend>sensor_msgs</depend>
<depend>cv_bridge</depend>
<depend>std_msgs</depend>
<depend>saltybot_social_msgs</depend>
<exec_depend>python3-numpy</exec_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,184 @@
#!/usr/bin/env python3
"""enrollment_cli.py -- Gallery management CLI for saltybot social.
Usage:
ros2 run saltybot_social_enrollment enrollment_cli enroll --name "Alice" [--samples 15] [--mode face]
ros2 run saltybot_social_enrollment enrollment_cli list
ros2 run saltybot_social_enrollment enrollment_cli delete --id 3
ros2 run saltybot_social_enrollment enrollment_cli rename --id 2 --name "Bob"
"""
import argparse
import sys
import rclpy
from rclpy.node import Node
from saltybot_social_msgs.srv import (
EnrollPerson,
ListPersons,
DeletePerson,
UpdatePerson,
)
class EnrollmentCLI(Node):
def __init__(self):
super().__init__('enrollment_cli')
self._enroll_client = self.create_client(
EnrollPerson, '/social/enroll'
)
self._list_client = self.create_client(
ListPersons, '/social/persons/list'
)
self._delete_client = self.create_client(
DeletePerson, '/social/persons/delete'
)
self._update_client = self.create_client(
UpdatePerson, '/social/persons/update'
)
def enroll(self, name: str, n_samples: int = 10, mode: str = 'face'):
if not self._enroll_client.wait_for_service(timeout_sec=5.0):
print('ERROR: /social/enroll service not available')
return False
req = EnrollPerson.Request()
req.name = name
req.mode = mode
req.n_samples = n_samples
print(f'Enrolling "{name}" ({mode}, {n_samples} samples)...')
future = self._enroll_client.call_async(req)
rclpy.spin_until_future_complete(self, future, timeout_sec=60.0)
if future.result() is None:
print('ERROR: Enrollment timed out')
return False
resp = future.result()
if resp.success:
print(f'OK: Enrolled "{name}" as person_id={resp.person_id}')
else:
print(f'FAILED: {resp.message}')
return resp.success
def list_persons(self):
if not self._list_client.wait_for_service(timeout_sec=5.0):
print('ERROR: /social/persons/list service not available')
return
req = ListPersons.Request()
future = self._list_client.call_async(req)
rclpy.spin_until_future_complete(self, future, timeout_sec=10.0)
if future.result() is None:
print('ERROR: List request timed out')
return
resp = future.result()
if not resp.persons:
print('Gallery is empty.')
return
print(f'{"ID":>4} {"Name":<20} {"Samples":>7} {"Embedding Dim":>13}')
print('-' * 50)
for p in resp.persons:
dim = len(p.embedding) if p.embedding else 0
print(f'{p.person_id:>4} {p.person_name:<20} {p.sample_count:>7} {dim:>13}')
def delete(self, person_id: int):
if not self._delete_client.wait_for_service(timeout_sec=5.0):
print('ERROR: /social/persons/delete service not available')
return False
req = DeletePerson.Request()
req.person_id = person_id
future = self._delete_client.call_async(req)
rclpy.spin_until_future_complete(self, future, timeout_sec=10.0)
if future.result() is None:
print('ERROR: Delete request timed out')
return False
resp = future.result()
if resp.success:
print(f'OK: {resp.message}')
else:
print(f'FAILED: {resp.message}')
return resp.success
def rename(self, person_id: int, new_name: str):
if not self._update_client.wait_for_service(timeout_sec=5.0):
print('ERROR: /social/persons/update service not available')
return False
req = UpdatePerson.Request()
req.person_id = person_id
req.new_name = new_name
future = self._update_client.call_async(req)
rclpy.spin_until_future_complete(self, future, timeout_sec=10.0)
if future.result() is None:
print('ERROR: Rename request timed out')
return False
resp = future.result()
if resp.success:
print(f'OK: {resp.message}')
else:
print(f'FAILED: {resp.message}')
return resp.success
def main(args=None):
parser = argparse.ArgumentParser(
description='saltybot person enrollment CLI'
)
subparsers = parser.add_subparsers(dest='command', required=True)
# enroll
enroll_p = subparsers.add_parser('enroll', help='Enroll a new person')
enroll_p.add_argument('--name', required=True, help='Person name')
enroll_p.add_argument('--samples', type=int, default=10,
help='Number of face samples (default: 10)')
enroll_p.add_argument('--mode', default='face',
help='Enrollment mode (default: face)')
# list
subparsers.add_parser('list', help='List enrolled persons')
# delete
delete_p = subparsers.add_parser('delete', help='Delete a person')
delete_p.add_argument('--id', type=int, required=True,
help='Person ID to delete')
# rename
rename_p = subparsers.add_parser('rename', help='Rename a person')
rename_p.add_argument('--id', type=int, required=True,
help='Person ID to rename')
rename_p.add_argument('--name', required=True, help='New name')
parsed = parser.parse_args(sys.argv[1:])
rclpy.init()
cli = EnrollmentCLI()
try:
if parsed.command == 'enroll':
cli.enroll(parsed.name, parsed.samples, parsed.mode)
elif parsed.command == 'list':
cli.list_persons()
elif parsed.command == 'delete':
cli.delete(parsed.id)
elif parsed.command == 'rename':
cli.rename(parsed.id, parsed.name)
finally:
cli.destroy_node()
rclpy.try_shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,302 @@
"""enrollment_node.py -- ROS2 person enrollment node for saltybot social.
Coordinates person enrollment:
- Forwards /social/enroll to face_recognizer's service
- Owns persistent SQLite gallery (PersonDB)
- Voice-triggered enrollment via "remember me my name is X"
- Gallery management services (list/delete/update)
- Syncs DB from /social/faces/embeddings topic
"""
import threading
import numpy as np
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, DurabilityPolicy
from std_msgs.msg import String
from saltybot_social_msgs.msg import (
FaceDetectionArray,
FaceEmbedding,
FaceEmbeddingArray,
)
from saltybot_social_msgs.srv import (
EnrollPerson,
ListPersons,
DeletePerson,
UpdatePerson,
)
from saltybot_social_enrollment.person_db import PersonDB
class EnrollmentNode(Node):
def __init__(self):
super().__init__('enrollment_node')
# Parameters
self.declare_parameter('db_path', '/mnt/nvme/saltybot/gallery/persons.db')
self.declare_parameter('voice_samples_dir', '/mnt/nvme/saltybot/gallery/voice')
self.declare_parameter('auto_enroll_phrase', 'remember me my name is')
self.declare_parameter('n_samples_default', 10)
db_path = self.get_parameter('db_path').value
self._voice_dir = self.get_parameter('voice_samples_dir').value
self._phrase = self.get_parameter('auto_enroll_phrase').value
self._n_samples = self.get_parameter('n_samples_default').value
self._db = PersonDB(db_path)
self.get_logger().info(f'PersonDB initialized at {db_path}')
# Client to face_recognizer's enroll service
self._enroll_client = self.create_client(
EnrollPerson, '/social/face_recognizer/enroll'
)
# QoS profiles
best_effort_qos = QoSProfile(
depth=10,
reliability=ReliabilityPolicy.BEST_EFFORT,
durability=DurabilityPolicy.VOLATILE,
)
reliable_qos = QoSProfile(
depth=1,
reliability=ReliabilityPolicy.RELIABLE,
durability=DurabilityPolicy.VOLATILE,
)
status_qos = QoSProfile(
depth=1,
reliability=ReliabilityPolicy.BEST_EFFORT,
durability=DurabilityPolicy.VOLATILE,
)
# Subscriptions
self.create_subscription(
FaceDetectionArray, '/social/faces/detections',
self._on_detections, best_effort_qos
)
self.create_subscription(
FaceEmbeddingArray, '/social/faces/embeddings',
self._on_embeddings, reliable_qos
)
self.create_subscription(
String, '/social/speech/transcript',
self._on_transcript, best_effort_qos
)
self.create_subscription(
String, '/social/speech/command',
self._on_command, best_effort_qos
)
# Services
self.create_service(EnrollPerson, '/social/enroll', self._handle_enroll)
self.create_service(ListPersons, '/social/persons/list', self._handle_list)
self.create_service(DeletePerson, '/social/persons/delete', self._handle_delete)
self.create_service(UpdatePerson, '/social/persons/update', self._handle_update)
# Publishers
self._pub_embeddings = self.create_publisher(
FaceEmbeddingArray, '/social/faces/embeddings', reliable_qos
)
self._pub_status = self.create_publisher(
String, '/social/enrollment/status', status_qos
)
self.get_logger().info('EnrollmentNode ready')
# ---- Voice-triggered enrollment ----
def _on_transcript(self, msg: String):
text = msg.data.lower()
phrase = self._phrase.lower()
if phrase in text:
idx = text.index(phrase) + len(phrase)
name = text[idx:].strip()
# Clean up: take first 3 words max as the name
words = name.split()
if words:
name = ' '.join(words[:3]).title()
self.get_logger().info(f'Voice enrollment triggered: "{name}"')
self._trigger_voice_enroll(name)
def _on_command(self, msg: String):
# Reserved for explicit voice commands (e.g., "enroll Alice")
pass
def _trigger_voice_enroll(self, name: str):
if not self._enroll_client.wait_for_service(timeout_sec=1.0):
self.get_logger().warn(
'face_recognizer enroll service not available'
)
self._publish_status(f'Enrollment failed: face_recognizer unavailable')
return
req = EnrollPerson.Request()
req.name = name
req.mode = 'face'
req.n_samples = self._n_samples
future = self._enroll_client.call_async(req)
future.add_done_callback(
lambda f: self._on_enroll_done(f, name)
)
self._publish_status(f'Enrolling "{name}"... look at the camera')
def _on_enroll_done(self, future, name: str):
try:
resp = future.result()
if resp.success:
status = f'Enrolled "{name}" (id={resp.person_id})'
self.get_logger().info(status)
else:
status = f'Enrollment failed for "{name}": {resp.message}'
self.get_logger().warn(status)
self._publish_status(status)
except Exception as e:
self.get_logger().error(f'Enroll call failed: {e}')
self._publish_status(f'Enrollment error: {e}')
# ---- Face detection callback (during enrollment) ----
def _on_detections(self, msg: FaceDetectionArray):
# Reserved for future direct enrollment (without face_recognizer)
pass
# ---- Embeddings sync from face_recognizer ----
def _on_embeddings(self, msg: FaceEmbeddingArray):
for emb in msg.embeddings:
existing = self._db.get_person(emb.person_id)
if existing is None:
arr = np.array(emb.embedding, dtype=np.float32)
self._db.add_person(
emb.person_name, arr, emb.sample_count
)
self.get_logger().info(
f'Synced new person from face_recognizer: '
f'{emb.person_name} (id={emb.person_id})'
)
# ---- Service handlers ----
def _handle_enroll(self, request, response):
"""Forward enrollment to face_recognizer service."""
if not self._enroll_client.wait_for_service(timeout_sec=2.0):
response.success = False
response.message = 'face_recognizer service unavailable'
return response
# Use threading.Event to bridge async call in service callback
event = threading.Event()
result_holder = {}
req = EnrollPerson.Request()
req.name = request.name
req.mode = request.mode
req.n_samples = request.n_samples
future = self._enroll_client.call_async(req)
def _done(f):
try:
result_holder['resp'] = f.result()
except Exception as e:
result_holder['err'] = str(e)
event.set()
future.add_done_callback(_done)
success = event.wait(timeout=35.0)
if not success:
response.success = False
response.message = 'Enrollment timed out'
elif 'resp' in result_holder:
resp = result_holder['resp']
response.success = resp.success
response.message = resp.message
response.person_id = resp.person_id
else:
response.success = False
response.message = result_holder.get('err', 'Unknown error')
return response
def _handle_list(self, request, response):
persons = self._db.get_all()
response.persons = []
for p in persons:
fe = FaceEmbedding()
fe.person_id = p['id']
fe.person_name = p['name']
fe.sample_count = p['sample_count']
fe.enrolled_at.sec = int(p['enrolled_at'])
fe.enrolled_at.nanosec = int(
(p['enrolled_at'] - int(p['enrolled_at'])) * 1e9
)
if p['embedding'] is not None:
fe.embedding = p['embedding'].tolist()
response.persons.append(fe)
return response
def _handle_delete(self, request, response):
success = self._db.delete_person(request.person_id)
response.success = success
if success:
response.message = f'Deleted person {request.person_id}'
self.get_logger().info(response.message)
self._publish_embeddings_from_db()
else:
response.message = f'Person {request.person_id} not found'
return response
def _handle_update(self, request, response):
success = self._db.update_name(request.person_id, request.new_name)
response.success = success
if success:
response.message = f'Updated person {request.person_id} name to "{request.new_name}"'
self.get_logger().info(response.message)
self._publish_embeddings_from_db()
else:
response.message = f'Person {request.person_id} not found'
return response
# ---- Helpers ----
def _publish_status(self, text: str):
msg = String()
msg.data = text
self._pub_status.publish(msg)
def _publish_embeddings_from_db(self):
persons = self._db.get_all()
arr = FaceEmbeddingArray()
arr.header.stamp = self.get_clock().now().to_msg()
for p in persons:
if p['embedding'] is not None:
fe = FaceEmbedding()
fe.person_id = p['id']
fe.person_name = p['name']
fe.sample_count = p['sample_count']
fe.enrolled_at.sec = int(p['enrolled_at'])
fe.enrolled_at.nanosec = int(
(p['enrolled_at'] - int(p['enrolled_at'])) * 1e9
)
fe.embedding = p['embedding'].tolist()
arr.embeddings.append(fe)
self._pub_embeddings.publish(arr)
def main(args=None):
rclpy.init(args=args)
node = EnrollmentNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.try_shutdown()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,138 @@
"""person_db.py -- Persistent SQLite person gallery for saltybot enrollment."""
import sqlite3
import json
import time
import numpy as np
import threading
from pathlib import Path
class PersonDB:
"""Thread-safe SQLite-backed person gallery.
Schema:
persons(id INTEGER PRIMARY KEY, name TEXT, enrolled_at REAL,
sample_count INTEGER, embedding BLOB, metadata TEXT)
voice_samples(id INTEGER PRIMARY KEY, person_id INTEGER REFERENCES persons,
recorded_at REAL, sample_path TEXT)
"""
def __init__(self, db_path: str):
self._db_path = db_path
self._lock = threading.Lock()
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
self._init_db()
def _init_db(self):
with self._connect() as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS persons (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
enrolled_at REAL NOT NULL,
sample_count INTEGER DEFAULT 1,
embedding BLOB,
metadata TEXT DEFAULT '{}'
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS voice_samples (
id INTEGER PRIMARY KEY AUTOINCREMENT,
person_id INTEGER REFERENCES persons(id),
recorded_at REAL NOT NULL,
sample_path TEXT NOT NULL
)
""")
def _connect(self):
return sqlite3.connect(self._db_path)
def add_person(self, name: str, embedding: np.ndarray, sample_count: int = 1,
metadata: dict = None) -> int:
"""Add a new person. Returns new person_id."""
emb_blob = embedding.astype(np.float32).tobytes() if embedding is not None else None
now = time.time()
with self._lock:
with self._connect() as conn:
cur = conn.execute(
"INSERT INTO persons (name, enrolled_at, sample_count, embedding, metadata) "
"VALUES (?, ?, ?, ?, ?)",
(name, now, sample_count, emb_blob, json.dumps(metadata or {}))
)
return cur.lastrowid
def update_embedding(self, person_id: int, embedding: np.ndarray,
sample_count: int) -> bool:
emb_blob = embedding.astype(np.float32).tobytes()
with self._lock:
with self._connect() as conn:
conn.execute(
"UPDATE persons SET embedding=?, sample_count=? WHERE id=?",
(emb_blob, sample_count, person_id)
)
return conn.total_changes > 0
def update_name(self, person_id: int, new_name: str) -> bool:
with self._lock:
with self._connect() as conn:
conn.execute(
"UPDATE persons SET name=? WHERE id=?",
(new_name, person_id)
)
return conn.total_changes > 0
def delete_person(self, person_id: int) -> bool:
with self._lock:
with self._connect() as conn:
conn.execute(
"DELETE FROM voice_samples WHERE person_id=?",
(person_id,)
)
conn.execute(
"DELETE FROM persons WHERE id=?",
(person_id,)
)
return conn.total_changes > 0
def get_all(self) -> list:
"""Returns list of dicts with id, name, enrolled_at, sample_count, embedding."""
with self._lock:
with self._connect() as conn:
rows = conn.execute(
"SELECT id, name, enrolled_at, sample_count, embedding FROM persons"
).fetchall()
result = []
for row in rows:
emb = np.frombuffer(row[4], dtype=np.float32).copy() if row[4] else None
result.append({
'id': row[0], 'name': row[1], 'enrolled_at': row[2],
'sample_count': row[3], 'embedding': emb
})
return result
def get_person(self, person_id: int) -> dict | None:
with self._lock:
with self._connect() as conn:
row = conn.execute(
"SELECT id, name, enrolled_at, sample_count, embedding "
"FROM persons WHERE id=?",
(person_id,)
).fetchone()
if row is None:
return None
emb = np.frombuffer(row[4], dtype=np.float32).copy() if row[4] else None
return {
'id': row[0], 'name': row[1], 'enrolled_at': row[2],
'sample_count': row[3], 'embedding': emb
}
def add_voice_sample(self, person_id: int, sample_path: str) -> int:
with self._lock:
with self._connect() as conn:
cur = conn.execute(
"INSERT INTO voice_samples (person_id, recorded_at, sample_path) "
"VALUES (?, ?, ?)",
(person_id, time.time(), sample_path)
)
return cur.lastrowid

View File

@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/saltybot_social_enrollment
[install]
install_scripts=$base/lib/saltybot_social_enrollment

View File

@ -0,0 +1,29 @@
from setuptools import setup
import os
from glob import glob
package_name = 'saltybot_social_enrollment'
setup(
name=package_name,
version='0.1.0',
packages=[package_name],
data_files=[
('share/ament_index/resource_index/packages', ['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
(os.path.join('share', package_name, 'launch'), glob('launch/*.py')),
(os.path.join('share', package_name, 'config'), glob('config/*.yaml')),
],
install_requires=['setuptools'],
zip_safe=True,
maintainer='seb',
maintainer_email='seb@vayrette.com',
description='Person enrollment system for saltybot social interaction',
license='MIT',
entry_points={
'console_scripts': [
'enrollment_node = saltybot_social_enrollment.enrollment_node:main',
'enrollment_cli = saltybot_social_enrollment.enrollment_cli:main',
],
},
)

View File

@ -0,0 +1,4 @@
sensor_msgs/Image crop
---
bool success
float32[512] embedding