diff --git a/jetson/ros2_ws/src/saltybot_social/config/conversation_params.yaml b/jetson/ros2_ws/src/saltybot_social/config/conversation_params.yaml
new file mode 100644
index 0000000..b0de63e
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social/config/conversation_params.yaml
@@ -0,0 +1,12 @@
+conversation_node:
+ ros__parameters:
+ model_path: "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf"
+ n_ctx: 4096
+ n_gpu_layers: 20 # Increase for more VRAM (Orin has 8GB shared)
+ max_tokens: 200
+ temperature: 0.7
+ top_p: 0.9
+ soul_path: "/soul/SOUL.md"
+ context_db_path: "/social_db/conversation_context.json"
+ save_interval_s: 30.0
+ stream: true
diff --git a/jetson/ros2_ws/src/saltybot_social/config/orchestrator_params.yaml b/jetson/ros2_ws/src/saltybot_social/config/orchestrator_params.yaml
new file mode 100644
index 0000000..7efd8b8
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social/config/orchestrator_params.yaml
@@ -0,0 +1,8 @@
+orchestrator_node:
+ ros__parameters:
+ gpu_mem_warn_mb: 4000.0 # Warn when GPU free < 4GB
+ gpu_mem_throttle_mb: 2000.0 # Throttle new inferences at < 2GB
+ watchdog_timeout_s: 30.0
+ latency_window: 20
+ profile_enabled: true
+ state_publish_rate: 2.0 # Hz
diff --git a/jetson/ros2_ws/src/saltybot_social/config/speech_params.yaml b/jetson/ros2_ws/src/saltybot_social/config/speech_params.yaml
new file mode 100644
index 0000000..8c1ef5e
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social/config/speech_params.yaml
@@ -0,0 +1,13 @@
+speech_pipeline_node:
+ ros__parameters:
+ mic_device_index: -1 # -1 = system default; use `arecord -l` to list
+ sample_rate: 16000
+ wake_word_model: "hey_salty"
+ wake_word_threshold: 0.5
+ vad_threshold_db: -35.0
+ use_silero_vad: true
+ whisper_model: "small" # small (~500ms), medium (better quality, ~900ms)
+ whisper_compute_type: "float16"
+ speaker_threshold: 0.65
+ speaker_db_path: "/social_db/speaker_embeddings.json"
+ publish_partial: true
diff --git a/jetson/ros2_ws/src/saltybot_social/config/tts_params.yaml b/jetson/ros2_ws/src/saltybot_social/config/tts_params.yaml
new file mode 100644
index 0000000..4d8fe52
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social/config/tts_params.yaml
@@ -0,0 +1,9 @@
+tts_node:
+ ros__parameters:
+ voice_path: "/models/piper/en_US-lessac-medium.onnx"
+ sample_rate: 22050
+ volume: 1.0
+ audio_device: "" # "" = system default; set to device name if needed
+ playback_enabled: true
+ publish_audio: false
+ sentence_streaming: true
diff --git a/jetson/ros2_ws/src/saltybot_social/launch/social_bot.launch.py b/jetson/ros2_ws/src/saltybot_social/launch/social_bot.launch.py
new file mode 100644
index 0000000..88e34ca
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social/launch/social_bot.launch.py
@@ -0,0 +1,129 @@
+"""social_bot.launch.py — Launch the full social-bot pipeline.
+
+Issue #89: End-to-end pipeline orchestrator launch file.
+
+Brings up all social-bot nodes in dependency order:
+ 1. Orchestrator (state machine + watchdog) — t=0s
+ 2. Speech pipeline (wake word + VAD + STT) — t=2s
+ 3. Conversation engine (LLM) — t=4s
+ 4. TTS node (Piper streaming) — t=4s
+ 5. Person state tracker (already in social.launch.py — optional)
+
+Launch args:
+ enable_speech (bool, true)
+ enable_llm (bool, true)
+ enable_tts (bool, true)
+ enable_orchestrator (bool, true)
+ voice_path (str, /models/piper/en_US-lessac-medium.onnx)
+ whisper_model (str, small)
+ llm_model_path (str, /models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf)
+ n_gpu_layers (int, 20)
+"""
+
+import os
+from ament_index_python.packages import get_package_share_directory
+from launch import LaunchDescription
+from launch.actions import (
+ DeclareLaunchArgument, GroupAction, LogInfo, TimerAction
+)
+from launch.conditions import IfCondition
+from launch.substitutions import LaunchConfiguration, PythonExpression
+from launch_ros.actions import Node
+
+
+def generate_launch_description() -> LaunchDescription:
+ pkg = get_package_share_directory("saltybot_social")
+ cfg = os.path.join(pkg, "config")
+
+ # ── Launch arguments ─────────────────────────────────────────────────────
+ args = [
+ DeclareLaunchArgument("enable_speech", default_value="true"),
+ DeclareLaunchArgument("enable_llm", default_value="true"),
+ DeclareLaunchArgument("enable_tts", default_value="true"),
+ DeclareLaunchArgument("enable_orchestrator", default_value="true"),
+ DeclareLaunchArgument("voice_path",
+ default_value="/models/piper/en_US-lessac-medium.onnx"),
+ DeclareLaunchArgument("whisper_model", default_value="small"),
+ DeclareLaunchArgument("llm_model_path",
+ default_value="/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf"),
+ DeclareLaunchArgument("n_gpu_layers", default_value="20"),
+ ]
+
+ # ── Orchestrator (t=0s) ──────────────────────────────────────────────────
+ orchestrator = Node(
+ package="saltybot_social",
+ executable="orchestrator_node",
+ name="orchestrator_node",
+ parameters=[os.path.join(cfg, "orchestrator_params.yaml")],
+ condition=IfCondition(LaunchConfiguration("enable_orchestrator")),
+ output="screen",
+ emulate_tty=True,
+ )
+
+ # ── Speech pipeline (t=2s) ───────────────────────────────────────────────
+ speech = TimerAction(period=2.0, actions=[
+ GroupAction(
+ condition=IfCondition(LaunchConfiguration("enable_speech")),
+ actions=[
+ LogInfo(msg="[social_bot] Starting speech pipeline..."),
+ Node(
+ package="saltybot_social",
+ executable="speech_pipeline_node",
+ name="speech_pipeline_node",
+ parameters=[
+ os.path.join(cfg, "speech_params.yaml"),
+ {"whisper_model": LaunchConfiguration("whisper_model")},
+ ],
+ output="screen",
+ emulate_tty=True,
+ ),
+ ]
+ )
+ ])
+
+ # ── LLM conversation engine (t=4s) ───────────────────────────────────────
+ llm = TimerAction(period=4.0, actions=[
+ GroupAction(
+ condition=IfCondition(LaunchConfiguration("enable_llm")),
+ actions=[
+ LogInfo(msg="[social_bot] Starting LLM conversation engine..."),
+ Node(
+ package="saltybot_social",
+ executable="conversation_node",
+ name="conversation_node",
+ parameters=[
+ os.path.join(cfg, "conversation_params.yaml"),
+ {
+ "model_path": LaunchConfiguration("llm_model_path"),
+ "n_gpu_layers": LaunchConfiguration("n_gpu_layers"),
+ },
+ ],
+ output="screen",
+ emulate_tty=True,
+ ),
+ ]
+ )
+ ])
+
+ # ── TTS node (t=4s, parallel with LLM) ──────────────────────────────────
+ tts = TimerAction(period=4.0, actions=[
+ GroupAction(
+ condition=IfCondition(LaunchConfiguration("enable_tts")),
+ actions=[
+ LogInfo(msg="[social_bot] Starting TTS node..."),
+ Node(
+ package="saltybot_social",
+ executable="tts_node",
+ name="tts_node",
+ parameters=[
+ os.path.join(cfg, "tts_params.yaml"),
+ {"voice_path": LaunchConfiguration("voice_path")},
+ ],
+ output="screen",
+ emulate_tty=True,
+ ),
+ ]
+ )
+ ])
+
+ return LaunchDescription(args + [orchestrator, speech, llm, tts])
diff --git a/jetson/ros2_ws/src/saltybot_social/package.xml b/jetson/ros2_ws/src/saltybot_social/package.xml
index ca4b63c..f73d2ce 100644
--- a/jetson/ros2_ws/src/saltybot_social/package.xml
+++ b/jetson/ros2_ws/src/saltybot_social/package.xml
@@ -5,9 +5,12 @@
0.1.0
Social interaction layer for saltybot.
+ speech_pipeline_node: wake word + VAD + Whisper STT + diarization (Issue #81).
+ conversation_node: local LLM with per-person context (Issue #83).
+ tts_node: streaming TTS with Piper first-chunk (Issue #85).
+ orchestrator_node: pipeline state machine + GPU watchdog + latency profiler (Issue #89).
person_state_tracker: multi-modal person identity fusion (Issue #82).
- expression_node: bridges /social/mood to ESP32-C3 NeoPixel ring over serial (Issue #86).
- attention_node: rotates robot toward active speaker via /social/persons bearing (Issue #86).
+ expression_node: LED expression + motor attention (Issue #86).
seb
MIT
diff --git a/jetson/ros2_ws/src/saltybot_social/saltybot_social/conversation_node.py b/jetson/ros2_ws/src/saltybot_social/saltybot_social/conversation_node.py
new file mode 100644
index 0000000..a9218a7
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social/saltybot_social/conversation_node.py
@@ -0,0 +1,267 @@
+"""conversation_node.py — Local LLM conversation engine with per-person context.
+
+Issue #83: Conversation engine for social-bot.
+
+Stack: Phi-3-mini or Llama-3.2-3B GGUF Q4_K_M via llama-cpp-python (CUDA).
+Subscribes /social/speech/transcript → builds per-person prompt → streams
+token output → publishes /social/conversation/response.
+
+Streaming: publishes partial=true tokens as they arrive, then final=false
+at end of generation. TTS node can begin synthesis on first sentence boundary.
+
+ROS2 topics:
+ Subscribe: /social/speech/transcript (saltybot_social_msgs/SpeechTranscript)
+ Publish: /social/conversation/response (saltybot_social_msgs/ConversationResponse)
+
+Parameters:
+ model_path (str, "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf")
+ n_ctx (int, 4096)
+ n_gpu_layers (int, 20) — GPU offload layers (increase for more VRAM usage)
+ max_tokens (int, 200)
+ temperature (float, 0.7)
+ top_p (float, 0.9)
+ soul_path (str, "/soul/SOUL.md")
+ context_db_path (str, "/social_db/conversation_context.json")
+ save_interval_s (float, 30.0) — how often to persist context to disk
+ stream (bool, true)
+"""
+
+from __future__ import annotations
+
+import threading
+import time
+from typing import Optional
+
+import rclpy
+from rclpy.node import Node
+from rclpy.qos import QoSProfile
+
+from saltybot_social_msgs.msg import SpeechTranscript, ConversationResponse
+from .llm_context import ContextStore, build_llama_prompt, load_system_prompt, needs_summary_prompt
+
+
+class ConversationNode(Node):
+ """Local LLM inference node with per-person conversation memory."""
+
+ def __init__(self) -> None:
+ super().__init__("conversation_node")
+
+ # ── Parameters ──────────────────────────────────────────────────────
+ self.declare_parameter("model_path",
+ "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf")
+ self.declare_parameter("n_ctx", 4096)
+ self.declare_parameter("n_gpu_layers", 20)
+ self.declare_parameter("max_tokens", 200)
+ self.declare_parameter("temperature", 0.7)
+ self.declare_parameter("top_p", 0.9)
+ self.declare_parameter("soul_path", "/soul/SOUL.md")
+ self.declare_parameter("context_db_path", "/social_db/conversation_context.json")
+ self.declare_parameter("save_interval_s", 30.0)
+ self.declare_parameter("stream", True)
+
+ self._model_path = self.get_parameter("model_path").value
+ self._n_ctx = self.get_parameter("n_ctx").value
+ self._n_gpu = self.get_parameter("n_gpu_layers").value
+ self._max_tokens = self.get_parameter("max_tokens").value
+ self._temperature = self.get_parameter("temperature").value
+ self._top_p = self.get_parameter("top_p").value
+ self._soul_path = self.get_parameter("soul_path").value
+ self._db_path = self.get_parameter("context_db_path").value
+ self._save_interval = self.get_parameter("save_interval_s").value
+ self._stream = self.get_parameter("stream").value
+
+ # ── Publishers / Subscribers ─────────────────────────────────────────
+ qos = QoSProfile(depth=10)
+ self._resp_pub = self.create_publisher(
+ ConversationResponse, "/social/conversation/response", qos
+ )
+ self._transcript_sub = self.create_subscription(
+ SpeechTranscript, "/social/speech/transcript",
+ self._on_transcript, qos
+ )
+
+ # ── State ────────────────────────────────────────────────────────────
+ self._llm = None
+ self._system_prompt = load_system_prompt(self._soul_path)
+ self._ctx_store = ContextStore(self._db_path)
+ self._lock = threading.Lock()
+ self._turn_counter = 0
+ self._generating = False
+ self._last_save = time.time()
+
+ # ── Load LLM in background ────────────────────────────────────────────
+ threading.Thread(target=self._load_llm, daemon=True).start()
+
+ # ── Periodic context save ────────────────────────────────────────────
+ self._save_timer = self.create_timer(self._save_interval, self._save_context)
+
+ self.get_logger().info(
+ f"ConversationNode init (model={self._model_path}, "
+ f"gpu_layers={self._n_gpu}, ctx={self._n_ctx})"
+ )
+
+ # ── Model loading ─────────────────────────────────────────────────────────
+
+ def _load_llm(self) -> None:
+ t0 = time.time()
+ self.get_logger().info(f"Loading LLM: {self._model_path}")
+ try:
+ from llama_cpp import Llama
+ self._llm = Llama(
+ model_path=self._model_path,
+ n_ctx=self._n_ctx,
+ n_gpu_layers=self._n_gpu,
+ n_threads=4,
+ verbose=False,
+ )
+ self.get_logger().info(
+ f"LLM ready ({time.time()-t0:.1f}s). "
+ f"Context: {self._n_ctx} tokens, GPU layers: {self._n_gpu}"
+ )
+ except Exception as e:
+ self.get_logger().error(f"LLM load failed: {e}")
+
+ # ── Transcript callback ───────────────────────────────────────────────────
+
+ def _on_transcript(self, msg: SpeechTranscript) -> None:
+ """Handle final transcripts only (skip streaming partials)."""
+ if msg.is_partial:
+ return
+ if not msg.text.strip():
+ return
+
+ self.get_logger().info(
+ f"Transcript [{msg.speaker_id}]: '{msg.text}'"
+ )
+
+ threading.Thread(
+ target=self._generate_response,
+ args=(msg.text.strip(), msg.speaker_id),
+ daemon=True,
+ ).start()
+
+ # ── LLM inference ─────────────────────────────────────────────────────────
+
+ def _generate_response(self, user_text: str, speaker_id: str) -> None:
+ """Generate LLM response with streaming. Runs in thread."""
+ if self._llm is None:
+ self.get_logger().warn("LLM not loaded yet, dropping utterance")
+ return
+
+ with self._lock:
+ if self._generating:
+ self.get_logger().warn("LLM busy, dropping utterance")
+ return
+ self._generating = True
+ self._turn_counter += 1
+ turn_id = self._turn_counter
+
+ try:
+ ctx = self._ctx_store.get(speaker_id)
+
+ # Summary compression if context is long
+ if ctx.needs_compression():
+ self._compress_context(ctx)
+
+ ctx.add_user(user_text)
+
+ prompt = build_llama_prompt(
+ ctx, user_text, self._system_prompt
+ )
+
+ t0 = time.perf_counter()
+ full_response = ""
+
+ if self._stream:
+ output = self._llm(
+ prompt,
+ max_tokens=self._max_tokens,
+ temperature=self._temperature,
+ top_p=self._top_p,
+ stream=True,
+ stop=["<|user|>", "<|system|>", "\n\n\n"],
+ )
+ for chunk in output:
+ token = chunk["choices"][0]["text"]
+ full_response += token
+ # Publish partial after each sentence boundary for low TTS latency
+ if token.endswith((".", "!", "?", "\n")):
+ self._publish_response(
+ full_response.strip(), speaker_id, turn_id, is_partial=True
+ )
+ else:
+ output = self._llm(
+ prompt,
+ max_tokens=self._max_tokens,
+ temperature=self._temperature,
+ top_p=self._top_p,
+ stream=False,
+ stop=["<|user|>", "<|system|>"],
+ )
+ full_response = output["choices"][0]["text"]
+
+ full_response = full_response.strip()
+ latency_ms = (time.perf_counter() - t0) * 1000
+ self.get_logger().info(
+ f"LLM [{speaker_id}] ({latency_ms:.0f}ms): '{full_response[:80]}'"
+ )
+
+ ctx.add_assistant(full_response)
+ self._publish_response(full_response, speaker_id, turn_id, is_partial=False)
+
+ except Exception as e:
+ self.get_logger().error(f"LLM inference error: {e}")
+ finally:
+ with self._lock:
+ self._generating = False
+
+ def _compress_context(self, ctx) -> None:
+ """Ask LLM to summarize old turns for context compression."""
+ if self._llm is None:
+ ctx.compress("(history omitted)")
+ return
+ try:
+ summary_prompt = needs_summary_prompt(ctx)
+ result = self._llm(summary_prompt, max_tokens=80, temperature=0.3, stream=False)
+ summary = result["choices"][0]["text"].strip()
+ ctx.compress(summary)
+ self.get_logger().debug(
+ f"Context compressed for {ctx.person_id}: '{summary[:60]}'"
+ )
+ except Exception:
+ ctx.compress("(history omitted)")
+
+ # ── Publish ───────────────────────────────────────────────────────────────
+
+ def _publish_response(
+ self, text: str, speaker_id: str, turn_id: int, is_partial: bool
+ ) -> None:
+ msg = ConversationResponse()
+ msg.header.stamp = self.get_clock().now().to_msg()
+ msg.text = text
+ msg.speaker_id = speaker_id
+ msg.is_partial = is_partial
+ msg.turn_id = turn_id
+ self._resp_pub.publish(msg)
+
+ def _save_context(self) -> None:
+ try:
+ self._ctx_store.save()
+ except Exception as e:
+ self.get_logger().error(f"Context save error: {e}")
+
+ def destroy_node(self) -> None:
+ self._save_context()
+ super().destroy_node()
+
+
+def main(args=None) -> None:
+ rclpy.init(args=args)
+ node = ConversationNode()
+ try:
+ rclpy.spin(node)
+ except KeyboardInterrupt:
+ pass
+ finally:
+ node.destroy_node()
+ rclpy.shutdown()
diff --git a/jetson/ros2_ws/src/saltybot_social/saltybot_social/llm_context.py b/jetson/ros2_ws/src/saltybot_social/saltybot_social/llm_context.py
new file mode 100644
index 0000000..a743f54
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social/saltybot_social/llm_context.py
@@ -0,0 +1,176 @@
+"""llm_context.py — Per-person conversation context management.
+
+No ROS2 dependencies. Manages sliding-window conversation history per person_id,
+builds llama.cpp prompt, handles summary compression when context is full.
+
+Used by conversation_node.py.
+Tested by test/test_llm_context.py.
+"""
+
+from __future__ import annotations
+
+import json
+import os
+import time
+from typing import List, Optional, Tuple
+
+# ── Constants ──────────────────────────────────────────────────────────────────
+MAX_TURNS = 20 # Max turns to keep before summary compression
+SUMMARY_KEEP = 4 # Keep last N turns after summarizing older ones
+SYSTEM_PROMPT_PATH = "/soul/SOUL.md"
+DEFAULT_SYSTEM_PROMPT = (
+ "You are Salty, a friendly social robot. "
+ "You are helpful, warm, and concise. "
+ "Keep responses under 2 sentences unless asked to elaborate."
+)
+
+
+# ── Conversation history ───────────────────────────────────────────────────────
+
+class PersonContext:
+ """Conversation history + relationship memory for one person."""
+
+ def __init__(self, person_id: str, person_name: str = "") -> None:
+ self.person_id = person_id
+ self.person_name = person_name or person_id
+ self.turns: List[dict] = [] # {"role": "user"|"assistant", "content": str, "ts": float}
+ self.summary: str = ""
+ self.last_seen: float = time.time()
+ self.interaction_count: int = 0
+
+ def add_user(self, text: str) -> None:
+ self.turns.append({"role": "user", "content": text, "ts": time.time()})
+ self.last_seen = time.time()
+ self.interaction_count += 1
+
+ def add_assistant(self, text: str) -> None:
+ self.turns.append({"role": "assistant", "content": text, "ts": time.time()})
+
+ def needs_compression(self) -> bool:
+ return len(self.turns) > MAX_TURNS
+
+ def compress(self, summary_text: str) -> None:
+ """Replace old turns with a summary, keeping last SUMMARY_KEEP turns."""
+ old_summary = self.summary
+ if old_summary:
+ self.summary = old_summary + " " + summary_text
+ else:
+ self.summary = summary_text
+ # Keep most recent turns
+ self.turns = self.turns[-SUMMARY_KEEP:]
+
+ def to_dict(self) -> dict:
+ return {
+ "person_id": self.person_id,
+ "person_name": self.person_name,
+ "turns": self.turns,
+ "summary": self.summary,
+ "last_seen": self.last_seen,
+ "interaction_count": self.interaction_count,
+ }
+
+ @classmethod
+ def from_dict(cls, d: dict) -> "PersonContext":
+ ctx = cls(d["person_id"], d.get("person_name", ""))
+ ctx.turns = d.get("turns", [])
+ ctx.summary = d.get("summary", "")
+ ctx.last_seen = d.get("last_seen", 0.0)
+ ctx.interaction_count = d.get("interaction_count", 0)
+ return ctx
+
+
+# ── Context store ──────────────────────────────────────────────────────────────
+
+class ContextStore:
+ """Persistent per-person conversation context with optional JSON backing."""
+
+ def __init__(self, db_path: str = "/social_db/conversation_context.json") -> None:
+ self._db_path = db_path
+ self._contexts: dict = {} # person_id → PersonContext
+ self._load()
+
+ def get(self, person_id: str, person_name: str = "") -> PersonContext:
+ if person_id not in self._contexts:
+ self._contexts[person_id] = PersonContext(person_id, person_name)
+ elif person_name and not self._contexts[person_id].person_name:
+ self._contexts[person_id].person_name = person_name
+ return self._contexts[person_id]
+
+ def save(self) -> None:
+ os.makedirs(os.path.dirname(self._db_path), exist_ok=True)
+ with open(self._db_path, "w") as f:
+ json.dump(
+ {pid: ctx.to_dict() for pid, ctx in self._contexts.items()},
+ f, indent=2
+ )
+
+ def _load(self) -> None:
+ if os.path.exists(self._db_path):
+ try:
+ with open(self._db_path) as f:
+ raw = json.load(f)
+ self._contexts = {
+ pid: PersonContext.from_dict(d) for pid, d in raw.items()
+ }
+ except Exception:
+ self._contexts = {}
+
+ def all_persons(self) -> List[str]:
+ return list(self._contexts.keys())
+
+
+# ── Prompt builder ─────────────────────────────────────────────────────────────
+
+def load_system_prompt(path: str = SYSTEM_PROMPT_PATH) -> str:
+ """Load SOUL.md as system prompt. Falls back to default."""
+ if os.path.exists(path):
+ with open(path) as f:
+ return f.read().strip()
+ return DEFAULT_SYSTEM_PROMPT
+
+
+def build_llama_prompt(
+ ctx: PersonContext,
+ user_text: str,
+ system_prompt: str,
+ group_persons: Optional[List[str]] = None,
+) -> str:
+ """Build Phi-3 / Llama-3 style chat prompt string.
+
+ Uses the ChatML format supported by llama-cpp-python.
+ """
+ lines = [f"<|system|>\n{system_prompt}"]
+
+ if ctx.summary:
+ lines.append(
+ f"<|system|>\n[Memory of {ctx.person_name}]: {ctx.summary}"
+ )
+
+ if group_persons:
+ others = [p for p in group_persons if p != ctx.person_id]
+ if others:
+ lines.append(
+ f"<|system|>\n[Also present]: {', '.join(others)}"
+ )
+
+ for turn in ctx.turns:
+ role = "user" if turn["role"] == "user" else "assistant"
+ lines.append(f"<|{role}|>\n{turn['content']}")
+
+ lines.append(f"<|user|>\n{user_text}")
+ lines.append("<|assistant|>")
+
+ return "\n".join(lines)
+
+
+def needs_summary_prompt(ctx: PersonContext) -> str:
+ """Build a prompt asking the LLM to summarize old conversation turns."""
+ turns_text = "\n".join(
+ f"{t['role'].upper()}: {t['content']}"
+ for t in ctx.turns[:-SUMMARY_KEEP]
+ )
+ return (
+ f"<|system|>\nSummarize this conversation between {ctx.person_name} and "
+ "Salty in one sentence, focusing on what was discussed and their relationship.\n"
+ f"<|user|>\n{turns_text}\n<|assistant|>"
+ )
diff --git a/jetson/ros2_ws/src/saltybot_social/saltybot_social/orchestrator_node.py b/jetson/ros2_ws/src/saltybot_social/saltybot_social/orchestrator_node.py
new file mode 100644
index 0000000..21bd8a0
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social/saltybot_social/orchestrator_node.py
@@ -0,0 +1,298 @@
+"""orchestrator_node.py — End-to-end social-bot pipeline orchestrator.
+
+Issue #89: Main loop, GPU memory scheduling, latency profiling, watchdog.
+
+State machine:
+ IDLE → LISTENING (wake word) → THINKING (LLM generating) → SPEAKING (TTS)
+ ↑___________________________________________|
+
+Responsibilities:
+ - Owns the top-level pipeline state machine
+ - Monitors GPU memory and throttles inference when approaching limit
+ - Profiles end-to-end latency per interaction turn
+ - Health watchdog: restart speech pipeline on hang
+ - Exposes /social/orchestrator/state for UI / other nodes
+
+ROS2 topics:
+ Subscribe: /social/speech/vad_state (VadState)
+ /social/speech/transcript (SpeechTranscript)
+ /social/conversation/response (ConversationResponse)
+ Publish: /social/orchestrator/state (std_msgs/String) — JSON status blob
+
+Parameters:
+ gpu_mem_warn_mb (float, 4000.0) — warn at this free GPU memory (MB)
+ gpu_mem_throttle_mb (float, 2000.0) — suspend new inference requests
+ watchdog_timeout_s (float, 30.0) — restart speech node if no VAD in N seconds
+ latency_window (int, 20) — rolling window for latency stats
+ profile_enabled (bool, true)
+"""
+
+from __future__ import annotations
+
+import json
+import threading
+import time
+from collections import deque
+from enum import Enum
+from typing import Optional
+
+import rclpy
+from rclpy.node import Node
+from rclpy.qos import QoSProfile
+from std_msgs.msg import String
+
+from saltybot_social_msgs.msg import VadState, SpeechTranscript, ConversationResponse
+
+
+# ── State machine ──────────────────────────────────────────────────────────────
+
+class PipelineState(Enum):
+ IDLE = "idle"
+ LISTENING = "listening"
+ THINKING = "thinking"
+ SPEAKING = "speaking"
+ THROTTLED = "throttled" # GPU OOM risk — waiting for memory
+
+
+# ── Latency tracker ────────────────────────────────────────────────────────────
+
+class LatencyTracker:
+ """Rolling window latency statistics per pipeline stage."""
+
+ def __init__(self, window: int = 20) -> None:
+ self._window = window
+ self._wakeword_to_transcript: deque = deque(maxlen=window)
+ self._transcript_to_llm_first: deque = deque(maxlen=window)
+ self._llm_first_to_tts_first: deque = deque(maxlen=window)
+ self._end_to_end: deque = deque(maxlen=window)
+
+ def record(self, stage: str, latency_ms: float) -> None:
+ mapping = {
+ "wakeword_to_transcript": self._wakeword_to_transcript,
+ "transcript_to_llm": self._transcript_to_llm_first,
+ "llm_to_tts": self._llm_first_to_tts_first,
+ "end_to_end": self._end_to_end,
+ }
+ q = mapping.get(stage)
+ if q is not None:
+ q.append(latency_ms)
+
+ def stats(self) -> dict:
+ def _stat(q: deque) -> dict:
+ if not q:
+ return {}
+ s = sorted(q)
+ return {
+ "mean_ms": round(sum(s) / len(s), 1),
+ "p50_ms": round(s[len(s) // 2], 1),
+ "p95_ms": round(s[int(len(s) * 0.95)], 1),
+ "n": len(s),
+ }
+ return {
+ "wakeword_to_transcript": _stat(self._wakeword_to_transcript),
+ "transcript_to_llm": _stat(self._transcript_to_llm_first),
+ "llm_to_tts": _stat(self._llm_first_to_tts_first),
+ "end_to_end": _stat(self._end_to_end),
+ }
+
+
+# ── GPU memory helpers ────────────────────────────────────────────────────────
+
+def get_gpu_free_mb() -> Optional[float]:
+ """Return free GPU memory in MB, or None if unavailable."""
+ try:
+ import pycuda.driver as drv
+ drv.init()
+ ctx = drv.Device(0).make_context()
+ free, _ = drv.mem_get_info()
+ ctx.pop()
+ return free / 1024 / 1024
+ except Exception:
+ pass
+ try:
+ import subprocess
+ out = subprocess.check_output(
+ ["nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits"],
+ timeout=2
+ ).decode().strip()
+ return float(out.split("\n")[0])
+ except Exception:
+ return None
+
+
+# ── Orchestrator ───────────────────────────────────────────────────────────────
+
+class OrchestratorNode(Node):
+ """Pipeline orchestrator — state machine, GPU watchdog, latency profiler."""
+
+ def __init__(self) -> None:
+ super().__init__("orchestrator_node")
+
+ # ── Parameters ──────────────────────────────────────────────────────
+ self.declare_parameter("gpu_mem_warn_mb", 4000.0)
+ self.declare_parameter("gpu_mem_throttle_mb", 2000.0)
+ self.declare_parameter("watchdog_timeout_s", 30.0)
+ self.declare_parameter("latency_window", 20)
+ self.declare_parameter("profile_enabled", True)
+ self.declare_parameter("state_publish_rate", 2.0) # Hz
+
+ self._gpu_warn = self.get_parameter("gpu_mem_warn_mb").value
+ self._gpu_throttle = self.get_parameter("gpu_mem_throttle_mb").value
+ self._watchdog_timeout = self.get_parameter("watchdog_timeout_s").value
+ self._profile = self.get_parameter("profile_enabled").value
+ self._state_rate = self.get_parameter("state_publish_rate").value
+
+ # ── Publishers / Subscribers ─────────────────────────────────────────
+ qos = QoSProfile(depth=10)
+ self._state_pub = self.create_publisher(String, "/social/orchestrator/state", qos)
+
+ self._vad_sub = self.create_subscription(
+ VadState, "/social/speech/vad_state", self._on_vad, qos
+ )
+ self._transcript_sub = self.create_subscription(
+ SpeechTranscript, "/social/speech/transcript", self._on_transcript, qos
+ )
+ self._response_sub = self.create_subscription(
+ ConversationResponse, "/social/conversation/response", self._on_response, qos
+ )
+
+ # ── State ────────────────────────────────────────────────────────────
+ self._state = PipelineState.IDLE
+ self._state_lock = threading.Lock()
+ self._tracker = LatencyTracker(self.get_parameter("latency_window").value)
+
+ # Timestamps for latency tracking
+ self._t_wake: float = 0.0
+ self._t_transcript: float = 0.0
+ self._t_llm_first: float = 0.0
+ self._t_tts_first: float = 0.0
+
+ self._last_vad_t: float = time.time()
+ self._last_transcript_t: float = 0.0
+ self._last_response_t: float = 0.0
+
+ # ── Timers ────────────────────────────────────────────────────────────
+ self._state_timer = self.create_timer(1.0 / self._state_rate, self._publish_state)
+ self._gpu_timer = self.create_timer(5.0, self._check_gpu)
+ self._watchdog_timer = self.create_timer(10.0, self._watchdog)
+
+ self.get_logger().info("OrchestratorNode ready")
+
+ # ── State transitions ─────────────────────────────────────────────────────
+
+ def _set_state(self, new_state: PipelineState) -> None:
+ with self._state_lock:
+ if self._state != new_state:
+ self.get_logger().info(
+ f"Pipeline: {self._state.value} → {new_state.value}"
+ )
+ self._state = new_state
+
+ # ── Subscribers ───────────────────────────────────────────────────────────
+
+ def _on_vad(self, msg: VadState) -> None:
+ self._last_vad_t = time.time()
+ if msg.state == "wake_word":
+ self._t_wake = time.time()
+ self._set_state(PipelineState.LISTENING)
+ elif msg.speech_active and self._state == PipelineState.IDLE:
+ self._set_state(PipelineState.LISTENING)
+
+ def _on_transcript(self, msg: SpeechTranscript) -> None:
+ if msg.is_partial:
+ return
+ self._last_transcript_t = time.time()
+ self._t_transcript = time.time()
+ self._set_state(PipelineState.THINKING)
+
+ if self._profile and self._t_wake > 0:
+ latency_ms = (self._t_transcript - self._t_wake) * 1000
+ self._tracker.record("wakeword_to_transcript", latency_ms)
+ if latency_ms > 500:
+ self.get_logger().warn(
+ f"Wake→transcript latency high: {latency_ms:.0f}ms (target <500ms)"
+ )
+
+ def _on_response(self, msg: ConversationResponse) -> None:
+ self._last_response_t = time.time()
+
+ if msg.is_partial and self._t_llm_first == 0.0:
+ # First token from LLM
+ self._t_llm_first = time.time()
+ if self._profile and self._t_transcript > 0:
+ latency_ms = (self._t_llm_first - self._t_transcript) * 1000
+ self._tracker.record("transcript_to_llm", latency_ms)
+
+ self._set_state(PipelineState.SPEAKING)
+
+ if not msg.is_partial:
+ # Full response complete
+ self._t_llm_first = 0.0
+ self._t_wake = 0.0
+ self._t_transcript = 0.0
+
+ if self._profile and self._t_wake > 0:
+ e2e_ms = (time.time() - self._t_wake) * 1000
+ self._tracker.record("end_to_end", e2e_ms)
+
+ # Return to IDLE after a short delay (TTS still playing)
+ threading.Timer(2.0, lambda: self._set_state(PipelineState.IDLE)).start()
+
+ # ── GPU memory check ──────────────────────────────────────────────────────
+
+ def _check_gpu(self) -> None:
+ free_mb = get_gpu_free_mb()
+ if free_mb is None:
+ return
+
+ if free_mb < self._gpu_throttle:
+ self.get_logger().error(
+ f"GPU memory critical: {free_mb:.0f}MB free < {self._gpu_throttle:.0f}MB threshold"
+ )
+ self._set_state(PipelineState.THROTTLED)
+ elif free_mb < self._gpu_warn:
+ self.get_logger().warn(f"GPU memory low: {free_mb:.0f}MB free")
+ else:
+ # Recover from throttled state
+ with self._state_lock:
+ if self._state == PipelineState.THROTTLED:
+ self._set_state(PipelineState.IDLE)
+
+ # ── Watchdog ──────────────────────────────────────────────────────────────
+
+ def _watchdog(self) -> None:
+ """Alert if speech pipeline has gone silent for too long."""
+ since_vad = time.time() - self._last_vad_t
+ if since_vad > self._watchdog_timeout:
+ self.get_logger().error(
+ f"Watchdog: No VAD signal for {since_vad:.0f}s. "
+ "Speech pipeline may be hung. Check /social/speech/vad_state."
+ )
+
+ # ── State publisher ───────────────────────────────────────────────────────
+
+ def _publish_state(self) -> None:
+ with self._state_lock:
+ current = self._state.value
+
+ payload = {
+ "state": current,
+ "ts": time.time(),
+ "latency": self._tracker.stats() if self._profile else {},
+ "gpu_free_mb": get_gpu_free_mb(),
+ }
+ msg = String()
+ msg.data = json.dumps(payload)
+ self._state_pub.publish(msg)
+
+
+def main(args=None) -> None:
+ rclpy.init(args=args)
+ node = OrchestratorNode()
+ try:
+ rclpy.spin(node)
+ except KeyboardInterrupt:
+ pass
+ finally:
+ node.destroy_node()
+ rclpy.shutdown()
diff --git a/jetson/ros2_ws/src/saltybot_social/saltybot_social/speech_pipeline_node.py b/jetson/ros2_ws/src/saltybot_social/saltybot_social/speech_pipeline_node.py
new file mode 100644
index 0000000..b9e6c27
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social/saltybot_social/speech_pipeline_node.py
@@ -0,0 +1,391 @@
+"""speech_pipeline_node.py — Wake word + VAD + Whisper STT + speaker diarization.
+
+Issue #81: Speech pipeline for social-bot.
+
+Pipeline:
+ USB mic (PyAudio) → VAD (Silero / energy fallback)
+ → wake word gate (OpenWakeWord "hey_salty")
+ → utterance buffer → faster-whisper STT (Orin GPU)
+ → ECAPA-TDNN speaker embedding → identify_speaker()
+ → publish /social/speech/transcript + /social/speech/vad_state
+
+Latency target: <500ms wake-to-first-token (streaming partial transcripts).
+
+ROS2 topics published:
+ /social/speech/transcript (saltybot_social_msgs/SpeechTranscript)
+ /social/speech/vad_state (saltybot_social_msgs/VadState)
+
+Parameters:
+ mic_device_index (int, -1 = system default)
+ sample_rate (int, 16000)
+ wake_word_model (str, "hey_salty" — OpenWakeWord model name or path)
+ wake_word_threshold (float, 0.5)
+ vad_threshold_db (float, -35.0) — fallback energy VAD
+ use_silero_vad (bool, true)
+ whisper_model (str, "small")
+ whisper_compute_type (str, "float16")
+ speaker_threshold (float, 0.65) — cosine similarity threshold
+ speaker_db_path (str, "/social_db/speaker_embeddings.json")
+ publish_partial (bool, true)
+"""
+
+from __future__ import annotations
+
+import json
+import os
+import threading
+import time
+from typing import Optional
+
+import rclpy
+from rclpy.node import Node
+from rclpy.qos import QoSProfile, QoSReliabilityPolicy, QoSDurabilityPolicy
+from builtin_interfaces.msg import Time
+from std_msgs.msg import Header
+
+from saltybot_social_msgs.msg import SpeechTranscript, VadState
+from .speech_utils import (
+ EnergyVad, UtteranceSegmenter,
+ pcm16_to_float32, rms_db, identify_speaker,
+ SAMPLE_RATE, CHUNK_SAMPLES,
+)
+
+
+class SpeechPipelineNode(Node):
+ """Wake word → VAD → STT → diarization → ROS2 publisher."""
+
+ def __init__(self) -> None:
+ super().__init__("speech_pipeline_node")
+
+ # ── Parameters ──────────────────────────────────────────────────────
+ self.declare_parameter("mic_device_index", -1)
+ self.declare_parameter("sample_rate", SAMPLE_RATE)
+ self.declare_parameter("wake_word_model", "hey_salty")
+ self.declare_parameter("wake_word_threshold", 0.5)
+ self.declare_parameter("vad_threshold_db", -35.0)
+ self.declare_parameter("use_silero_vad", True)
+ self.declare_parameter("whisper_model", "small")
+ self.declare_parameter("whisper_compute_type", "float16")
+ self.declare_parameter("speaker_threshold", 0.65)
+ self.declare_parameter("speaker_db_path", "/social_db/speaker_embeddings.json")
+ self.declare_parameter("publish_partial", True)
+
+ self._mic_idx = self.get_parameter("mic_device_index").value
+ self._rate = self.get_parameter("sample_rate").value
+ self._ww_model = self.get_parameter("wake_word_model").value
+ self._ww_thresh = self.get_parameter("wake_word_threshold").value
+ self._vad_thresh_db = self.get_parameter("vad_threshold_db").value
+ self._use_silero = self.get_parameter("use_silero_vad").value
+ self._whisper_model_name = self.get_parameter("whisper_model").value
+ self._compute_type = self.get_parameter("whisper_compute_type").value
+ self._speaker_thresh = self.get_parameter("speaker_threshold").value
+ self._speaker_db = self.get_parameter("speaker_db_path").value
+ self._publish_partial = self.get_parameter("publish_partial").value
+
+ # ── Publishers ───────────────────────────────────────────────────────
+ qos = QoSProfile(depth=10)
+ self._transcript_pub = self.create_publisher(SpeechTranscript,
+ "/social/speech/transcript", qos)
+ self._vad_pub = self.create_publisher(VadState, "/social/speech/vad_state", qos)
+
+ # ── Internal state ───────────────────────────────────────────────────
+ self._whisper = None
+ self._ecapa = None
+ self._oww = None
+ self._vad = None
+ self._segmenter = None
+ self._known_speakers: dict = {}
+ self._audio_stream = None
+ self._pa = None
+ self._lock = threading.Lock()
+ self._running = False
+ self._wake_word_active = False
+ self._wake_expiry = 0.0
+ self._wake_window_s = 8.0 # seconds to listen after wake word
+ self._turn_counter = 0
+
+ # ── Lazy model loading in background thread ──────────────────────────
+ self._model_thread = threading.Thread(target=self._load_models, daemon=True)
+ self._model_thread.start()
+
+ # ── Model loading ─────────────────────────────────────────────────────────
+
+ def _load_models(self) -> None:
+ """Load all AI models (runs in background to not block ROS spin)."""
+ self.get_logger().info("Loading speech models (Whisper, ECAPA-TDNN, VAD)...")
+ t0 = time.time()
+
+ # faster-whisper
+ try:
+ from faster_whisper import WhisperModel
+ self._whisper = WhisperModel(
+ self._whisper_model_name,
+ device="cuda",
+ compute_type=self._compute_type,
+ download_root="/models",
+ )
+ self.get_logger().info(
+ f"Whisper '{self._whisper_model_name}' loaded "
+ f"({time.time()-t0:.1f}s)"
+ )
+ except Exception as e:
+ self.get_logger().error(f"Whisper load failed: {e}")
+
+ # Silero VAD
+ if self._use_silero:
+ try:
+ from silero_vad import load_silero_vad
+ self._vad = load_silero_vad()
+ self.get_logger().info("Silero VAD loaded")
+ except Exception as e:
+ self.get_logger().warn(f"Silero VAD unavailable ({e}), using energy VAD")
+ self._vad = None
+
+ energy_vad = EnergyVad(threshold_db=self._vad_thresh_db)
+ self._segmenter = UtteranceSegmenter(energy_vad)
+
+ # ECAPA-TDNN speaker embeddings
+ try:
+ from speechbrain.pretrained import EncoderClassifier
+ self._ecapa = EncoderClassifier.from_hparams(
+ source="speechbrain/spkrec-ecapa-voxceleb",
+ savedir="/models/speechbrain_ecapa",
+ )
+ self.get_logger().info("ECAPA-TDNN loaded")
+ except Exception as e:
+ self.get_logger().warn(f"ECAPA-TDNN unavailable ({e}), using 'unknown' speaker")
+
+ # OpenWakeWord
+ try:
+ import openwakeword
+ from openwakeword.model import Model as OWWModel
+ self._oww = OWWModel(
+ wakeword_models=[self._ww_model],
+ inference_framework="onnx",
+ )
+ self.get_logger().info(f"OpenWakeWord '{self._ww_model}' loaded")
+ except Exception as e:
+ self.get_logger().warn(
+ f"OpenWakeWord unavailable ({e}). Robot will listen continuously."
+ )
+
+ # Load speaker DB
+ self._load_speaker_db()
+
+ # Start audio capture
+ self._start_audio()
+ self.get_logger().info(f"Speech pipeline ready ({time.time()-t0:.1f}s total)")
+
+ def _load_speaker_db(self) -> None:
+ """Load speaker embedding database from JSON."""
+ if os.path.exists(self._speaker_db):
+ try:
+ with open(self._speaker_db) as f:
+ self._known_speakers = json.load(f)
+ self.get_logger().info(
+ f"Speaker DB loaded: {len(self._known_speakers)} persons"
+ )
+ except Exception as e:
+ self.get_logger().error(f"Speaker DB load error: {e}")
+
+ # ── Audio capture ─────────────────────────────────────────────────────────
+
+ def _start_audio(self) -> None:
+ """Open PyAudio stream and start capture thread."""
+ try:
+ import pyaudio
+ self._pa = pyaudio.PyAudio()
+ dev_idx = None if self._mic_idx < 0 else self._mic_idx
+ self._audio_stream = self._pa.open(
+ format=pyaudio.paInt16,
+ channels=1,
+ rate=self._rate,
+ input=True,
+ input_device_index=dev_idx,
+ frames_per_buffer=CHUNK_SAMPLES,
+ )
+ self._running = True
+ t = threading.Thread(target=self._audio_loop, daemon=True)
+ t.start()
+ self.get_logger().info(
+ f"Audio capture started (device={self._mic_idx}, {self._rate}Hz)"
+ )
+ except Exception as e:
+ self.get_logger().error(f"Audio stream error: {e}")
+
+ def _audio_loop(self) -> None:
+ """Continuous audio capture loop — runs in dedicated thread."""
+ while self._running:
+ try:
+ raw = self._audio_stream.read(CHUNK_SAMPLES, exception_on_overflow=False)
+ except Exception:
+ continue
+
+ samples = pcm16_to_float32(raw)
+ db = rms_db(samples)
+ now = time.time()
+
+ # VAD state publish (10Hz — every 3 chunks at 30ms each)
+ vad_active = self._check_vad(samples)
+ state_str = "silence"
+
+ # Wake word check
+ if self._oww is not None:
+ preds = self._oww.predict(samples)
+ score = preds.get(self._ww_model, 0.0)
+ if isinstance(score, (list, tuple)):
+ score = score[-1]
+ if score >= self._ww_thresh:
+ self._wake_word_active = True
+ self._wake_expiry = now + self._wake_window_s
+ state_str = "wake_word"
+ self.get_logger().info(
+ f"Wake word detected (score={score:.2f})"
+ )
+ else:
+ # No wake word detector — always listen
+ self._wake_word_active = True
+ self._wake_expiry = now + self._wake_window_s
+
+ if self._wake_word_active and now > self._wake_expiry:
+ self._wake_word_active = False
+
+ # Publish VAD state
+ if vad_active:
+ state_str = "speech"
+ self._publish_vad(vad_active, db, state_str)
+
+ # Only feed segmenter when wake word is active
+ if not self._wake_word_active:
+ continue
+
+ if self._segmenter is None:
+ continue
+
+ completed = self._segmenter.push(samples)
+ for utt_samples, duration in completed:
+ threading.Thread(
+ target=self._process_utterance,
+ args=(utt_samples, duration),
+ daemon=True,
+ ).start()
+
+ def _check_vad(self, samples: list) -> bool:
+ """Run Silero VAD or fall back to segmenter's energy VAD."""
+ if self._vad is not None:
+ try:
+ import torch
+ t = torch.tensor(samples, dtype=torch.float32)
+ prob = self._vad(t, self._rate).item()
+ return prob > 0.5
+ except Exception:
+ pass
+ if self._segmenter is not None:
+ return self._segmenter._vad.is_active
+ return False
+
+ # ── STT + diarization ─────────────────────────────────────────────────────
+
+ def _process_utterance(self, samples: list, duration: float) -> None:
+ """Transcribe utterance and identify speaker. Runs in thread."""
+ if self._whisper is None:
+ return
+
+ t0 = time.perf_counter()
+
+ # Convert to numpy for faster-whisper
+ try:
+ import numpy as np
+ audio_np = np.array(samples, dtype=np.float32)
+ except ImportError:
+ audio_np = samples
+
+ # Speaker embedding first (can be concurrent with transcription)
+ speaker_id = "unknown"
+ if self._ecapa is not None:
+ try:
+ import torch
+ audio_tensor = torch.tensor([samples], dtype=torch.float32)
+ with torch.no_grad():
+ emb = self._ecapa.encode_batch(audio_tensor)
+ emb_list = emb[0].cpu().numpy().tolist()
+ speaker_id = identify_speaker(
+ emb_list, self._known_speakers, self._speaker_thresh
+ )
+ except Exception as e:
+ self.get_logger().debug(f"Speaker ID error: {e}")
+
+ # Streaming Whisper transcription
+ partial_text = ""
+ try:
+ segments_gen, _info = self._whisper.transcribe(
+ audio_np,
+ language="en",
+ beam_size=3,
+ vad_filter=False,
+ )
+ for seg in segments_gen:
+ partial_text += seg.text.strip() + " "
+ if self._publish_partial:
+ self._publish_transcript(
+ partial_text.strip(), speaker_id, 0.0, duration, is_partial=True
+ )
+ except Exception as e:
+ self.get_logger().error(f"Whisper error: {e}")
+ return
+
+ final_text = partial_text.strip()
+ if not final_text:
+ return
+
+ latency_ms = (time.perf_counter() - t0) * 1000
+ self.get_logger().info(
+ f"STT [{speaker_id}] ({duration:.1f}s, {latency_ms:.0f}ms): '{final_text}'"
+ )
+ self._publish_transcript(final_text, speaker_id, 0.9, duration, is_partial=False)
+
+ # ── Publishers ────────────────────────────────────────────────────────────
+
+ def _publish_transcript(
+ self, text: str, speaker_id: str, confidence: float,
+ duration: float, is_partial: bool
+ ) -> None:
+ msg = SpeechTranscript()
+ msg.header.stamp = self.get_clock().now().to_msg()
+ msg.text = text
+ msg.speaker_id = speaker_id
+ msg.confidence = confidence
+ msg.audio_duration = duration
+ msg.is_partial = is_partial
+ self._transcript_pub.publish(msg)
+
+ def _publish_vad(self, speech_active: bool, db: float, state: str) -> None:
+ msg = VadState()
+ msg.header.stamp = self.get_clock().now().to_msg()
+ msg.speech_active = speech_active
+ msg.energy_db = db
+ msg.state = state
+ self._vad_pub.publish(msg)
+
+ # ── Cleanup ───────────────────────────────────────────────────────────────
+
+ def destroy_node(self) -> None:
+ self._running = False
+ if self._audio_stream:
+ self._audio_stream.stop_stream()
+ self._audio_stream.close()
+ if self._pa:
+ self._pa.terminate()
+ super().destroy_node()
+
+
+def main(args=None) -> None:
+ rclpy.init(args=args)
+ node = SpeechPipelineNode()
+ try:
+ rclpy.spin(node)
+ except KeyboardInterrupt:
+ pass
+ finally:
+ node.destroy_node()
+ rclpy.shutdown()
diff --git a/jetson/ros2_ws/src/saltybot_social/saltybot_social/speech_utils.py b/jetson/ros2_ws/src/saltybot_social/saltybot_social/speech_utils.py
new file mode 100644
index 0000000..900513a
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social/saltybot_social/speech_utils.py
@@ -0,0 +1,188 @@
+"""speech_utils.py — Pure helpers for wake word, VAD, and STT.
+
+No ROS2 dependencies. All functions operate on raw numpy float32 arrays
+at 16 kHz mono (the standard for Whisper / Silero VAD).
+
+Tested by test/test_speech_utils.py (no GPU required).
+"""
+
+from __future__ import annotations
+
+import math
+import struct
+from typing import Generator, List, Tuple
+
+# ── Constants ──────────────────────────────────────────────────────────────────
+SAMPLE_RATE = 16000 # Hz — standard for Whisper / Silero
+CHUNK_MS = 30 # VAD frame size
+CHUNK_SAMPLES = int(SAMPLE_RATE * CHUNK_MS / 1000) # 480 samples
+INT16_MAX = 32768.0
+
+
+# ── Audio helpers ──────────────────────────────────────────────────────────────
+
+def pcm16_to_float32(data: bytes) -> list:
+ """Convert raw PCM16 LE bytes → float32 list in [-1.0, 1.0]."""
+ n = len(data) // 2
+ samples = struct.unpack(f"<{n}h", data[:n * 2])
+ return [s / INT16_MAX for s in samples]
+
+
+def float32_to_pcm16(samples: list) -> bytes:
+ """Convert float32 list → PCM16 LE bytes (clipped to ±1.0)."""
+ clipped = [max(-1.0, min(1.0, s)) for s in samples]
+ ints = [max(-32768, min(32767, int(s * INT16_MAX))) for s in clipped]
+ return struct.pack(f"<{len(ints)}h", *ints)
+
+
+def rms_db(samples: list) -> float:
+ """Compute RMS energy in dBFS. Returns -96.0 for silence."""
+ if not samples:
+ return -96.0
+ mean_sq = sum(s * s for s in samples) / len(samples)
+ rms = math.sqrt(mean_sq) if mean_sq > 0.0 else 1e-10
+ return 20.0 * math.log10(max(rms, 1e-10))
+
+
+def chunk_audio(samples: list, chunk_size: int = CHUNK_SAMPLES) -> Generator[list, None, None]:
+ """Yield fixed-size chunks from a flat sample list."""
+ for i in range(0, len(samples) - chunk_size + 1, chunk_size):
+ yield samples[i:i + chunk_size]
+
+
+# ── Simple energy-based VAD (fallback when Silero unavailable) ────────────────
+
+class EnergyVad:
+ """Threshold-based energy VAD.
+
+ Useful as a lightweight fallback / unit-testable reference.
+ """
+
+ def __init__(
+ self,
+ threshold_db: float = -35.0,
+ onset_frames: int = 2,
+ offset_frames: int = 8,
+ ) -> None:
+ self.threshold_db = threshold_db
+ self.onset_frames = onset_frames
+ self.offset_frames = offset_frames
+ self._above_count = 0
+ self._below_count = 0
+ self._active = False
+
+ def process(self, chunk: list) -> bool:
+ """Process one chunk (CHUNK_SAMPLES long). Returns True if speech active."""
+ db = rms_db(chunk)
+ if db >= self.threshold_db:
+ self._above_count += 1
+ self._below_count = 0
+ if self._above_count >= self.onset_frames:
+ self._active = True
+ else:
+ self._below_count += 1
+ self._above_count = 0
+ if self._below_count >= self.offset_frames:
+ self._active = False
+ return self._active
+
+ def reset(self) -> None:
+ self._above_count = 0
+ self._below_count = 0
+ self._active = False
+
+ @property
+ def is_active(self) -> bool:
+ return self._active
+
+
+# ── Utterance segmenter ────────────────────────────────────────────────────────
+
+class UtteranceSegmenter:
+ """Accumulates audio chunks into complete utterances using VAD.
+
+ Yields (samples, duration_s) when an utterance ends.
+ Pre-roll keeps a buffer before speech onset (catches first phoneme).
+ """
+
+ def __init__(
+ self,
+ vad: EnergyVad,
+ pre_roll_frames: int = 5,
+ max_duration_s: float = 15.0,
+ sample_rate: int = SAMPLE_RATE,
+ chunk_samples: int = CHUNK_SAMPLES,
+ ) -> None:
+ self._vad = vad
+ self._pre_roll = pre_roll_frames
+ self._max_frames = int(max_duration_s * sample_rate / chunk_samples)
+ self._sample_rate = sample_rate
+ self._chunk_samples = chunk_samples
+ self._pre_buf: list = [] # ring buffer for pre-roll
+ self._utt_buf: list = [] # growing utterance buffer
+ self._in_utt = False
+ self._silence_after = 0
+
+ def push(self, chunk: list) -> list:
+ """Push one chunk, return list of completed (samples, duration_s) tuples."""
+ speech = self._vad.process(chunk)
+ results = []
+
+ if not self._in_utt:
+ self._pre_buf.append(chunk)
+ if len(self._pre_buf) > self._pre_roll:
+ self._pre_buf.pop(0)
+
+ if speech:
+ self._in_utt = True
+ self._utt_buf = [s for c in self._pre_buf for s in c]
+ self._utt_buf.extend(chunk)
+ self._silence_after = 0
+ else:
+ self._utt_buf.extend(chunk)
+ if not speech:
+ self._silence_after += 1
+ if (self._silence_after >= self._vad.offset_frames or
+ len(self._utt_buf) >= self._max_frames * self._chunk_samples):
+ dur = len(self._utt_buf) / self._sample_rate
+ results.append((list(self._utt_buf), dur))
+ self._utt_buf = []
+ self._in_utt = False
+ self._silence_after = 0
+ else:
+ self._silence_after = 0
+
+ return results
+
+
+# ── Speaker embedding distance ─────────────────────────────────────────────────
+
+def cosine_similarity(a: list, b: list) -> float:
+ """Cosine similarity between two embedding vectors."""
+ if len(a) != len(b) or not a:
+ return 0.0
+ dot = sum(x * y for x, y in zip(a, b))
+ norm_a = math.sqrt(sum(x * x for x in a))
+ norm_b = math.sqrt(sum(x * x for x in b))
+ if norm_a < 1e-10 or norm_b < 1e-10:
+ return 0.0
+ return dot / (norm_a * norm_b)
+
+
+def identify_speaker(
+ embedding: list,
+ known_speakers: dict,
+ threshold: float = 0.65,
+) -> str:
+ """Return speaker_id from known_speakers dict or 'unknown'.
+
+ known_speakers: {speaker_id: embedding_list}
+ """
+ best_id = "unknown"
+ best_sim = threshold
+ for spk_id, spk_emb in known_speakers.items():
+ sim = cosine_similarity(embedding, spk_emb)
+ if sim > best_sim:
+ best_sim = sim
+ best_id = spk_id
+ return best_id
diff --git a/jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_node.py b/jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_node.py
new file mode 100644
index 0000000..d1a0d7f
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_node.py
@@ -0,0 +1,228 @@
+"""tts_node.py — Streaming TTS with Piper / first-chunk streaming.
+
+Issue #85: Streaming TTS — Piper/XTTS integration with first-chunk streaming.
+
+Pipeline:
+ /social/conversation/response (ConversationResponse)
+ → sentence split → Piper ONNX synthesis (sentence by sentence)
+ → PCM16 chunks → USB speaker (sounddevice) + publish /social/tts/audio
+
+First-chunk strategy:
+ - On partial=true ConversationResponse, extract first sentence and synthesize
+ immediately → audio starts before LLM finishes generating
+ - On final=false, synthesize remaining sentences
+
+Latency target: <200ms to first audio chunk.
+
+ROS2 topics:
+ Subscribe: /social/conversation/response (saltybot_social_msgs/ConversationResponse)
+ Publish: /social/tts/audio (audio_msgs/Audio or std_msgs/UInt8MultiArray fallback)
+
+Parameters:
+ voice_path (str, "/models/piper/en_US-lessac-medium.onnx")
+ sample_rate (int, 22050)
+ volume (float, 1.0)
+ audio_device (str, "") — sounddevice device name; "" = system default
+ playback_enabled (bool, true)
+ publish_audio (bool, false) — publish PCM to ROS2 topic
+ sentence_streaming (bool, true) — synthesize sentence-by-sentence
+"""
+
+from __future__ import annotations
+
+import queue
+import threading
+import time
+from typing import Optional
+
+import rclpy
+from rclpy.node import Node
+from rclpy.qos import QoSProfile
+from std_msgs.msg import UInt8MultiArray
+
+from saltybot_social_msgs.msg import ConversationResponse
+from .tts_utils import split_sentences, strip_ssml, apply_volume, chunk_pcm, estimate_duration_ms
+
+
+class TtsNode(Node):
+ """Streaming TTS node using Piper ONNX."""
+
+ def __init__(self) -> None:
+ super().__init__("tts_node")
+
+ # ── Parameters ──────────────────────────────────────────────────────
+ self.declare_parameter("voice_path", "/models/piper/en_US-lessac-medium.onnx")
+ self.declare_parameter("sample_rate", 22050)
+ self.declare_parameter("volume", 1.0)
+ self.declare_parameter("audio_device", "")
+ self.declare_parameter("playback_enabled", True)
+ self.declare_parameter("publish_audio", False)
+ self.declare_parameter("sentence_streaming", True)
+
+ self._voice_path = self.get_parameter("voice_path").value
+ self._sample_rate = self.get_parameter("sample_rate").value
+ self._volume = self.get_parameter("volume").value
+ self._audio_device = self.get_parameter("audio_device").value or None
+ self._playback = self.get_parameter("playback_enabled").value
+ self._publish_audio = self.get_parameter("publish_audio").value
+ self._sentence_streaming = self.get_parameter("sentence_streaming").value
+
+ # ── Publishers / Subscribers ─────────────────────────────────────────
+ qos = QoSProfile(depth=10)
+ self._resp_sub = self.create_subscription(
+ ConversationResponse, "/social/conversation/response",
+ self._on_response, qos
+ )
+ if self._publish_audio:
+ self._audio_pub = self.create_publisher(
+ UInt8MultiArray, "/social/tts/audio", qos
+ )
+
+ # ── TTS engine ────────────────────────────────────────────────────────
+ self._voice = None
+ self._playback_queue: queue.Queue = queue.Queue(maxsize=16)
+ self._current_turn = -1
+ self._synthesized_turns: set = set() # turn_ids already synthesized
+ self._lock = threading.Lock()
+
+ threading.Thread(target=self._load_voice, daemon=True).start()
+ threading.Thread(target=self._playback_worker, daemon=True).start()
+
+ self.get_logger().info(
+ f"TtsNode init (voice={self._voice_path}, "
+ f"streaming={self._sentence_streaming})"
+ )
+
+ # ── Voice loading ─────────────────────────────────────────────────────────
+
+ def _load_voice(self) -> None:
+ t0 = time.time()
+ self.get_logger().info(f"Loading Piper voice: {self._voice_path}")
+ try:
+ from piper import PiperVoice
+ self._voice = PiperVoice.load(self._voice_path)
+ # Warmup synthesis to pre-JIT ONNX graph
+ warmup_text = "Hello."
+ list(self._voice.synthesize_stream_raw(warmup_text))
+ self.get_logger().info(f"Piper voice ready ({time.time()-t0:.1f}s)")
+ except Exception as e:
+ self.get_logger().error(f"Piper voice load failed: {e}")
+
+ # ── Response handler ──────────────────────────────────────────────────────
+
+ def _on_response(self, msg: ConversationResponse) -> None:
+ """Handle streaming LLM response — synthesize sentence by sentence."""
+ if not msg.text.strip():
+ return
+
+ with self._lock:
+ is_new_turn = msg.turn_id != self._current_turn
+ if is_new_turn:
+ self._current_turn = msg.turn_id
+ # Clear old synthesized sentence cache for this new turn
+ self._synthesized_turns = set()
+
+ text = strip_ssml(msg.text)
+
+ if self._sentence_streaming:
+ sentences = split_sentences(text)
+ for sentence in sentences:
+ # Track which sentences we've already queued by content hash
+ key = (msg.turn_id, hash(sentence))
+ with self._lock:
+ if key in self._synthesized_turns:
+ continue
+ self._synthesized_turns.add(key)
+ self._queue_synthesis(sentence)
+ elif not msg.is_partial:
+ # Non-streaming: synthesize full response at end
+ self._queue_synthesis(text)
+
+ def _queue_synthesis(self, text: str) -> None:
+ """Queue a text segment for synthesis in the playback worker."""
+ if not text.strip():
+ return
+ try:
+ self._playback_queue.put_nowait(text.strip())
+ except queue.Full:
+ self.get_logger().warn("TTS playback queue full, dropping segment")
+
+ # ── Playback worker ───────────────────────────────────────────────────────
+
+ def _playback_worker(self) -> None:
+ """Consume synthesis queue: synthesize → play → publish."""
+ while rclpy.ok():
+ try:
+ text = self._playback_queue.get(timeout=0.5)
+ except queue.Empty:
+ continue
+
+ if self._voice is None:
+ self.get_logger().warn("TTS voice not loaded yet")
+ self._playback_queue.task_done()
+ continue
+
+ t0 = time.perf_counter()
+ pcm_data = self._synthesize(text)
+ if pcm_data is None:
+ self._playback_queue.task_done()
+ continue
+
+ synth_ms = (time.perf_counter() - t0) * 1000
+ dur_ms = estimate_duration_ms(pcm_data, self._sample_rate)
+ self.get_logger().debug(
+ f"TTS '{text[:40]}' synth={synth_ms:.0f}ms, dur={dur_ms:.0f}ms"
+ )
+
+ if self._volume != 1.0:
+ pcm_data = apply_volume(pcm_data, self._volume)
+
+ if self._playback:
+ self._play_audio(pcm_data)
+
+ if self._publish_audio:
+ self._publish_pcm(pcm_data)
+
+ self._playback_queue.task_done()
+
+ def _synthesize(self, text: str) -> Optional[bytes]:
+ """Synthesize text to PCM16 bytes using Piper streaming."""
+ if self._voice is None:
+ return None
+ try:
+ chunks = list(self._voice.synthesize_stream_raw(text))
+ return b"".join(chunks)
+ except Exception as e:
+ self.get_logger().error(f"TTS synthesis error: {e}")
+ return None
+
+ def _play_audio(self, pcm_data: bytes) -> None:
+ """Play PCM16 data on USB speaker via sounddevice."""
+ try:
+ import sounddevice as sd
+ import numpy as np
+ samples = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0
+ sd.play(samples, samplerate=self._sample_rate, device=self._audio_device,
+ blocking=True)
+ except Exception as e:
+ self.get_logger().error(f"Audio playback error: {e}")
+
+ def _publish_pcm(self, pcm_data: bytes) -> None:
+ """Publish PCM data as UInt8MultiArray."""
+ if not hasattr(self, "_audio_pub"):
+ return
+ msg = UInt8MultiArray()
+ msg.data = list(pcm_data)
+ self._audio_pub.publish(msg)
+
+
+def main(args=None) -> None:
+ rclpy.init(args=args)
+ node = TtsNode()
+ try:
+ rclpy.spin(node)
+ except KeyboardInterrupt:
+ pass
+ finally:
+ node.destroy_node()
+ rclpy.shutdown()
diff --git a/jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_utils.py b/jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_utils.py
new file mode 100644
index 0000000..e0660a5
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_utils.py
@@ -0,0 +1,100 @@
+"""tts_utils.py — Pure TTS helpers for streaming synthesis.
+
+No ROS2 dependencies. Provides streaming Piper synthesis and audio chunking.
+
+Tested by test/test_tts_utils.py.
+"""
+
+from __future__ import annotations
+
+import re
+import struct
+from typing import Generator, List, Tuple
+
+
+# ── SSML helpers ──────────────────────────────────────────────────────────────
+
+_SSML_STRIP_RE = re.compile(r"<[^>]+>")
+
+
+def strip_ssml(text: str) -> str:
+ """Remove SSML tags, returning plain text."""
+ return _SSML_STRIP_RE.sub("", text).strip()
+
+
+def split_sentences(text: str) -> List[str]:
+ """Split text into sentences for streaming TTS synthesis.
+
+ Returns list of non-empty sentence strings.
+ """
+ # Split on sentence-ending punctuation followed by whitespace or end
+ parts = re.split(r"(?<=[.!?])\s+", text.strip())
+ return [p.strip() for p in parts if p.strip()]
+
+
+def split_first_sentence(text: str) -> Tuple[str, str]:
+ """Return (first_sentence, remainder).
+
+ Used by TTS node to begin synthesis of first sentence immediately
+ while the LLM continues generating the rest.
+ """
+ sentences = split_sentences(text)
+ if not sentences:
+ return "", ""
+ return sentences[0], " ".join(sentences[1:])
+
+
+# ── PCM audio helpers ─────────────────────────────────────────────────────────
+
+def pcm16_to_wav_bytes(pcm_data: bytes, sample_rate: int = 22050) -> bytes:
+ """Wrap raw PCM16 LE mono data with a WAV header."""
+ num_samples = len(pcm_data) // 2
+ num_channels = 1
+ bits_per_sample = 16
+ byte_rate = sample_rate * num_channels * bits_per_sample // 8
+ block_align = num_channels * bits_per_sample // 8
+ data_size = len(pcm_data)
+ header_size = 44
+
+ header = struct.pack(
+ "<4sI4s4sIHHIIHH4sI",
+ b"RIFF",
+ header_size - 8 + data_size,
+ b"WAVE",
+ b"fmt ",
+ 16, # chunk size
+ 1, # PCM format
+ num_channels,
+ sample_rate,
+ byte_rate,
+ block_align,
+ bits_per_sample,
+ b"data",
+ data_size,
+ )
+ return header + pcm_data
+
+
+def chunk_pcm(pcm_data: bytes, chunk_ms: int = 200, sample_rate: int = 22050) -> Generator[bytes, None, None]:
+ """Yield PCM data in fixed-size chunks for streaming playback."""
+ chunk_bytes = (sample_rate * 2 * chunk_ms) // 1000 # int16 = 2 bytes/sample
+ chunk_bytes = max(chunk_bytes, 2) & ~1 # ensure even (int16 aligned)
+ for i in range(0, len(pcm_data), chunk_bytes):
+ yield pcm_data[i:i + chunk_bytes]
+
+
+def estimate_duration_ms(pcm_data: bytes, sample_rate: int = 22050) -> float:
+ """Estimate playback duration in ms from PCM byte length."""
+ return len(pcm_data) / (sample_rate * 2) * 1000.0
+
+
+# ── Volume control ────────────────────────────────────────────────────────────
+
+def apply_volume(pcm_data: bytes, gain: float) -> bytes:
+ """Scale PCM16 LE samples by gain factor (0.0–2.0). Clips to ±32767."""
+ if abs(gain - 1.0) < 0.01:
+ return pcm_data
+ n = len(pcm_data) // 2
+ samples = struct.unpack(f"<{n}h", pcm_data[:n * 2])
+ scaled = [max(-32767, min(32767, int(s * gain))) for s in samples]
+ return struct.pack(f"<{n}h", *scaled)
diff --git a/jetson/ros2_ws/src/saltybot_social/setup.py b/jetson/ros2_ws/src/saltybot_social/setup.py
index 4f55aeb..e8f040c 100644
--- a/jetson/ros2_ws/src/saltybot_social/setup.py
+++ b/jetson/ros2_ws/src/saltybot_social/setup.py
@@ -21,7 +21,7 @@ setup(
zip_safe=True,
maintainer='seb',
maintainer_email='seb@vayrette.com',
- description='Social interaction layer — person state tracking, LED expression + attention',
+ description='Social interaction layer — person tracking, speech, LLM, TTS, orchestrator',
license='MIT',
tests_require=['pytest'],
entry_points={
@@ -29,6 +29,10 @@ setup(
'person_state_tracker = saltybot_social.person_state_tracker_node:main',
'expression_node = saltybot_social.expression_node:main',
'attention_node = saltybot_social.attention_node:main',
+ 'speech_pipeline_node = saltybot_social.speech_pipeline_node:main',
+ 'conversation_node = saltybot_social.conversation_node:main',
+ 'tts_node = saltybot_social.tts_node:main',
+ 'orchestrator_node = saltybot_social.orchestrator_node:main',
],
},
)
diff --git a/jetson/ros2_ws/src/saltybot_social/test/test_llm_context.py b/jetson/ros2_ws/src/saltybot_social/test/test_llm_context.py
new file mode 100644
index 0000000..9806f65
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social/test/test_llm_context.py
@@ -0,0 +1,244 @@
+"""test_llm_context.py — Unit tests for llm_context.py (no GPU / LLM needed)."""
+
+import sys
+import os
+import tempfile
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
+
+from saltybot_social.llm_context import (
+ PersonContext, ContextStore,
+ build_llama_prompt, MAX_TURNS, SUMMARY_KEEP,
+ DEFAULT_SYSTEM_PROMPT,
+)
+import pytest
+
+
+# ── PersonContext ─────────────────────────────────────────────────────────────
+
+class TestPersonContext:
+ def test_create(self):
+ ctx = PersonContext("person_1", "Alice")
+ assert ctx.person_id == "person_1"
+ assert ctx.person_name == "Alice"
+ assert ctx.turns == []
+ assert ctx.interaction_count == 0
+
+ def test_add_user_increments_count(self):
+ ctx = PersonContext("person_1")
+ ctx.add_user("Hello")
+ assert ctx.interaction_count == 1
+ assert ctx.turns[-1]["role"] == "user"
+ assert ctx.turns[-1]["content"] == "Hello"
+
+ def test_add_assistant(self):
+ ctx = PersonContext("person_1")
+ ctx.add_assistant("Hi there!")
+ assert ctx.turns[-1]["role"] == "assistant"
+
+ def test_needs_compression_false(self):
+ ctx = PersonContext("p")
+ for i in range(MAX_TURNS - 1):
+ ctx.add_user(f"msg {i}")
+ assert not ctx.needs_compression()
+
+ def test_needs_compression_true(self):
+ ctx = PersonContext("p")
+ for i in range(MAX_TURNS + 1):
+ ctx.add_user(f"msg {i}")
+ assert ctx.needs_compression()
+
+ def test_compress_keeps_recent_turns(self):
+ ctx = PersonContext("p")
+ for i in range(MAX_TURNS + 5):
+ ctx.add_user(f"msg {i}")
+ ctx.compress("Summary of conversation")
+ assert len(ctx.turns) == SUMMARY_KEEP
+ assert ctx.summary == "Summary of conversation"
+
+ def test_compress_appends_summary(self):
+ ctx = PersonContext("p")
+ ctx.add_user("msg")
+ ctx.compress("First summary")
+ ctx.add_user("msg2")
+ ctx.compress("Second summary")
+ assert "First summary" in ctx.summary
+ assert "Second summary" in ctx.summary
+
+ def test_roundtrip_serialization(self):
+ ctx = PersonContext("person_42", "Bob")
+ ctx.add_user("test")
+ ctx.add_assistant("response")
+ ctx.summary = "Earlier they discussed robots"
+ d = ctx.to_dict()
+ ctx2 = PersonContext.from_dict(d)
+ assert ctx2.person_id == "person_42"
+ assert ctx2.person_name == "Bob"
+ assert len(ctx2.turns) == 2
+ assert ctx2.summary == "Earlier they discussed robots"
+
+
+# ── ContextStore ──────────────────────────────────────────────────────────────
+
+class TestContextStore:
+ def test_get_creates_new(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ db = os.path.join(tmpdir, "ctx.json")
+ store = ContextStore(db)
+ ctx = store.get("p1", "Alice")
+ assert ctx.person_id == "p1"
+ assert ctx.person_name == "Alice"
+
+ def test_get_same_instance(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ db = os.path.join(tmpdir, "ctx.json")
+ store = ContextStore(db)
+ ctx1 = store.get("p1")
+ ctx2 = store.get("p1")
+ assert ctx1 is ctx2
+
+ def test_save_and_load(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ db = os.path.join(tmpdir, "ctx.json")
+ store = ContextStore(db)
+ ctx = store.get("p1", "Alice")
+ ctx.add_user("Hello")
+ store.save()
+
+ store2 = ContextStore(db)
+ ctx2 = store2.get("p1")
+ assert ctx2.person_name == "Alice"
+ assert len(ctx2.turns) == 1
+ assert ctx2.turns[0]["content"] == "Hello"
+
+ def test_all_persons(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ db = os.path.join(tmpdir, "ctx.json")
+ store = ContextStore(db)
+ store.get("p1")
+ store.get("p2")
+ store.get("p3")
+ assert set(store.all_persons()) == {"p1", "p2", "p3"}
+
+
+# ── Prompt builder ────────────────────────────────────────────────────────────
+
+class TestBuildLlamaPrompt:
+ def test_contains_system_prompt(self):
+ ctx = PersonContext("p1")
+ prompt = build_llama_prompt(ctx, "hello", DEFAULT_SYSTEM_PROMPT)
+ assert DEFAULT_SYSTEM_PROMPT in prompt
+
+ def test_contains_user_text(self):
+ ctx = PersonContext("p1")
+ prompt = build_llama_prompt(ctx, "What time is it?", DEFAULT_SYSTEM_PROMPT)
+ assert "What time is it?" in prompt
+
+ def test_ends_with_assistant_tag(self):
+ ctx = PersonContext("p1")
+ prompt = build_llama_prompt(ctx, "hello", DEFAULT_SYSTEM_PROMPT)
+ assert prompt.strip().endswith("<|assistant|>")
+
+ def test_history_included(self):
+ ctx = PersonContext("p1")
+ ctx.add_user("Previous question")
+ ctx.add_assistant("Previous answer")
+ prompt = build_llama_prompt(ctx, "New question", DEFAULT_SYSTEM_PROMPT)
+ assert "Previous question" in prompt
+ assert "Previous answer" in prompt
+
+ def test_summary_included(self):
+ ctx = PersonContext("p1", "Alice")
+ ctx.summary = "Alice visited last week and likes coffee"
+ prompt = build_llama_prompt(ctx, "hello", DEFAULT_SYSTEM_PROMPT)
+ assert "likes coffee" in prompt
+
+ def test_group_persons_included(self):
+ ctx = PersonContext("p1", "Alice")
+ prompt = build_llama_prompt(ctx, "hello", DEFAULT_SYSTEM_PROMPT,
+ group_persons=["p1", "p2_Bob", "p3_Charlie"])
+ assert "p2_Bob" in prompt or "p3_Charlie" in prompt
+
+ def test_no_group_persons_not_in_prompt(self):
+ ctx = PersonContext("p1")
+ prompt = build_llama_prompt(ctx, "hello", DEFAULT_SYSTEM_PROMPT)
+ assert "Also present" not in prompt
+
+
+# ── TTS utils (imported here for coverage) ────────────────────────────────────
+
+class TestTtsUtils:
+ def setup_method(self):
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
+ from saltybot_social.tts_utils import (
+ strip_ssml, split_sentences, split_first_sentence,
+ pcm16_to_wav_bytes, chunk_pcm, estimate_duration_ms, apply_volume
+ )
+ self.strip_ssml = strip_ssml
+ self.split_sentences = split_sentences
+ self.split_first_sentence = split_first_sentence
+ self.pcm16_to_wav_bytes = pcm16_to_wav_bytes
+ self.chunk_pcm = chunk_pcm
+ self.estimate_duration_ms = estimate_duration_ms
+ self.apply_volume = apply_volume
+
+ def test_strip_ssml(self):
+ assert self.strip_ssml("Hello world") == "Hello world"
+ assert self.strip_ssml("plain text") == "plain text"
+
+ def test_split_sentences_basic(self):
+ text = "Hello world. How are you? I am fine!"
+ sentences = self.split_sentences(text)
+ assert len(sentences) == 3
+
+ def test_split_sentences_single(self):
+ sentences = self.split_sentences("Hello world.")
+ assert len(sentences) == 1
+ assert sentences[0] == "Hello world."
+
+ def test_split_sentences_empty(self):
+ assert self.split_sentences("") == []
+ assert self.split_sentences(" ") == []
+
+ def test_split_first_sentence(self):
+ text = "Hello world. How are you? I am fine."
+ first, rest = self.split_first_sentence(text)
+ assert first == "Hello world."
+ assert "How are you?" in rest
+
+ def test_wav_header(self):
+ pcm = b"\x00\x00" * 1000
+ wav = self.pcm16_to_wav_bytes(pcm, 22050)
+ assert wav[:4] == b"RIFF"
+ assert wav[8:12] == b"WAVE"
+ assert len(wav) == 44 + len(pcm)
+
+ def test_chunk_pcm_sizes(self):
+ pcm = b"\x00\x00" * 2205 # ~100ms at 22050Hz
+ chunks = list(self.chunk_pcm(pcm, 50, 22050)) # 50ms chunks
+ for chunk in chunks:
+ assert len(chunk) % 2 == 0 # int16 aligned
+
+ def test_estimate_duration(self):
+ pcm = b"\x00\x00" * 22050 # 1s at 22050Hz
+ ms = self.estimate_duration_ms(pcm, 22050)
+ assert abs(ms - 1000.0) < 1.0
+
+ def test_apply_volume_unity(self):
+ pcm = b"\x00\x40" * 100 # some non-zero samples
+ result = self.apply_volume(pcm, 1.0)
+ assert result == pcm
+
+ def test_apply_volume_half(self):
+ import struct
+ pcm = struct.pack(" 1.0 should be clipped
+ pcm = float32_to_pcm16([2.0, -2.0])
+ samples = pcm16_to_float32(pcm)
+ assert samples[0] <= 1.0
+ assert samples[1] >= -1.0
+
+
+# ── RMS dB ────────────────────────────────────────────────────────────────────
+
+class TestRmsDb:
+ def test_silence_returns_minus96(self):
+ assert rms_db([]) == -96.0
+ assert rms_db([0.0] * 100) < -60.0
+
+ def test_full_scale_is_near_zero(self):
+ samples = [1.0] * 100
+ db = rms_db(samples)
+ assert abs(db) < 0.1 # 0 dBFS
+
+ def test_half_scale_is_minus6(self):
+ samples = [0.5] * 100
+ db = rms_db(samples)
+ assert abs(db - (-6.02)) < 0.1
+
+ def test_low_level_signal(self):
+ samples = [0.001] * 100
+ db = rms_db(samples)
+ assert db < -40.0
+
+
+# ── Chunk audio ───────────────────────────────────────────────────────────────
+
+class TestChunkAudio:
+ def test_even_chunks(self):
+ samples = list(range(480 * 4))
+ chunks = list(chunk_audio(samples, 480))
+ assert len(chunks) == 4
+ assert all(len(c) == 480 for c in chunks)
+
+ def test_remainder_dropped(self):
+ samples = list(range(500))
+ chunks = list(chunk_audio(samples, 480))
+ assert len(chunks) == 1
+
+ def test_empty_input(self):
+ assert list(chunk_audio([], 480)) == []
+
+
+# ── Energy VAD ────────────────────────────────────────────────────────────────
+
+class TestEnergyVad:
+ def test_silence_not_active(self):
+ vad = EnergyVad(threshold_db=-35.0, onset_frames=2, offset_frames=3)
+ silence = [0.0001] * CHUNK_SAMPLES
+ for _ in range(5):
+ active = vad.process(silence)
+ assert not active
+
+ def test_loud_signal_activates(self):
+ vad = EnergyVad(threshold_db=-35.0, onset_frames=2, offset_frames=8)
+ loud = [0.5] * CHUNK_SAMPLES
+ results = [vad.process(loud) for _ in range(3)]
+ assert results[-1] is True
+
+ def test_onset_requires_n_frames(self):
+ vad = EnergyVad(threshold_db=-35.0, onset_frames=3, offset_frames=8)
+ loud = [0.5] * CHUNK_SAMPLES
+ assert not vad.process(loud) # frame 1
+ assert not vad.process(loud) # frame 2
+ assert vad.process(loud) # frame 3 → activates
+
+ def test_offset_deactivates(self):
+ vad = EnergyVad(threshold_db=-35.0, onset_frames=1, offset_frames=2)
+ loud = [0.5] * CHUNK_SAMPLES
+ silence = [0.0001] * CHUNK_SAMPLES
+ vad.process(loud)
+ assert vad.is_active
+ vad.process(silence)
+ assert vad.is_active # offset_frames=2, need 2
+ vad.process(silence)
+ assert not vad.is_active
+
+ def test_reset(self):
+ vad = EnergyVad(threshold_db=-35.0, onset_frames=1, offset_frames=8)
+ loud = [0.5] * CHUNK_SAMPLES
+ vad.process(loud)
+ assert vad.is_active
+ vad.reset()
+ assert not vad.is_active
+
+
+# ── Utterance segmenter ───────────────────────────────────────────────────────
+
+class TestUtteranceSegmenter:
+ def _make_segmenter(self):
+ vad = EnergyVad(threshold_db=-35.0, onset_frames=1, offset_frames=2)
+ return UtteranceSegmenter(vad, pre_roll_frames=2, max_duration_s=5.0)
+
+ def test_silence_yields_nothing(self):
+ seg = self._make_segmenter()
+ silence = [0.0001] * CHUNK_SAMPLES
+ for _ in range(10):
+ results = seg.push(silence)
+ assert results == []
+
+ def test_speech_then_silence_yields_utterance(self):
+ seg = self._make_segmenter()
+ loud = [0.5] * CHUNK_SAMPLES
+ silence = [0.0001] * CHUNK_SAMPLES
+
+ # 3 speech frames
+ for _ in range(3):
+ assert seg.push(loud) == []
+
+ # 2 silence frames to trigger offset
+ all_results = []
+ for _ in range(3):
+ all_results.extend(seg.push(silence))
+
+ assert len(all_results) == 1
+ utt_samples, duration = all_results[0]
+ assert duration > 0
+ assert len(utt_samples) > 0
+
+ def test_preroll_included(self):
+ vad = EnergyVad(threshold_db=-35.0, onset_frames=1, offset_frames=2)
+ seg = UtteranceSegmenter(vad, pre_roll_frames=3, max_duration_s=5.0)
+ silence = [0.0001] * CHUNK_SAMPLES
+ loud = [0.5] * CHUNK_SAMPLES
+
+ # Push 3 silence frames (become pre-roll buffer)
+ for _ in range(3):
+ seg.push(silence)
+
+ # Push 1 speech frame (triggers onset, pre-roll included)
+ seg.push(loud)
+
+ # End utterance
+ all_results = []
+ for _ in range(3):
+ all_results.extend(seg.push(silence))
+
+ assert len(all_results) == 1
+ # Utterance should include pre-roll + speech + trailing silence
+ utt_samples, _ = all_results[0]
+ assert len(utt_samples) >= 4 * CHUNK_SAMPLES
+
+
+# ── Speaker identification ────────────────────────────────────────────────────
+
+class TestSpeakerIdentification:
+ def test_cosine_identical(self):
+ v = [1.0, 0.0, 0.0]
+ assert abs(cosine_similarity(v, v) - 1.0) < 1e-6
+
+ def test_cosine_orthogonal(self):
+ a = [1.0, 0.0]
+ b = [0.0, 1.0]
+ assert abs(cosine_similarity(a, b)) < 1e-6
+
+ def test_cosine_opposite(self):
+ a = [1.0, 0.0]
+ b = [-1.0, 0.0]
+ assert abs(cosine_similarity(a, b) + 1.0) < 1e-6
+
+ def test_cosine_zero_vector(self):
+ assert cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
+
+ def test_cosine_length_mismatch(self):
+ assert cosine_similarity([1.0], [1.0, 0.0]) == 0.0
+
+ def test_identify_speaker_match(self):
+ known = {
+ "alice": [1.0, 0.0, 0.0],
+ "bob": [0.0, 1.0, 0.0],
+ }
+ result = identify_speaker([1.0, 0.0, 0.0], known, threshold=0.5)
+ assert result == "alice"
+
+ def test_identify_speaker_unknown(self):
+ known = {"alice": [1.0, 0.0, 0.0]}
+ result = identify_speaker([0.0, 1.0, 0.0], known, threshold=0.5)
+ assert result == "unknown"
+
+ def test_identify_speaker_empty_db(self):
+ result = identify_speaker([1.0, 0.0], {}, threshold=0.5)
+ assert result == "unknown"
+
+ def test_identify_best_match(self):
+ known = {
+ "alice": [1.0, 0.0, 0.0],
+ "bob": [0.8, 0.6, 0.0], # closer to [1,0,0] than bob is
+ }
+ # Query close to alice
+ result = identify_speaker([0.99, 0.01, 0.0], known, threshold=0.5)
+ assert result == "alice"
diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/CMakeLists.txt b/jetson/ros2_ws/src/saltybot_social_msgs/CMakeLists.txt
index 544ca7e..c7eb9b4 100644
--- a/jetson/ros2_ws/src/saltybot_social_msgs/CMakeLists.txt
+++ b/jetson/ros2_ws/src/saltybot_social_msgs/CMakeLists.txt
@@ -28,6 +28,11 @@ rosidl_generate_interfaces(${PROJECT_NAME}
"srv/QueryMood.srv"
# Issue #92 — multi-modal tracking fusion
"msg/FusedTarget.msg"
+ # Issue #81 — speech pipeline
+ "msg/SpeechTranscript.msg"
+ "msg/VadState.msg"
+ # Issue #83 — conversation engine
+ "msg/ConversationResponse.msg"
DEPENDENCIES std_msgs geometry_msgs builtin_interfaces
)
diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/msg/ConversationResponse.msg b/jetson/ros2_ws/src/saltybot_social_msgs/msg/ConversationResponse.msg
new file mode 100644
index 0000000..e35b331
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_msgs/msg/ConversationResponse.msg
@@ -0,0 +1,9 @@
+# ConversationResponse.msg — LLM response, supports streaming token output.
+# Published by conversation_node on /social/conversation/response
+
+std_msgs/Header header
+
+string text # Full or partial response text
+string speaker_id # Who the response is addressed to
+bool is_partial # true = streaming token chunk, false = final response
+int32 turn_id # Conversation turn counter (for deduplication)
diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/msg/SpeechTranscript.msg b/jetson/ros2_ws/src/saltybot_social_msgs/msg/SpeechTranscript.msg
new file mode 100644
index 0000000..8f4cb5f
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_msgs/msg/SpeechTranscript.msg
@@ -0,0 +1,10 @@
+# SpeechTranscript.msg — Result of STT with speaker identification.
+# Published by speech_pipeline_node on /social/speech/transcript
+
+std_msgs/Header header
+
+string text # Transcribed text (UTF-8)
+string speaker_id # e.g. "person_42" or "unknown"
+float32 confidence # ASR confidence 0..1
+float32 audio_duration # Duration of the utterance in seconds
+bool is_partial # true = intermediate streaming result, false = final
diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/msg/VadState.msg b/jetson/ros2_ws/src/saltybot_social_msgs/msg/VadState.msg
new file mode 100644
index 0000000..5a2cff6
--- /dev/null
+++ b/jetson/ros2_ws/src/saltybot_social_msgs/msg/VadState.msg
@@ -0,0 +1,8 @@
+# VadState.msg — Voice Activity Detection state.
+# Published by speech_pipeline_node on /social/speech/vad_state
+
+std_msgs/Header header
+
+bool speech_active # true = speech detected, false = silence
+float32 energy_db # RMS energy in dBFS
+string state # "silence" | "speech" | "wake_word"