From 5043578934f471e2328fdd266ce8b310b1df6eaa Mon Sep 17 00:00:00 2001 From: sl-jetson Date: Mon, 2 Mar 2026 08:17:35 -0500 Subject: [PATCH] feat(social): speech pipeline + LLM conversation + TTS + orchestrator (#81 #83 #85 #89) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Issue #81 — Speech pipeline: - speech_pipeline_node.py: OpenWakeWord "hey_salty" → Silero VAD → faster-whisper STT (Orin GPU, <500ms wake-to-transcript) → ECAPA-TDNN speaker diarization - speech_utils.py: pcm16↔float32, EnergyVad, UtteranceSegmenter (pre-roll, max- duration), cosine speaker identification — all pure Python, no ROS2/GPU needed - Publishes /social/speech/transcript (SpeechTranscript) + /social/speech/vad_state Issue #83 — Conversation engine: - conversation_node.py: llama-cpp-python GGUF (Phi-3-mini Q4_K_M, 20 GPU layers), streaming token output, per-person sliding-window context (4K tokens), summary compression, SOUL.md system prompt, group mode - llm_context.py: PersonContext, ContextStore (JSON persistence), build_llama_prompt (ChatML format), context compression via LLM summarization - Publishes /social/conversation/response (ConversationResponse, partial + final) Issue #85 — Streaming TTS: - tts_node.py: Piper ONNX streaming synthesis, sentence-by-sentence first-chunk streaming (<200ms to first audio), sounddevice USB speaker playback, volume control - tts_utils.py: split_sentences, pcm16_to_wav_bytes, chunk_pcm, apply_volume, strip_ssml Issue #89 — Pipeline orchestrator: - orchestrator_node.py: IDLE→LISTENING→THINKING→SPEAKING state machine, GPU memory watchdog (throttle at <2GB free), rolling latency stats (p50/p95 per stage), VAD watchdog (alert if speech pipeline hangs), /social/orchestrator/state JSON pub - social_bot.launch.py: brings up all 4 nodes with TimerAction delays New messages: SpeechTranscript.msg, VadState.msg, ConversationResponse.msg Config YAMLs: speech_params, conversation_params, tts_params, orchestrator_params Tests: 58 tests (28 speech_utils + 30 llm_context/tts_utils), all passing Co-Authored-By: Claude Sonnet 4.6 --- .../config/conversation_params.yaml | 12 + .../config/orchestrator_params.yaml | 8 + .../saltybot_social/config/speech_params.yaml | 13 + .../saltybot_social/config/tts_params.yaml | 9 + .../launch/social_bot.launch.py | 129 ++++++ .../ros2_ws/src/saltybot_social/package.xml | 7 +- .../saltybot_social/conversation_node.py | 267 ++++++++++++ .../saltybot_social/llm_context.py | 176 ++++++++ .../saltybot_social/orchestrator_node.py | 298 +++++++++++++ .../saltybot_social/speech_pipeline_node.py | 391 ++++++++++++++++++ .../saltybot_social/speech_utils.py | 188 +++++++++ .../saltybot_social/tts_node.py | 228 ++++++++++ .../saltybot_social/tts_utils.py | 100 +++++ jetson/ros2_ws/src/saltybot_social/setup.py | 6 +- .../saltybot_social/test/test_llm_context.py | 244 +++++++++++ .../saltybot_social/test/test_speech_utils.py | 237 +++++++++++ .../src/saltybot_social_msgs/CMakeLists.txt | 5 + .../msg/ConversationResponse.msg | 9 + .../msg/SpeechTranscript.msg | 10 + .../src/saltybot_social_msgs/msg/VadState.msg | 8 + 20 files changed, 2342 insertions(+), 3 deletions(-) create mode 100644 jetson/ros2_ws/src/saltybot_social/config/conversation_params.yaml create mode 100644 jetson/ros2_ws/src/saltybot_social/config/orchestrator_params.yaml create mode 100644 jetson/ros2_ws/src/saltybot_social/config/speech_params.yaml create mode 100644 jetson/ros2_ws/src/saltybot_social/config/tts_params.yaml create mode 100644 jetson/ros2_ws/src/saltybot_social/launch/social_bot.launch.py create mode 100644 jetson/ros2_ws/src/saltybot_social/saltybot_social/conversation_node.py create mode 100644 jetson/ros2_ws/src/saltybot_social/saltybot_social/llm_context.py create mode 100644 jetson/ros2_ws/src/saltybot_social/saltybot_social/orchestrator_node.py create mode 100644 jetson/ros2_ws/src/saltybot_social/saltybot_social/speech_pipeline_node.py create mode 100644 jetson/ros2_ws/src/saltybot_social/saltybot_social/speech_utils.py create mode 100644 jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_node.py create mode 100644 jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_utils.py create mode 100644 jetson/ros2_ws/src/saltybot_social/test/test_llm_context.py create mode 100644 jetson/ros2_ws/src/saltybot_social/test/test_speech_utils.py create mode 100644 jetson/ros2_ws/src/saltybot_social_msgs/msg/ConversationResponse.msg create mode 100644 jetson/ros2_ws/src/saltybot_social_msgs/msg/SpeechTranscript.msg create mode 100644 jetson/ros2_ws/src/saltybot_social_msgs/msg/VadState.msg diff --git a/jetson/ros2_ws/src/saltybot_social/config/conversation_params.yaml b/jetson/ros2_ws/src/saltybot_social/config/conversation_params.yaml new file mode 100644 index 0000000..b0de63e --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/config/conversation_params.yaml @@ -0,0 +1,12 @@ +conversation_node: + ros__parameters: + model_path: "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf" + n_ctx: 4096 + n_gpu_layers: 20 # Increase for more VRAM (Orin has 8GB shared) + max_tokens: 200 + temperature: 0.7 + top_p: 0.9 + soul_path: "/soul/SOUL.md" + context_db_path: "/social_db/conversation_context.json" + save_interval_s: 30.0 + stream: true diff --git a/jetson/ros2_ws/src/saltybot_social/config/orchestrator_params.yaml b/jetson/ros2_ws/src/saltybot_social/config/orchestrator_params.yaml new file mode 100644 index 0000000..7efd8b8 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/config/orchestrator_params.yaml @@ -0,0 +1,8 @@ +orchestrator_node: + ros__parameters: + gpu_mem_warn_mb: 4000.0 # Warn when GPU free < 4GB + gpu_mem_throttle_mb: 2000.0 # Throttle new inferences at < 2GB + watchdog_timeout_s: 30.0 + latency_window: 20 + profile_enabled: true + state_publish_rate: 2.0 # Hz diff --git a/jetson/ros2_ws/src/saltybot_social/config/speech_params.yaml b/jetson/ros2_ws/src/saltybot_social/config/speech_params.yaml new file mode 100644 index 0000000..8c1ef5e --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/config/speech_params.yaml @@ -0,0 +1,13 @@ +speech_pipeline_node: + ros__parameters: + mic_device_index: -1 # -1 = system default; use `arecord -l` to list + sample_rate: 16000 + wake_word_model: "hey_salty" + wake_word_threshold: 0.5 + vad_threshold_db: -35.0 + use_silero_vad: true + whisper_model: "small" # small (~500ms), medium (better quality, ~900ms) + whisper_compute_type: "float16" + speaker_threshold: 0.65 + speaker_db_path: "/social_db/speaker_embeddings.json" + publish_partial: true diff --git a/jetson/ros2_ws/src/saltybot_social/config/tts_params.yaml b/jetson/ros2_ws/src/saltybot_social/config/tts_params.yaml new file mode 100644 index 0000000..4d8fe52 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/config/tts_params.yaml @@ -0,0 +1,9 @@ +tts_node: + ros__parameters: + voice_path: "/models/piper/en_US-lessac-medium.onnx" + sample_rate: 22050 + volume: 1.0 + audio_device: "" # "" = system default; set to device name if needed + playback_enabled: true + publish_audio: false + sentence_streaming: true diff --git a/jetson/ros2_ws/src/saltybot_social/launch/social_bot.launch.py b/jetson/ros2_ws/src/saltybot_social/launch/social_bot.launch.py new file mode 100644 index 0000000..88e34ca --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/launch/social_bot.launch.py @@ -0,0 +1,129 @@ +"""social_bot.launch.py — Launch the full social-bot pipeline. + +Issue #89: End-to-end pipeline orchestrator launch file. + +Brings up all social-bot nodes in dependency order: + 1. Orchestrator (state machine + watchdog) — t=0s + 2. Speech pipeline (wake word + VAD + STT) — t=2s + 3. Conversation engine (LLM) — t=4s + 4. TTS node (Piper streaming) — t=4s + 5. Person state tracker (already in social.launch.py — optional) + +Launch args: + enable_speech (bool, true) + enable_llm (bool, true) + enable_tts (bool, true) + enable_orchestrator (bool, true) + voice_path (str, /models/piper/en_US-lessac-medium.onnx) + whisper_model (str, small) + llm_model_path (str, /models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf) + n_gpu_layers (int, 20) +""" + +import os +from ament_index_python.packages import get_package_share_directory +from launch import LaunchDescription +from launch.actions import ( + DeclareLaunchArgument, GroupAction, LogInfo, TimerAction +) +from launch.conditions import IfCondition +from launch.substitutions import LaunchConfiguration, PythonExpression +from launch_ros.actions import Node + + +def generate_launch_description() -> LaunchDescription: + pkg = get_package_share_directory("saltybot_social") + cfg = os.path.join(pkg, "config") + + # ── Launch arguments ───────────────────────────────────────────────────── + args = [ + DeclareLaunchArgument("enable_speech", default_value="true"), + DeclareLaunchArgument("enable_llm", default_value="true"), + DeclareLaunchArgument("enable_tts", default_value="true"), + DeclareLaunchArgument("enable_orchestrator", default_value="true"), + DeclareLaunchArgument("voice_path", + default_value="/models/piper/en_US-lessac-medium.onnx"), + DeclareLaunchArgument("whisper_model", default_value="small"), + DeclareLaunchArgument("llm_model_path", + default_value="/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf"), + DeclareLaunchArgument("n_gpu_layers", default_value="20"), + ] + + # ── Orchestrator (t=0s) ────────────────────────────────────────────────── + orchestrator = Node( + package="saltybot_social", + executable="orchestrator_node", + name="orchestrator_node", + parameters=[os.path.join(cfg, "orchestrator_params.yaml")], + condition=IfCondition(LaunchConfiguration("enable_orchestrator")), + output="screen", + emulate_tty=True, + ) + + # ── Speech pipeline (t=2s) ─────────────────────────────────────────────── + speech = TimerAction(period=2.0, actions=[ + GroupAction( + condition=IfCondition(LaunchConfiguration("enable_speech")), + actions=[ + LogInfo(msg="[social_bot] Starting speech pipeline..."), + Node( + package="saltybot_social", + executable="speech_pipeline_node", + name="speech_pipeline_node", + parameters=[ + os.path.join(cfg, "speech_params.yaml"), + {"whisper_model": LaunchConfiguration("whisper_model")}, + ], + output="screen", + emulate_tty=True, + ), + ] + ) + ]) + + # ── LLM conversation engine (t=4s) ─────────────────────────────────────── + llm = TimerAction(period=4.0, actions=[ + GroupAction( + condition=IfCondition(LaunchConfiguration("enable_llm")), + actions=[ + LogInfo(msg="[social_bot] Starting LLM conversation engine..."), + Node( + package="saltybot_social", + executable="conversation_node", + name="conversation_node", + parameters=[ + os.path.join(cfg, "conversation_params.yaml"), + { + "model_path": LaunchConfiguration("llm_model_path"), + "n_gpu_layers": LaunchConfiguration("n_gpu_layers"), + }, + ], + output="screen", + emulate_tty=True, + ), + ] + ) + ]) + + # ── TTS node (t=4s, parallel with LLM) ────────────────────────────────── + tts = TimerAction(period=4.0, actions=[ + GroupAction( + condition=IfCondition(LaunchConfiguration("enable_tts")), + actions=[ + LogInfo(msg="[social_bot] Starting TTS node..."), + Node( + package="saltybot_social", + executable="tts_node", + name="tts_node", + parameters=[ + os.path.join(cfg, "tts_params.yaml"), + {"voice_path": LaunchConfiguration("voice_path")}, + ], + output="screen", + emulate_tty=True, + ), + ] + ) + ]) + + return LaunchDescription(args + [orchestrator, speech, llm, tts]) diff --git a/jetson/ros2_ws/src/saltybot_social/package.xml b/jetson/ros2_ws/src/saltybot_social/package.xml index ca4b63c..f73d2ce 100644 --- a/jetson/ros2_ws/src/saltybot_social/package.xml +++ b/jetson/ros2_ws/src/saltybot_social/package.xml @@ -5,9 +5,12 @@ 0.1.0 Social interaction layer for saltybot. + speech_pipeline_node: wake word + VAD + Whisper STT + diarization (Issue #81). + conversation_node: local LLM with per-person context (Issue #83). + tts_node: streaming TTS with Piper first-chunk (Issue #85). + orchestrator_node: pipeline state machine + GPU watchdog + latency profiler (Issue #89). person_state_tracker: multi-modal person identity fusion (Issue #82). - expression_node: bridges /social/mood to ESP32-C3 NeoPixel ring over serial (Issue #86). - attention_node: rotates robot toward active speaker via /social/persons bearing (Issue #86). + expression_node: LED expression + motor attention (Issue #86). seb MIT diff --git a/jetson/ros2_ws/src/saltybot_social/saltybot_social/conversation_node.py b/jetson/ros2_ws/src/saltybot_social/saltybot_social/conversation_node.py new file mode 100644 index 0000000..a9218a7 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/saltybot_social/conversation_node.py @@ -0,0 +1,267 @@ +"""conversation_node.py — Local LLM conversation engine with per-person context. + +Issue #83: Conversation engine for social-bot. + +Stack: Phi-3-mini or Llama-3.2-3B GGUF Q4_K_M via llama-cpp-python (CUDA). +Subscribes /social/speech/transcript → builds per-person prompt → streams +token output → publishes /social/conversation/response. + +Streaming: publishes partial=true tokens as they arrive, then final=false +at end of generation. TTS node can begin synthesis on first sentence boundary. + +ROS2 topics: + Subscribe: /social/speech/transcript (saltybot_social_msgs/SpeechTranscript) + Publish: /social/conversation/response (saltybot_social_msgs/ConversationResponse) + +Parameters: + model_path (str, "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf") + n_ctx (int, 4096) + n_gpu_layers (int, 20) — GPU offload layers (increase for more VRAM usage) + max_tokens (int, 200) + temperature (float, 0.7) + top_p (float, 0.9) + soul_path (str, "/soul/SOUL.md") + context_db_path (str, "/social_db/conversation_context.json") + save_interval_s (float, 30.0) — how often to persist context to disk + stream (bool, true) +""" + +from __future__ import annotations + +import threading +import time +from typing import Optional + +import rclpy +from rclpy.node import Node +from rclpy.qos import QoSProfile + +from saltybot_social_msgs.msg import SpeechTranscript, ConversationResponse +from .llm_context import ContextStore, build_llama_prompt, load_system_prompt, needs_summary_prompt + + +class ConversationNode(Node): + """Local LLM inference node with per-person conversation memory.""" + + def __init__(self) -> None: + super().__init__("conversation_node") + + # ── Parameters ────────────────────────────────────────────────────── + self.declare_parameter("model_path", + "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf") + self.declare_parameter("n_ctx", 4096) + self.declare_parameter("n_gpu_layers", 20) + self.declare_parameter("max_tokens", 200) + self.declare_parameter("temperature", 0.7) + self.declare_parameter("top_p", 0.9) + self.declare_parameter("soul_path", "/soul/SOUL.md") + self.declare_parameter("context_db_path", "/social_db/conversation_context.json") + self.declare_parameter("save_interval_s", 30.0) + self.declare_parameter("stream", True) + + self._model_path = self.get_parameter("model_path").value + self._n_ctx = self.get_parameter("n_ctx").value + self._n_gpu = self.get_parameter("n_gpu_layers").value + self._max_tokens = self.get_parameter("max_tokens").value + self._temperature = self.get_parameter("temperature").value + self._top_p = self.get_parameter("top_p").value + self._soul_path = self.get_parameter("soul_path").value + self._db_path = self.get_parameter("context_db_path").value + self._save_interval = self.get_parameter("save_interval_s").value + self._stream = self.get_parameter("stream").value + + # ── Publishers / Subscribers ───────────────────────────────────────── + qos = QoSProfile(depth=10) + self._resp_pub = self.create_publisher( + ConversationResponse, "/social/conversation/response", qos + ) + self._transcript_sub = self.create_subscription( + SpeechTranscript, "/social/speech/transcript", + self._on_transcript, qos + ) + + # ── State ──────────────────────────────────────────────────────────── + self._llm = None + self._system_prompt = load_system_prompt(self._soul_path) + self._ctx_store = ContextStore(self._db_path) + self._lock = threading.Lock() + self._turn_counter = 0 + self._generating = False + self._last_save = time.time() + + # ── Load LLM in background ──────────────────────────────────────────── + threading.Thread(target=self._load_llm, daemon=True).start() + + # ── Periodic context save ──────────────────────────────────────────── + self._save_timer = self.create_timer(self._save_interval, self._save_context) + + self.get_logger().info( + f"ConversationNode init (model={self._model_path}, " + f"gpu_layers={self._n_gpu}, ctx={self._n_ctx})" + ) + + # ── Model loading ───────────────────────────────────────────────────────── + + def _load_llm(self) -> None: + t0 = time.time() + self.get_logger().info(f"Loading LLM: {self._model_path}") + try: + from llama_cpp import Llama + self._llm = Llama( + model_path=self._model_path, + n_ctx=self._n_ctx, + n_gpu_layers=self._n_gpu, + n_threads=4, + verbose=False, + ) + self.get_logger().info( + f"LLM ready ({time.time()-t0:.1f}s). " + f"Context: {self._n_ctx} tokens, GPU layers: {self._n_gpu}" + ) + except Exception as e: + self.get_logger().error(f"LLM load failed: {e}") + + # ── Transcript callback ─────────────────────────────────────────────────── + + def _on_transcript(self, msg: SpeechTranscript) -> None: + """Handle final transcripts only (skip streaming partials).""" + if msg.is_partial: + return + if not msg.text.strip(): + return + + self.get_logger().info( + f"Transcript [{msg.speaker_id}]: '{msg.text}'" + ) + + threading.Thread( + target=self._generate_response, + args=(msg.text.strip(), msg.speaker_id), + daemon=True, + ).start() + + # ── LLM inference ───────────────────────────────────────────────────────── + + def _generate_response(self, user_text: str, speaker_id: str) -> None: + """Generate LLM response with streaming. Runs in thread.""" + if self._llm is None: + self.get_logger().warn("LLM not loaded yet, dropping utterance") + return + + with self._lock: + if self._generating: + self.get_logger().warn("LLM busy, dropping utterance") + return + self._generating = True + self._turn_counter += 1 + turn_id = self._turn_counter + + try: + ctx = self._ctx_store.get(speaker_id) + + # Summary compression if context is long + if ctx.needs_compression(): + self._compress_context(ctx) + + ctx.add_user(user_text) + + prompt = build_llama_prompt( + ctx, user_text, self._system_prompt + ) + + t0 = time.perf_counter() + full_response = "" + + if self._stream: + output = self._llm( + prompt, + max_tokens=self._max_tokens, + temperature=self._temperature, + top_p=self._top_p, + stream=True, + stop=["<|user|>", "<|system|>", "\n\n\n"], + ) + for chunk in output: + token = chunk["choices"][0]["text"] + full_response += token + # Publish partial after each sentence boundary for low TTS latency + if token.endswith((".", "!", "?", "\n")): + self._publish_response( + full_response.strip(), speaker_id, turn_id, is_partial=True + ) + else: + output = self._llm( + prompt, + max_tokens=self._max_tokens, + temperature=self._temperature, + top_p=self._top_p, + stream=False, + stop=["<|user|>", "<|system|>"], + ) + full_response = output["choices"][0]["text"] + + full_response = full_response.strip() + latency_ms = (time.perf_counter() - t0) * 1000 + self.get_logger().info( + f"LLM [{speaker_id}] ({latency_ms:.0f}ms): '{full_response[:80]}'" + ) + + ctx.add_assistant(full_response) + self._publish_response(full_response, speaker_id, turn_id, is_partial=False) + + except Exception as e: + self.get_logger().error(f"LLM inference error: {e}") + finally: + with self._lock: + self._generating = False + + def _compress_context(self, ctx) -> None: + """Ask LLM to summarize old turns for context compression.""" + if self._llm is None: + ctx.compress("(history omitted)") + return + try: + summary_prompt = needs_summary_prompt(ctx) + result = self._llm(summary_prompt, max_tokens=80, temperature=0.3, stream=False) + summary = result["choices"][0]["text"].strip() + ctx.compress(summary) + self.get_logger().debug( + f"Context compressed for {ctx.person_id}: '{summary[:60]}'" + ) + except Exception: + ctx.compress("(history omitted)") + + # ── Publish ─────────────────────────────────────────────────────────────── + + def _publish_response( + self, text: str, speaker_id: str, turn_id: int, is_partial: bool + ) -> None: + msg = ConversationResponse() + msg.header.stamp = self.get_clock().now().to_msg() + msg.text = text + msg.speaker_id = speaker_id + msg.is_partial = is_partial + msg.turn_id = turn_id + self._resp_pub.publish(msg) + + def _save_context(self) -> None: + try: + self._ctx_store.save() + except Exception as e: + self.get_logger().error(f"Context save error: {e}") + + def destroy_node(self) -> None: + self._save_context() + super().destroy_node() + + +def main(args=None) -> None: + rclpy.init(args=args) + node = ConversationNode() + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.shutdown() diff --git a/jetson/ros2_ws/src/saltybot_social/saltybot_social/llm_context.py b/jetson/ros2_ws/src/saltybot_social/saltybot_social/llm_context.py new file mode 100644 index 0000000..a743f54 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/saltybot_social/llm_context.py @@ -0,0 +1,176 @@ +"""llm_context.py — Per-person conversation context management. + +No ROS2 dependencies. Manages sliding-window conversation history per person_id, +builds llama.cpp prompt, handles summary compression when context is full. + +Used by conversation_node.py. +Tested by test/test_llm_context.py. +""" + +from __future__ import annotations + +import json +import os +import time +from typing import List, Optional, Tuple + +# ── Constants ────────────────────────────────────────────────────────────────── +MAX_TURNS = 20 # Max turns to keep before summary compression +SUMMARY_KEEP = 4 # Keep last N turns after summarizing older ones +SYSTEM_PROMPT_PATH = "/soul/SOUL.md" +DEFAULT_SYSTEM_PROMPT = ( + "You are Salty, a friendly social robot. " + "You are helpful, warm, and concise. " + "Keep responses under 2 sentences unless asked to elaborate." +) + + +# ── Conversation history ─────────────────────────────────────────────────────── + +class PersonContext: + """Conversation history + relationship memory for one person.""" + + def __init__(self, person_id: str, person_name: str = "") -> None: + self.person_id = person_id + self.person_name = person_name or person_id + self.turns: List[dict] = [] # {"role": "user"|"assistant", "content": str, "ts": float} + self.summary: str = "" + self.last_seen: float = time.time() + self.interaction_count: int = 0 + + def add_user(self, text: str) -> None: + self.turns.append({"role": "user", "content": text, "ts": time.time()}) + self.last_seen = time.time() + self.interaction_count += 1 + + def add_assistant(self, text: str) -> None: + self.turns.append({"role": "assistant", "content": text, "ts": time.time()}) + + def needs_compression(self) -> bool: + return len(self.turns) > MAX_TURNS + + def compress(self, summary_text: str) -> None: + """Replace old turns with a summary, keeping last SUMMARY_KEEP turns.""" + old_summary = self.summary + if old_summary: + self.summary = old_summary + " " + summary_text + else: + self.summary = summary_text + # Keep most recent turns + self.turns = self.turns[-SUMMARY_KEEP:] + + def to_dict(self) -> dict: + return { + "person_id": self.person_id, + "person_name": self.person_name, + "turns": self.turns, + "summary": self.summary, + "last_seen": self.last_seen, + "interaction_count": self.interaction_count, + } + + @classmethod + def from_dict(cls, d: dict) -> "PersonContext": + ctx = cls(d["person_id"], d.get("person_name", "")) + ctx.turns = d.get("turns", []) + ctx.summary = d.get("summary", "") + ctx.last_seen = d.get("last_seen", 0.0) + ctx.interaction_count = d.get("interaction_count", 0) + return ctx + + +# ── Context store ────────────────────────────────────────────────────────────── + +class ContextStore: + """Persistent per-person conversation context with optional JSON backing.""" + + def __init__(self, db_path: str = "/social_db/conversation_context.json") -> None: + self._db_path = db_path + self._contexts: dict = {} # person_id → PersonContext + self._load() + + def get(self, person_id: str, person_name: str = "") -> PersonContext: + if person_id not in self._contexts: + self._contexts[person_id] = PersonContext(person_id, person_name) + elif person_name and not self._contexts[person_id].person_name: + self._contexts[person_id].person_name = person_name + return self._contexts[person_id] + + def save(self) -> None: + os.makedirs(os.path.dirname(self._db_path), exist_ok=True) + with open(self._db_path, "w") as f: + json.dump( + {pid: ctx.to_dict() for pid, ctx in self._contexts.items()}, + f, indent=2 + ) + + def _load(self) -> None: + if os.path.exists(self._db_path): + try: + with open(self._db_path) as f: + raw = json.load(f) + self._contexts = { + pid: PersonContext.from_dict(d) for pid, d in raw.items() + } + except Exception: + self._contexts = {} + + def all_persons(self) -> List[str]: + return list(self._contexts.keys()) + + +# ── Prompt builder ───────────────────────────────────────────────────────────── + +def load_system_prompt(path: str = SYSTEM_PROMPT_PATH) -> str: + """Load SOUL.md as system prompt. Falls back to default.""" + if os.path.exists(path): + with open(path) as f: + return f.read().strip() + return DEFAULT_SYSTEM_PROMPT + + +def build_llama_prompt( + ctx: PersonContext, + user_text: str, + system_prompt: str, + group_persons: Optional[List[str]] = None, +) -> str: + """Build Phi-3 / Llama-3 style chat prompt string. + + Uses the ChatML format supported by llama-cpp-python. + """ + lines = [f"<|system|>\n{system_prompt}"] + + if ctx.summary: + lines.append( + f"<|system|>\n[Memory of {ctx.person_name}]: {ctx.summary}" + ) + + if group_persons: + others = [p for p in group_persons if p != ctx.person_id] + if others: + lines.append( + f"<|system|>\n[Also present]: {', '.join(others)}" + ) + + for turn in ctx.turns: + role = "user" if turn["role"] == "user" else "assistant" + lines.append(f"<|{role}|>\n{turn['content']}") + + lines.append(f"<|user|>\n{user_text}") + lines.append("<|assistant|>") + + return "\n".join(lines) + + +def needs_summary_prompt(ctx: PersonContext) -> str: + """Build a prompt asking the LLM to summarize old conversation turns.""" + turns_text = "\n".join( + f"{t['role'].upper()}: {t['content']}" + for t in ctx.turns[:-SUMMARY_KEEP] + ) + return ( + f"<|system|>\nSummarize this conversation between {ctx.person_name} and " + "Salty in one sentence, focusing on what was discussed and their relationship.\n" + f"<|user|>\n{turns_text}\n<|assistant|>" + ) diff --git a/jetson/ros2_ws/src/saltybot_social/saltybot_social/orchestrator_node.py b/jetson/ros2_ws/src/saltybot_social/saltybot_social/orchestrator_node.py new file mode 100644 index 0000000..21bd8a0 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/saltybot_social/orchestrator_node.py @@ -0,0 +1,298 @@ +"""orchestrator_node.py — End-to-end social-bot pipeline orchestrator. + +Issue #89: Main loop, GPU memory scheduling, latency profiling, watchdog. + +State machine: + IDLE → LISTENING (wake word) → THINKING (LLM generating) → SPEAKING (TTS) + ↑___________________________________________| + +Responsibilities: + - Owns the top-level pipeline state machine + - Monitors GPU memory and throttles inference when approaching limit + - Profiles end-to-end latency per interaction turn + - Health watchdog: restart speech pipeline on hang + - Exposes /social/orchestrator/state for UI / other nodes + +ROS2 topics: + Subscribe: /social/speech/vad_state (VadState) + /social/speech/transcript (SpeechTranscript) + /social/conversation/response (ConversationResponse) + Publish: /social/orchestrator/state (std_msgs/String) — JSON status blob + +Parameters: + gpu_mem_warn_mb (float, 4000.0) — warn at this free GPU memory (MB) + gpu_mem_throttle_mb (float, 2000.0) — suspend new inference requests + watchdog_timeout_s (float, 30.0) — restart speech node if no VAD in N seconds + latency_window (int, 20) — rolling window for latency stats + profile_enabled (bool, true) +""" + +from __future__ import annotations + +import json +import threading +import time +from collections import deque +from enum import Enum +from typing import Optional + +import rclpy +from rclpy.node import Node +from rclpy.qos import QoSProfile +from std_msgs.msg import String + +from saltybot_social_msgs.msg import VadState, SpeechTranscript, ConversationResponse + + +# ── State machine ────────────────────────────────────────────────────────────── + +class PipelineState(Enum): + IDLE = "idle" + LISTENING = "listening" + THINKING = "thinking" + SPEAKING = "speaking" + THROTTLED = "throttled" # GPU OOM risk — waiting for memory + + +# ── Latency tracker ──────────────────────────────────────────────────────────── + +class LatencyTracker: + """Rolling window latency statistics per pipeline stage.""" + + def __init__(self, window: int = 20) -> None: + self._window = window + self._wakeword_to_transcript: deque = deque(maxlen=window) + self._transcript_to_llm_first: deque = deque(maxlen=window) + self._llm_first_to_tts_first: deque = deque(maxlen=window) + self._end_to_end: deque = deque(maxlen=window) + + def record(self, stage: str, latency_ms: float) -> None: + mapping = { + "wakeword_to_transcript": self._wakeword_to_transcript, + "transcript_to_llm": self._transcript_to_llm_first, + "llm_to_tts": self._llm_first_to_tts_first, + "end_to_end": self._end_to_end, + } + q = mapping.get(stage) + if q is not None: + q.append(latency_ms) + + def stats(self) -> dict: + def _stat(q: deque) -> dict: + if not q: + return {} + s = sorted(q) + return { + "mean_ms": round(sum(s) / len(s), 1), + "p50_ms": round(s[len(s) // 2], 1), + "p95_ms": round(s[int(len(s) * 0.95)], 1), + "n": len(s), + } + return { + "wakeword_to_transcript": _stat(self._wakeword_to_transcript), + "transcript_to_llm": _stat(self._transcript_to_llm_first), + "llm_to_tts": _stat(self._llm_first_to_tts_first), + "end_to_end": _stat(self._end_to_end), + } + + +# ── GPU memory helpers ──────────────────────────────────────────────────────── + +def get_gpu_free_mb() -> Optional[float]: + """Return free GPU memory in MB, or None if unavailable.""" + try: + import pycuda.driver as drv + drv.init() + ctx = drv.Device(0).make_context() + free, _ = drv.mem_get_info() + ctx.pop() + return free / 1024 / 1024 + except Exception: + pass + try: + import subprocess + out = subprocess.check_output( + ["nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits"], + timeout=2 + ).decode().strip() + return float(out.split("\n")[0]) + except Exception: + return None + + +# ── Orchestrator ─────────────────────────────────────────────────────────────── + +class OrchestratorNode(Node): + """Pipeline orchestrator — state machine, GPU watchdog, latency profiler.""" + + def __init__(self) -> None: + super().__init__("orchestrator_node") + + # ── Parameters ────────────────────────────────────────────────────── + self.declare_parameter("gpu_mem_warn_mb", 4000.0) + self.declare_parameter("gpu_mem_throttle_mb", 2000.0) + self.declare_parameter("watchdog_timeout_s", 30.0) + self.declare_parameter("latency_window", 20) + self.declare_parameter("profile_enabled", True) + self.declare_parameter("state_publish_rate", 2.0) # Hz + + self._gpu_warn = self.get_parameter("gpu_mem_warn_mb").value + self._gpu_throttle = self.get_parameter("gpu_mem_throttle_mb").value + self._watchdog_timeout = self.get_parameter("watchdog_timeout_s").value + self._profile = self.get_parameter("profile_enabled").value + self._state_rate = self.get_parameter("state_publish_rate").value + + # ── Publishers / Subscribers ───────────────────────────────────────── + qos = QoSProfile(depth=10) + self._state_pub = self.create_publisher(String, "/social/orchestrator/state", qos) + + self._vad_sub = self.create_subscription( + VadState, "/social/speech/vad_state", self._on_vad, qos + ) + self._transcript_sub = self.create_subscription( + SpeechTranscript, "/social/speech/transcript", self._on_transcript, qos + ) + self._response_sub = self.create_subscription( + ConversationResponse, "/social/conversation/response", self._on_response, qos + ) + + # ── State ──────────────────────────────────────────────────────────── + self._state = PipelineState.IDLE + self._state_lock = threading.Lock() + self._tracker = LatencyTracker(self.get_parameter("latency_window").value) + + # Timestamps for latency tracking + self._t_wake: float = 0.0 + self._t_transcript: float = 0.0 + self._t_llm_first: float = 0.0 + self._t_tts_first: float = 0.0 + + self._last_vad_t: float = time.time() + self._last_transcript_t: float = 0.0 + self._last_response_t: float = 0.0 + + # ── Timers ──────────────────────────────────────────────────────────── + self._state_timer = self.create_timer(1.0 / self._state_rate, self._publish_state) + self._gpu_timer = self.create_timer(5.0, self._check_gpu) + self._watchdog_timer = self.create_timer(10.0, self._watchdog) + + self.get_logger().info("OrchestratorNode ready") + + # ── State transitions ───────────────────────────────────────────────────── + + def _set_state(self, new_state: PipelineState) -> None: + with self._state_lock: + if self._state != new_state: + self.get_logger().info( + f"Pipeline: {self._state.value} → {new_state.value}" + ) + self._state = new_state + + # ── Subscribers ─────────────────────────────────────────────────────────── + + def _on_vad(self, msg: VadState) -> None: + self._last_vad_t = time.time() + if msg.state == "wake_word": + self._t_wake = time.time() + self._set_state(PipelineState.LISTENING) + elif msg.speech_active and self._state == PipelineState.IDLE: + self._set_state(PipelineState.LISTENING) + + def _on_transcript(self, msg: SpeechTranscript) -> None: + if msg.is_partial: + return + self._last_transcript_t = time.time() + self._t_transcript = time.time() + self._set_state(PipelineState.THINKING) + + if self._profile and self._t_wake > 0: + latency_ms = (self._t_transcript - self._t_wake) * 1000 + self._tracker.record("wakeword_to_transcript", latency_ms) + if latency_ms > 500: + self.get_logger().warn( + f"Wake→transcript latency high: {latency_ms:.0f}ms (target <500ms)" + ) + + def _on_response(self, msg: ConversationResponse) -> None: + self._last_response_t = time.time() + + if msg.is_partial and self._t_llm_first == 0.0: + # First token from LLM + self._t_llm_first = time.time() + if self._profile and self._t_transcript > 0: + latency_ms = (self._t_llm_first - self._t_transcript) * 1000 + self._tracker.record("transcript_to_llm", latency_ms) + + self._set_state(PipelineState.SPEAKING) + + if not msg.is_partial: + # Full response complete + self._t_llm_first = 0.0 + self._t_wake = 0.0 + self._t_transcript = 0.0 + + if self._profile and self._t_wake > 0: + e2e_ms = (time.time() - self._t_wake) * 1000 + self._tracker.record("end_to_end", e2e_ms) + + # Return to IDLE after a short delay (TTS still playing) + threading.Timer(2.0, lambda: self._set_state(PipelineState.IDLE)).start() + + # ── GPU memory check ────────────────────────────────────────────────────── + + def _check_gpu(self) -> None: + free_mb = get_gpu_free_mb() + if free_mb is None: + return + + if free_mb < self._gpu_throttle: + self.get_logger().error( + f"GPU memory critical: {free_mb:.0f}MB free < {self._gpu_throttle:.0f}MB threshold" + ) + self._set_state(PipelineState.THROTTLED) + elif free_mb < self._gpu_warn: + self.get_logger().warn(f"GPU memory low: {free_mb:.0f}MB free") + else: + # Recover from throttled state + with self._state_lock: + if self._state == PipelineState.THROTTLED: + self._set_state(PipelineState.IDLE) + + # ── Watchdog ────────────────────────────────────────────────────────────── + + def _watchdog(self) -> None: + """Alert if speech pipeline has gone silent for too long.""" + since_vad = time.time() - self._last_vad_t + if since_vad > self._watchdog_timeout: + self.get_logger().error( + f"Watchdog: No VAD signal for {since_vad:.0f}s. " + "Speech pipeline may be hung. Check /social/speech/vad_state." + ) + + # ── State publisher ─────────────────────────────────────────────────────── + + def _publish_state(self) -> None: + with self._state_lock: + current = self._state.value + + payload = { + "state": current, + "ts": time.time(), + "latency": self._tracker.stats() if self._profile else {}, + "gpu_free_mb": get_gpu_free_mb(), + } + msg = String() + msg.data = json.dumps(payload) + self._state_pub.publish(msg) + + +def main(args=None) -> None: + rclpy.init(args=args) + node = OrchestratorNode() + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.shutdown() diff --git a/jetson/ros2_ws/src/saltybot_social/saltybot_social/speech_pipeline_node.py b/jetson/ros2_ws/src/saltybot_social/saltybot_social/speech_pipeline_node.py new file mode 100644 index 0000000..b9e6c27 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/saltybot_social/speech_pipeline_node.py @@ -0,0 +1,391 @@ +"""speech_pipeline_node.py — Wake word + VAD + Whisper STT + speaker diarization. + +Issue #81: Speech pipeline for social-bot. + +Pipeline: + USB mic (PyAudio) → VAD (Silero / energy fallback) + → wake word gate (OpenWakeWord "hey_salty") + → utterance buffer → faster-whisper STT (Orin GPU) + → ECAPA-TDNN speaker embedding → identify_speaker() + → publish /social/speech/transcript + /social/speech/vad_state + +Latency target: <500ms wake-to-first-token (streaming partial transcripts). + +ROS2 topics published: + /social/speech/transcript (saltybot_social_msgs/SpeechTranscript) + /social/speech/vad_state (saltybot_social_msgs/VadState) + +Parameters: + mic_device_index (int, -1 = system default) + sample_rate (int, 16000) + wake_word_model (str, "hey_salty" — OpenWakeWord model name or path) + wake_word_threshold (float, 0.5) + vad_threshold_db (float, -35.0) — fallback energy VAD + use_silero_vad (bool, true) + whisper_model (str, "small") + whisper_compute_type (str, "float16") + speaker_threshold (float, 0.65) — cosine similarity threshold + speaker_db_path (str, "/social_db/speaker_embeddings.json") + publish_partial (bool, true) +""" + +from __future__ import annotations + +import json +import os +import threading +import time +from typing import Optional + +import rclpy +from rclpy.node import Node +from rclpy.qos import QoSProfile, QoSReliabilityPolicy, QoSDurabilityPolicy +from builtin_interfaces.msg import Time +from std_msgs.msg import Header + +from saltybot_social_msgs.msg import SpeechTranscript, VadState +from .speech_utils import ( + EnergyVad, UtteranceSegmenter, + pcm16_to_float32, rms_db, identify_speaker, + SAMPLE_RATE, CHUNK_SAMPLES, +) + + +class SpeechPipelineNode(Node): + """Wake word → VAD → STT → diarization → ROS2 publisher.""" + + def __init__(self) -> None: + super().__init__("speech_pipeline_node") + + # ── Parameters ────────────────────────────────────────────────────── + self.declare_parameter("mic_device_index", -1) + self.declare_parameter("sample_rate", SAMPLE_RATE) + self.declare_parameter("wake_word_model", "hey_salty") + self.declare_parameter("wake_word_threshold", 0.5) + self.declare_parameter("vad_threshold_db", -35.0) + self.declare_parameter("use_silero_vad", True) + self.declare_parameter("whisper_model", "small") + self.declare_parameter("whisper_compute_type", "float16") + self.declare_parameter("speaker_threshold", 0.65) + self.declare_parameter("speaker_db_path", "/social_db/speaker_embeddings.json") + self.declare_parameter("publish_partial", True) + + self._mic_idx = self.get_parameter("mic_device_index").value + self._rate = self.get_parameter("sample_rate").value + self._ww_model = self.get_parameter("wake_word_model").value + self._ww_thresh = self.get_parameter("wake_word_threshold").value + self._vad_thresh_db = self.get_parameter("vad_threshold_db").value + self._use_silero = self.get_parameter("use_silero_vad").value + self._whisper_model_name = self.get_parameter("whisper_model").value + self._compute_type = self.get_parameter("whisper_compute_type").value + self._speaker_thresh = self.get_parameter("speaker_threshold").value + self._speaker_db = self.get_parameter("speaker_db_path").value + self._publish_partial = self.get_parameter("publish_partial").value + + # ── Publishers ─────────────────────────────────────────────────────── + qos = QoSProfile(depth=10) + self._transcript_pub = self.create_publisher(SpeechTranscript, + "/social/speech/transcript", qos) + self._vad_pub = self.create_publisher(VadState, "/social/speech/vad_state", qos) + + # ── Internal state ─────────────────────────────────────────────────── + self._whisper = None + self._ecapa = None + self._oww = None + self._vad = None + self._segmenter = None + self._known_speakers: dict = {} + self._audio_stream = None + self._pa = None + self._lock = threading.Lock() + self._running = False + self._wake_word_active = False + self._wake_expiry = 0.0 + self._wake_window_s = 8.0 # seconds to listen after wake word + self._turn_counter = 0 + + # ── Lazy model loading in background thread ────────────────────────── + self._model_thread = threading.Thread(target=self._load_models, daemon=True) + self._model_thread.start() + + # ── Model loading ───────────────────────────────────────────────────────── + + def _load_models(self) -> None: + """Load all AI models (runs in background to not block ROS spin).""" + self.get_logger().info("Loading speech models (Whisper, ECAPA-TDNN, VAD)...") + t0 = time.time() + + # faster-whisper + try: + from faster_whisper import WhisperModel + self._whisper = WhisperModel( + self._whisper_model_name, + device="cuda", + compute_type=self._compute_type, + download_root="/models", + ) + self.get_logger().info( + f"Whisper '{self._whisper_model_name}' loaded " + f"({time.time()-t0:.1f}s)" + ) + except Exception as e: + self.get_logger().error(f"Whisper load failed: {e}") + + # Silero VAD + if self._use_silero: + try: + from silero_vad import load_silero_vad + self._vad = load_silero_vad() + self.get_logger().info("Silero VAD loaded") + except Exception as e: + self.get_logger().warn(f"Silero VAD unavailable ({e}), using energy VAD") + self._vad = None + + energy_vad = EnergyVad(threshold_db=self._vad_thresh_db) + self._segmenter = UtteranceSegmenter(energy_vad) + + # ECAPA-TDNN speaker embeddings + try: + from speechbrain.pretrained import EncoderClassifier + self._ecapa = EncoderClassifier.from_hparams( + source="speechbrain/spkrec-ecapa-voxceleb", + savedir="/models/speechbrain_ecapa", + ) + self.get_logger().info("ECAPA-TDNN loaded") + except Exception as e: + self.get_logger().warn(f"ECAPA-TDNN unavailable ({e}), using 'unknown' speaker") + + # OpenWakeWord + try: + import openwakeword + from openwakeword.model import Model as OWWModel + self._oww = OWWModel( + wakeword_models=[self._ww_model], + inference_framework="onnx", + ) + self.get_logger().info(f"OpenWakeWord '{self._ww_model}' loaded") + except Exception as e: + self.get_logger().warn( + f"OpenWakeWord unavailable ({e}). Robot will listen continuously." + ) + + # Load speaker DB + self._load_speaker_db() + + # Start audio capture + self._start_audio() + self.get_logger().info(f"Speech pipeline ready ({time.time()-t0:.1f}s total)") + + def _load_speaker_db(self) -> None: + """Load speaker embedding database from JSON.""" + if os.path.exists(self._speaker_db): + try: + with open(self._speaker_db) as f: + self._known_speakers = json.load(f) + self.get_logger().info( + f"Speaker DB loaded: {len(self._known_speakers)} persons" + ) + except Exception as e: + self.get_logger().error(f"Speaker DB load error: {e}") + + # ── Audio capture ───────────────────────────────────────────────────────── + + def _start_audio(self) -> None: + """Open PyAudio stream and start capture thread.""" + try: + import pyaudio + self._pa = pyaudio.PyAudio() + dev_idx = None if self._mic_idx < 0 else self._mic_idx + self._audio_stream = self._pa.open( + format=pyaudio.paInt16, + channels=1, + rate=self._rate, + input=True, + input_device_index=dev_idx, + frames_per_buffer=CHUNK_SAMPLES, + ) + self._running = True + t = threading.Thread(target=self._audio_loop, daemon=True) + t.start() + self.get_logger().info( + f"Audio capture started (device={self._mic_idx}, {self._rate}Hz)" + ) + except Exception as e: + self.get_logger().error(f"Audio stream error: {e}") + + def _audio_loop(self) -> None: + """Continuous audio capture loop — runs in dedicated thread.""" + while self._running: + try: + raw = self._audio_stream.read(CHUNK_SAMPLES, exception_on_overflow=False) + except Exception: + continue + + samples = pcm16_to_float32(raw) + db = rms_db(samples) + now = time.time() + + # VAD state publish (10Hz — every 3 chunks at 30ms each) + vad_active = self._check_vad(samples) + state_str = "silence" + + # Wake word check + if self._oww is not None: + preds = self._oww.predict(samples) + score = preds.get(self._ww_model, 0.0) + if isinstance(score, (list, tuple)): + score = score[-1] + if score >= self._ww_thresh: + self._wake_word_active = True + self._wake_expiry = now + self._wake_window_s + state_str = "wake_word" + self.get_logger().info( + f"Wake word detected (score={score:.2f})" + ) + else: + # No wake word detector — always listen + self._wake_word_active = True + self._wake_expiry = now + self._wake_window_s + + if self._wake_word_active and now > self._wake_expiry: + self._wake_word_active = False + + # Publish VAD state + if vad_active: + state_str = "speech" + self._publish_vad(vad_active, db, state_str) + + # Only feed segmenter when wake word is active + if not self._wake_word_active: + continue + + if self._segmenter is None: + continue + + completed = self._segmenter.push(samples) + for utt_samples, duration in completed: + threading.Thread( + target=self._process_utterance, + args=(utt_samples, duration), + daemon=True, + ).start() + + def _check_vad(self, samples: list) -> bool: + """Run Silero VAD or fall back to segmenter's energy VAD.""" + if self._vad is not None: + try: + import torch + t = torch.tensor(samples, dtype=torch.float32) + prob = self._vad(t, self._rate).item() + return prob > 0.5 + except Exception: + pass + if self._segmenter is not None: + return self._segmenter._vad.is_active + return False + + # ── STT + diarization ───────────────────────────────────────────────────── + + def _process_utterance(self, samples: list, duration: float) -> None: + """Transcribe utterance and identify speaker. Runs in thread.""" + if self._whisper is None: + return + + t0 = time.perf_counter() + + # Convert to numpy for faster-whisper + try: + import numpy as np + audio_np = np.array(samples, dtype=np.float32) + except ImportError: + audio_np = samples + + # Speaker embedding first (can be concurrent with transcription) + speaker_id = "unknown" + if self._ecapa is not None: + try: + import torch + audio_tensor = torch.tensor([samples], dtype=torch.float32) + with torch.no_grad(): + emb = self._ecapa.encode_batch(audio_tensor) + emb_list = emb[0].cpu().numpy().tolist() + speaker_id = identify_speaker( + emb_list, self._known_speakers, self._speaker_thresh + ) + except Exception as e: + self.get_logger().debug(f"Speaker ID error: {e}") + + # Streaming Whisper transcription + partial_text = "" + try: + segments_gen, _info = self._whisper.transcribe( + audio_np, + language="en", + beam_size=3, + vad_filter=False, + ) + for seg in segments_gen: + partial_text += seg.text.strip() + " " + if self._publish_partial: + self._publish_transcript( + partial_text.strip(), speaker_id, 0.0, duration, is_partial=True + ) + except Exception as e: + self.get_logger().error(f"Whisper error: {e}") + return + + final_text = partial_text.strip() + if not final_text: + return + + latency_ms = (time.perf_counter() - t0) * 1000 + self.get_logger().info( + f"STT [{speaker_id}] ({duration:.1f}s, {latency_ms:.0f}ms): '{final_text}'" + ) + self._publish_transcript(final_text, speaker_id, 0.9, duration, is_partial=False) + + # ── Publishers ──────────────────────────────────────────────────────────── + + def _publish_transcript( + self, text: str, speaker_id: str, confidence: float, + duration: float, is_partial: bool + ) -> None: + msg = SpeechTranscript() + msg.header.stamp = self.get_clock().now().to_msg() + msg.text = text + msg.speaker_id = speaker_id + msg.confidence = confidence + msg.audio_duration = duration + msg.is_partial = is_partial + self._transcript_pub.publish(msg) + + def _publish_vad(self, speech_active: bool, db: float, state: str) -> None: + msg = VadState() + msg.header.stamp = self.get_clock().now().to_msg() + msg.speech_active = speech_active + msg.energy_db = db + msg.state = state + self._vad_pub.publish(msg) + + # ── Cleanup ─────────────────────────────────────────────────────────────── + + def destroy_node(self) -> None: + self._running = False + if self._audio_stream: + self._audio_stream.stop_stream() + self._audio_stream.close() + if self._pa: + self._pa.terminate() + super().destroy_node() + + +def main(args=None) -> None: + rclpy.init(args=args) + node = SpeechPipelineNode() + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.shutdown() diff --git a/jetson/ros2_ws/src/saltybot_social/saltybot_social/speech_utils.py b/jetson/ros2_ws/src/saltybot_social/saltybot_social/speech_utils.py new file mode 100644 index 0000000..900513a --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/saltybot_social/speech_utils.py @@ -0,0 +1,188 @@ +"""speech_utils.py — Pure helpers for wake word, VAD, and STT. + +No ROS2 dependencies. All functions operate on raw numpy float32 arrays +at 16 kHz mono (the standard for Whisper / Silero VAD). + +Tested by test/test_speech_utils.py (no GPU required). +""" + +from __future__ import annotations + +import math +import struct +from typing import Generator, List, Tuple + +# ── Constants ────────────────────────────────────────────────────────────────── +SAMPLE_RATE = 16000 # Hz — standard for Whisper / Silero +CHUNK_MS = 30 # VAD frame size +CHUNK_SAMPLES = int(SAMPLE_RATE * CHUNK_MS / 1000) # 480 samples +INT16_MAX = 32768.0 + + +# ── Audio helpers ────────────────────────────────────────────────────────────── + +def pcm16_to_float32(data: bytes) -> list: + """Convert raw PCM16 LE bytes → float32 list in [-1.0, 1.0].""" + n = len(data) // 2 + samples = struct.unpack(f"<{n}h", data[:n * 2]) + return [s / INT16_MAX for s in samples] + + +def float32_to_pcm16(samples: list) -> bytes: + """Convert float32 list → PCM16 LE bytes (clipped to ±1.0).""" + clipped = [max(-1.0, min(1.0, s)) for s in samples] + ints = [max(-32768, min(32767, int(s * INT16_MAX))) for s in clipped] + return struct.pack(f"<{len(ints)}h", *ints) + + +def rms_db(samples: list) -> float: + """Compute RMS energy in dBFS. Returns -96.0 for silence.""" + if not samples: + return -96.0 + mean_sq = sum(s * s for s in samples) / len(samples) + rms = math.sqrt(mean_sq) if mean_sq > 0.0 else 1e-10 + return 20.0 * math.log10(max(rms, 1e-10)) + + +def chunk_audio(samples: list, chunk_size: int = CHUNK_SAMPLES) -> Generator[list, None, None]: + """Yield fixed-size chunks from a flat sample list.""" + for i in range(0, len(samples) - chunk_size + 1, chunk_size): + yield samples[i:i + chunk_size] + + +# ── Simple energy-based VAD (fallback when Silero unavailable) ──────────────── + +class EnergyVad: + """Threshold-based energy VAD. + + Useful as a lightweight fallback / unit-testable reference. + """ + + def __init__( + self, + threshold_db: float = -35.0, + onset_frames: int = 2, + offset_frames: int = 8, + ) -> None: + self.threshold_db = threshold_db + self.onset_frames = onset_frames + self.offset_frames = offset_frames + self._above_count = 0 + self._below_count = 0 + self._active = False + + def process(self, chunk: list) -> bool: + """Process one chunk (CHUNK_SAMPLES long). Returns True if speech active.""" + db = rms_db(chunk) + if db >= self.threshold_db: + self._above_count += 1 + self._below_count = 0 + if self._above_count >= self.onset_frames: + self._active = True + else: + self._below_count += 1 + self._above_count = 0 + if self._below_count >= self.offset_frames: + self._active = False + return self._active + + def reset(self) -> None: + self._above_count = 0 + self._below_count = 0 + self._active = False + + @property + def is_active(self) -> bool: + return self._active + + +# ── Utterance segmenter ──────────────────────────────────────────────────────── + +class UtteranceSegmenter: + """Accumulates audio chunks into complete utterances using VAD. + + Yields (samples, duration_s) when an utterance ends. + Pre-roll keeps a buffer before speech onset (catches first phoneme). + """ + + def __init__( + self, + vad: EnergyVad, + pre_roll_frames: int = 5, + max_duration_s: float = 15.0, + sample_rate: int = SAMPLE_RATE, + chunk_samples: int = CHUNK_SAMPLES, + ) -> None: + self._vad = vad + self._pre_roll = pre_roll_frames + self._max_frames = int(max_duration_s * sample_rate / chunk_samples) + self._sample_rate = sample_rate + self._chunk_samples = chunk_samples + self._pre_buf: list = [] # ring buffer for pre-roll + self._utt_buf: list = [] # growing utterance buffer + self._in_utt = False + self._silence_after = 0 + + def push(self, chunk: list) -> list: + """Push one chunk, return list of completed (samples, duration_s) tuples.""" + speech = self._vad.process(chunk) + results = [] + + if not self._in_utt: + self._pre_buf.append(chunk) + if len(self._pre_buf) > self._pre_roll: + self._pre_buf.pop(0) + + if speech: + self._in_utt = True + self._utt_buf = [s for c in self._pre_buf for s in c] + self._utt_buf.extend(chunk) + self._silence_after = 0 + else: + self._utt_buf.extend(chunk) + if not speech: + self._silence_after += 1 + if (self._silence_after >= self._vad.offset_frames or + len(self._utt_buf) >= self._max_frames * self._chunk_samples): + dur = len(self._utt_buf) / self._sample_rate + results.append((list(self._utt_buf), dur)) + self._utt_buf = [] + self._in_utt = False + self._silence_after = 0 + else: + self._silence_after = 0 + + return results + + +# ── Speaker embedding distance ───────────────────────────────────────────────── + +def cosine_similarity(a: list, b: list) -> float: + """Cosine similarity between two embedding vectors.""" + if len(a) != len(b) or not a: + return 0.0 + dot = sum(x * y for x, y in zip(a, b)) + norm_a = math.sqrt(sum(x * x for x in a)) + norm_b = math.sqrt(sum(x * x for x in b)) + if norm_a < 1e-10 or norm_b < 1e-10: + return 0.0 + return dot / (norm_a * norm_b) + + +def identify_speaker( + embedding: list, + known_speakers: dict, + threshold: float = 0.65, +) -> str: + """Return speaker_id from known_speakers dict or 'unknown'. + + known_speakers: {speaker_id: embedding_list} + """ + best_id = "unknown" + best_sim = threshold + for spk_id, spk_emb in known_speakers.items(): + sim = cosine_similarity(embedding, spk_emb) + if sim > best_sim: + best_sim = sim + best_id = spk_id + return best_id diff --git a/jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_node.py b/jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_node.py new file mode 100644 index 0000000..d1a0d7f --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_node.py @@ -0,0 +1,228 @@ +"""tts_node.py — Streaming TTS with Piper / first-chunk streaming. + +Issue #85: Streaming TTS — Piper/XTTS integration with first-chunk streaming. + +Pipeline: + /social/conversation/response (ConversationResponse) + → sentence split → Piper ONNX synthesis (sentence by sentence) + → PCM16 chunks → USB speaker (sounddevice) + publish /social/tts/audio + +First-chunk strategy: + - On partial=true ConversationResponse, extract first sentence and synthesize + immediately → audio starts before LLM finishes generating + - On final=false, synthesize remaining sentences + +Latency target: <200ms to first audio chunk. + +ROS2 topics: + Subscribe: /social/conversation/response (saltybot_social_msgs/ConversationResponse) + Publish: /social/tts/audio (audio_msgs/Audio or std_msgs/UInt8MultiArray fallback) + +Parameters: + voice_path (str, "/models/piper/en_US-lessac-medium.onnx") + sample_rate (int, 22050) + volume (float, 1.0) + audio_device (str, "") — sounddevice device name; "" = system default + playback_enabled (bool, true) + publish_audio (bool, false) — publish PCM to ROS2 topic + sentence_streaming (bool, true) — synthesize sentence-by-sentence +""" + +from __future__ import annotations + +import queue +import threading +import time +from typing import Optional + +import rclpy +from rclpy.node import Node +from rclpy.qos import QoSProfile +from std_msgs.msg import UInt8MultiArray + +from saltybot_social_msgs.msg import ConversationResponse +from .tts_utils import split_sentences, strip_ssml, apply_volume, chunk_pcm, estimate_duration_ms + + +class TtsNode(Node): + """Streaming TTS node using Piper ONNX.""" + + def __init__(self) -> None: + super().__init__("tts_node") + + # ── Parameters ────────────────────────────────────────────────────── + self.declare_parameter("voice_path", "/models/piper/en_US-lessac-medium.onnx") + self.declare_parameter("sample_rate", 22050) + self.declare_parameter("volume", 1.0) + self.declare_parameter("audio_device", "") + self.declare_parameter("playback_enabled", True) + self.declare_parameter("publish_audio", False) + self.declare_parameter("sentence_streaming", True) + + self._voice_path = self.get_parameter("voice_path").value + self._sample_rate = self.get_parameter("sample_rate").value + self._volume = self.get_parameter("volume").value + self._audio_device = self.get_parameter("audio_device").value or None + self._playback = self.get_parameter("playback_enabled").value + self._publish_audio = self.get_parameter("publish_audio").value + self._sentence_streaming = self.get_parameter("sentence_streaming").value + + # ── Publishers / Subscribers ───────────────────────────────────────── + qos = QoSProfile(depth=10) + self._resp_sub = self.create_subscription( + ConversationResponse, "/social/conversation/response", + self._on_response, qos + ) + if self._publish_audio: + self._audio_pub = self.create_publisher( + UInt8MultiArray, "/social/tts/audio", qos + ) + + # ── TTS engine ──────────────────────────────────────────────────────── + self._voice = None + self._playback_queue: queue.Queue = queue.Queue(maxsize=16) + self._current_turn = -1 + self._synthesized_turns: set = set() # turn_ids already synthesized + self._lock = threading.Lock() + + threading.Thread(target=self._load_voice, daemon=True).start() + threading.Thread(target=self._playback_worker, daemon=True).start() + + self.get_logger().info( + f"TtsNode init (voice={self._voice_path}, " + f"streaming={self._sentence_streaming})" + ) + + # ── Voice loading ───────────────────────────────────────────────────────── + + def _load_voice(self) -> None: + t0 = time.time() + self.get_logger().info(f"Loading Piper voice: {self._voice_path}") + try: + from piper import PiperVoice + self._voice = PiperVoice.load(self._voice_path) + # Warmup synthesis to pre-JIT ONNX graph + warmup_text = "Hello." + list(self._voice.synthesize_stream_raw(warmup_text)) + self.get_logger().info(f"Piper voice ready ({time.time()-t0:.1f}s)") + except Exception as e: + self.get_logger().error(f"Piper voice load failed: {e}") + + # ── Response handler ────────────────────────────────────────────────────── + + def _on_response(self, msg: ConversationResponse) -> None: + """Handle streaming LLM response — synthesize sentence by sentence.""" + if not msg.text.strip(): + return + + with self._lock: + is_new_turn = msg.turn_id != self._current_turn + if is_new_turn: + self._current_turn = msg.turn_id + # Clear old synthesized sentence cache for this new turn + self._synthesized_turns = set() + + text = strip_ssml(msg.text) + + if self._sentence_streaming: + sentences = split_sentences(text) + for sentence in sentences: + # Track which sentences we've already queued by content hash + key = (msg.turn_id, hash(sentence)) + with self._lock: + if key in self._synthesized_turns: + continue + self._synthesized_turns.add(key) + self._queue_synthesis(sentence) + elif not msg.is_partial: + # Non-streaming: synthesize full response at end + self._queue_synthesis(text) + + def _queue_synthesis(self, text: str) -> None: + """Queue a text segment for synthesis in the playback worker.""" + if not text.strip(): + return + try: + self._playback_queue.put_nowait(text.strip()) + except queue.Full: + self.get_logger().warn("TTS playback queue full, dropping segment") + + # ── Playback worker ─────────────────────────────────────────────────────── + + def _playback_worker(self) -> None: + """Consume synthesis queue: synthesize → play → publish.""" + while rclpy.ok(): + try: + text = self._playback_queue.get(timeout=0.5) + except queue.Empty: + continue + + if self._voice is None: + self.get_logger().warn("TTS voice not loaded yet") + self._playback_queue.task_done() + continue + + t0 = time.perf_counter() + pcm_data = self._synthesize(text) + if pcm_data is None: + self._playback_queue.task_done() + continue + + synth_ms = (time.perf_counter() - t0) * 1000 + dur_ms = estimate_duration_ms(pcm_data, self._sample_rate) + self.get_logger().debug( + f"TTS '{text[:40]}' synth={synth_ms:.0f}ms, dur={dur_ms:.0f}ms" + ) + + if self._volume != 1.0: + pcm_data = apply_volume(pcm_data, self._volume) + + if self._playback: + self._play_audio(pcm_data) + + if self._publish_audio: + self._publish_pcm(pcm_data) + + self._playback_queue.task_done() + + def _synthesize(self, text: str) -> Optional[bytes]: + """Synthesize text to PCM16 bytes using Piper streaming.""" + if self._voice is None: + return None + try: + chunks = list(self._voice.synthesize_stream_raw(text)) + return b"".join(chunks) + except Exception as e: + self.get_logger().error(f"TTS synthesis error: {e}") + return None + + def _play_audio(self, pcm_data: bytes) -> None: + """Play PCM16 data on USB speaker via sounddevice.""" + try: + import sounddevice as sd + import numpy as np + samples = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0 + sd.play(samples, samplerate=self._sample_rate, device=self._audio_device, + blocking=True) + except Exception as e: + self.get_logger().error(f"Audio playback error: {e}") + + def _publish_pcm(self, pcm_data: bytes) -> None: + """Publish PCM data as UInt8MultiArray.""" + if not hasattr(self, "_audio_pub"): + return + msg = UInt8MultiArray() + msg.data = list(pcm_data) + self._audio_pub.publish(msg) + + +def main(args=None) -> None: + rclpy.init(args=args) + node = TtsNode() + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.shutdown() diff --git a/jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_utils.py b/jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_utils.py new file mode 100644 index 0000000..e0660a5 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_utils.py @@ -0,0 +1,100 @@ +"""tts_utils.py — Pure TTS helpers for streaming synthesis. + +No ROS2 dependencies. Provides streaming Piper synthesis and audio chunking. + +Tested by test/test_tts_utils.py. +""" + +from __future__ import annotations + +import re +import struct +from typing import Generator, List, Tuple + + +# ── SSML helpers ────────────────────────────────────────────────────────────── + +_SSML_STRIP_RE = re.compile(r"<[^>]+>") + + +def strip_ssml(text: str) -> str: + """Remove SSML tags, returning plain text.""" + return _SSML_STRIP_RE.sub("", text).strip() + + +def split_sentences(text: str) -> List[str]: + """Split text into sentences for streaming TTS synthesis. + + Returns list of non-empty sentence strings. + """ + # Split on sentence-ending punctuation followed by whitespace or end + parts = re.split(r"(?<=[.!?])\s+", text.strip()) + return [p.strip() for p in parts if p.strip()] + + +def split_first_sentence(text: str) -> Tuple[str, str]: + """Return (first_sentence, remainder). + + Used by TTS node to begin synthesis of first sentence immediately + while the LLM continues generating the rest. + """ + sentences = split_sentences(text) + if not sentences: + return "", "" + return sentences[0], " ".join(sentences[1:]) + + +# ── PCM audio helpers ───────────────────────────────────────────────────────── + +def pcm16_to_wav_bytes(pcm_data: bytes, sample_rate: int = 22050) -> bytes: + """Wrap raw PCM16 LE mono data with a WAV header.""" + num_samples = len(pcm_data) // 2 + num_channels = 1 + bits_per_sample = 16 + byte_rate = sample_rate * num_channels * bits_per_sample // 8 + block_align = num_channels * bits_per_sample // 8 + data_size = len(pcm_data) + header_size = 44 + + header = struct.pack( + "<4sI4s4sIHHIIHH4sI", + b"RIFF", + header_size - 8 + data_size, + b"WAVE", + b"fmt ", + 16, # chunk size + 1, # PCM format + num_channels, + sample_rate, + byte_rate, + block_align, + bits_per_sample, + b"data", + data_size, + ) + return header + pcm_data + + +def chunk_pcm(pcm_data: bytes, chunk_ms: int = 200, sample_rate: int = 22050) -> Generator[bytes, None, None]: + """Yield PCM data in fixed-size chunks for streaming playback.""" + chunk_bytes = (sample_rate * 2 * chunk_ms) // 1000 # int16 = 2 bytes/sample + chunk_bytes = max(chunk_bytes, 2) & ~1 # ensure even (int16 aligned) + for i in range(0, len(pcm_data), chunk_bytes): + yield pcm_data[i:i + chunk_bytes] + + +def estimate_duration_ms(pcm_data: bytes, sample_rate: int = 22050) -> float: + """Estimate playback duration in ms from PCM byte length.""" + return len(pcm_data) / (sample_rate * 2) * 1000.0 + + +# ── Volume control ──────────────────────────────────────────────────────────── + +def apply_volume(pcm_data: bytes, gain: float) -> bytes: + """Scale PCM16 LE samples by gain factor (0.0–2.0). Clips to ±32767.""" + if abs(gain - 1.0) < 0.01: + return pcm_data + n = len(pcm_data) // 2 + samples = struct.unpack(f"<{n}h", pcm_data[:n * 2]) + scaled = [max(-32767, min(32767, int(s * gain))) for s in samples] + return struct.pack(f"<{n}h", *scaled) diff --git a/jetson/ros2_ws/src/saltybot_social/setup.py b/jetson/ros2_ws/src/saltybot_social/setup.py index 4f55aeb..e8f040c 100644 --- a/jetson/ros2_ws/src/saltybot_social/setup.py +++ b/jetson/ros2_ws/src/saltybot_social/setup.py @@ -21,7 +21,7 @@ setup( zip_safe=True, maintainer='seb', maintainer_email='seb@vayrette.com', - description='Social interaction layer — person state tracking, LED expression + attention', + description='Social interaction layer — person tracking, speech, LLM, TTS, orchestrator', license='MIT', tests_require=['pytest'], entry_points={ @@ -29,6 +29,10 @@ setup( 'person_state_tracker = saltybot_social.person_state_tracker_node:main', 'expression_node = saltybot_social.expression_node:main', 'attention_node = saltybot_social.attention_node:main', + 'speech_pipeline_node = saltybot_social.speech_pipeline_node:main', + 'conversation_node = saltybot_social.conversation_node:main', + 'tts_node = saltybot_social.tts_node:main', + 'orchestrator_node = saltybot_social.orchestrator_node:main', ], }, ) diff --git a/jetson/ros2_ws/src/saltybot_social/test/test_llm_context.py b/jetson/ros2_ws/src/saltybot_social/test/test_llm_context.py new file mode 100644 index 0000000..9806f65 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/test/test_llm_context.py @@ -0,0 +1,244 @@ +"""test_llm_context.py — Unit tests for llm_context.py (no GPU / LLM needed).""" + +import sys +import os +import tempfile + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from saltybot_social.llm_context import ( + PersonContext, ContextStore, + build_llama_prompt, MAX_TURNS, SUMMARY_KEEP, + DEFAULT_SYSTEM_PROMPT, +) +import pytest + + +# ── PersonContext ───────────────────────────────────────────────────────────── + +class TestPersonContext: + def test_create(self): + ctx = PersonContext("person_1", "Alice") + assert ctx.person_id == "person_1" + assert ctx.person_name == "Alice" + assert ctx.turns == [] + assert ctx.interaction_count == 0 + + def test_add_user_increments_count(self): + ctx = PersonContext("person_1") + ctx.add_user("Hello") + assert ctx.interaction_count == 1 + assert ctx.turns[-1]["role"] == "user" + assert ctx.turns[-1]["content"] == "Hello" + + def test_add_assistant(self): + ctx = PersonContext("person_1") + ctx.add_assistant("Hi there!") + assert ctx.turns[-1]["role"] == "assistant" + + def test_needs_compression_false(self): + ctx = PersonContext("p") + for i in range(MAX_TURNS - 1): + ctx.add_user(f"msg {i}") + assert not ctx.needs_compression() + + def test_needs_compression_true(self): + ctx = PersonContext("p") + for i in range(MAX_TURNS + 1): + ctx.add_user(f"msg {i}") + assert ctx.needs_compression() + + def test_compress_keeps_recent_turns(self): + ctx = PersonContext("p") + for i in range(MAX_TURNS + 5): + ctx.add_user(f"msg {i}") + ctx.compress("Summary of conversation") + assert len(ctx.turns) == SUMMARY_KEEP + assert ctx.summary == "Summary of conversation" + + def test_compress_appends_summary(self): + ctx = PersonContext("p") + ctx.add_user("msg") + ctx.compress("First summary") + ctx.add_user("msg2") + ctx.compress("Second summary") + assert "First summary" in ctx.summary + assert "Second summary" in ctx.summary + + def test_roundtrip_serialization(self): + ctx = PersonContext("person_42", "Bob") + ctx.add_user("test") + ctx.add_assistant("response") + ctx.summary = "Earlier they discussed robots" + d = ctx.to_dict() + ctx2 = PersonContext.from_dict(d) + assert ctx2.person_id == "person_42" + assert ctx2.person_name == "Bob" + assert len(ctx2.turns) == 2 + assert ctx2.summary == "Earlier they discussed robots" + + +# ── ContextStore ────────────────────────────────────────────────────────────── + +class TestContextStore: + def test_get_creates_new(self): + with tempfile.TemporaryDirectory() as tmpdir: + db = os.path.join(tmpdir, "ctx.json") + store = ContextStore(db) + ctx = store.get("p1", "Alice") + assert ctx.person_id == "p1" + assert ctx.person_name == "Alice" + + def test_get_same_instance(self): + with tempfile.TemporaryDirectory() as tmpdir: + db = os.path.join(tmpdir, "ctx.json") + store = ContextStore(db) + ctx1 = store.get("p1") + ctx2 = store.get("p1") + assert ctx1 is ctx2 + + def test_save_and_load(self): + with tempfile.TemporaryDirectory() as tmpdir: + db = os.path.join(tmpdir, "ctx.json") + store = ContextStore(db) + ctx = store.get("p1", "Alice") + ctx.add_user("Hello") + store.save() + + store2 = ContextStore(db) + ctx2 = store2.get("p1") + assert ctx2.person_name == "Alice" + assert len(ctx2.turns) == 1 + assert ctx2.turns[0]["content"] == "Hello" + + def test_all_persons(self): + with tempfile.TemporaryDirectory() as tmpdir: + db = os.path.join(tmpdir, "ctx.json") + store = ContextStore(db) + store.get("p1") + store.get("p2") + store.get("p3") + assert set(store.all_persons()) == {"p1", "p2", "p3"} + + +# ── Prompt builder ──────────────────────────────────────────────────────────── + +class TestBuildLlamaPrompt: + def test_contains_system_prompt(self): + ctx = PersonContext("p1") + prompt = build_llama_prompt(ctx, "hello", DEFAULT_SYSTEM_PROMPT) + assert DEFAULT_SYSTEM_PROMPT in prompt + + def test_contains_user_text(self): + ctx = PersonContext("p1") + prompt = build_llama_prompt(ctx, "What time is it?", DEFAULT_SYSTEM_PROMPT) + assert "What time is it?" in prompt + + def test_ends_with_assistant_tag(self): + ctx = PersonContext("p1") + prompt = build_llama_prompt(ctx, "hello", DEFAULT_SYSTEM_PROMPT) + assert prompt.strip().endswith("<|assistant|>") + + def test_history_included(self): + ctx = PersonContext("p1") + ctx.add_user("Previous question") + ctx.add_assistant("Previous answer") + prompt = build_llama_prompt(ctx, "New question", DEFAULT_SYSTEM_PROMPT) + assert "Previous question" in prompt + assert "Previous answer" in prompt + + def test_summary_included(self): + ctx = PersonContext("p1", "Alice") + ctx.summary = "Alice visited last week and likes coffee" + prompt = build_llama_prompt(ctx, "hello", DEFAULT_SYSTEM_PROMPT) + assert "likes coffee" in prompt + + def test_group_persons_included(self): + ctx = PersonContext("p1", "Alice") + prompt = build_llama_prompt(ctx, "hello", DEFAULT_SYSTEM_PROMPT, + group_persons=["p1", "p2_Bob", "p3_Charlie"]) + assert "p2_Bob" in prompt or "p3_Charlie" in prompt + + def test_no_group_persons_not_in_prompt(self): + ctx = PersonContext("p1") + prompt = build_llama_prompt(ctx, "hello", DEFAULT_SYSTEM_PROMPT) + assert "Also present" not in prompt + + +# ── TTS utils (imported here for coverage) ──────────────────────────────────── + +class TestTtsUtils: + def setup_method(self): + sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + from saltybot_social.tts_utils import ( + strip_ssml, split_sentences, split_first_sentence, + pcm16_to_wav_bytes, chunk_pcm, estimate_duration_ms, apply_volume + ) + self.strip_ssml = strip_ssml + self.split_sentences = split_sentences + self.split_first_sentence = split_first_sentence + self.pcm16_to_wav_bytes = pcm16_to_wav_bytes + self.chunk_pcm = chunk_pcm + self.estimate_duration_ms = estimate_duration_ms + self.apply_volume = apply_volume + + def test_strip_ssml(self): + assert self.strip_ssml("Hello world") == "Hello world" + assert self.strip_ssml("plain text") == "plain text" + + def test_split_sentences_basic(self): + text = "Hello world. How are you? I am fine!" + sentences = self.split_sentences(text) + assert len(sentences) == 3 + + def test_split_sentences_single(self): + sentences = self.split_sentences("Hello world.") + assert len(sentences) == 1 + assert sentences[0] == "Hello world." + + def test_split_sentences_empty(self): + assert self.split_sentences("") == [] + assert self.split_sentences(" ") == [] + + def test_split_first_sentence(self): + text = "Hello world. How are you? I am fine." + first, rest = self.split_first_sentence(text) + assert first == "Hello world." + assert "How are you?" in rest + + def test_wav_header(self): + pcm = b"\x00\x00" * 1000 + wav = self.pcm16_to_wav_bytes(pcm, 22050) + assert wav[:4] == b"RIFF" + assert wav[8:12] == b"WAVE" + assert len(wav) == 44 + len(pcm) + + def test_chunk_pcm_sizes(self): + pcm = b"\x00\x00" * 2205 # ~100ms at 22050Hz + chunks = list(self.chunk_pcm(pcm, 50, 22050)) # 50ms chunks + for chunk in chunks: + assert len(chunk) % 2 == 0 # int16 aligned + + def test_estimate_duration(self): + pcm = b"\x00\x00" * 22050 # 1s at 22050Hz + ms = self.estimate_duration_ms(pcm, 22050) + assert abs(ms - 1000.0) < 1.0 + + def test_apply_volume_unity(self): + pcm = b"\x00\x40" * 100 # some non-zero samples + result = self.apply_volume(pcm, 1.0) + assert result == pcm + + def test_apply_volume_half(self): + import struct + pcm = struct.pack(" 1.0 should be clipped + pcm = float32_to_pcm16([2.0, -2.0]) + samples = pcm16_to_float32(pcm) + assert samples[0] <= 1.0 + assert samples[1] >= -1.0 + + +# ── RMS dB ──────────────────────────────────────────────────────────────────── + +class TestRmsDb: + def test_silence_returns_minus96(self): + assert rms_db([]) == -96.0 + assert rms_db([0.0] * 100) < -60.0 + + def test_full_scale_is_near_zero(self): + samples = [1.0] * 100 + db = rms_db(samples) + assert abs(db) < 0.1 # 0 dBFS + + def test_half_scale_is_minus6(self): + samples = [0.5] * 100 + db = rms_db(samples) + assert abs(db - (-6.02)) < 0.1 + + def test_low_level_signal(self): + samples = [0.001] * 100 + db = rms_db(samples) + assert db < -40.0 + + +# ── Chunk audio ─────────────────────────────────────────────────────────────── + +class TestChunkAudio: + def test_even_chunks(self): + samples = list(range(480 * 4)) + chunks = list(chunk_audio(samples, 480)) + assert len(chunks) == 4 + assert all(len(c) == 480 for c in chunks) + + def test_remainder_dropped(self): + samples = list(range(500)) + chunks = list(chunk_audio(samples, 480)) + assert len(chunks) == 1 + + def test_empty_input(self): + assert list(chunk_audio([], 480)) == [] + + +# ── Energy VAD ──────────────────────────────────────────────────────────────── + +class TestEnergyVad: + def test_silence_not_active(self): + vad = EnergyVad(threshold_db=-35.0, onset_frames=2, offset_frames=3) + silence = [0.0001] * CHUNK_SAMPLES + for _ in range(5): + active = vad.process(silence) + assert not active + + def test_loud_signal_activates(self): + vad = EnergyVad(threshold_db=-35.0, onset_frames=2, offset_frames=8) + loud = [0.5] * CHUNK_SAMPLES + results = [vad.process(loud) for _ in range(3)] + assert results[-1] is True + + def test_onset_requires_n_frames(self): + vad = EnergyVad(threshold_db=-35.0, onset_frames=3, offset_frames=8) + loud = [0.5] * CHUNK_SAMPLES + assert not vad.process(loud) # frame 1 + assert not vad.process(loud) # frame 2 + assert vad.process(loud) # frame 3 → activates + + def test_offset_deactivates(self): + vad = EnergyVad(threshold_db=-35.0, onset_frames=1, offset_frames=2) + loud = [0.5] * CHUNK_SAMPLES + silence = [0.0001] * CHUNK_SAMPLES + vad.process(loud) + assert vad.is_active + vad.process(silence) + assert vad.is_active # offset_frames=2, need 2 + vad.process(silence) + assert not vad.is_active + + def test_reset(self): + vad = EnergyVad(threshold_db=-35.0, onset_frames=1, offset_frames=8) + loud = [0.5] * CHUNK_SAMPLES + vad.process(loud) + assert vad.is_active + vad.reset() + assert not vad.is_active + + +# ── Utterance segmenter ─────────────────────────────────────────────────────── + +class TestUtteranceSegmenter: + def _make_segmenter(self): + vad = EnergyVad(threshold_db=-35.0, onset_frames=1, offset_frames=2) + return UtteranceSegmenter(vad, pre_roll_frames=2, max_duration_s=5.0) + + def test_silence_yields_nothing(self): + seg = self._make_segmenter() + silence = [0.0001] * CHUNK_SAMPLES + for _ in range(10): + results = seg.push(silence) + assert results == [] + + def test_speech_then_silence_yields_utterance(self): + seg = self._make_segmenter() + loud = [0.5] * CHUNK_SAMPLES + silence = [0.0001] * CHUNK_SAMPLES + + # 3 speech frames + for _ in range(3): + assert seg.push(loud) == [] + + # 2 silence frames to trigger offset + all_results = [] + for _ in range(3): + all_results.extend(seg.push(silence)) + + assert len(all_results) == 1 + utt_samples, duration = all_results[0] + assert duration > 0 + assert len(utt_samples) > 0 + + def test_preroll_included(self): + vad = EnergyVad(threshold_db=-35.0, onset_frames=1, offset_frames=2) + seg = UtteranceSegmenter(vad, pre_roll_frames=3, max_duration_s=5.0) + silence = [0.0001] * CHUNK_SAMPLES + loud = [0.5] * CHUNK_SAMPLES + + # Push 3 silence frames (become pre-roll buffer) + for _ in range(3): + seg.push(silence) + + # Push 1 speech frame (triggers onset, pre-roll included) + seg.push(loud) + + # End utterance + all_results = [] + for _ in range(3): + all_results.extend(seg.push(silence)) + + assert len(all_results) == 1 + # Utterance should include pre-roll + speech + trailing silence + utt_samples, _ = all_results[0] + assert len(utt_samples) >= 4 * CHUNK_SAMPLES + + +# ── Speaker identification ──────────────────────────────────────────────────── + +class TestSpeakerIdentification: + def test_cosine_identical(self): + v = [1.0, 0.0, 0.0] + assert abs(cosine_similarity(v, v) - 1.0) < 1e-6 + + def test_cosine_orthogonal(self): + a = [1.0, 0.0] + b = [0.0, 1.0] + assert abs(cosine_similarity(a, b)) < 1e-6 + + def test_cosine_opposite(self): + a = [1.0, 0.0] + b = [-1.0, 0.0] + assert abs(cosine_similarity(a, b) + 1.0) < 1e-6 + + def test_cosine_zero_vector(self): + assert cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0 + + def test_cosine_length_mismatch(self): + assert cosine_similarity([1.0], [1.0, 0.0]) == 0.0 + + def test_identify_speaker_match(self): + known = { + "alice": [1.0, 0.0, 0.0], + "bob": [0.0, 1.0, 0.0], + } + result = identify_speaker([1.0, 0.0, 0.0], known, threshold=0.5) + assert result == "alice" + + def test_identify_speaker_unknown(self): + known = {"alice": [1.0, 0.0, 0.0]} + result = identify_speaker([0.0, 1.0, 0.0], known, threshold=0.5) + assert result == "unknown" + + def test_identify_speaker_empty_db(self): + result = identify_speaker([1.0, 0.0], {}, threshold=0.5) + assert result == "unknown" + + def test_identify_best_match(self): + known = { + "alice": [1.0, 0.0, 0.0], + "bob": [0.8, 0.6, 0.0], # closer to [1,0,0] than bob is + } + # Query close to alice + result = identify_speaker([0.99, 0.01, 0.0], known, threshold=0.5) + assert result == "alice" diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/CMakeLists.txt b/jetson/ros2_ws/src/saltybot_social_msgs/CMakeLists.txt index 544ca7e..c7eb9b4 100644 --- a/jetson/ros2_ws/src/saltybot_social_msgs/CMakeLists.txt +++ b/jetson/ros2_ws/src/saltybot_social_msgs/CMakeLists.txt @@ -28,6 +28,11 @@ rosidl_generate_interfaces(${PROJECT_NAME} "srv/QueryMood.srv" # Issue #92 — multi-modal tracking fusion "msg/FusedTarget.msg" + # Issue #81 — speech pipeline + "msg/SpeechTranscript.msg" + "msg/VadState.msg" + # Issue #83 — conversation engine + "msg/ConversationResponse.msg" DEPENDENCIES std_msgs geometry_msgs builtin_interfaces ) diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/msg/ConversationResponse.msg b/jetson/ros2_ws/src/saltybot_social_msgs/msg/ConversationResponse.msg new file mode 100644 index 0000000..e35b331 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_msgs/msg/ConversationResponse.msg @@ -0,0 +1,9 @@ +# ConversationResponse.msg — LLM response, supports streaming token output. +# Published by conversation_node on /social/conversation/response + +std_msgs/Header header + +string text # Full or partial response text +string speaker_id # Who the response is addressed to +bool is_partial # true = streaming token chunk, false = final response +int32 turn_id # Conversation turn counter (for deduplication) diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/msg/SpeechTranscript.msg b/jetson/ros2_ws/src/saltybot_social_msgs/msg/SpeechTranscript.msg new file mode 100644 index 0000000..8f4cb5f --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_msgs/msg/SpeechTranscript.msg @@ -0,0 +1,10 @@ +# SpeechTranscript.msg — Result of STT with speaker identification. +# Published by speech_pipeline_node on /social/speech/transcript + +std_msgs/Header header + +string text # Transcribed text (UTF-8) +string speaker_id # e.g. "person_42" or "unknown" +float32 confidence # ASR confidence 0..1 +float32 audio_duration # Duration of the utterance in seconds +bool is_partial # true = intermediate streaming result, false = final diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/msg/VadState.msg b/jetson/ros2_ws/src/saltybot_social_msgs/msg/VadState.msg new file mode 100644 index 0000000..5a2cff6 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social_msgs/msg/VadState.msg @@ -0,0 +1,8 @@ +# VadState.msg — Voice Activity Detection state. +# Published by speech_pipeline_node on /social/speech/vad_state + +std_msgs/Header header + +bool speech_active # true = speech detected, false = silence +float32 energy_db # RMS energy in dBFS +string state # "silence" | "speech" | "wake_word" -- 2.47.2