feat: first encounter orchestrator (Issue #400) #402

Merged
sl-jetson merged 2 commits from sl-jetson/issue-400-encounter-launch into main 2026-03-04 13:30:16 -05:00
11 changed files with 1087 additions and 0 deletions

View File

@ -0,0 +1,18 @@
first_encounter:
# Directory to queue encounter data (JSON files)
encounter_queue_dir: "/home/seb/encounter-queue"
# Face confidence threshold for "unknown" detection
unknown_face_threshold: 0.3 # Below this = not in gallery
# Timeout for waiting for STT responses (seconds)
small_talk_timeout: 10.0
# Timeout before considering person "left" (seconds)
person_away_timeout: 30.0
# Auto-enroll person after conversation
auto_enroll: true
# State machine flow:
# DETECT → GREET → ASK_NAME → SMALL_TALK → ENROLL → FAREWELL → save to queue

View File

@ -0,0 +1,23 @@
"""Launch file for first encounter orchestrator node."""
from launch import LaunchDescription
from launch_ros.actions import Node
def generate_launch_description():
"""Generate launch description."""
encounter_node = Node(
package="saltybot_first_encounter",
executable="first_encounter_node",
name="first_encounter",
parameters=[
{"encounter_queue_dir": "/home/seb/encounter-queue"},
{"unknown_face_threshold": 0.3},
{"small_talk_timeout": 10.0},
{"person_away_timeout": 30.0},
{"auto_enroll": True},
],
output="screen",
)
return LaunchDescription([encounter_node])

View File

@ -0,0 +1,28 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_first_encounter</name>
<version>0.1.0</version>
<description>
First encounter orchestrator node for unknown person detection and enrollment.
State machine: DETECT → GREET → ASK_NAME → SMALL_TALK → ENROLL → FAREWELL
</description>
<maintainer email="sl-controls@saltylab.local">sl-controls</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>std_msgs</depend>
<depend>geometry_msgs</depend>
<depend>sensor_msgs</depend>
<buildtool_depend>ament_python</buildtool_depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,293 @@
#!/usr/bin/env python3
"""First encounter state machine orchestrator.
Detects unknown persons and manages enrollment workflow.
State machine: DETECT GREET ASK_NAME SMALL_TALK ENROLL FAREWELL
Subscribes to:
/saltybot/person_tracker (detected persons with face match confidence)
/saltybot/stt_result (STT transcriptions)
/saltybot/person_left (signal when person walks away)
Publishes:
/social/orchestrator/state (JSON: state, person_id, encounter_data)
/saltybot/tts_request (Piper TTS triggers)
Stores encounter data as JSON files in /home/seb/encounter-queue/
"""
import json
import time
import threading
from enum import Enum
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Optional, Dict
import rclpy
from rclpy.node import Node
from std_msgs.msg import String, Bool
class EncounterState(Enum):
"""First encounter state machine states."""
IDLE = "IDLE"
DETECT = "DETECT"
GREET = "GREET"
ASK_NAME = "ASK_NAME"
SMALL_TALK = "SMALL_TALK"
ENROLL = "ENROLL"
FAREWELL = "FAREWELL"
@dataclass
class EncounterData:
"""Encounter metadata and responses."""
person_id: str
timestamp: float
state: str
name: Optional[str] = None
context: Optional[str] = None
greeting_response: Optional[str] = None
interests: list = None
enrollment_success: bool = False
duration_sec: float = 0.0
notes: str = ""
def __post_init__(self):
if self.interests is None:
self.interests = []
class FirstEncounterOrchestrator(Node):
"""First encounter state machine."""
# TTS prompts for each state
TTS_PROMPTS = {
EncounterState.GREET: "Hi there! I'm SaltyBot. What's your name?",
EncounterState.ASK_NAME: "Could you please tell me your name?",
EncounterState.SMALL_TALK: "Nice to meet you! What are you interested in?",
EncounterState.ENROLL: "I'm saving your face to remember you next time!",
EncounterState.FAREWELL: "It was great meeting you! Goodbye!",
}
def __init__(self):
super().__init__("first_encounter")
# Parameters
self.declare_parameter("encounter_queue_dir", "/home/seb/encounter-queue")
self.declare_parameter("unknown_face_threshold", 0.3) # Confidence below threshold = unknown
self.declare_parameter("small_talk_timeout", 10.0) # Seconds to wait for STT
self.declare_parameter("person_away_timeout", 30.0) # Seconds before person "left"
self.declare_parameter("auto_enroll", True)
self.encounter_queue_dir = Path(self.get_parameter("encounter_queue_dir").value)
self.unknown_threshold = self.get_parameter("unknown_face_threshold").value
self.small_talk_timeout = self.get_parameter("small_talk_timeout").value
self.person_away_timeout = self.get_parameter("person_away_timeout").value
self.auto_enroll = self.get_parameter("auto_enroll").value
# Create encounter queue directory
self.encounter_queue_dir.mkdir(parents=True, exist_ok=True)
# State
self.current_state = EncounterState.IDLE
self.current_person_id: Optional[str] = None
self.encounter_data: Optional[EncounterData] = None
self.encounter_start_time = 0.0
self.last_tracker_update = time.time()
self.state_lock = threading.Lock()
self.stt_response = None
self.stt_ready = False
# Subscriptions
self.create_subscription(String, "/saltybot/person_tracker", self._on_person_track, 10)
self.create_subscription(String, "/saltybot/stt_result", self._on_stt_result, 10)
self.create_subscription(Bool, "/saltybot/person_left", self._on_person_left, 10)
# Publishers
self.pub_orchestrator_state = self.create_publisher(String, "/social/orchestrator/state", 10)
self.pub_tts_request = self.create_publisher(String, "/saltybot/tts_request", 10)
# Timer for state machine update loop
self.create_timer(0.5, self._state_machine_update)
self.get_logger().info(
f"First encounter orchestrator initialized. Queue: {self.encounter_queue_dir}"
)
def _on_person_track(self, msg: String) -> None:
"""Handle person tracker update - detect unknown faces."""
try:
data = json.loads(msg.data)
person_id = data.get("person_id")
face_confidence = data.get("face_confidence", 1.0)
# Detect unknown face (low confidence = not in gallery)
if face_confidence < self.unknown_threshold:
with self.state_lock:
if self.current_state == EncounterState.IDLE:
self.current_person_id = person_id
self.current_state = EncounterState.DETECT
self.encounter_start_time = time.time()
self.encounter_data = EncounterData(
person_id=person_id,
timestamp=self.encounter_start_time,
state=EncounterState.DETECT.value
)
self.get_logger().info(f"Unknown person detected: {person_id}")
self.last_tracker_update = time.time()
except json.JSONDecodeError:
self.get_logger().error(f"Invalid JSON in person tracker: {msg.data}")
def _on_stt_result(self, msg: String) -> None:
"""Handle STT result."""
self.stt_response = msg.data
self.stt_ready = True
self.get_logger().debug(f"STT result: {msg.data}")
def _on_person_left(self, msg: Bool) -> None:
"""Handle person walking away - save partial data."""
if msg.data and self.encounter_data:
self.get_logger().info(f"Person {self.current_person_id} left. Saving encounter data.")
self._save_encounter_data("interrupted")
self._reset_encounter()
def _send_tts_request(self, text: str) -> None:
"""Send TTS request to Piper."""
tts_msg = String(data=json.dumps({
"text": text,
"voice": "default",
"speed": 1.0
}))
self.pub_tts_request.publish(tts_msg)
def _publish_state(self) -> None:
"""Publish current orchestrator state."""
state_data = {
"state": self.current_state.value,
"person_id": self.current_person_id,
}
if self.encounter_data:
state_data.update({
"name": self.encounter_data.name,
"context": self.encounter_data.context,
})
self.pub_orchestrator_state.publish(String(data=json.dumps(state_data)))
def _wait_for_stt(self, timeout: float) -> Optional[str]:
"""Wait for STT result with timeout."""
self.stt_ready = False
self.stt_response = None
start_time = time.time()
while time.time() - start_time < timeout:
if self.stt_ready:
response = self.stt_response
self.stt_ready = False
return response
time.sleep(0.1)
return None
def _state_machine_update(self) -> None:
"""Update state machine."""
with self.state_lock:
state = self.current_state
if state == EncounterState.IDLE:
pass # Waiting for detection
elif state == EncounterState.DETECT:
# Transition: DETECT → GREET
self.current_state = EncounterState.GREET
self._send_tts_request(self.TTS_PROMPTS[EncounterState.GREET])
self._publish_state()
elif state == EncounterState.GREET:
# Wait for STT response (handled asynchronously)
response = self._wait_for_stt(self.small_talk_timeout)
if response:
self.encounter_data.greeting_response = response
self.current_state = EncounterState.ASK_NAME
self._send_tts_request(self.TTS_PROMPTS[EncounterState.ASK_NAME])
self._publish_state()
elif state == EncounterState.ASK_NAME:
# Capture name from STT
response = self._wait_for_stt(self.small_talk_timeout)
if response:
self.encounter_data.name = response
self.current_state = EncounterState.SMALL_TALK
self._send_tts_request(self.TTS_PROMPTS[EncounterState.SMALL_TALK])
self._publish_state()
elif state == EncounterState.SMALL_TALK:
# Capture interests/context
response = self._wait_for_stt(self.small_talk_timeout)
if response:
self.encounter_data.context = response
self.encounter_data.interests = response.split(",")
self.current_state = EncounterState.ENROLL
self._send_tts_request(self.TTS_PROMPTS[EncounterState.ENROLL])
self._publish_state()
elif state == EncounterState.ENROLL:
# In real implementation, trigger face enrollment API
if self.auto_enroll:
self.encounter_data.enrollment_success = True
self.current_state = EncounterState.FAREWELL
self._send_tts_request(self.TTS_PROMPTS[EncounterState.FAREWELL])
self._publish_state()
elif state == EncounterState.FAREWELL:
# Complete encounter - save data and reset
self.encounter_data.duration_sec = time.time() - self.encounter_start_time
self._save_encounter_data("completed")
self._reset_encounter()
def _save_encounter_data(self, status: str) -> None:
"""Save encounter data to JSON file."""
if not self.encounter_data:
return
self.encounter_data.notes = status
self.encounter_data.state = self.current_state.value
filename = (
self.encounter_queue_dir /
f"encounter_{self.encounter_data.person_id}_{int(self.encounter_data.timestamp)}.json"
)
try:
with open(filename, 'w') as f:
json.dump(asdict(self.encounter_data), f, indent=2)
self.get_logger().info(f"Encounter data saved: {filename}")
except Exception as e:
self.get_logger().error(f"Failed to save encounter data: {e}")
def _reset_encounter(self) -> None:
"""Reset encounter state."""
with self.state_lock:
self.current_state = EncounterState.IDLE
self.current_person_id = None
self.encounter_data = None
self.stt_response = None
self.stt_ready = False
def main(args=None):
rclpy.init(args=args)
node = FirstEncounterOrchestrator()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,4 @@
[develop]
script-dir=$base/lib/saltybot_first_encounter
[egg_info]
tag_date = 0

View File

@ -0,0 +1,27 @@
from setuptools import setup
package_name = "saltybot_first_encounter"
setup(
name=package_name,
version="0.1.0",
packages=[package_name],
data_files=[
("share/ament_index/resource_index/packages", [f"resource/{package_name}"]),
(f"share/{package_name}", ["package.xml"]),
(f"share/{package_name}/launch", ["launch/first_encounter.launch.py"]),
(f"share/{package_name}/config", ["config/encounter_params.yaml"]),
],
install_requires=["setuptools"],
zip_safe=True,
maintainer="sl-controls",
maintainer_email="sl-controls@saltylab.local",
description="First encounter state machine orchestrator",
license="MIT",
tests_require=["pytest"],
entry_points={
"console_scripts": [
"first_encounter_node = saltybot_first_encounter.first_encounter_node:main",
],
},
)

View File

@ -0,0 +1,266 @@
#!/usr/bin/env python3
"""Encounter data sync service for offline-first queue management.
Monitors a local encounter queue directory and syncs JSON files to cloud
API when internet connectivity is available. Implements exponential backoff
retry strategy and manages processed files.
Watched directory: /home/seb/encounter-queue/
Synced directory: /home/seb/encounter-queue/synced/
Published topics:
/social/encounter_sync_status (std_msgs/String) - Sync status updates
"""
import json
import os
import shutil
import time
from pathlib import Path
from typing import Optional
from datetime import datetime
import socket
import urllib.request
import urllib.error
import rclpy
from rclpy.node import Node
from std_msgs.msg import String
class EncounterSyncService(Node):
"""ROS2 node for syncing encounter data to cloud API."""
def __init__(self):
super().__init__("encounter_sync_service")
# Parameters
self.declare_parameter(
"api_url",
"https://api.openclaw.io/encounters", # Default placeholder
)
self.declare_parameter("queue_dir", "/home/seb/encounter-queue")
self.declare_parameter("synced_subdir", "synced")
self.declare_parameter("check_interval", 5.0) # seconds
self.declare_parameter("connectivity_check_url", "https://www.google.com")
self.declare_parameter("connectivity_timeout", 3.0) # seconds
self.declare_parameter("max_retries", 5)
self.declare_parameter("initial_backoff", 1.0) # seconds
self.declare_parameter("max_backoff", 300.0) # 5 minutes
self.api_url = self.get_parameter("api_url").value
self.queue_dir = Path(self.get_parameter("queue_dir").value)
self.synced_subdir = self.get_parameter("synced_subdir").value
self.check_interval = self.get_parameter("check_interval").value
self.connectivity_url = self.get_parameter("connectivity_check_url").value
self.connectivity_timeout = self.get_parameter("connectivity_timeout").value
self.max_retries = self.get_parameter("max_retries").value
self.initial_backoff = self.get_parameter("initial_backoff").value
self.max_backoff = self.get_parameter("max_backoff").value
# Ensure queue directory exists
self.queue_dir.mkdir(parents=True, exist_ok=True)
self.synced_dir = self.queue_dir / self.synced_subdir
self.synced_dir.mkdir(parents=True, exist_ok=True)
# Publisher for sync status
self.pub_status = self.create_publisher(String, "/social/encounter_sync_status", 10)
# Track retry state per file
self.retry_counts = {}
self.backoff_times = {}
# Main processing timer
self.create_timer(self.check_interval, self._sync_loop)
self.get_logger().info(
f"Encounter sync service initialized. "
f"Queue: {self.queue_dir}, API: {self.api_url}"
)
self._publish_status("initialized", f"Queue: {self.queue_dir}")
def _sync_loop(self) -> None:
"""Main loop: check connectivity and sync queued files."""
# Check internet connectivity
if not self._check_connectivity():
self._publish_status("offline", "No internet connectivity")
return
self._publish_status("online", "Internet connectivity detected")
# Get all JSON files in queue directory (not in synced subdirectory)
queued_files = [
f
for f in self.queue_dir.glob("*.json")
if f.is_file() and not f.parent.name == self.synced_subdir
]
if not queued_files:
return
self.get_logger().debug(f"Found {len(queued_files)} queued encounter files")
for encounter_file in queued_files:
self._sync_file(encounter_file)
def _check_connectivity(self) -> bool:
"""Check internet connectivity via HTTP ping.
Returns:
True if connected, False otherwise
"""
try:
request = urllib.request.Request(
self.connectivity_url, method="HEAD"
)
with urllib.request.urlopen(request, timeout=self.connectivity_timeout):
return True
except (urllib.error.URLError, socket.timeout, OSError):
return False
def _sync_file(self, filepath: Path) -> None:
"""Attempt to sync a single encounter file, with exponential backoff retry.
Args:
filepath: Path to JSON file to sync
"""
file_id = filepath.name
# Check if we should retry this file
if file_id in self.retry_counts:
if self.retry_counts[file_id] >= self.max_retries:
self.get_logger().error(
f"Max retries exceeded for {file_id}, moving to synced with error flag"
)
self._move_to_synced(filepath, error=True)
del self.retry_counts[file_id]
if file_id in self.backoff_times:
del self.backoff_times[file_id]
return
# Check backoff timer
if file_id in self.backoff_times:
if time.time() < self.backoff_times[file_id]:
return # Not yet time to retry
# Attempt upload
try:
with open(filepath, "r") as f:
encounter_data = json.load(f)
success = self._upload_encounter(encounter_data)
if success:
self.get_logger().info(f"Successfully synced {file_id}")
self._move_to_synced(filepath, error=False)
if file_id in self.retry_counts:
del self.retry_counts[file_id]
if file_id in self.backoff_times:
del self.backoff_times[file_id]
self._publish_status("synced", f"File: {file_id}")
else:
self._handle_sync_failure(file_id)
except (json.JSONDecodeError, IOError) as e:
self.get_logger().error(f"Failed to read {file_id}: {e}")
self._move_to_synced(filepath, error=True)
if file_id in self.retry_counts:
del self.retry_counts[file_id]
if file_id in self.backoff_times:
del self.backoff_times[file_id]
def _upload_encounter(self, encounter_data: dict) -> bool:
"""Upload encounter data to cloud API.
Args:
encounter_data: Encounter JSON data
Returns:
True if successful, False otherwise
"""
try:
json_bytes = json.dumps(encounter_data).encode("utf-8")
request = urllib.request.Request(
self.api_url,
data=json_bytes,
headers={"Content-Type": "application/json"},
method="POST",
)
with urllib.request.urlopen(request, timeout=10.0) as response:
return response.status == 200 or response.status == 201
except (urllib.error.URLError, socket.timeout, OSError, json.JSONEncodeError) as e:
self.get_logger().warning(f"Upload failed: {e}")
return False
def _handle_sync_failure(self, file_id: str) -> None:
"""Handle sync failure with exponential backoff.
Args:
file_id: Filename identifier
"""
if file_id not in self.retry_counts:
self.retry_counts[file_id] = 0
self.backoff_times[file_id] = 0
self.retry_counts[file_id] += 1
backoff = min(
self.initial_backoff * (2 ** (self.retry_counts[file_id] - 1)),
self.max_backoff,
)
self.backoff_times[file_id] = time.time() + backoff
self.get_logger().warning(
f"Sync failed for {file_id}, retry {self.retry_counts[file_id]}/{self.max_retries} "
f"in {backoff:.1f}s"
)
self._publish_status(
"retry",
f"File: {file_id}, attempt {self.retry_counts[file_id]}/{self.max_retries}",
)
def _move_to_synced(self, filepath: Path, error: bool = False) -> None:
"""Move processed file to synced directory.
Args:
filepath: Path to file
error: Whether file had an error during sync
"""
timestamp = datetime.now().isoformat()
status_suffix = "_error" if error else ""
new_name = f"{filepath.stem}_{timestamp}{status_suffix}.json"
dest_path = self.synced_dir / new_name
try:
shutil.move(str(filepath), str(dest_path))
self.get_logger().debug(f"Moved {filepath.name} to synced/")
except OSError as e:
self.get_logger().error(f"Failed to move {filepath.name} to synced: {e}")
def _publish_status(self, status: str, details: str = "") -> None:
"""Publish sync status update.
Args:
status: Status string (e.g., 'online', 'offline', 'synced', 'retry')
details: Additional details
"""
timestamp = datetime.now().isoformat()
message = f"{timestamp} | {status.upper()} | {details}" if details else timestamp
msg = String()
msg.data = message
self.pub_status.publish(msg)
def main(args=None):
rclpy.init(args=args)
node = EncounterSyncService()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,428 @@
#!/usr/bin/env python3
"""social_enrollment_node.py -- First Encounter enrollment with face + voice biometrics.
Triggered by FirstEncounterOrchestrator when state transitions to ENROLL.
Captures:
- Face embedding (via SCRFD + ArcFace from RealSense RGB)
- Voice speaker embedding (via ECAPA-TDNN)
- RealSense RGB photo snapshot
- Metadata (name, context, timestamp)
Stores to:
- /home/seb/encounter-queue/{person_id}_{timestamp}.json (for offline cloud sync)
- Local speaker_embeddings.json (for immediate voice recognition)
- Face gallery (via EnrollPerson service to face_recognizer)
Subscribes to:
/social/orchestrator/state (JSON: state, person_id, name, context)
/social/faces/embeddings (FaceEmbeddingArray with ArcFace embeddings)
/camera/color/image_raw (RealSense RGB for snapshots)
/social/speech/speaker_embedding (speaker embedding from ECAPA-TDNN)
Publishes:
/social/enrollment/status (JSON: person_id, status, person_db_id)
"""
import json
import time
import threading
import numpy as np
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Optional, Dict
from datetime import datetime
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, ReliabilityPolicy, DurabilityPolicy
from std_msgs.msg import String
from sensor_msgs.msg import Image
import cv2
from cv_bridge import CvBridge
from saltybot_social_msgs.msg import FaceEmbeddingArray
from saltybot_social_msgs.srv import EnrollPerson
from saltybot_social_enrollment.person_db import PersonDB
@dataclass
class EnrollmentRequest:
"""Enrollment request from first encounter."""
person_id: str
name: str
context: Optional[str] = None
timestamp: float = 0.0
face_embedding: Optional[np.ndarray] = None
voice_embedding: Optional[np.ndarray] = None
photo_data: Optional[bytes] = None # JPEG encoded
class SocialEnrollmentNode(Node):
"""Face + voice enrollment during first encounter."""
def __init__(self):
super().__init__('social_enrollment')
# Parameters
self.declare_parameter('encounter_queue_dir', '/home/seb/encounter-queue')
self.declare_parameter('speaker_embeddings_path', '/home/seb/speaker_embeddings.json')
self.declare_parameter('photos_dir', '/home/seb/encounter-photos')
self.declare_parameter('face_recognizer_service', '/social/face_recognizer/enroll')
self.declare_parameter('embedding_dim_face', 512)
self.declare_parameter('embedding_dim_voice', 192)
self.queue_dir = Path(self.get_parameter('encounter_queue_dir').value)
self.speaker_embeddings_path = Path(self.get_parameter('speaker_embeddings_path').value)
self.photos_dir = Path(self.get_parameter('photos_dir').value)
self.face_service_name = self.get_parameter('face_recognizer_service').value
self.face_emb_dim = self.get_parameter('embedding_dim_face').value
self.voice_emb_dim = self.get_parameter('embedding_dim_voice').value
# Create directories
self.queue_dir.mkdir(parents=True, exist_ok=True)
self.photos_dir.mkdir(parents=True, exist_ok=True)
# Initialize PersonDB
self._db = PersonDB(str(self.queue_dir.parent / 'persons.db'))
self.get_logger().info(f'PersonDB initialized')
# CV bridge for image conversion
self._bridge = CvBridge()
# State
self._enrollment_request: Optional[EnrollmentRequest] = None
self._lock = threading.Lock()
self._latest_face_embedding: Optional[np.ndarray] = None
self._latest_voice_embedding: Optional[np.ndarray] = None
self._latest_image: Optional[Image] = None
self._face_embedding_timestamp = 0.0
self._voice_embedding_timestamp = 0.0
self._image_timestamp = 0.0
# QoS profiles
best_effort_qos = QoSProfile(
depth=10,
reliability=ReliabilityPolicy.BEST_EFFORT,
durability=DurabilityPolicy.VOLATILE,
)
reliable_qos = QoSProfile(
depth=1,
reliability=ReliabilityPolicy.RELIABLE,
durability=DurabilityPolicy.VOLATILE,
)
# Subscriptions
self.create_subscription(
String, '/social/orchestrator/state',
self._on_orchestrator_state, reliable_qos
)
self.create_subscription(
FaceEmbeddingArray, '/social/faces/embeddings',
self._on_face_embeddings, reliable_qos
)
self.create_subscription(
Image, '/camera/color/image_raw',
self._on_camera_image, best_effort_qos
)
self.create_subscription(
String, '/social/speech/speaker_embedding',
self._on_speaker_embedding, best_effort_qos
)
# Service clients
self._enroll_face_client = self.create_client(
EnrollPerson, self.face_service_name
)
# Publishers
self._pub_status = self.create_publisher(
String, '/social/enrollment/status', reliable_qos
)
# Timer for enrollment timeout handling
self.create_timer(0.5, self._enrollment_timeout_check)
self.get_logger().info(
f'Social enrollment node initialized. '
f'Queue: {self.queue_dir}, '
f'Speakers: {self.speaker_embeddings_path}'
)
def _on_orchestrator_state(self, msg: String) -> None:
"""Handle orchestrator state transitions."""
try:
state_data = json.loads(msg.data)
state = state_data.get('state')
if state == 'ENROLL':
person_id = state_data.get('person_id')
name = state_data.get('name')
context = state_data.get('context')
with self._lock:
self._enrollment_request = EnrollmentRequest(
person_id=person_id,
name=name,
context=context,
timestamp=time.time()
)
self._face_embedding_timestamp = 0.0
self._voice_embedding_timestamp = 0.0
self._image_timestamp = 0.0
self.get_logger().info(
f'Enrollment triggered: {name} (ID: {person_id})'
)
except json.JSONDecodeError as e:
self.get_logger().error(f'Invalid orchestrator state JSON: {e}')
def _on_face_embeddings(self, msg: FaceEmbeddingArray) -> None:
"""Capture face embedding from social face recognizer."""
if not msg.embeddings:
return
with self._lock:
if self._enrollment_request is None:
return
# Take first detected face embedding
face_emb = msg.embeddings[0]
emb_array = np.frombuffer(face_emb.embedding, dtype=np.float32)
if len(emb_array) == self.face_emb_dim:
self._latest_face_embedding = emb_array.copy()
self._face_embedding_timestamp = time.time()
self.get_logger().debug(
f'Face embedding captured: {face_emb.track_id}'
)
def _on_speaker_embedding(self, msg: String) -> None:
"""Capture voice speaker embedding from ECAPA-TDNN."""
try:
emb_data = json.loads(msg.data)
emb_values = emb_data.get('embedding')
if emb_values:
with self._lock:
if self._enrollment_request is None:
return
emb_array = np.array(emb_values, dtype=np.float32)
if len(emb_array) == self.voice_emb_dim:
self._latest_voice_embedding = emb_array.copy()
self._voice_embedding_timestamp = time.time()
self.get_logger().debug(
f'Voice embedding captured: {len(emb_array)} dims'
)
except json.JSONDecodeError as e:
self.get_logger().error(f'Invalid speaker embedding JSON: {e}')
def _on_camera_image(self, msg: Image) -> None:
"""Capture RealSense RGB image for enrollment photo."""
try:
with self._lock:
if self._enrollment_request is None:
return
# Store latest image
self._latest_image = msg
self._image_timestamp = time.time()
except Exception as e:
self.get_logger().error(f'Error capturing camera image: {e}')
def _enrollment_timeout_check(self) -> None:
"""Check if enrollment data is ready or timed out."""
with self._lock:
if self._enrollment_request is None:
return
now = time.time()
timeout = 10.0 # 10 seconds to collect embeddings
# Check if all data collected
has_face = self._latest_face_embedding is not None and \
(now - self._face_embedding_timestamp < 5.0)
has_voice = self._latest_voice_embedding is not None and \
(now - self._voice_embedding_timestamp < 5.0)
has_image = self._latest_image is not None and \
(now - self._image_timestamp < 5.0)
# If we have face + voice, proceed with enrollment
if has_face and has_voice:
self._complete_enrollment()
# If timeout exceeded, save what we have
elif (now - self._enrollment_request.timestamp) > timeout:
self.get_logger().warn(
f'Enrollment timeout for {self._enrollment_request.name}. '
f'Proceeding with available data.'
)
self._complete_enrollment()
def _complete_enrollment(self) -> None:
"""Complete enrollment process."""
request = self._enrollment_request
if request is None:
return
try:
# Save enrollment data to queue
enroll_data = {
'person_id': request.person_id,
'name': request.name,
'context': request.context,
'timestamp': request.timestamp,
'datetime': datetime.fromtimestamp(request.timestamp).isoformat(),
'face_embedding_shape': list(self._latest_face_embedding.shape)
if self._latest_face_embedding is not None else None,
'voice_embedding_shape': list(self._latest_voice_embedding.shape)
if self._latest_voice_embedding is not None else None,
}
# Save queue JSON
queue_file = self.queue_dir / f"enrollment_{request.person_id}_{int(request.timestamp)}.json"
with open(queue_file, 'w') as f:
json.dump(enroll_data, f, indent=2)
self.get_logger().info(f'Enrollment data queued: {queue_file}')
# Save photo if available
photo_id = None
if self._latest_image is not None:
photo_id = self._save_enrollment_photo(request)
# Add to PersonDB with embeddings
person_db_id = self._db.add_person(
name=request.name,
embedding=self._latest_face_embedding,
sample_count=1,
metadata={
'encounter_person_id': request.person_id,
'context': request.context,
'photo_id': photo_id,
'timestamp': request.timestamp,
}
)
self.get_logger().info(f'Added to PersonDB: ID {person_db_id}')
# Update speaker embeddings JSON
self._update_speaker_embeddings(person_db_id, request)
# Enroll face via face_recognizer service
self._enroll_face(person_db_id, request)
# Publish success status
self._publish_enrollment_status('success', person_db_id)
except Exception as e:
self.get_logger().error(f'Enrollment failed for {request.name}: {e}')
self._publish_enrollment_status('failed', None)
finally:
self._enrollment_request = None
self._latest_face_embedding = None
self._latest_voice_embedding = None
self._latest_image = None
def _save_enrollment_photo(self, request: EnrollmentRequest) -> str:
"""Save enrollment photo from RealSense."""
try:
if self._latest_image is None:
return None
cv_image = self._bridge.imgmsg_to_cv2(self._latest_image, 'bgr8')
photo_id = f"{request.person_id}_{int(request.timestamp)}"
photo_path = self.photos_dir / f"{photo_id}.jpg"
cv2.imwrite(str(photo_path), cv_image)
self.get_logger().info(f'Enrollment photo saved: {photo_path}')
return photo_id
except Exception as e:
self.get_logger().error(f'Failed to save enrollment photo: {e}')
return None
def _update_speaker_embeddings(self, person_db_id: int, request: EnrollmentRequest) -> None:
"""Update speaker_embeddings.json with voice embedding."""
try:
if self._latest_voice_embedding is None:
return
# Load existing embeddings
speaker_db = {}
if self.speaker_embeddings_path.exists():
with open(self.speaker_embeddings_path, 'r') as f:
speaker_db = json.load(f)
# Add new embedding
speaker_db[str(person_db_id)] = {
'name': request.name,
'person_id': request.person_id,
'embedding': self._latest_voice_embedding.tolist(),
'timestamp': request.timestamp,
}
# Save updated embeddings
with open(self.speaker_embeddings_path, 'w') as f:
json.dump(speaker_db, f, indent=2)
self.get_logger().info(
f'Speaker embedding saved for {request.name}'
)
except Exception as e:
self.get_logger().error(f'Failed to update speaker embeddings: {e}')
def _enroll_face(self, person_db_id: int, request: EnrollmentRequest) -> None:
"""Enroll face via face_recognizer service."""
try:
if self._latest_face_embedding is None:
return
if not self._enroll_face_client.wait_for_service(timeout_sec=2.0):
self.get_logger().warn(
f'Face recognizer service not available. Skipping face enrollment.'
)
return
# Call EnrollPerson service
req = EnrollPerson.Request()
req.name = request.name
req.mode = 'face'
req.n_samples = 1
future = self._enroll_face_client.call_async(req)
self.get_logger().info(f'Face enrollment service called for {request.name}')
except Exception as e:
self.get_logger().error(f'Face enrollment service call failed: {e}')
def _publish_enrollment_status(self, status: str, person_db_id: Optional[int]) -> None:
"""Publish enrollment completion status."""
try:
status_msg = {
'status': status,
'person_id': self._enrollment_request.person_id if self._enrollment_request else None,
'name': self._enrollment_request.name if self._enrollment_request else None,
'person_db_id': person_db_id,
'timestamp': time.time(),
}
self._pub_status.publish(String(data=json.dumps(status_msg)))
except Exception as e:
self.get_logger().error(f'Failed to publish enrollment status: {e}')
def main(args=None):
rclpy.init(args=args)
node = SocialEnrollmentNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()