Merge pull request 'feat(social): person enrollment system #87' (#95) from sl-perception/social-enrollment into main
This commit is contained in:
commit
d6a6965af6
@ -0,0 +1,6 @@
|
|||||||
|
enrollment_node:
|
||||||
|
ros__parameters:
|
||||||
|
db_path: '/mnt/nvme/saltybot/gallery/persons.db'
|
||||||
|
voice_samples_dir: '/mnt/nvme/saltybot/gallery/voice'
|
||||||
|
auto_enroll_phrase: 'remember me my name is'
|
||||||
|
n_samples_default: 10
|
||||||
@ -0,0 +1,22 @@
|
|||||||
|
"""Launch file for saltybot social enrollment node."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from ament_index_python.packages import get_package_share_directory
|
||||||
|
from launch import LaunchDescription
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launch_description():
|
||||||
|
pkg_share = get_package_share_directory('saltybot_social_enrollment')
|
||||||
|
config_file = os.path.join(pkg_share, 'config', 'enrollment_params.yaml')
|
||||||
|
|
||||||
|
return LaunchDescription([
|
||||||
|
Node(
|
||||||
|
package='saltybot_social_enrollment',
|
||||||
|
executable='enrollment_node',
|
||||||
|
name='enrollment_node',
|
||||||
|
parameters=[config_file],
|
||||||
|
output='screen',
|
||||||
|
),
|
||||||
|
])
|
||||||
23
jetson/ros2_ws/src/saltybot_social_enrollment/package.xml
Normal file
23
jetson/ros2_ws/src/saltybot_social_enrollment/package.xml
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
<?xml version="1.0"?>
|
||||||
|
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||||
|
<package format="3">
|
||||||
|
<name>saltybot_social_enrollment</name>
|
||||||
|
<version>0.1.0</version>
|
||||||
|
<description>Person enrollment system for saltybot social interaction</description>
|
||||||
|
<maintainer email="seb@vayrette.com">seb</maintainer>
|
||||||
|
<license>MIT</license>
|
||||||
|
|
||||||
|
<buildtool_depend>ament_python</buildtool_depend>
|
||||||
|
|
||||||
|
<depend>rclpy</depend>
|
||||||
|
<depend>sensor_msgs</depend>
|
||||||
|
<depend>cv_bridge</depend>
|
||||||
|
<depend>std_msgs</depend>
|
||||||
|
<depend>saltybot_social_msgs</depend>
|
||||||
|
|
||||||
|
<exec_depend>python3-numpy</exec_depend>
|
||||||
|
|
||||||
|
<export>
|
||||||
|
<build_type>ament_python</build_type>
|
||||||
|
</export>
|
||||||
|
</package>
|
||||||
@ -0,0 +1,184 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""enrollment_cli.py -- Gallery management CLI for saltybot social.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ros2 run saltybot_social_enrollment enrollment_cli enroll --name "Alice" [--samples 15] [--mode face]
|
||||||
|
ros2 run saltybot_social_enrollment enrollment_cli list
|
||||||
|
ros2 run saltybot_social_enrollment enrollment_cli delete --id 3
|
||||||
|
ros2 run saltybot_social_enrollment enrollment_cli rename --id 2 --name "Bob"
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
|
||||||
|
from saltybot_social_msgs.srv import (
|
||||||
|
EnrollPerson,
|
||||||
|
ListPersons,
|
||||||
|
DeletePerson,
|
||||||
|
UpdatePerson,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EnrollmentCLI(Node):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('enrollment_cli')
|
||||||
|
self._enroll_client = self.create_client(
|
||||||
|
EnrollPerson, '/social/enroll'
|
||||||
|
)
|
||||||
|
self._list_client = self.create_client(
|
||||||
|
ListPersons, '/social/persons/list'
|
||||||
|
)
|
||||||
|
self._delete_client = self.create_client(
|
||||||
|
DeletePerson, '/social/persons/delete'
|
||||||
|
)
|
||||||
|
self._update_client = self.create_client(
|
||||||
|
UpdatePerson, '/social/persons/update'
|
||||||
|
)
|
||||||
|
|
||||||
|
def enroll(self, name: str, n_samples: int = 10, mode: str = 'face'):
|
||||||
|
if not self._enroll_client.wait_for_service(timeout_sec=5.0):
|
||||||
|
print('ERROR: /social/enroll service not available')
|
||||||
|
return False
|
||||||
|
|
||||||
|
req = EnrollPerson.Request()
|
||||||
|
req.name = name
|
||||||
|
req.mode = mode
|
||||||
|
req.n_samples = n_samples
|
||||||
|
|
||||||
|
print(f'Enrolling "{name}" ({mode}, {n_samples} samples)...')
|
||||||
|
future = self._enroll_client.call_async(req)
|
||||||
|
rclpy.spin_until_future_complete(self, future, timeout_sec=60.0)
|
||||||
|
|
||||||
|
if future.result() is None:
|
||||||
|
print('ERROR: Enrollment timed out')
|
||||||
|
return False
|
||||||
|
|
||||||
|
resp = future.result()
|
||||||
|
if resp.success:
|
||||||
|
print(f'OK: Enrolled "{name}" as person_id={resp.person_id}')
|
||||||
|
else:
|
||||||
|
print(f'FAILED: {resp.message}')
|
||||||
|
return resp.success
|
||||||
|
|
||||||
|
def list_persons(self):
|
||||||
|
if not self._list_client.wait_for_service(timeout_sec=5.0):
|
||||||
|
print('ERROR: /social/persons/list service not available')
|
||||||
|
return
|
||||||
|
|
||||||
|
req = ListPersons.Request()
|
||||||
|
future = self._list_client.call_async(req)
|
||||||
|
rclpy.spin_until_future_complete(self, future, timeout_sec=10.0)
|
||||||
|
|
||||||
|
if future.result() is None:
|
||||||
|
print('ERROR: List request timed out')
|
||||||
|
return
|
||||||
|
|
||||||
|
resp = future.result()
|
||||||
|
if not resp.persons:
|
||||||
|
print('Gallery is empty.')
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f'{"ID":>4} {"Name":<20} {"Samples":>7} {"Embedding Dim":>13}')
|
||||||
|
print('-' * 50)
|
||||||
|
for p in resp.persons:
|
||||||
|
dim = len(p.embedding) if p.embedding else 0
|
||||||
|
print(f'{p.person_id:>4} {p.person_name:<20} {p.sample_count:>7} {dim:>13}')
|
||||||
|
|
||||||
|
def delete(self, person_id: int):
|
||||||
|
if not self._delete_client.wait_for_service(timeout_sec=5.0):
|
||||||
|
print('ERROR: /social/persons/delete service not available')
|
||||||
|
return False
|
||||||
|
|
||||||
|
req = DeletePerson.Request()
|
||||||
|
req.person_id = person_id
|
||||||
|
|
||||||
|
future = self._delete_client.call_async(req)
|
||||||
|
rclpy.spin_until_future_complete(self, future, timeout_sec=10.0)
|
||||||
|
|
||||||
|
if future.result() is None:
|
||||||
|
print('ERROR: Delete request timed out')
|
||||||
|
return False
|
||||||
|
|
||||||
|
resp = future.result()
|
||||||
|
if resp.success:
|
||||||
|
print(f'OK: {resp.message}')
|
||||||
|
else:
|
||||||
|
print(f'FAILED: {resp.message}')
|
||||||
|
return resp.success
|
||||||
|
|
||||||
|
def rename(self, person_id: int, new_name: str):
|
||||||
|
if not self._update_client.wait_for_service(timeout_sec=5.0):
|
||||||
|
print('ERROR: /social/persons/update service not available')
|
||||||
|
return False
|
||||||
|
|
||||||
|
req = UpdatePerson.Request()
|
||||||
|
req.person_id = person_id
|
||||||
|
req.new_name = new_name
|
||||||
|
|
||||||
|
future = self._update_client.call_async(req)
|
||||||
|
rclpy.spin_until_future_complete(self, future, timeout_sec=10.0)
|
||||||
|
|
||||||
|
if future.result() is None:
|
||||||
|
print('ERROR: Rename request timed out')
|
||||||
|
return False
|
||||||
|
|
||||||
|
resp = future.result()
|
||||||
|
if resp.success:
|
||||||
|
print(f'OK: {resp.message}')
|
||||||
|
else:
|
||||||
|
print(f'FAILED: {resp.message}')
|
||||||
|
return resp.success
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='saltybot person enrollment CLI'
|
||||||
|
)
|
||||||
|
subparsers = parser.add_subparsers(dest='command', required=True)
|
||||||
|
|
||||||
|
# enroll
|
||||||
|
enroll_p = subparsers.add_parser('enroll', help='Enroll a new person')
|
||||||
|
enroll_p.add_argument('--name', required=True, help='Person name')
|
||||||
|
enroll_p.add_argument('--samples', type=int, default=10,
|
||||||
|
help='Number of face samples (default: 10)')
|
||||||
|
enroll_p.add_argument('--mode', default='face',
|
||||||
|
help='Enrollment mode (default: face)')
|
||||||
|
|
||||||
|
# list
|
||||||
|
subparsers.add_parser('list', help='List enrolled persons')
|
||||||
|
|
||||||
|
# delete
|
||||||
|
delete_p = subparsers.add_parser('delete', help='Delete a person')
|
||||||
|
delete_p.add_argument('--id', type=int, required=True,
|
||||||
|
help='Person ID to delete')
|
||||||
|
|
||||||
|
# rename
|
||||||
|
rename_p = subparsers.add_parser('rename', help='Rename a person')
|
||||||
|
rename_p.add_argument('--id', type=int, required=True,
|
||||||
|
help='Person ID to rename')
|
||||||
|
rename_p.add_argument('--name', required=True, help='New name')
|
||||||
|
|
||||||
|
parsed = parser.parse_args(sys.argv[1:])
|
||||||
|
|
||||||
|
rclpy.init()
|
||||||
|
cli = EnrollmentCLI()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if parsed.command == 'enroll':
|
||||||
|
cli.enroll(parsed.name, parsed.samples, parsed.mode)
|
||||||
|
elif parsed.command == 'list':
|
||||||
|
cli.list_persons()
|
||||||
|
elif parsed.command == 'delete':
|
||||||
|
cli.delete(parsed.id)
|
||||||
|
elif parsed.command == 'rename':
|
||||||
|
cli.rename(parsed.id, parsed.name)
|
||||||
|
finally:
|
||||||
|
cli.destroy_node()
|
||||||
|
rclpy.try_shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -0,0 +1,302 @@
|
|||||||
|
"""enrollment_node.py -- ROS2 person enrollment node for saltybot social.
|
||||||
|
|
||||||
|
Coordinates person enrollment:
|
||||||
|
- Forwards /social/enroll to face_recognizer's service
|
||||||
|
- Owns persistent SQLite gallery (PersonDB)
|
||||||
|
- Voice-triggered enrollment via "remember me my name is X"
|
||||||
|
- Gallery management services (list/delete/update)
|
||||||
|
- Syncs DB from /social/faces/embeddings topic
|
||||||
|
"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile, ReliabilityPolicy, DurabilityPolicy
|
||||||
|
from std_msgs.msg import String
|
||||||
|
|
||||||
|
from saltybot_social_msgs.msg import (
|
||||||
|
FaceDetectionArray,
|
||||||
|
FaceEmbedding,
|
||||||
|
FaceEmbeddingArray,
|
||||||
|
)
|
||||||
|
from saltybot_social_msgs.srv import (
|
||||||
|
EnrollPerson,
|
||||||
|
ListPersons,
|
||||||
|
DeletePerson,
|
||||||
|
UpdatePerson,
|
||||||
|
)
|
||||||
|
|
||||||
|
from saltybot_social_enrollment.person_db import PersonDB
|
||||||
|
|
||||||
|
|
||||||
|
class EnrollmentNode(Node):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('enrollment_node')
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
self.declare_parameter('db_path', '/mnt/nvme/saltybot/gallery/persons.db')
|
||||||
|
self.declare_parameter('voice_samples_dir', '/mnt/nvme/saltybot/gallery/voice')
|
||||||
|
self.declare_parameter('auto_enroll_phrase', 'remember me my name is')
|
||||||
|
self.declare_parameter('n_samples_default', 10)
|
||||||
|
|
||||||
|
db_path = self.get_parameter('db_path').value
|
||||||
|
self._voice_dir = self.get_parameter('voice_samples_dir').value
|
||||||
|
self._phrase = self.get_parameter('auto_enroll_phrase').value
|
||||||
|
self._n_samples = self.get_parameter('n_samples_default').value
|
||||||
|
|
||||||
|
self._db = PersonDB(db_path)
|
||||||
|
self.get_logger().info(f'PersonDB initialized at {db_path}')
|
||||||
|
|
||||||
|
# Client to face_recognizer's enroll service
|
||||||
|
self._enroll_client = self.create_client(
|
||||||
|
EnrollPerson, '/social/face_recognizer/enroll'
|
||||||
|
)
|
||||||
|
|
||||||
|
# QoS profiles
|
||||||
|
best_effort_qos = QoSProfile(
|
||||||
|
depth=10,
|
||||||
|
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
durability=DurabilityPolicy.VOLATILE,
|
||||||
|
)
|
||||||
|
reliable_qos = QoSProfile(
|
||||||
|
depth=1,
|
||||||
|
reliability=ReliabilityPolicy.RELIABLE,
|
||||||
|
durability=DurabilityPolicy.VOLATILE,
|
||||||
|
)
|
||||||
|
status_qos = QoSProfile(
|
||||||
|
depth=1,
|
||||||
|
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
durability=DurabilityPolicy.VOLATILE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Subscriptions
|
||||||
|
self.create_subscription(
|
||||||
|
FaceDetectionArray, '/social/faces/detections',
|
||||||
|
self._on_detections, best_effort_qos
|
||||||
|
)
|
||||||
|
self.create_subscription(
|
||||||
|
FaceEmbeddingArray, '/social/faces/embeddings',
|
||||||
|
self._on_embeddings, reliable_qos
|
||||||
|
)
|
||||||
|
self.create_subscription(
|
||||||
|
String, '/social/speech/transcript',
|
||||||
|
self._on_transcript, best_effort_qos
|
||||||
|
)
|
||||||
|
self.create_subscription(
|
||||||
|
String, '/social/speech/command',
|
||||||
|
self._on_command, best_effort_qos
|
||||||
|
)
|
||||||
|
|
||||||
|
# Services
|
||||||
|
self.create_service(EnrollPerson, '/social/enroll', self._handle_enroll)
|
||||||
|
self.create_service(ListPersons, '/social/persons/list', self._handle_list)
|
||||||
|
self.create_service(DeletePerson, '/social/persons/delete', self._handle_delete)
|
||||||
|
self.create_service(UpdatePerson, '/social/persons/update', self._handle_update)
|
||||||
|
|
||||||
|
# Publishers
|
||||||
|
self._pub_embeddings = self.create_publisher(
|
||||||
|
FaceEmbeddingArray, '/social/faces/embeddings', reliable_qos
|
||||||
|
)
|
||||||
|
self._pub_status = self.create_publisher(
|
||||||
|
String, '/social/enrollment/status', status_qos
|
||||||
|
)
|
||||||
|
|
||||||
|
self.get_logger().info('EnrollmentNode ready')
|
||||||
|
|
||||||
|
# ---- Voice-triggered enrollment ----
|
||||||
|
|
||||||
|
def _on_transcript(self, msg: String):
|
||||||
|
text = msg.data.lower()
|
||||||
|
phrase = self._phrase.lower()
|
||||||
|
if phrase in text:
|
||||||
|
idx = text.index(phrase) + len(phrase)
|
||||||
|
name = text[idx:].strip()
|
||||||
|
# Clean up: take first 3 words max as the name
|
||||||
|
words = name.split()
|
||||||
|
if words:
|
||||||
|
name = ' '.join(words[:3]).title()
|
||||||
|
self.get_logger().info(f'Voice enrollment triggered: "{name}"')
|
||||||
|
self._trigger_voice_enroll(name)
|
||||||
|
|
||||||
|
def _on_command(self, msg: String):
|
||||||
|
# Reserved for explicit voice commands (e.g., "enroll Alice")
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _trigger_voice_enroll(self, name: str):
|
||||||
|
if not self._enroll_client.wait_for_service(timeout_sec=1.0):
|
||||||
|
self.get_logger().warn(
|
||||||
|
'face_recognizer enroll service not available'
|
||||||
|
)
|
||||||
|
self._publish_status(f'Enrollment failed: face_recognizer unavailable')
|
||||||
|
return
|
||||||
|
|
||||||
|
req = EnrollPerson.Request()
|
||||||
|
req.name = name
|
||||||
|
req.mode = 'face'
|
||||||
|
req.n_samples = self._n_samples
|
||||||
|
future = self._enroll_client.call_async(req)
|
||||||
|
future.add_done_callback(
|
||||||
|
lambda f: self._on_enroll_done(f, name)
|
||||||
|
)
|
||||||
|
self._publish_status(f'Enrolling "{name}"... look at the camera')
|
||||||
|
|
||||||
|
def _on_enroll_done(self, future, name: str):
|
||||||
|
try:
|
||||||
|
resp = future.result()
|
||||||
|
if resp.success:
|
||||||
|
status = f'Enrolled "{name}" (id={resp.person_id})'
|
||||||
|
self.get_logger().info(status)
|
||||||
|
else:
|
||||||
|
status = f'Enrollment failed for "{name}": {resp.message}'
|
||||||
|
self.get_logger().warn(status)
|
||||||
|
self._publish_status(status)
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f'Enroll call failed: {e}')
|
||||||
|
self._publish_status(f'Enrollment error: {e}')
|
||||||
|
|
||||||
|
# ---- Face detection callback (during enrollment) ----
|
||||||
|
|
||||||
|
def _on_detections(self, msg: FaceDetectionArray):
|
||||||
|
# Reserved for future direct enrollment (without face_recognizer)
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ---- Embeddings sync from face_recognizer ----
|
||||||
|
|
||||||
|
def _on_embeddings(self, msg: FaceEmbeddingArray):
|
||||||
|
for emb in msg.embeddings:
|
||||||
|
existing = self._db.get_person(emb.person_id)
|
||||||
|
if existing is None:
|
||||||
|
arr = np.array(emb.embedding, dtype=np.float32)
|
||||||
|
self._db.add_person(
|
||||||
|
emb.person_name, arr, emb.sample_count
|
||||||
|
)
|
||||||
|
self.get_logger().info(
|
||||||
|
f'Synced new person from face_recognizer: '
|
||||||
|
f'{emb.person_name} (id={emb.person_id})'
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- Service handlers ----
|
||||||
|
|
||||||
|
def _handle_enroll(self, request, response):
|
||||||
|
"""Forward enrollment to face_recognizer service."""
|
||||||
|
if not self._enroll_client.wait_for_service(timeout_sec=2.0):
|
||||||
|
response.success = False
|
||||||
|
response.message = 'face_recognizer service unavailable'
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Use threading.Event to bridge async call in service callback
|
||||||
|
event = threading.Event()
|
||||||
|
result_holder = {}
|
||||||
|
|
||||||
|
req = EnrollPerson.Request()
|
||||||
|
req.name = request.name
|
||||||
|
req.mode = request.mode
|
||||||
|
req.n_samples = request.n_samples
|
||||||
|
|
||||||
|
future = self._enroll_client.call_async(req)
|
||||||
|
|
||||||
|
def _done(f):
|
||||||
|
try:
|
||||||
|
result_holder['resp'] = f.result()
|
||||||
|
except Exception as e:
|
||||||
|
result_holder['err'] = str(e)
|
||||||
|
event.set()
|
||||||
|
|
||||||
|
future.add_done_callback(_done)
|
||||||
|
success = event.wait(timeout=35.0)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
response.success = False
|
||||||
|
response.message = 'Enrollment timed out'
|
||||||
|
elif 'resp' in result_holder:
|
||||||
|
resp = result_holder['resp']
|
||||||
|
response.success = resp.success
|
||||||
|
response.message = resp.message
|
||||||
|
response.person_id = resp.person_id
|
||||||
|
else:
|
||||||
|
response.success = False
|
||||||
|
response.message = result_holder.get('err', 'Unknown error')
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _handle_list(self, request, response):
|
||||||
|
persons = self._db.get_all()
|
||||||
|
response.persons = []
|
||||||
|
for p in persons:
|
||||||
|
fe = FaceEmbedding()
|
||||||
|
fe.person_id = p['id']
|
||||||
|
fe.person_name = p['name']
|
||||||
|
fe.sample_count = p['sample_count']
|
||||||
|
fe.enrolled_at.sec = int(p['enrolled_at'])
|
||||||
|
fe.enrolled_at.nanosec = int(
|
||||||
|
(p['enrolled_at'] - int(p['enrolled_at'])) * 1e9
|
||||||
|
)
|
||||||
|
if p['embedding'] is not None:
|
||||||
|
fe.embedding = p['embedding'].tolist()
|
||||||
|
response.persons.append(fe)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _handle_delete(self, request, response):
|
||||||
|
success = self._db.delete_person(request.person_id)
|
||||||
|
response.success = success
|
||||||
|
if success:
|
||||||
|
response.message = f'Deleted person {request.person_id}'
|
||||||
|
self.get_logger().info(response.message)
|
||||||
|
self._publish_embeddings_from_db()
|
||||||
|
else:
|
||||||
|
response.message = f'Person {request.person_id} not found'
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _handle_update(self, request, response):
|
||||||
|
success = self._db.update_name(request.person_id, request.new_name)
|
||||||
|
response.success = success
|
||||||
|
if success:
|
||||||
|
response.message = f'Updated person {request.person_id} name to "{request.new_name}"'
|
||||||
|
self.get_logger().info(response.message)
|
||||||
|
self._publish_embeddings_from_db()
|
||||||
|
else:
|
||||||
|
response.message = f'Person {request.person_id} not found'
|
||||||
|
return response
|
||||||
|
|
||||||
|
# ---- Helpers ----
|
||||||
|
|
||||||
|
def _publish_status(self, text: str):
|
||||||
|
msg = String()
|
||||||
|
msg.data = text
|
||||||
|
self._pub_status.publish(msg)
|
||||||
|
|
||||||
|
def _publish_embeddings_from_db(self):
|
||||||
|
persons = self._db.get_all()
|
||||||
|
arr = FaceEmbeddingArray()
|
||||||
|
arr.header.stamp = self.get_clock().now().to_msg()
|
||||||
|
for p in persons:
|
||||||
|
if p['embedding'] is not None:
|
||||||
|
fe = FaceEmbedding()
|
||||||
|
fe.person_id = p['id']
|
||||||
|
fe.person_name = p['name']
|
||||||
|
fe.sample_count = p['sample_count']
|
||||||
|
fe.enrolled_at.sec = int(p['enrolled_at'])
|
||||||
|
fe.enrolled_at.nanosec = int(
|
||||||
|
(p['enrolled_at'] - int(p['enrolled_at'])) * 1e9
|
||||||
|
)
|
||||||
|
fe.embedding = p['embedding'].tolist()
|
||||||
|
arr.embeddings.append(fe)
|
||||||
|
self._pub_embeddings.publish(arr)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = EnrollmentNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.try_shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -0,0 +1,138 @@
|
|||||||
|
"""person_db.py -- Persistent SQLite person gallery for saltybot enrollment."""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class PersonDB:
|
||||||
|
"""Thread-safe SQLite-backed person gallery.
|
||||||
|
|
||||||
|
Schema:
|
||||||
|
persons(id INTEGER PRIMARY KEY, name TEXT, enrolled_at REAL,
|
||||||
|
sample_count INTEGER, embedding BLOB, metadata TEXT)
|
||||||
|
voice_samples(id INTEGER PRIMARY KEY, person_id INTEGER REFERENCES persons,
|
||||||
|
recorded_at REAL, sample_path TEXT)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str):
|
||||||
|
self._db_path = db_path
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._init_db()
|
||||||
|
|
||||||
|
def _init_db(self):
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS persons (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
enrolled_at REAL NOT NULL,
|
||||||
|
sample_count INTEGER DEFAULT 1,
|
||||||
|
embedding BLOB,
|
||||||
|
metadata TEXT DEFAULT '{}'
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS voice_samples (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
person_id INTEGER REFERENCES persons(id),
|
||||||
|
recorded_at REAL NOT NULL,
|
||||||
|
sample_path TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
def _connect(self):
|
||||||
|
return sqlite3.connect(self._db_path)
|
||||||
|
|
||||||
|
def add_person(self, name: str, embedding: np.ndarray, sample_count: int = 1,
|
||||||
|
metadata: dict = None) -> int:
|
||||||
|
"""Add a new person. Returns new person_id."""
|
||||||
|
emb_blob = embedding.astype(np.float32).tobytes() if embedding is not None else None
|
||||||
|
now = time.time()
|
||||||
|
with self._lock:
|
||||||
|
with self._connect() as conn:
|
||||||
|
cur = conn.execute(
|
||||||
|
"INSERT INTO persons (name, enrolled_at, sample_count, embedding, metadata) "
|
||||||
|
"VALUES (?, ?, ?, ?, ?)",
|
||||||
|
(name, now, sample_count, emb_blob, json.dumps(metadata or {}))
|
||||||
|
)
|
||||||
|
return cur.lastrowid
|
||||||
|
|
||||||
|
def update_embedding(self, person_id: int, embedding: np.ndarray,
|
||||||
|
sample_count: int) -> bool:
|
||||||
|
emb_blob = embedding.astype(np.float32).tobytes()
|
||||||
|
with self._lock:
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE persons SET embedding=?, sample_count=? WHERE id=?",
|
||||||
|
(emb_blob, sample_count, person_id)
|
||||||
|
)
|
||||||
|
return conn.total_changes > 0
|
||||||
|
|
||||||
|
def update_name(self, person_id: int, new_name: str) -> bool:
|
||||||
|
with self._lock:
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE persons SET name=? WHERE id=?",
|
||||||
|
(new_name, person_id)
|
||||||
|
)
|
||||||
|
return conn.total_changes > 0
|
||||||
|
|
||||||
|
def delete_person(self, person_id: int) -> bool:
|
||||||
|
with self._lock:
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"DELETE FROM voice_samples WHERE person_id=?",
|
||||||
|
(person_id,)
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"DELETE FROM persons WHERE id=?",
|
||||||
|
(person_id,)
|
||||||
|
)
|
||||||
|
return conn.total_changes > 0
|
||||||
|
|
||||||
|
def get_all(self) -> list:
|
||||||
|
"""Returns list of dicts with id, name, enrolled_at, sample_count, embedding."""
|
||||||
|
with self._lock:
|
||||||
|
with self._connect() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT id, name, enrolled_at, sample_count, embedding FROM persons"
|
||||||
|
).fetchall()
|
||||||
|
result = []
|
||||||
|
for row in rows:
|
||||||
|
emb = np.frombuffer(row[4], dtype=np.float32).copy() if row[4] else None
|
||||||
|
result.append({
|
||||||
|
'id': row[0], 'name': row[1], 'enrolled_at': row[2],
|
||||||
|
'sample_count': row[3], 'embedding': emb
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_person(self, person_id: int) -> dict | None:
|
||||||
|
with self._lock:
|
||||||
|
with self._connect() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT id, name, enrolled_at, sample_count, embedding "
|
||||||
|
"FROM persons WHERE id=?",
|
||||||
|
(person_id,)
|
||||||
|
).fetchone()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
emb = np.frombuffer(row[4], dtype=np.float32).copy() if row[4] else None
|
||||||
|
return {
|
||||||
|
'id': row[0], 'name': row[1], 'enrolled_at': row[2],
|
||||||
|
'sample_count': row[3], 'embedding': emb
|
||||||
|
}
|
||||||
|
|
||||||
|
def add_voice_sample(self, person_id: int, sample_path: str) -> int:
|
||||||
|
with self._lock:
|
||||||
|
with self._connect() as conn:
|
||||||
|
cur = conn.execute(
|
||||||
|
"INSERT INTO voice_samples (person_id, recorded_at, sample_path) "
|
||||||
|
"VALUES (?, ?, ?)",
|
||||||
|
(person_id, time.time(), sample_path)
|
||||||
|
)
|
||||||
|
return cur.lastrowid
|
||||||
4
jetson/ros2_ws/src/saltybot_social_enrollment/setup.cfg
Normal file
4
jetson/ros2_ws/src/saltybot_social_enrollment/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
[develop]
|
||||||
|
script_dir=$base/lib/saltybot_social_enrollment
|
||||||
|
[install]
|
||||||
|
install_scripts=$base/lib/saltybot_social_enrollment
|
||||||
29
jetson/ros2_ws/src/saltybot_social_enrollment/setup.py
Normal file
29
jetson/ros2_ws/src/saltybot_social_enrollment/setup.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from setuptools import setup
|
||||||
|
import os
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
package_name = 'saltybot_social_enrollment'
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name=package_name,
|
||||||
|
version='0.1.0',
|
||||||
|
packages=[package_name],
|
||||||
|
data_files=[
|
||||||
|
('share/ament_index/resource_index/packages', ['resource/' + package_name]),
|
||||||
|
('share/' + package_name, ['package.xml']),
|
||||||
|
(os.path.join('share', package_name, 'launch'), glob('launch/*.py')),
|
||||||
|
(os.path.join('share', package_name, 'config'), glob('config/*.yaml')),
|
||||||
|
],
|
||||||
|
install_requires=['setuptools'],
|
||||||
|
zip_safe=True,
|
||||||
|
maintainer='seb',
|
||||||
|
maintainer_email='seb@vayrette.com',
|
||||||
|
description='Person enrollment system for saltybot social interaction',
|
||||||
|
license='MIT',
|
||||||
|
entry_points={
|
||||||
|
'console_scripts': [
|
||||||
|
'enrollment_node = saltybot_social_enrollment.enrollment_node:main',
|
||||||
|
'enrollment_cli = saltybot_social_enrollment.enrollment_cli:main',
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
@ -0,0 +1,4 @@
|
|||||||
|
sensor_msgs/Image crop
|
||||||
|
---
|
||||||
|
bool success
|
||||||
|
float32[512] embedding
|
||||||
Loading…
x
Reference in New Issue
Block a user