feat(social): person enrollment system #87 #95
@ -0,0 +1,6 @@
|
||||
enrollment_node:
|
||||
ros__parameters:
|
||||
db_path: '/mnt/nvme/saltybot/gallery/persons.db'
|
||||
voice_samples_dir: '/mnt/nvme/saltybot/gallery/voice'
|
||||
auto_enroll_phrase: 'remember me my name is'
|
||||
n_samples_default: 10
|
||||
@ -0,0 +1,22 @@
|
||||
"""Launch file for saltybot social enrollment node."""
|
||||
|
||||
import os
|
||||
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
from launch import LaunchDescription
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg_share = get_package_share_directory('saltybot_social_enrollment')
|
||||
config_file = os.path.join(pkg_share, 'config', 'enrollment_params.yaml')
|
||||
|
||||
return LaunchDescription([
|
||||
Node(
|
||||
package='saltybot_social_enrollment',
|
||||
executable='enrollment_node',
|
||||
name='enrollment_node',
|
||||
parameters=[config_file],
|
||||
output='screen',
|
||||
),
|
||||
])
|
||||
23
jetson/ros2_ws/src/saltybot_social_enrollment/package.xml
Normal file
23
jetson/ros2_ws/src/saltybot_social_enrollment/package.xml
Normal file
@ -0,0 +1,23 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_social_enrollment</name>
|
||||
<version>0.1.0</version>
|
||||
<description>Person enrollment system for saltybot social interaction</description>
|
||||
<maintainer email="seb@vayrette.com">seb</maintainer>
|
||||
<license>MIT</license>
|
||||
|
||||
<buildtool_depend>ament_python</buildtool_depend>
|
||||
|
||||
<depend>rclpy</depend>
|
||||
<depend>sensor_msgs</depend>
|
||||
<depend>cv_bridge</depend>
|
||||
<depend>std_msgs</depend>
|
||||
<depend>saltybot_social_msgs</depend>
|
||||
|
||||
<exec_depend>python3-numpy</exec_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1,184 @@
|
||||
#!/usr/bin/env python3
|
||||
"""enrollment_cli.py -- Gallery management CLI for saltybot social.
|
||||
|
||||
Usage:
|
||||
ros2 run saltybot_social_enrollment enrollment_cli enroll --name "Alice" [--samples 15] [--mode face]
|
||||
ros2 run saltybot_social_enrollment enrollment_cli list
|
||||
ros2 run saltybot_social_enrollment enrollment_cli delete --id 3
|
||||
ros2 run saltybot_social_enrollment enrollment_cli rename --id 2 --name "Bob"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
|
||||
from saltybot_social_msgs.srv import (
|
||||
EnrollPerson,
|
||||
ListPersons,
|
||||
DeletePerson,
|
||||
UpdatePerson,
|
||||
)
|
||||
|
||||
|
||||
class EnrollmentCLI(Node):
|
||||
def __init__(self):
|
||||
super().__init__('enrollment_cli')
|
||||
self._enroll_client = self.create_client(
|
||||
EnrollPerson, '/social/enroll'
|
||||
)
|
||||
self._list_client = self.create_client(
|
||||
ListPersons, '/social/persons/list'
|
||||
)
|
||||
self._delete_client = self.create_client(
|
||||
DeletePerson, '/social/persons/delete'
|
||||
)
|
||||
self._update_client = self.create_client(
|
||||
UpdatePerson, '/social/persons/update'
|
||||
)
|
||||
|
||||
def enroll(self, name: str, n_samples: int = 10, mode: str = 'face'):
|
||||
if not self._enroll_client.wait_for_service(timeout_sec=5.0):
|
||||
print('ERROR: /social/enroll service not available')
|
||||
return False
|
||||
|
||||
req = EnrollPerson.Request()
|
||||
req.name = name
|
||||
req.mode = mode
|
||||
req.n_samples = n_samples
|
||||
|
||||
print(f'Enrolling "{name}" ({mode}, {n_samples} samples)...')
|
||||
future = self._enroll_client.call_async(req)
|
||||
rclpy.spin_until_future_complete(self, future, timeout_sec=60.0)
|
||||
|
||||
if future.result() is None:
|
||||
print('ERROR: Enrollment timed out')
|
||||
return False
|
||||
|
||||
resp = future.result()
|
||||
if resp.success:
|
||||
print(f'OK: Enrolled "{name}" as person_id={resp.person_id}')
|
||||
else:
|
||||
print(f'FAILED: {resp.message}')
|
||||
return resp.success
|
||||
|
||||
def list_persons(self):
|
||||
if not self._list_client.wait_for_service(timeout_sec=5.0):
|
||||
print('ERROR: /social/persons/list service not available')
|
||||
return
|
||||
|
||||
req = ListPersons.Request()
|
||||
future = self._list_client.call_async(req)
|
||||
rclpy.spin_until_future_complete(self, future, timeout_sec=10.0)
|
||||
|
||||
if future.result() is None:
|
||||
print('ERROR: List request timed out')
|
||||
return
|
||||
|
||||
resp = future.result()
|
||||
if not resp.persons:
|
||||
print('Gallery is empty.')
|
||||
return
|
||||
|
||||
print(f'{"ID":>4} {"Name":<20} {"Samples":>7} {"Embedding Dim":>13}')
|
||||
print('-' * 50)
|
||||
for p in resp.persons:
|
||||
dim = len(p.embedding) if p.embedding else 0
|
||||
print(f'{p.person_id:>4} {p.person_name:<20} {p.sample_count:>7} {dim:>13}')
|
||||
|
||||
def delete(self, person_id: int):
|
||||
if not self._delete_client.wait_for_service(timeout_sec=5.0):
|
||||
print('ERROR: /social/persons/delete service not available')
|
||||
return False
|
||||
|
||||
req = DeletePerson.Request()
|
||||
req.person_id = person_id
|
||||
|
||||
future = self._delete_client.call_async(req)
|
||||
rclpy.spin_until_future_complete(self, future, timeout_sec=10.0)
|
||||
|
||||
if future.result() is None:
|
||||
print('ERROR: Delete request timed out')
|
||||
return False
|
||||
|
||||
resp = future.result()
|
||||
if resp.success:
|
||||
print(f'OK: {resp.message}')
|
||||
else:
|
||||
print(f'FAILED: {resp.message}')
|
||||
return resp.success
|
||||
|
||||
def rename(self, person_id: int, new_name: str):
|
||||
if not self._update_client.wait_for_service(timeout_sec=5.0):
|
||||
print('ERROR: /social/persons/update service not available')
|
||||
return False
|
||||
|
||||
req = UpdatePerson.Request()
|
||||
req.person_id = person_id
|
||||
req.new_name = new_name
|
||||
|
||||
future = self._update_client.call_async(req)
|
||||
rclpy.spin_until_future_complete(self, future, timeout_sec=10.0)
|
||||
|
||||
if future.result() is None:
|
||||
print('ERROR: Rename request timed out')
|
||||
return False
|
||||
|
||||
resp = future.result()
|
||||
if resp.success:
|
||||
print(f'OK: {resp.message}')
|
||||
else:
|
||||
print(f'FAILED: {resp.message}')
|
||||
return resp.success
|
||||
|
||||
|
||||
def main(args=None):
|
||||
parser = argparse.ArgumentParser(
|
||||
description='saltybot person enrollment CLI'
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest='command', required=True)
|
||||
|
||||
# enroll
|
||||
enroll_p = subparsers.add_parser('enroll', help='Enroll a new person')
|
||||
enroll_p.add_argument('--name', required=True, help='Person name')
|
||||
enroll_p.add_argument('--samples', type=int, default=10,
|
||||
help='Number of face samples (default: 10)')
|
||||
enroll_p.add_argument('--mode', default='face',
|
||||
help='Enrollment mode (default: face)')
|
||||
|
||||
# list
|
||||
subparsers.add_parser('list', help='List enrolled persons')
|
||||
|
||||
# delete
|
||||
delete_p = subparsers.add_parser('delete', help='Delete a person')
|
||||
delete_p.add_argument('--id', type=int, required=True,
|
||||
help='Person ID to delete')
|
||||
|
||||
# rename
|
||||
rename_p = subparsers.add_parser('rename', help='Rename a person')
|
||||
rename_p.add_argument('--id', type=int, required=True,
|
||||
help='Person ID to rename')
|
||||
rename_p.add_argument('--name', required=True, help='New name')
|
||||
|
||||
parsed = parser.parse_args(sys.argv[1:])
|
||||
|
||||
rclpy.init()
|
||||
cli = EnrollmentCLI()
|
||||
|
||||
try:
|
||||
if parsed.command == 'enroll':
|
||||
cli.enroll(parsed.name, parsed.samples, parsed.mode)
|
||||
elif parsed.command == 'list':
|
||||
cli.list_persons()
|
||||
elif parsed.command == 'delete':
|
||||
cli.delete(parsed.id)
|
||||
elif parsed.command == 'rename':
|
||||
cli.rename(parsed.id, parsed.name)
|
||||
finally:
|
||||
cli.destroy_node()
|
||||
rclpy.try_shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -0,0 +1,302 @@
|
||||
"""enrollment_node.py -- ROS2 person enrollment node for saltybot social.
|
||||
|
||||
Coordinates person enrollment:
|
||||
- Forwards /social/enroll to face_recognizer's service
|
||||
- Owns persistent SQLite gallery (PersonDB)
|
||||
- Voice-triggered enrollment via "remember me my name is X"
|
||||
- Gallery management services (list/delete/update)
|
||||
- Syncs DB from /social/faces/embeddings topic
|
||||
"""
|
||||
|
||||
import threading
|
||||
import numpy as np
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile, ReliabilityPolicy, DurabilityPolicy
|
||||
from std_msgs.msg import String
|
||||
|
||||
from saltybot_social_msgs.msg import (
|
||||
FaceDetectionArray,
|
||||
FaceEmbedding,
|
||||
FaceEmbeddingArray,
|
||||
)
|
||||
from saltybot_social_msgs.srv import (
|
||||
EnrollPerson,
|
||||
ListPersons,
|
||||
DeletePerson,
|
||||
UpdatePerson,
|
||||
)
|
||||
|
||||
from saltybot_social_enrollment.person_db import PersonDB
|
||||
|
||||
|
||||
class EnrollmentNode(Node):
|
||||
def __init__(self):
|
||||
super().__init__('enrollment_node')
|
||||
|
||||
# Parameters
|
||||
self.declare_parameter('db_path', '/mnt/nvme/saltybot/gallery/persons.db')
|
||||
self.declare_parameter('voice_samples_dir', '/mnt/nvme/saltybot/gallery/voice')
|
||||
self.declare_parameter('auto_enroll_phrase', 'remember me my name is')
|
||||
self.declare_parameter('n_samples_default', 10)
|
||||
|
||||
db_path = self.get_parameter('db_path').value
|
||||
self._voice_dir = self.get_parameter('voice_samples_dir').value
|
||||
self._phrase = self.get_parameter('auto_enroll_phrase').value
|
||||
self._n_samples = self.get_parameter('n_samples_default').value
|
||||
|
||||
self._db = PersonDB(db_path)
|
||||
self.get_logger().info(f'PersonDB initialized at {db_path}')
|
||||
|
||||
# Client to face_recognizer's enroll service
|
||||
self._enroll_client = self.create_client(
|
||||
EnrollPerson, '/social/face_recognizer/enroll'
|
||||
)
|
||||
|
||||
# QoS profiles
|
||||
best_effort_qos = QoSProfile(
|
||||
depth=10,
|
||||
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||
durability=DurabilityPolicy.VOLATILE,
|
||||
)
|
||||
reliable_qos = QoSProfile(
|
||||
depth=1,
|
||||
reliability=ReliabilityPolicy.RELIABLE,
|
||||
durability=DurabilityPolicy.VOLATILE,
|
||||
)
|
||||
status_qos = QoSProfile(
|
||||
depth=1,
|
||||
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||
durability=DurabilityPolicy.VOLATILE,
|
||||
)
|
||||
|
||||
# Subscriptions
|
||||
self.create_subscription(
|
||||
FaceDetectionArray, '/social/faces/detections',
|
||||
self._on_detections, best_effort_qos
|
||||
)
|
||||
self.create_subscription(
|
||||
FaceEmbeddingArray, '/social/faces/embeddings',
|
||||
self._on_embeddings, reliable_qos
|
||||
)
|
||||
self.create_subscription(
|
||||
String, '/social/speech/transcript',
|
||||
self._on_transcript, best_effort_qos
|
||||
)
|
||||
self.create_subscription(
|
||||
String, '/social/speech/command',
|
||||
self._on_command, best_effort_qos
|
||||
)
|
||||
|
||||
# Services
|
||||
self.create_service(EnrollPerson, '/social/enroll', self._handle_enroll)
|
||||
self.create_service(ListPersons, '/social/persons/list', self._handle_list)
|
||||
self.create_service(DeletePerson, '/social/persons/delete', self._handle_delete)
|
||||
self.create_service(UpdatePerson, '/social/persons/update', self._handle_update)
|
||||
|
||||
# Publishers
|
||||
self._pub_embeddings = self.create_publisher(
|
||||
FaceEmbeddingArray, '/social/faces/embeddings', reliable_qos
|
||||
)
|
||||
self._pub_status = self.create_publisher(
|
||||
String, '/social/enrollment/status', status_qos
|
||||
)
|
||||
|
||||
self.get_logger().info('EnrollmentNode ready')
|
||||
|
||||
# ---- Voice-triggered enrollment ----
|
||||
|
||||
def _on_transcript(self, msg: String):
|
||||
text = msg.data.lower()
|
||||
phrase = self._phrase.lower()
|
||||
if phrase in text:
|
||||
idx = text.index(phrase) + len(phrase)
|
||||
name = text[idx:].strip()
|
||||
# Clean up: take first 3 words max as the name
|
||||
words = name.split()
|
||||
if words:
|
||||
name = ' '.join(words[:3]).title()
|
||||
self.get_logger().info(f'Voice enrollment triggered: "{name}"')
|
||||
self._trigger_voice_enroll(name)
|
||||
|
||||
def _on_command(self, msg: String):
|
||||
# Reserved for explicit voice commands (e.g., "enroll Alice")
|
||||
pass
|
||||
|
||||
def _trigger_voice_enroll(self, name: str):
|
||||
if not self._enroll_client.wait_for_service(timeout_sec=1.0):
|
||||
self.get_logger().warn(
|
||||
'face_recognizer enroll service not available'
|
||||
)
|
||||
self._publish_status(f'Enrollment failed: face_recognizer unavailable')
|
||||
return
|
||||
|
||||
req = EnrollPerson.Request()
|
||||
req.name = name
|
||||
req.mode = 'face'
|
||||
req.n_samples = self._n_samples
|
||||
future = self._enroll_client.call_async(req)
|
||||
future.add_done_callback(
|
||||
lambda f: self._on_enroll_done(f, name)
|
||||
)
|
||||
self._publish_status(f'Enrolling "{name}"... look at the camera')
|
||||
|
||||
def _on_enroll_done(self, future, name: str):
|
||||
try:
|
||||
resp = future.result()
|
||||
if resp.success:
|
||||
status = f'Enrolled "{name}" (id={resp.person_id})'
|
||||
self.get_logger().info(status)
|
||||
else:
|
||||
status = f'Enrollment failed for "{name}": {resp.message}'
|
||||
self.get_logger().warn(status)
|
||||
self._publish_status(status)
|
||||
except Exception as e:
|
||||
self.get_logger().error(f'Enroll call failed: {e}')
|
||||
self._publish_status(f'Enrollment error: {e}')
|
||||
|
||||
# ---- Face detection callback (during enrollment) ----
|
||||
|
||||
def _on_detections(self, msg: FaceDetectionArray):
|
||||
# Reserved for future direct enrollment (without face_recognizer)
|
||||
pass
|
||||
|
||||
# ---- Embeddings sync from face_recognizer ----
|
||||
|
||||
def _on_embeddings(self, msg: FaceEmbeddingArray):
|
||||
for emb in msg.embeddings:
|
||||
existing = self._db.get_person(emb.person_id)
|
||||
if existing is None:
|
||||
arr = np.array(emb.embedding, dtype=np.float32)
|
||||
self._db.add_person(
|
||||
emb.person_name, arr, emb.sample_count
|
||||
)
|
||||
self.get_logger().info(
|
||||
f'Synced new person from face_recognizer: '
|
||||
f'{emb.person_name} (id={emb.person_id})'
|
||||
)
|
||||
|
||||
# ---- Service handlers ----
|
||||
|
||||
def _handle_enroll(self, request, response):
|
||||
"""Forward enrollment to face_recognizer service."""
|
||||
if not self._enroll_client.wait_for_service(timeout_sec=2.0):
|
||||
response.success = False
|
||||
response.message = 'face_recognizer service unavailable'
|
||||
return response
|
||||
|
||||
# Use threading.Event to bridge async call in service callback
|
||||
event = threading.Event()
|
||||
result_holder = {}
|
||||
|
||||
req = EnrollPerson.Request()
|
||||
req.name = request.name
|
||||
req.mode = request.mode
|
||||
req.n_samples = request.n_samples
|
||||
|
||||
future = self._enroll_client.call_async(req)
|
||||
|
||||
def _done(f):
|
||||
try:
|
||||
result_holder['resp'] = f.result()
|
||||
except Exception as e:
|
||||
result_holder['err'] = str(e)
|
||||
event.set()
|
||||
|
||||
future.add_done_callback(_done)
|
||||
success = event.wait(timeout=35.0)
|
||||
|
||||
if not success:
|
||||
response.success = False
|
||||
response.message = 'Enrollment timed out'
|
||||
elif 'resp' in result_holder:
|
||||
resp = result_holder['resp']
|
||||
response.success = resp.success
|
||||
response.message = resp.message
|
||||
response.person_id = resp.person_id
|
||||
else:
|
||||
response.success = False
|
||||
response.message = result_holder.get('err', 'Unknown error')
|
||||
|
||||
return response
|
||||
|
||||
def _handle_list(self, request, response):
|
||||
persons = self._db.get_all()
|
||||
response.persons = []
|
||||
for p in persons:
|
||||
fe = FaceEmbedding()
|
||||
fe.person_id = p['id']
|
||||
fe.person_name = p['name']
|
||||
fe.sample_count = p['sample_count']
|
||||
fe.enrolled_at.sec = int(p['enrolled_at'])
|
||||
fe.enrolled_at.nanosec = int(
|
||||
(p['enrolled_at'] - int(p['enrolled_at'])) * 1e9
|
||||
)
|
||||
if p['embedding'] is not None:
|
||||
fe.embedding = p['embedding'].tolist()
|
||||
response.persons.append(fe)
|
||||
return response
|
||||
|
||||
def _handle_delete(self, request, response):
|
||||
success = self._db.delete_person(request.person_id)
|
||||
response.success = success
|
||||
if success:
|
||||
response.message = f'Deleted person {request.person_id}'
|
||||
self.get_logger().info(response.message)
|
||||
self._publish_embeddings_from_db()
|
||||
else:
|
||||
response.message = f'Person {request.person_id} not found'
|
||||
return response
|
||||
|
||||
def _handle_update(self, request, response):
|
||||
success = self._db.update_name(request.person_id, request.new_name)
|
||||
response.success = success
|
||||
if success:
|
||||
response.message = f'Updated person {request.person_id} name to "{request.new_name}"'
|
||||
self.get_logger().info(response.message)
|
||||
self._publish_embeddings_from_db()
|
||||
else:
|
||||
response.message = f'Person {request.person_id} not found'
|
||||
return response
|
||||
|
||||
# ---- Helpers ----
|
||||
|
||||
def _publish_status(self, text: str):
|
||||
msg = String()
|
||||
msg.data = text
|
||||
self._pub_status.publish(msg)
|
||||
|
||||
def _publish_embeddings_from_db(self):
|
||||
persons = self._db.get_all()
|
||||
arr = FaceEmbeddingArray()
|
||||
arr.header.stamp = self.get_clock().now().to_msg()
|
||||
for p in persons:
|
||||
if p['embedding'] is not None:
|
||||
fe = FaceEmbedding()
|
||||
fe.person_id = p['id']
|
||||
fe.person_name = p['name']
|
||||
fe.sample_count = p['sample_count']
|
||||
fe.enrolled_at.sec = int(p['enrolled_at'])
|
||||
fe.enrolled_at.nanosec = int(
|
||||
(p['enrolled_at'] - int(p['enrolled_at'])) * 1e9
|
||||
)
|
||||
fe.embedding = p['embedding'].tolist()
|
||||
arr.embeddings.append(fe)
|
||||
self._pub_embeddings.publish(arr)
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = EnrollmentNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.try_shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -0,0 +1,138 @@
|
||||
"""person_db.py -- Persistent SQLite person gallery for saltybot enrollment."""
|
||||
|
||||
import sqlite3
|
||||
import json
|
||||
import time
|
||||
import numpy as np
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class PersonDB:
|
||||
"""Thread-safe SQLite-backed person gallery.
|
||||
|
||||
Schema:
|
||||
persons(id INTEGER PRIMARY KEY, name TEXT, enrolled_at REAL,
|
||||
sample_count INTEGER, embedding BLOB, metadata TEXT)
|
||||
voice_samples(id INTEGER PRIMARY KEY, person_id INTEGER REFERENCES persons,
|
||||
recorded_at REAL, sample_path TEXT)
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str):
|
||||
self._db_path = db_path
|
||||
self._lock = threading.Lock()
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
self._init_db()
|
||||
|
||||
def _init_db(self):
|
||||
with self._connect() as conn:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS persons (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
enrolled_at REAL NOT NULL,
|
||||
sample_count INTEGER DEFAULT 1,
|
||||
embedding BLOB,
|
||||
metadata TEXT DEFAULT '{}'
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS voice_samples (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
person_id INTEGER REFERENCES persons(id),
|
||||
recorded_at REAL NOT NULL,
|
||||
sample_path TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
def _connect(self):
|
||||
return sqlite3.connect(self._db_path)
|
||||
|
||||
def add_person(self, name: str, embedding: np.ndarray, sample_count: int = 1,
|
||||
metadata: dict = None) -> int:
|
||||
"""Add a new person. Returns new person_id."""
|
||||
emb_blob = embedding.astype(np.float32).tobytes() if embedding is not None else None
|
||||
now = time.time()
|
||||
with self._lock:
|
||||
with self._connect() as conn:
|
||||
cur = conn.execute(
|
||||
"INSERT INTO persons (name, enrolled_at, sample_count, embedding, metadata) "
|
||||
"VALUES (?, ?, ?, ?, ?)",
|
||||
(name, now, sample_count, emb_blob, json.dumps(metadata or {}))
|
||||
)
|
||||
return cur.lastrowid
|
||||
|
||||
def update_embedding(self, person_id: int, embedding: np.ndarray,
|
||||
sample_count: int) -> bool:
|
||||
emb_blob = embedding.astype(np.float32).tobytes()
|
||||
with self._lock:
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"UPDATE persons SET embedding=?, sample_count=? WHERE id=?",
|
||||
(emb_blob, sample_count, person_id)
|
||||
)
|
||||
return conn.total_changes > 0
|
||||
|
||||
def update_name(self, person_id: int, new_name: str) -> bool:
|
||||
with self._lock:
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"UPDATE persons SET name=? WHERE id=?",
|
||||
(new_name, person_id)
|
||||
)
|
||||
return conn.total_changes > 0
|
||||
|
||||
def delete_person(self, person_id: int) -> bool:
|
||||
with self._lock:
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"DELETE FROM voice_samples WHERE person_id=?",
|
||||
(person_id,)
|
||||
)
|
||||
conn.execute(
|
||||
"DELETE FROM persons WHERE id=?",
|
||||
(person_id,)
|
||||
)
|
||||
return conn.total_changes > 0
|
||||
|
||||
def get_all(self) -> list:
|
||||
"""Returns list of dicts with id, name, enrolled_at, sample_count, embedding."""
|
||||
with self._lock:
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT id, name, enrolled_at, sample_count, embedding FROM persons"
|
||||
).fetchall()
|
||||
result = []
|
||||
for row in rows:
|
||||
emb = np.frombuffer(row[4], dtype=np.float32).copy() if row[4] else None
|
||||
result.append({
|
||||
'id': row[0], 'name': row[1], 'enrolled_at': row[2],
|
||||
'sample_count': row[3], 'embedding': emb
|
||||
})
|
||||
return result
|
||||
|
||||
def get_person(self, person_id: int) -> dict | None:
|
||||
with self._lock:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT id, name, enrolled_at, sample_count, embedding "
|
||||
"FROM persons WHERE id=?",
|
||||
(person_id,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
emb = np.frombuffer(row[4], dtype=np.float32).copy() if row[4] else None
|
||||
return {
|
||||
'id': row[0], 'name': row[1], 'enrolled_at': row[2],
|
||||
'sample_count': row[3], 'embedding': emb
|
||||
}
|
||||
|
||||
def add_voice_sample(self, person_id: int, sample_path: str) -> int:
|
||||
with self._lock:
|
||||
with self._connect() as conn:
|
||||
cur = conn.execute(
|
||||
"INSERT INTO voice_samples (person_id, recorded_at, sample_path) "
|
||||
"VALUES (?, ?, ?)",
|
||||
(person_id, time.time(), sample_path)
|
||||
)
|
||||
return cur.lastrowid
|
||||
4
jetson/ros2_ws/src/saltybot_social_enrollment/setup.cfg
Normal file
4
jetson/ros2_ws/src/saltybot_social_enrollment/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/saltybot_social_enrollment
|
||||
[install]
|
||||
install_scripts=$base/lib/saltybot_social_enrollment
|
||||
29
jetson/ros2_ws/src/saltybot_social_enrollment/setup.py
Normal file
29
jetson/ros2_ws/src/saltybot_social_enrollment/setup.py
Normal file
@ -0,0 +1,29 @@
|
||||
from setuptools import setup
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
package_name = 'saltybot_social_enrollment'
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version='0.1.0',
|
||||
packages=[package_name],
|
||||
data_files=[
|
||||
('share/ament_index/resource_index/packages', ['resource/' + package_name]),
|
||||
('share/' + package_name, ['package.xml']),
|
||||
(os.path.join('share', package_name, 'launch'), glob('launch/*.py')),
|
||||
(os.path.join('share', package_name, 'config'), glob('config/*.yaml')),
|
||||
],
|
||||
install_requires=['setuptools'],
|
||||
zip_safe=True,
|
||||
maintainer='seb',
|
||||
maintainer_email='seb@vayrette.com',
|
||||
description='Person enrollment system for saltybot social interaction',
|
||||
license='MIT',
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'enrollment_node = saltybot_social_enrollment.enrollment_node:main',
|
||||
'enrollment_cli = saltybot_social_enrollment.enrollment_cli:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
@ -0,0 +1,4 @@
|
||||
sensor_msgs/Image crop
|
||||
---
|
||||
bool success
|
||||
float32[512] embedding
|
||||
Loading…
x
Reference in New Issue
Block a user