feat(social): speech + LLM + TTS + orchestrator (#81 #83 #85 #89) #102

Merged
sl-jetson merged 1 commits from sl-jetson/social-speech-llm-tts into main 2026-03-02 08:24:25 -05:00
20 changed files with 2342 additions and 3 deletions
Showing only changes of commit 5043578934 - Show all commits

View File

@ -0,0 +1,12 @@
conversation_node:
ros__parameters:
model_path: "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf"
n_ctx: 4096
n_gpu_layers: 20 # Increase for more VRAM (Orin has 8GB shared)
max_tokens: 200
temperature: 0.7
top_p: 0.9
soul_path: "/soul/SOUL.md"
context_db_path: "/social_db/conversation_context.json"
save_interval_s: 30.0
stream: true

View File

@ -0,0 +1,8 @@
orchestrator_node:
ros__parameters:
gpu_mem_warn_mb: 4000.0 # Warn when GPU free < 4GB
gpu_mem_throttle_mb: 2000.0 # Throttle new inferences at < 2GB
watchdog_timeout_s: 30.0
latency_window: 20
profile_enabled: true
state_publish_rate: 2.0 # Hz

View File

@ -0,0 +1,13 @@
speech_pipeline_node:
ros__parameters:
mic_device_index: -1 # -1 = system default; use `arecord -l` to list
sample_rate: 16000
wake_word_model: "hey_salty"
wake_word_threshold: 0.5
vad_threshold_db: -35.0
use_silero_vad: true
whisper_model: "small" # small (~500ms), medium (better quality, ~900ms)
whisper_compute_type: "float16"
speaker_threshold: 0.65
speaker_db_path: "/social_db/speaker_embeddings.json"
publish_partial: true

View File

@ -0,0 +1,9 @@
tts_node:
ros__parameters:
voice_path: "/models/piper/en_US-lessac-medium.onnx"
sample_rate: 22050
volume: 1.0
audio_device: "" # "" = system default; set to device name if needed
playback_enabled: true
publish_audio: false
sentence_streaming: true

View File

@ -0,0 +1,129 @@
"""social_bot.launch.py — Launch the full social-bot pipeline.
Issue #89: End-to-end pipeline orchestrator launch file.
Brings up all social-bot nodes in dependency order:
1. Orchestrator (state machine + watchdog) t=0s
2. Speech pipeline (wake word + VAD + STT) t=2s
3. Conversation engine (LLM) t=4s
4. TTS node (Piper streaming) t=4s
5. Person state tracker (already in social.launch.py optional)
Launch args:
enable_speech (bool, true)
enable_llm (bool, true)
enable_tts (bool, true)
enable_orchestrator (bool, true)
voice_path (str, /models/piper/en_US-lessac-medium.onnx)
whisper_model (str, small)
llm_model_path (str, /models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf)
n_gpu_layers (int, 20)
"""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch.actions import (
DeclareLaunchArgument, GroupAction, LogInfo, TimerAction
)
from launch.conditions import IfCondition
from launch.substitutions import LaunchConfiguration, PythonExpression
from launch_ros.actions import Node
def generate_launch_description() -> LaunchDescription:
pkg = get_package_share_directory("saltybot_social")
cfg = os.path.join(pkg, "config")
# ── Launch arguments ─────────────────────────────────────────────────────
args = [
DeclareLaunchArgument("enable_speech", default_value="true"),
DeclareLaunchArgument("enable_llm", default_value="true"),
DeclareLaunchArgument("enable_tts", default_value="true"),
DeclareLaunchArgument("enable_orchestrator", default_value="true"),
DeclareLaunchArgument("voice_path",
default_value="/models/piper/en_US-lessac-medium.onnx"),
DeclareLaunchArgument("whisper_model", default_value="small"),
DeclareLaunchArgument("llm_model_path",
default_value="/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf"),
DeclareLaunchArgument("n_gpu_layers", default_value="20"),
]
# ── Orchestrator (t=0s) ──────────────────────────────────────────────────
orchestrator = Node(
package="saltybot_social",
executable="orchestrator_node",
name="orchestrator_node",
parameters=[os.path.join(cfg, "orchestrator_params.yaml")],
condition=IfCondition(LaunchConfiguration("enable_orchestrator")),
output="screen",
emulate_tty=True,
)
# ── Speech pipeline (t=2s) ───────────────────────────────────────────────
speech = TimerAction(period=2.0, actions=[
GroupAction(
condition=IfCondition(LaunchConfiguration("enable_speech")),
actions=[
LogInfo(msg="[social_bot] Starting speech pipeline..."),
Node(
package="saltybot_social",
executable="speech_pipeline_node",
name="speech_pipeline_node",
parameters=[
os.path.join(cfg, "speech_params.yaml"),
{"whisper_model": LaunchConfiguration("whisper_model")},
],
output="screen",
emulate_tty=True,
),
]
)
])
# ── LLM conversation engine (t=4s) ───────────────────────────────────────
llm = TimerAction(period=4.0, actions=[
GroupAction(
condition=IfCondition(LaunchConfiguration("enable_llm")),
actions=[
LogInfo(msg="[social_bot] Starting LLM conversation engine..."),
Node(
package="saltybot_social",
executable="conversation_node",
name="conversation_node",
parameters=[
os.path.join(cfg, "conversation_params.yaml"),
{
"model_path": LaunchConfiguration("llm_model_path"),
"n_gpu_layers": LaunchConfiguration("n_gpu_layers"),
},
],
output="screen",
emulate_tty=True,
),
]
)
])
# ── TTS node (t=4s, parallel with LLM) ──────────────────────────────────
tts = TimerAction(period=4.0, actions=[
GroupAction(
condition=IfCondition(LaunchConfiguration("enable_tts")),
actions=[
LogInfo(msg="[social_bot] Starting TTS node..."),
Node(
package="saltybot_social",
executable="tts_node",
name="tts_node",
parameters=[
os.path.join(cfg, "tts_params.yaml"),
{"voice_path": LaunchConfiguration("voice_path")},
],
output="screen",
emulate_tty=True,
),
]
)
])
return LaunchDescription(args + [orchestrator, speech, llm, tts])

View File

@ -5,9 +5,12 @@
<version>0.1.0</version> <version>0.1.0</version>
<description> <description>
Social interaction layer for saltybot. Social interaction layer for saltybot.
speech_pipeline_node: wake word + VAD + Whisper STT + diarization (Issue #81).
conversation_node: local LLM with per-person context (Issue #83).
tts_node: streaming TTS with Piper first-chunk (Issue #85).
orchestrator_node: pipeline state machine + GPU watchdog + latency profiler (Issue #89).
person_state_tracker: multi-modal person identity fusion (Issue #82). person_state_tracker: multi-modal person identity fusion (Issue #82).
expression_node: bridges /social/mood to ESP32-C3 NeoPixel ring over serial (Issue #86). expression_node: LED expression + motor attention (Issue #86).
attention_node: rotates robot toward active speaker via /social/persons bearing (Issue #86).
</description> </description>
<maintainer email="seb@vayrette.com">seb</maintainer> <maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license> <license>MIT</license>

View File

@ -0,0 +1,267 @@
"""conversation_node.py — Local LLM conversation engine with per-person context.
Issue #83: Conversation engine for social-bot.
Stack: Phi-3-mini or Llama-3.2-3B GGUF Q4_K_M via llama-cpp-python (CUDA).
Subscribes /social/speech/transcript builds per-person prompt streams
token output publishes /social/conversation/response.
Streaming: publishes partial=true tokens as they arrive, then final=false
at end of generation. TTS node can begin synthesis on first sentence boundary.
ROS2 topics:
Subscribe: /social/speech/transcript (saltybot_social_msgs/SpeechTranscript)
Publish: /social/conversation/response (saltybot_social_msgs/ConversationResponse)
Parameters:
model_path (str, "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf")
n_ctx (int, 4096)
n_gpu_layers (int, 20) GPU offload layers (increase for more VRAM usage)
max_tokens (int, 200)
temperature (float, 0.7)
top_p (float, 0.9)
soul_path (str, "/soul/SOUL.md")
context_db_path (str, "/social_db/conversation_context.json")
save_interval_s (float, 30.0) how often to persist context to disk
stream (bool, true)
"""
from __future__ import annotations
import threading
import time
from typing import Optional
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile
from saltybot_social_msgs.msg import SpeechTranscript, ConversationResponse
from .llm_context import ContextStore, build_llama_prompt, load_system_prompt, needs_summary_prompt
class ConversationNode(Node):
"""Local LLM inference node with per-person conversation memory."""
def __init__(self) -> None:
super().__init__("conversation_node")
# ── Parameters ──────────────────────────────────────────────────────
self.declare_parameter("model_path",
"/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf")
self.declare_parameter("n_ctx", 4096)
self.declare_parameter("n_gpu_layers", 20)
self.declare_parameter("max_tokens", 200)
self.declare_parameter("temperature", 0.7)
self.declare_parameter("top_p", 0.9)
self.declare_parameter("soul_path", "/soul/SOUL.md")
self.declare_parameter("context_db_path", "/social_db/conversation_context.json")
self.declare_parameter("save_interval_s", 30.0)
self.declare_parameter("stream", True)
self._model_path = self.get_parameter("model_path").value
self._n_ctx = self.get_parameter("n_ctx").value
self._n_gpu = self.get_parameter("n_gpu_layers").value
self._max_tokens = self.get_parameter("max_tokens").value
self._temperature = self.get_parameter("temperature").value
self._top_p = self.get_parameter("top_p").value
self._soul_path = self.get_parameter("soul_path").value
self._db_path = self.get_parameter("context_db_path").value
self._save_interval = self.get_parameter("save_interval_s").value
self._stream = self.get_parameter("stream").value
# ── Publishers / Subscribers ─────────────────────────────────────────
qos = QoSProfile(depth=10)
self._resp_pub = self.create_publisher(
ConversationResponse, "/social/conversation/response", qos
)
self._transcript_sub = self.create_subscription(
SpeechTranscript, "/social/speech/transcript",
self._on_transcript, qos
)
# ── State ────────────────────────────────────────────────────────────
self._llm = None
self._system_prompt = load_system_prompt(self._soul_path)
self._ctx_store = ContextStore(self._db_path)
self._lock = threading.Lock()
self._turn_counter = 0
self._generating = False
self._last_save = time.time()
# ── Load LLM in background ────────────────────────────────────────────
threading.Thread(target=self._load_llm, daemon=True).start()
# ── Periodic context save ────────────────────────────────────────────
self._save_timer = self.create_timer(self._save_interval, self._save_context)
self.get_logger().info(
f"ConversationNode init (model={self._model_path}, "
f"gpu_layers={self._n_gpu}, ctx={self._n_ctx})"
)
# ── Model loading ─────────────────────────────────────────────────────────
def _load_llm(self) -> None:
t0 = time.time()
self.get_logger().info(f"Loading LLM: {self._model_path}")
try:
from llama_cpp import Llama
self._llm = Llama(
model_path=self._model_path,
n_ctx=self._n_ctx,
n_gpu_layers=self._n_gpu,
n_threads=4,
verbose=False,
)
self.get_logger().info(
f"LLM ready ({time.time()-t0:.1f}s). "
f"Context: {self._n_ctx} tokens, GPU layers: {self._n_gpu}"
)
except Exception as e:
self.get_logger().error(f"LLM load failed: {e}")
# ── Transcript callback ───────────────────────────────────────────────────
def _on_transcript(self, msg: SpeechTranscript) -> None:
"""Handle final transcripts only (skip streaming partials)."""
if msg.is_partial:
return
if not msg.text.strip():
return
self.get_logger().info(
f"Transcript [{msg.speaker_id}]: '{msg.text}'"
)
threading.Thread(
target=self._generate_response,
args=(msg.text.strip(), msg.speaker_id),
daemon=True,
).start()
# ── LLM inference ─────────────────────────────────────────────────────────
def _generate_response(self, user_text: str, speaker_id: str) -> None:
"""Generate LLM response with streaming. Runs in thread."""
if self._llm is None:
self.get_logger().warn("LLM not loaded yet, dropping utterance")
return
with self._lock:
if self._generating:
self.get_logger().warn("LLM busy, dropping utterance")
return
self._generating = True
self._turn_counter += 1
turn_id = self._turn_counter
try:
ctx = self._ctx_store.get(speaker_id)
# Summary compression if context is long
if ctx.needs_compression():
self._compress_context(ctx)
ctx.add_user(user_text)
prompt = build_llama_prompt(
ctx, user_text, self._system_prompt
)
t0 = time.perf_counter()
full_response = ""
if self._stream:
output = self._llm(
prompt,
max_tokens=self._max_tokens,
temperature=self._temperature,
top_p=self._top_p,
stream=True,
stop=["<|user|>", "<|system|>", "\n\n\n"],
)
for chunk in output:
token = chunk["choices"][0]["text"]
full_response += token
# Publish partial after each sentence boundary for low TTS latency
if token.endswith((".", "!", "?", "\n")):
self._publish_response(
full_response.strip(), speaker_id, turn_id, is_partial=True
)
else:
output = self._llm(
prompt,
max_tokens=self._max_tokens,
temperature=self._temperature,
top_p=self._top_p,
stream=False,
stop=["<|user|>", "<|system|>"],
)
full_response = output["choices"][0]["text"]
full_response = full_response.strip()
latency_ms = (time.perf_counter() - t0) * 1000
self.get_logger().info(
f"LLM [{speaker_id}] ({latency_ms:.0f}ms): '{full_response[:80]}'"
)
ctx.add_assistant(full_response)
self._publish_response(full_response, speaker_id, turn_id, is_partial=False)
except Exception as e:
self.get_logger().error(f"LLM inference error: {e}")
finally:
with self._lock:
self._generating = False
def _compress_context(self, ctx) -> None:
"""Ask LLM to summarize old turns for context compression."""
if self._llm is None:
ctx.compress("(history omitted)")
return
try:
summary_prompt = needs_summary_prompt(ctx)
result = self._llm(summary_prompt, max_tokens=80, temperature=0.3, stream=False)
summary = result["choices"][0]["text"].strip()
ctx.compress(summary)
self.get_logger().debug(
f"Context compressed for {ctx.person_id}: '{summary[:60]}'"
)
except Exception:
ctx.compress("(history omitted)")
# ── Publish ───────────────────────────────────────────────────────────────
def _publish_response(
self, text: str, speaker_id: str, turn_id: int, is_partial: bool
) -> None:
msg = ConversationResponse()
msg.header.stamp = self.get_clock().now().to_msg()
msg.text = text
msg.speaker_id = speaker_id
msg.is_partial = is_partial
msg.turn_id = turn_id
self._resp_pub.publish(msg)
def _save_context(self) -> None:
try:
self._ctx_store.save()
except Exception as e:
self.get_logger().error(f"Context save error: {e}")
def destroy_node(self) -> None:
self._save_context()
super().destroy_node()
def main(args=None) -> None:
rclpy.init(args=args)
node = ConversationNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@ -0,0 +1,176 @@
"""llm_context.py — Per-person conversation context management.
No ROS2 dependencies. Manages sliding-window conversation history per person_id,
builds llama.cpp prompt, handles summary compression when context is full.
Used by conversation_node.py.
Tested by test/test_llm_context.py.
"""
from __future__ import annotations
import json
import os
import time
from typing import List, Optional, Tuple
# ── Constants ──────────────────────────────────────────────────────────────────
MAX_TURNS = 20 # Max turns to keep before summary compression
SUMMARY_KEEP = 4 # Keep last N turns after summarizing older ones
SYSTEM_PROMPT_PATH = "/soul/SOUL.md"
DEFAULT_SYSTEM_PROMPT = (
"You are Salty, a friendly social robot. "
"You are helpful, warm, and concise. "
"Keep responses under 2 sentences unless asked to elaborate."
)
# ── Conversation history ───────────────────────────────────────────────────────
class PersonContext:
"""Conversation history + relationship memory for one person."""
def __init__(self, person_id: str, person_name: str = "") -> None:
self.person_id = person_id
self.person_name = person_name or person_id
self.turns: List[dict] = [] # {"role": "user"|"assistant", "content": str, "ts": float}
self.summary: str = ""
self.last_seen: float = time.time()
self.interaction_count: int = 0
def add_user(self, text: str) -> None:
self.turns.append({"role": "user", "content": text, "ts": time.time()})
self.last_seen = time.time()
self.interaction_count += 1
def add_assistant(self, text: str) -> None:
self.turns.append({"role": "assistant", "content": text, "ts": time.time()})
def needs_compression(self) -> bool:
return len(self.turns) > MAX_TURNS
def compress(self, summary_text: str) -> None:
"""Replace old turns with a summary, keeping last SUMMARY_KEEP turns."""
old_summary = self.summary
if old_summary:
self.summary = old_summary + " " + summary_text
else:
self.summary = summary_text
# Keep most recent turns
self.turns = self.turns[-SUMMARY_KEEP:]
def to_dict(self) -> dict:
return {
"person_id": self.person_id,
"person_name": self.person_name,
"turns": self.turns,
"summary": self.summary,
"last_seen": self.last_seen,
"interaction_count": self.interaction_count,
}
@classmethod
def from_dict(cls, d: dict) -> "PersonContext":
ctx = cls(d["person_id"], d.get("person_name", ""))
ctx.turns = d.get("turns", [])
ctx.summary = d.get("summary", "")
ctx.last_seen = d.get("last_seen", 0.0)
ctx.interaction_count = d.get("interaction_count", 0)
return ctx
# ── Context store ──────────────────────────────────────────────────────────────
class ContextStore:
"""Persistent per-person conversation context with optional JSON backing."""
def __init__(self, db_path: str = "/social_db/conversation_context.json") -> None:
self._db_path = db_path
self._contexts: dict = {} # person_id → PersonContext
self._load()
def get(self, person_id: str, person_name: str = "") -> PersonContext:
if person_id not in self._contexts:
self._contexts[person_id] = PersonContext(person_id, person_name)
elif person_name and not self._contexts[person_id].person_name:
self._contexts[person_id].person_name = person_name
return self._contexts[person_id]
def save(self) -> None:
os.makedirs(os.path.dirname(self._db_path), exist_ok=True)
with open(self._db_path, "w") as f:
json.dump(
{pid: ctx.to_dict() for pid, ctx in self._contexts.items()},
f, indent=2
)
def _load(self) -> None:
if os.path.exists(self._db_path):
try:
with open(self._db_path) as f:
raw = json.load(f)
self._contexts = {
pid: PersonContext.from_dict(d) for pid, d in raw.items()
}
except Exception:
self._contexts = {}
def all_persons(self) -> List[str]:
return list(self._contexts.keys())
# ── Prompt builder ─────────────────────────────────────────────────────────────
def load_system_prompt(path: str = SYSTEM_PROMPT_PATH) -> str:
"""Load SOUL.md as system prompt. Falls back to default."""
if os.path.exists(path):
with open(path) as f:
return f.read().strip()
return DEFAULT_SYSTEM_PROMPT
def build_llama_prompt(
ctx: PersonContext,
user_text: str,
system_prompt: str,
group_persons: Optional[List[str]] = None,
) -> str:
"""Build Phi-3 / Llama-3 style chat prompt string.
Uses the ChatML format supported by llama-cpp-python.
"""
lines = [f"<|system|>\n{system_prompt}"]
if ctx.summary:
lines.append(
f"<|system|>\n[Memory of {ctx.person_name}]: {ctx.summary}"
)
if group_persons:
others = [p for p in group_persons if p != ctx.person_id]
if others:
lines.append(
f"<|system|>\n[Also present]: {', '.join(others)}"
)
for turn in ctx.turns:
role = "user" if turn["role"] == "user" else "assistant"
lines.append(f"<|{role}|>\n{turn['content']}")
lines.append(f"<|user|>\n{user_text}")
lines.append("<|assistant|>")
return "\n".join(lines)
def needs_summary_prompt(ctx: PersonContext) -> str:
"""Build a prompt asking the LLM to summarize old conversation turns."""
turns_text = "\n".join(
f"{t['role'].upper()}: {t['content']}"
for t in ctx.turns[:-SUMMARY_KEEP]
)
return (
f"<|system|>\nSummarize this conversation between {ctx.person_name} and "
"Salty in one sentence, focusing on what was discussed and their relationship.\n"
f"<|user|>\n{turns_text}\n<|assistant|>"
)

View File

@ -0,0 +1,298 @@
"""orchestrator_node.py — End-to-end social-bot pipeline orchestrator.
Issue #89: Main loop, GPU memory scheduling, latency profiling, watchdog.
State machine:
IDLE LISTENING (wake word) THINKING (LLM generating) SPEAKING (TTS)
___________________________________________|
Responsibilities:
- Owns the top-level pipeline state machine
- Monitors GPU memory and throttles inference when approaching limit
- Profiles end-to-end latency per interaction turn
- Health watchdog: restart speech pipeline on hang
- Exposes /social/orchestrator/state for UI / other nodes
ROS2 topics:
Subscribe: /social/speech/vad_state (VadState)
/social/speech/transcript (SpeechTranscript)
/social/conversation/response (ConversationResponse)
Publish: /social/orchestrator/state (std_msgs/String) JSON status blob
Parameters:
gpu_mem_warn_mb (float, 4000.0) warn at this free GPU memory (MB)
gpu_mem_throttle_mb (float, 2000.0) suspend new inference requests
watchdog_timeout_s (float, 30.0) restart speech node if no VAD in N seconds
latency_window (int, 20) rolling window for latency stats
profile_enabled (bool, true)
"""
from __future__ import annotations
import json
import threading
import time
from collections import deque
from enum import Enum
from typing import Optional
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile
from std_msgs.msg import String
from saltybot_social_msgs.msg import VadState, SpeechTranscript, ConversationResponse
# ── State machine ──────────────────────────────────────────────────────────────
class PipelineState(Enum):
IDLE = "idle"
LISTENING = "listening"
THINKING = "thinking"
SPEAKING = "speaking"
THROTTLED = "throttled" # GPU OOM risk — waiting for memory
# ── Latency tracker ────────────────────────────────────────────────────────────
class LatencyTracker:
"""Rolling window latency statistics per pipeline stage."""
def __init__(self, window: int = 20) -> None:
self._window = window
self._wakeword_to_transcript: deque = deque(maxlen=window)
self._transcript_to_llm_first: deque = deque(maxlen=window)
self._llm_first_to_tts_first: deque = deque(maxlen=window)
self._end_to_end: deque = deque(maxlen=window)
def record(self, stage: str, latency_ms: float) -> None:
mapping = {
"wakeword_to_transcript": self._wakeword_to_transcript,
"transcript_to_llm": self._transcript_to_llm_first,
"llm_to_tts": self._llm_first_to_tts_first,
"end_to_end": self._end_to_end,
}
q = mapping.get(stage)
if q is not None:
q.append(latency_ms)
def stats(self) -> dict:
def _stat(q: deque) -> dict:
if not q:
return {}
s = sorted(q)
return {
"mean_ms": round(sum(s) / len(s), 1),
"p50_ms": round(s[len(s) // 2], 1),
"p95_ms": round(s[int(len(s) * 0.95)], 1),
"n": len(s),
}
return {
"wakeword_to_transcript": _stat(self._wakeword_to_transcript),
"transcript_to_llm": _stat(self._transcript_to_llm_first),
"llm_to_tts": _stat(self._llm_first_to_tts_first),
"end_to_end": _stat(self._end_to_end),
}
# ── GPU memory helpers ────────────────────────────────────────────────────────
def get_gpu_free_mb() -> Optional[float]:
"""Return free GPU memory in MB, or None if unavailable."""
try:
import pycuda.driver as drv
drv.init()
ctx = drv.Device(0).make_context()
free, _ = drv.mem_get_info()
ctx.pop()
return free / 1024 / 1024
except Exception:
pass
try:
import subprocess
out = subprocess.check_output(
["nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits"],
timeout=2
).decode().strip()
return float(out.split("\n")[0])
except Exception:
return None
# ── Orchestrator ───────────────────────────────────────────────────────────────
class OrchestratorNode(Node):
"""Pipeline orchestrator — state machine, GPU watchdog, latency profiler."""
def __init__(self) -> None:
super().__init__("orchestrator_node")
# ── Parameters ──────────────────────────────────────────────────────
self.declare_parameter("gpu_mem_warn_mb", 4000.0)
self.declare_parameter("gpu_mem_throttle_mb", 2000.0)
self.declare_parameter("watchdog_timeout_s", 30.0)
self.declare_parameter("latency_window", 20)
self.declare_parameter("profile_enabled", True)
self.declare_parameter("state_publish_rate", 2.0) # Hz
self._gpu_warn = self.get_parameter("gpu_mem_warn_mb").value
self._gpu_throttle = self.get_parameter("gpu_mem_throttle_mb").value
self._watchdog_timeout = self.get_parameter("watchdog_timeout_s").value
self._profile = self.get_parameter("profile_enabled").value
self._state_rate = self.get_parameter("state_publish_rate").value
# ── Publishers / Subscribers ─────────────────────────────────────────
qos = QoSProfile(depth=10)
self._state_pub = self.create_publisher(String, "/social/orchestrator/state", qos)
self._vad_sub = self.create_subscription(
VadState, "/social/speech/vad_state", self._on_vad, qos
)
self._transcript_sub = self.create_subscription(
SpeechTranscript, "/social/speech/transcript", self._on_transcript, qos
)
self._response_sub = self.create_subscription(
ConversationResponse, "/social/conversation/response", self._on_response, qos
)
# ── State ────────────────────────────────────────────────────────────
self._state = PipelineState.IDLE
self._state_lock = threading.Lock()
self._tracker = LatencyTracker(self.get_parameter("latency_window").value)
# Timestamps for latency tracking
self._t_wake: float = 0.0
self._t_transcript: float = 0.0
self._t_llm_first: float = 0.0
self._t_tts_first: float = 0.0
self._last_vad_t: float = time.time()
self._last_transcript_t: float = 0.0
self._last_response_t: float = 0.0
# ── Timers ────────────────────────────────────────────────────────────
self._state_timer = self.create_timer(1.0 / self._state_rate, self._publish_state)
self._gpu_timer = self.create_timer(5.0, self._check_gpu)
self._watchdog_timer = self.create_timer(10.0, self._watchdog)
self.get_logger().info("OrchestratorNode ready")
# ── State transitions ─────────────────────────────────────────────────────
def _set_state(self, new_state: PipelineState) -> None:
with self._state_lock:
if self._state != new_state:
self.get_logger().info(
f"Pipeline: {self._state.value}{new_state.value}"
)
self._state = new_state
# ── Subscribers ───────────────────────────────────────────────────────────
def _on_vad(self, msg: VadState) -> None:
self._last_vad_t = time.time()
if msg.state == "wake_word":
self._t_wake = time.time()
self._set_state(PipelineState.LISTENING)
elif msg.speech_active and self._state == PipelineState.IDLE:
self._set_state(PipelineState.LISTENING)
def _on_transcript(self, msg: SpeechTranscript) -> None:
if msg.is_partial:
return
self._last_transcript_t = time.time()
self._t_transcript = time.time()
self._set_state(PipelineState.THINKING)
if self._profile and self._t_wake > 0:
latency_ms = (self._t_transcript - self._t_wake) * 1000
self._tracker.record("wakeword_to_transcript", latency_ms)
if latency_ms > 500:
self.get_logger().warn(
f"Wake→transcript latency high: {latency_ms:.0f}ms (target <500ms)"
)
def _on_response(self, msg: ConversationResponse) -> None:
self._last_response_t = time.time()
if msg.is_partial and self._t_llm_first == 0.0:
# First token from LLM
self._t_llm_first = time.time()
if self._profile and self._t_transcript > 0:
latency_ms = (self._t_llm_first - self._t_transcript) * 1000
self._tracker.record("transcript_to_llm", latency_ms)
self._set_state(PipelineState.SPEAKING)
if not msg.is_partial:
# Full response complete
self._t_llm_first = 0.0
self._t_wake = 0.0
self._t_transcript = 0.0
if self._profile and self._t_wake > 0:
e2e_ms = (time.time() - self._t_wake) * 1000
self._tracker.record("end_to_end", e2e_ms)
# Return to IDLE after a short delay (TTS still playing)
threading.Timer(2.0, lambda: self._set_state(PipelineState.IDLE)).start()
# ── GPU memory check ──────────────────────────────────────────────────────
def _check_gpu(self) -> None:
free_mb = get_gpu_free_mb()
if free_mb is None:
return
if free_mb < self._gpu_throttle:
self.get_logger().error(
f"GPU memory critical: {free_mb:.0f}MB free < {self._gpu_throttle:.0f}MB threshold"
)
self._set_state(PipelineState.THROTTLED)
elif free_mb < self._gpu_warn:
self.get_logger().warn(f"GPU memory low: {free_mb:.0f}MB free")
else:
# Recover from throttled state
with self._state_lock:
if self._state == PipelineState.THROTTLED:
self._set_state(PipelineState.IDLE)
# ── Watchdog ──────────────────────────────────────────────────────────────
def _watchdog(self) -> None:
"""Alert if speech pipeline has gone silent for too long."""
since_vad = time.time() - self._last_vad_t
if since_vad > self._watchdog_timeout:
self.get_logger().error(
f"Watchdog: No VAD signal for {since_vad:.0f}s. "
"Speech pipeline may be hung. Check /social/speech/vad_state."
)
# ── State publisher ───────────────────────────────────────────────────────
def _publish_state(self) -> None:
with self._state_lock:
current = self._state.value
payload = {
"state": current,
"ts": time.time(),
"latency": self._tracker.stats() if self._profile else {},
"gpu_free_mb": get_gpu_free_mb(),
}
msg = String()
msg.data = json.dumps(payload)
self._state_pub.publish(msg)
def main(args=None) -> None:
rclpy.init(args=args)
node = OrchestratorNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@ -0,0 +1,391 @@
"""speech_pipeline_node.py — Wake word + VAD + Whisper STT + speaker diarization.
Issue #81: Speech pipeline for social-bot.
Pipeline:
USB mic (PyAudio) VAD (Silero / energy fallback)
wake word gate (OpenWakeWord "hey_salty")
utterance buffer faster-whisper STT (Orin GPU)
ECAPA-TDNN speaker embedding identify_speaker()
publish /social/speech/transcript + /social/speech/vad_state
Latency target: <500ms wake-to-first-token (streaming partial transcripts).
ROS2 topics published:
/social/speech/transcript (saltybot_social_msgs/SpeechTranscript)
/social/speech/vad_state (saltybot_social_msgs/VadState)
Parameters:
mic_device_index (int, -1 = system default)
sample_rate (int, 16000)
wake_word_model (str, "hey_salty" OpenWakeWord model name or path)
wake_word_threshold (float, 0.5)
vad_threshold_db (float, -35.0) fallback energy VAD
use_silero_vad (bool, true)
whisper_model (str, "small")
whisper_compute_type (str, "float16")
speaker_threshold (float, 0.65) cosine similarity threshold
speaker_db_path (str, "/social_db/speaker_embeddings.json")
publish_partial (bool, true)
"""
from __future__ import annotations
import json
import os
import threading
import time
from typing import Optional
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile, QoSReliabilityPolicy, QoSDurabilityPolicy
from builtin_interfaces.msg import Time
from std_msgs.msg import Header
from saltybot_social_msgs.msg import SpeechTranscript, VadState
from .speech_utils import (
EnergyVad, UtteranceSegmenter,
pcm16_to_float32, rms_db, identify_speaker,
SAMPLE_RATE, CHUNK_SAMPLES,
)
class SpeechPipelineNode(Node):
"""Wake word → VAD → STT → diarization → ROS2 publisher."""
def __init__(self) -> None:
super().__init__("speech_pipeline_node")
# ── Parameters ──────────────────────────────────────────────────────
self.declare_parameter("mic_device_index", -1)
self.declare_parameter("sample_rate", SAMPLE_RATE)
self.declare_parameter("wake_word_model", "hey_salty")
self.declare_parameter("wake_word_threshold", 0.5)
self.declare_parameter("vad_threshold_db", -35.0)
self.declare_parameter("use_silero_vad", True)
self.declare_parameter("whisper_model", "small")
self.declare_parameter("whisper_compute_type", "float16")
self.declare_parameter("speaker_threshold", 0.65)
self.declare_parameter("speaker_db_path", "/social_db/speaker_embeddings.json")
self.declare_parameter("publish_partial", True)
self._mic_idx = self.get_parameter("mic_device_index").value
self._rate = self.get_parameter("sample_rate").value
self._ww_model = self.get_parameter("wake_word_model").value
self._ww_thresh = self.get_parameter("wake_word_threshold").value
self._vad_thresh_db = self.get_parameter("vad_threshold_db").value
self._use_silero = self.get_parameter("use_silero_vad").value
self._whisper_model_name = self.get_parameter("whisper_model").value
self._compute_type = self.get_parameter("whisper_compute_type").value
self._speaker_thresh = self.get_parameter("speaker_threshold").value
self._speaker_db = self.get_parameter("speaker_db_path").value
self._publish_partial = self.get_parameter("publish_partial").value
# ── Publishers ───────────────────────────────────────────────────────
qos = QoSProfile(depth=10)
self._transcript_pub = self.create_publisher(SpeechTranscript,
"/social/speech/transcript", qos)
self._vad_pub = self.create_publisher(VadState, "/social/speech/vad_state", qos)
# ── Internal state ───────────────────────────────────────────────────
self._whisper = None
self._ecapa = None
self._oww = None
self._vad = None
self._segmenter = None
self._known_speakers: dict = {}
self._audio_stream = None
self._pa = None
self._lock = threading.Lock()
self._running = False
self._wake_word_active = False
self._wake_expiry = 0.0
self._wake_window_s = 8.0 # seconds to listen after wake word
self._turn_counter = 0
# ── Lazy model loading in background thread ──────────────────────────
self._model_thread = threading.Thread(target=self._load_models, daemon=True)
self._model_thread.start()
# ── Model loading ─────────────────────────────────────────────────────────
def _load_models(self) -> None:
"""Load all AI models (runs in background to not block ROS spin)."""
self.get_logger().info("Loading speech models (Whisper, ECAPA-TDNN, VAD)...")
t0 = time.time()
# faster-whisper
try:
from faster_whisper import WhisperModel
self._whisper = WhisperModel(
self._whisper_model_name,
device="cuda",
compute_type=self._compute_type,
download_root="/models",
)
self.get_logger().info(
f"Whisper '{self._whisper_model_name}' loaded "
f"({time.time()-t0:.1f}s)"
)
except Exception as e:
self.get_logger().error(f"Whisper load failed: {e}")
# Silero VAD
if self._use_silero:
try:
from silero_vad import load_silero_vad
self._vad = load_silero_vad()
self.get_logger().info("Silero VAD loaded")
except Exception as e:
self.get_logger().warn(f"Silero VAD unavailable ({e}), using energy VAD")
self._vad = None
energy_vad = EnergyVad(threshold_db=self._vad_thresh_db)
self._segmenter = UtteranceSegmenter(energy_vad)
# ECAPA-TDNN speaker embeddings
try:
from speechbrain.pretrained import EncoderClassifier
self._ecapa = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir="/models/speechbrain_ecapa",
)
self.get_logger().info("ECAPA-TDNN loaded")
except Exception as e:
self.get_logger().warn(f"ECAPA-TDNN unavailable ({e}), using 'unknown' speaker")
# OpenWakeWord
try:
import openwakeword
from openwakeword.model import Model as OWWModel
self._oww = OWWModel(
wakeword_models=[self._ww_model],
inference_framework="onnx",
)
self.get_logger().info(f"OpenWakeWord '{self._ww_model}' loaded")
except Exception as e:
self.get_logger().warn(
f"OpenWakeWord unavailable ({e}). Robot will listen continuously."
)
# Load speaker DB
self._load_speaker_db()
# Start audio capture
self._start_audio()
self.get_logger().info(f"Speech pipeline ready ({time.time()-t0:.1f}s total)")
def _load_speaker_db(self) -> None:
"""Load speaker embedding database from JSON."""
if os.path.exists(self._speaker_db):
try:
with open(self._speaker_db) as f:
self._known_speakers = json.load(f)
self.get_logger().info(
f"Speaker DB loaded: {len(self._known_speakers)} persons"
)
except Exception as e:
self.get_logger().error(f"Speaker DB load error: {e}")
# ── Audio capture ─────────────────────────────────────────────────────────
def _start_audio(self) -> None:
"""Open PyAudio stream and start capture thread."""
try:
import pyaudio
self._pa = pyaudio.PyAudio()
dev_idx = None if self._mic_idx < 0 else self._mic_idx
self._audio_stream = self._pa.open(
format=pyaudio.paInt16,
channels=1,
rate=self._rate,
input=True,
input_device_index=dev_idx,
frames_per_buffer=CHUNK_SAMPLES,
)
self._running = True
t = threading.Thread(target=self._audio_loop, daemon=True)
t.start()
self.get_logger().info(
f"Audio capture started (device={self._mic_idx}, {self._rate}Hz)"
)
except Exception as e:
self.get_logger().error(f"Audio stream error: {e}")
def _audio_loop(self) -> None:
"""Continuous audio capture loop — runs in dedicated thread."""
while self._running:
try:
raw = self._audio_stream.read(CHUNK_SAMPLES, exception_on_overflow=False)
except Exception:
continue
samples = pcm16_to_float32(raw)
db = rms_db(samples)
now = time.time()
# VAD state publish (10Hz — every 3 chunks at 30ms each)
vad_active = self._check_vad(samples)
state_str = "silence"
# Wake word check
if self._oww is not None:
preds = self._oww.predict(samples)
score = preds.get(self._ww_model, 0.0)
if isinstance(score, (list, tuple)):
score = score[-1]
if score >= self._ww_thresh:
self._wake_word_active = True
self._wake_expiry = now + self._wake_window_s
state_str = "wake_word"
self.get_logger().info(
f"Wake word detected (score={score:.2f})"
)
else:
# No wake word detector — always listen
self._wake_word_active = True
self._wake_expiry = now + self._wake_window_s
if self._wake_word_active and now > self._wake_expiry:
self._wake_word_active = False
# Publish VAD state
if vad_active:
state_str = "speech"
self._publish_vad(vad_active, db, state_str)
# Only feed segmenter when wake word is active
if not self._wake_word_active:
continue
if self._segmenter is None:
continue
completed = self._segmenter.push(samples)
for utt_samples, duration in completed:
threading.Thread(
target=self._process_utterance,
args=(utt_samples, duration),
daemon=True,
).start()
def _check_vad(self, samples: list) -> bool:
"""Run Silero VAD or fall back to segmenter's energy VAD."""
if self._vad is not None:
try:
import torch
t = torch.tensor(samples, dtype=torch.float32)
prob = self._vad(t, self._rate).item()
return prob > 0.5
except Exception:
pass
if self._segmenter is not None:
return self._segmenter._vad.is_active
return False
# ── STT + diarization ─────────────────────────────────────────────────────
def _process_utterance(self, samples: list, duration: float) -> None:
"""Transcribe utterance and identify speaker. Runs in thread."""
if self._whisper is None:
return
t0 = time.perf_counter()
# Convert to numpy for faster-whisper
try:
import numpy as np
audio_np = np.array(samples, dtype=np.float32)
except ImportError:
audio_np = samples
# Speaker embedding first (can be concurrent with transcription)
speaker_id = "unknown"
if self._ecapa is not None:
try:
import torch
audio_tensor = torch.tensor([samples], dtype=torch.float32)
with torch.no_grad():
emb = self._ecapa.encode_batch(audio_tensor)
emb_list = emb[0].cpu().numpy().tolist()
speaker_id = identify_speaker(
emb_list, self._known_speakers, self._speaker_thresh
)
except Exception as e:
self.get_logger().debug(f"Speaker ID error: {e}")
# Streaming Whisper transcription
partial_text = ""
try:
segments_gen, _info = self._whisper.transcribe(
audio_np,
language="en",
beam_size=3,
vad_filter=False,
)
for seg in segments_gen:
partial_text += seg.text.strip() + " "
if self._publish_partial:
self._publish_transcript(
partial_text.strip(), speaker_id, 0.0, duration, is_partial=True
)
except Exception as e:
self.get_logger().error(f"Whisper error: {e}")
return
final_text = partial_text.strip()
if not final_text:
return
latency_ms = (time.perf_counter() - t0) * 1000
self.get_logger().info(
f"STT [{speaker_id}] ({duration:.1f}s, {latency_ms:.0f}ms): '{final_text}'"
)
self._publish_transcript(final_text, speaker_id, 0.9, duration, is_partial=False)
# ── Publishers ────────────────────────────────────────────────────────────
def _publish_transcript(
self, text: str, speaker_id: str, confidence: float,
duration: float, is_partial: bool
) -> None:
msg = SpeechTranscript()
msg.header.stamp = self.get_clock().now().to_msg()
msg.text = text
msg.speaker_id = speaker_id
msg.confidence = confidence
msg.audio_duration = duration
msg.is_partial = is_partial
self._transcript_pub.publish(msg)
def _publish_vad(self, speech_active: bool, db: float, state: str) -> None:
msg = VadState()
msg.header.stamp = self.get_clock().now().to_msg()
msg.speech_active = speech_active
msg.energy_db = db
msg.state = state
self._vad_pub.publish(msg)
# ── Cleanup ───────────────────────────────────────────────────────────────
def destroy_node(self) -> None:
self._running = False
if self._audio_stream:
self._audio_stream.stop_stream()
self._audio_stream.close()
if self._pa:
self._pa.terminate()
super().destroy_node()
def main(args=None) -> None:
rclpy.init(args=args)
node = SpeechPipelineNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@ -0,0 +1,188 @@
"""speech_utils.py — Pure helpers for wake word, VAD, and STT.
No ROS2 dependencies. All functions operate on raw numpy float32 arrays
at 16 kHz mono (the standard for Whisper / Silero VAD).
Tested by test/test_speech_utils.py (no GPU required).
"""
from __future__ import annotations
import math
import struct
from typing import Generator, List, Tuple
# ── Constants ──────────────────────────────────────────────────────────────────
SAMPLE_RATE = 16000 # Hz — standard for Whisper / Silero
CHUNK_MS = 30 # VAD frame size
CHUNK_SAMPLES = int(SAMPLE_RATE * CHUNK_MS / 1000) # 480 samples
INT16_MAX = 32768.0
# ── Audio helpers ──────────────────────────────────────────────────────────────
def pcm16_to_float32(data: bytes) -> list:
"""Convert raw PCM16 LE bytes → float32 list in [-1.0, 1.0]."""
n = len(data) // 2
samples = struct.unpack(f"<{n}h", data[:n * 2])
return [s / INT16_MAX for s in samples]
def float32_to_pcm16(samples: list) -> bytes:
"""Convert float32 list → PCM16 LE bytes (clipped to ±1.0)."""
clipped = [max(-1.0, min(1.0, s)) for s in samples]
ints = [max(-32768, min(32767, int(s * INT16_MAX))) for s in clipped]
return struct.pack(f"<{len(ints)}h", *ints)
def rms_db(samples: list) -> float:
"""Compute RMS energy in dBFS. Returns -96.0 for silence."""
if not samples:
return -96.0
mean_sq = sum(s * s for s in samples) / len(samples)
rms = math.sqrt(mean_sq) if mean_sq > 0.0 else 1e-10
return 20.0 * math.log10(max(rms, 1e-10))
def chunk_audio(samples: list, chunk_size: int = CHUNK_SAMPLES) -> Generator[list, None, None]:
"""Yield fixed-size chunks from a flat sample list."""
for i in range(0, len(samples) - chunk_size + 1, chunk_size):
yield samples[i:i + chunk_size]
# ── Simple energy-based VAD (fallback when Silero unavailable) ────────────────
class EnergyVad:
"""Threshold-based energy VAD.
Useful as a lightweight fallback / unit-testable reference.
"""
def __init__(
self,
threshold_db: float = -35.0,
onset_frames: int = 2,
offset_frames: int = 8,
) -> None:
self.threshold_db = threshold_db
self.onset_frames = onset_frames
self.offset_frames = offset_frames
self._above_count = 0
self._below_count = 0
self._active = False
def process(self, chunk: list) -> bool:
"""Process one chunk (CHUNK_SAMPLES long). Returns True if speech active."""
db = rms_db(chunk)
if db >= self.threshold_db:
self._above_count += 1
self._below_count = 0
if self._above_count >= self.onset_frames:
self._active = True
else:
self._below_count += 1
self._above_count = 0
if self._below_count >= self.offset_frames:
self._active = False
return self._active
def reset(self) -> None:
self._above_count = 0
self._below_count = 0
self._active = False
@property
def is_active(self) -> bool:
return self._active
# ── Utterance segmenter ────────────────────────────────────────────────────────
class UtteranceSegmenter:
"""Accumulates audio chunks into complete utterances using VAD.
Yields (samples, duration_s) when an utterance ends.
Pre-roll keeps a buffer before speech onset (catches first phoneme).
"""
def __init__(
self,
vad: EnergyVad,
pre_roll_frames: int = 5,
max_duration_s: float = 15.0,
sample_rate: int = SAMPLE_RATE,
chunk_samples: int = CHUNK_SAMPLES,
) -> None:
self._vad = vad
self._pre_roll = pre_roll_frames
self._max_frames = int(max_duration_s * sample_rate / chunk_samples)
self._sample_rate = sample_rate
self._chunk_samples = chunk_samples
self._pre_buf: list = [] # ring buffer for pre-roll
self._utt_buf: list = [] # growing utterance buffer
self._in_utt = False
self._silence_after = 0
def push(self, chunk: list) -> list:
"""Push one chunk, return list of completed (samples, duration_s) tuples."""
speech = self._vad.process(chunk)
results = []
if not self._in_utt:
self._pre_buf.append(chunk)
if len(self._pre_buf) > self._pre_roll:
self._pre_buf.pop(0)
if speech:
self._in_utt = True
self._utt_buf = [s for c in self._pre_buf for s in c]
self._utt_buf.extend(chunk)
self._silence_after = 0
else:
self._utt_buf.extend(chunk)
if not speech:
self._silence_after += 1
if (self._silence_after >= self._vad.offset_frames or
len(self._utt_buf) >= self._max_frames * self._chunk_samples):
dur = len(self._utt_buf) / self._sample_rate
results.append((list(self._utt_buf), dur))
self._utt_buf = []
self._in_utt = False
self._silence_after = 0
else:
self._silence_after = 0
return results
# ── Speaker embedding distance ─────────────────────────────────────────────────
def cosine_similarity(a: list, b: list) -> float:
"""Cosine similarity between two embedding vectors."""
if len(a) != len(b) or not a:
return 0.0
dot = sum(x * y for x, y in zip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(x * x for x in b))
if norm_a < 1e-10 or norm_b < 1e-10:
return 0.0
return dot / (norm_a * norm_b)
def identify_speaker(
embedding: list,
known_speakers: dict,
threshold: float = 0.65,
) -> str:
"""Return speaker_id from known_speakers dict or 'unknown'.
known_speakers: {speaker_id: embedding_list}
"""
best_id = "unknown"
best_sim = threshold
for spk_id, spk_emb in known_speakers.items():
sim = cosine_similarity(embedding, spk_emb)
if sim > best_sim:
best_sim = sim
best_id = spk_id
return best_id

View File

@ -0,0 +1,228 @@
"""tts_node.py — Streaming TTS with Piper / first-chunk streaming.
Issue #85: Streaming TTS — Piper/XTTS integration with first-chunk streaming.
Pipeline:
/social/conversation/response (ConversationResponse)
sentence split Piper ONNX synthesis (sentence by sentence)
PCM16 chunks USB speaker (sounddevice) + publish /social/tts/audio
First-chunk strategy:
- On partial=true ConversationResponse, extract first sentence and synthesize
immediately audio starts before LLM finishes generating
- On final=false, synthesize remaining sentences
Latency target: <200ms to first audio chunk.
ROS2 topics:
Subscribe: /social/conversation/response (saltybot_social_msgs/ConversationResponse)
Publish: /social/tts/audio (audio_msgs/Audio or std_msgs/UInt8MultiArray fallback)
Parameters:
voice_path (str, "/models/piper/en_US-lessac-medium.onnx")
sample_rate (int, 22050)
volume (float, 1.0)
audio_device (str, "") sounddevice device name; "" = system default
playback_enabled (bool, true)
publish_audio (bool, false) publish PCM to ROS2 topic
sentence_streaming (bool, true) synthesize sentence-by-sentence
"""
from __future__ import annotations
import queue
import threading
import time
from typing import Optional
import rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile
from std_msgs.msg import UInt8MultiArray
from saltybot_social_msgs.msg import ConversationResponse
from .tts_utils import split_sentences, strip_ssml, apply_volume, chunk_pcm, estimate_duration_ms
class TtsNode(Node):
"""Streaming TTS node using Piper ONNX."""
def __init__(self) -> None:
super().__init__("tts_node")
# ── Parameters ──────────────────────────────────────────────────────
self.declare_parameter("voice_path", "/models/piper/en_US-lessac-medium.onnx")
self.declare_parameter("sample_rate", 22050)
self.declare_parameter("volume", 1.0)
self.declare_parameter("audio_device", "")
self.declare_parameter("playback_enabled", True)
self.declare_parameter("publish_audio", False)
self.declare_parameter("sentence_streaming", True)
self._voice_path = self.get_parameter("voice_path").value
self._sample_rate = self.get_parameter("sample_rate").value
self._volume = self.get_parameter("volume").value
self._audio_device = self.get_parameter("audio_device").value or None
self._playback = self.get_parameter("playback_enabled").value
self._publish_audio = self.get_parameter("publish_audio").value
self._sentence_streaming = self.get_parameter("sentence_streaming").value
# ── Publishers / Subscribers ─────────────────────────────────────────
qos = QoSProfile(depth=10)
self._resp_sub = self.create_subscription(
ConversationResponse, "/social/conversation/response",
self._on_response, qos
)
if self._publish_audio:
self._audio_pub = self.create_publisher(
UInt8MultiArray, "/social/tts/audio", qos
)
# ── TTS engine ────────────────────────────────────────────────────────
self._voice = None
self._playback_queue: queue.Queue = queue.Queue(maxsize=16)
self._current_turn = -1
self._synthesized_turns: set = set() # turn_ids already synthesized
self._lock = threading.Lock()
threading.Thread(target=self._load_voice, daemon=True).start()
threading.Thread(target=self._playback_worker, daemon=True).start()
self.get_logger().info(
f"TtsNode init (voice={self._voice_path}, "
f"streaming={self._sentence_streaming})"
)
# ── Voice loading ─────────────────────────────────────────────────────────
def _load_voice(self) -> None:
t0 = time.time()
self.get_logger().info(f"Loading Piper voice: {self._voice_path}")
try:
from piper import PiperVoice
self._voice = PiperVoice.load(self._voice_path)
# Warmup synthesis to pre-JIT ONNX graph
warmup_text = "Hello."
list(self._voice.synthesize_stream_raw(warmup_text))
self.get_logger().info(f"Piper voice ready ({time.time()-t0:.1f}s)")
except Exception as e:
self.get_logger().error(f"Piper voice load failed: {e}")
# ── Response handler ──────────────────────────────────────────────────────
def _on_response(self, msg: ConversationResponse) -> None:
"""Handle streaming LLM response — synthesize sentence by sentence."""
if not msg.text.strip():
return
with self._lock:
is_new_turn = msg.turn_id != self._current_turn
if is_new_turn:
self._current_turn = msg.turn_id
# Clear old synthesized sentence cache for this new turn
self._synthesized_turns = set()
text = strip_ssml(msg.text)
if self._sentence_streaming:
sentences = split_sentences(text)
for sentence in sentences:
# Track which sentences we've already queued by content hash
key = (msg.turn_id, hash(sentence))
with self._lock:
if key in self._synthesized_turns:
continue
self._synthesized_turns.add(key)
self._queue_synthesis(sentence)
elif not msg.is_partial:
# Non-streaming: synthesize full response at end
self._queue_synthesis(text)
def _queue_synthesis(self, text: str) -> None:
"""Queue a text segment for synthesis in the playback worker."""
if not text.strip():
return
try:
self._playback_queue.put_nowait(text.strip())
except queue.Full:
self.get_logger().warn("TTS playback queue full, dropping segment")
# ── Playback worker ───────────────────────────────────────────────────────
def _playback_worker(self) -> None:
"""Consume synthesis queue: synthesize → play → publish."""
while rclpy.ok():
try:
text = self._playback_queue.get(timeout=0.5)
except queue.Empty:
continue
if self._voice is None:
self.get_logger().warn("TTS voice not loaded yet")
self._playback_queue.task_done()
continue
t0 = time.perf_counter()
pcm_data = self._synthesize(text)
if pcm_data is None:
self._playback_queue.task_done()
continue
synth_ms = (time.perf_counter() - t0) * 1000
dur_ms = estimate_duration_ms(pcm_data, self._sample_rate)
self.get_logger().debug(
f"TTS '{text[:40]}' synth={synth_ms:.0f}ms, dur={dur_ms:.0f}ms"
)
if self._volume != 1.0:
pcm_data = apply_volume(pcm_data, self._volume)
if self._playback:
self._play_audio(pcm_data)
if self._publish_audio:
self._publish_pcm(pcm_data)
self._playback_queue.task_done()
def _synthesize(self, text: str) -> Optional[bytes]:
"""Synthesize text to PCM16 bytes using Piper streaming."""
if self._voice is None:
return None
try:
chunks = list(self._voice.synthesize_stream_raw(text))
return b"".join(chunks)
except Exception as e:
self.get_logger().error(f"TTS synthesis error: {e}")
return None
def _play_audio(self, pcm_data: bytes) -> None:
"""Play PCM16 data on USB speaker via sounddevice."""
try:
import sounddevice as sd
import numpy as np
samples = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0
sd.play(samples, samplerate=self._sample_rate, device=self._audio_device,
blocking=True)
except Exception as e:
self.get_logger().error(f"Audio playback error: {e}")
def _publish_pcm(self, pcm_data: bytes) -> None:
"""Publish PCM data as UInt8MultiArray."""
if not hasattr(self, "_audio_pub"):
return
msg = UInt8MultiArray()
msg.data = list(pcm_data)
self._audio_pub.publish(msg)
def main(args=None) -> None:
rclpy.init(args=args)
node = TtsNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@ -0,0 +1,100 @@
"""tts_utils.py — Pure TTS helpers for streaming synthesis.
No ROS2 dependencies. Provides streaming Piper synthesis and audio chunking.
Tested by test/test_tts_utils.py.
"""
from __future__ import annotations
import re
import struct
from typing import Generator, List, Tuple
# ── SSML helpers ──────────────────────────────────────────────────────────────
_SSML_STRIP_RE = re.compile(r"<[^>]+>")
def strip_ssml(text: str) -> str:
"""Remove SSML tags, returning plain text."""
return _SSML_STRIP_RE.sub("", text).strip()
def split_sentences(text: str) -> List[str]:
"""Split text into sentences for streaming TTS synthesis.
Returns list of non-empty sentence strings.
"""
# Split on sentence-ending punctuation followed by whitespace or end
parts = re.split(r"(?<=[.!?])\s+", text.strip())
return [p.strip() for p in parts if p.strip()]
def split_first_sentence(text: str) -> Tuple[str, str]:
"""Return (first_sentence, remainder).
Used by TTS node to begin synthesis of first sentence immediately
while the LLM continues generating the rest.
"""
sentences = split_sentences(text)
if not sentences:
return "", ""
return sentences[0], " ".join(sentences[1:])
# ── PCM audio helpers ─────────────────────────────────────────────────────────
def pcm16_to_wav_bytes(pcm_data: bytes, sample_rate: int = 22050) -> bytes:
"""Wrap raw PCM16 LE mono data with a WAV header."""
num_samples = len(pcm_data) // 2
num_channels = 1
bits_per_sample = 16
byte_rate = sample_rate * num_channels * bits_per_sample // 8
block_align = num_channels * bits_per_sample // 8
data_size = len(pcm_data)
header_size = 44
header = struct.pack(
"<4sI4s4sIHHIIHH4sI",
b"RIFF",
header_size - 8 + data_size,
b"WAVE",
b"fmt ",
16, # chunk size
1, # PCM format
num_channels,
sample_rate,
byte_rate,
block_align,
bits_per_sample,
b"data",
data_size,
)
return header + pcm_data
def chunk_pcm(pcm_data: bytes, chunk_ms: int = 200, sample_rate: int = 22050) -> Generator[bytes, None, None]:
"""Yield PCM data in fixed-size chunks for streaming playback."""
chunk_bytes = (sample_rate * 2 * chunk_ms) // 1000 # int16 = 2 bytes/sample
chunk_bytes = max(chunk_bytes, 2) & ~1 # ensure even (int16 aligned)
for i in range(0, len(pcm_data), chunk_bytes):
yield pcm_data[i:i + chunk_bytes]
def estimate_duration_ms(pcm_data: bytes, sample_rate: int = 22050) -> float:
"""Estimate playback duration in ms from PCM byte length."""
return len(pcm_data) / (sample_rate * 2) * 1000.0
# ── Volume control ────────────────────────────────────────────────────────────
def apply_volume(pcm_data: bytes, gain: float) -> bytes:
"""Scale PCM16 LE samples by gain factor (0.02.0). Clips to ±32767."""
if abs(gain - 1.0) < 0.01:
return pcm_data
n = len(pcm_data) // 2
samples = struct.unpack(f"<{n}h", pcm_data[:n * 2])
scaled = [max(-32767, min(32767, int(s * gain))) for s in samples]
return struct.pack(f"<{n}h", *scaled)

View File

@ -21,7 +21,7 @@ setup(
zip_safe=True, zip_safe=True,
maintainer='seb', maintainer='seb',
maintainer_email='seb@vayrette.com', maintainer_email='seb@vayrette.com',
description='Social interaction layer — person state tracking, LED expression + attention', description='Social interaction layer — person tracking, speech, LLM, TTS, orchestrator',
license='MIT', license='MIT',
tests_require=['pytest'], tests_require=['pytest'],
entry_points={ entry_points={
@ -29,6 +29,10 @@ setup(
'person_state_tracker = saltybot_social.person_state_tracker_node:main', 'person_state_tracker = saltybot_social.person_state_tracker_node:main',
'expression_node = saltybot_social.expression_node:main', 'expression_node = saltybot_social.expression_node:main',
'attention_node = saltybot_social.attention_node:main', 'attention_node = saltybot_social.attention_node:main',
'speech_pipeline_node = saltybot_social.speech_pipeline_node:main',
'conversation_node = saltybot_social.conversation_node:main',
'tts_node = saltybot_social.tts_node:main',
'orchestrator_node = saltybot_social.orchestrator_node:main',
], ],
}, },
) )

View File

@ -0,0 +1,244 @@
"""test_llm_context.py — Unit tests for llm_context.py (no GPU / LLM needed)."""
import sys
import os
import tempfile
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from saltybot_social.llm_context import (
PersonContext, ContextStore,
build_llama_prompt, MAX_TURNS, SUMMARY_KEEP,
DEFAULT_SYSTEM_PROMPT,
)
import pytest
# ── PersonContext ─────────────────────────────────────────────────────────────
class TestPersonContext:
def test_create(self):
ctx = PersonContext("person_1", "Alice")
assert ctx.person_id == "person_1"
assert ctx.person_name == "Alice"
assert ctx.turns == []
assert ctx.interaction_count == 0
def test_add_user_increments_count(self):
ctx = PersonContext("person_1")
ctx.add_user("Hello")
assert ctx.interaction_count == 1
assert ctx.turns[-1]["role"] == "user"
assert ctx.turns[-1]["content"] == "Hello"
def test_add_assistant(self):
ctx = PersonContext("person_1")
ctx.add_assistant("Hi there!")
assert ctx.turns[-1]["role"] == "assistant"
def test_needs_compression_false(self):
ctx = PersonContext("p")
for i in range(MAX_TURNS - 1):
ctx.add_user(f"msg {i}")
assert not ctx.needs_compression()
def test_needs_compression_true(self):
ctx = PersonContext("p")
for i in range(MAX_TURNS + 1):
ctx.add_user(f"msg {i}")
assert ctx.needs_compression()
def test_compress_keeps_recent_turns(self):
ctx = PersonContext("p")
for i in range(MAX_TURNS + 5):
ctx.add_user(f"msg {i}")
ctx.compress("Summary of conversation")
assert len(ctx.turns) == SUMMARY_KEEP
assert ctx.summary == "Summary of conversation"
def test_compress_appends_summary(self):
ctx = PersonContext("p")
ctx.add_user("msg")
ctx.compress("First summary")
ctx.add_user("msg2")
ctx.compress("Second summary")
assert "First summary" in ctx.summary
assert "Second summary" in ctx.summary
def test_roundtrip_serialization(self):
ctx = PersonContext("person_42", "Bob")
ctx.add_user("test")
ctx.add_assistant("response")
ctx.summary = "Earlier they discussed robots"
d = ctx.to_dict()
ctx2 = PersonContext.from_dict(d)
assert ctx2.person_id == "person_42"
assert ctx2.person_name == "Bob"
assert len(ctx2.turns) == 2
assert ctx2.summary == "Earlier they discussed robots"
# ── ContextStore ──────────────────────────────────────────────────────────────
class TestContextStore:
def test_get_creates_new(self):
with tempfile.TemporaryDirectory() as tmpdir:
db = os.path.join(tmpdir, "ctx.json")
store = ContextStore(db)
ctx = store.get("p1", "Alice")
assert ctx.person_id == "p1"
assert ctx.person_name == "Alice"
def test_get_same_instance(self):
with tempfile.TemporaryDirectory() as tmpdir:
db = os.path.join(tmpdir, "ctx.json")
store = ContextStore(db)
ctx1 = store.get("p1")
ctx2 = store.get("p1")
assert ctx1 is ctx2
def test_save_and_load(self):
with tempfile.TemporaryDirectory() as tmpdir:
db = os.path.join(tmpdir, "ctx.json")
store = ContextStore(db)
ctx = store.get("p1", "Alice")
ctx.add_user("Hello")
store.save()
store2 = ContextStore(db)
ctx2 = store2.get("p1")
assert ctx2.person_name == "Alice"
assert len(ctx2.turns) == 1
assert ctx2.turns[0]["content"] == "Hello"
def test_all_persons(self):
with tempfile.TemporaryDirectory() as tmpdir:
db = os.path.join(tmpdir, "ctx.json")
store = ContextStore(db)
store.get("p1")
store.get("p2")
store.get("p3")
assert set(store.all_persons()) == {"p1", "p2", "p3"}
# ── Prompt builder ────────────────────────────────────────────────────────────
class TestBuildLlamaPrompt:
def test_contains_system_prompt(self):
ctx = PersonContext("p1")
prompt = build_llama_prompt(ctx, "hello", DEFAULT_SYSTEM_PROMPT)
assert DEFAULT_SYSTEM_PROMPT in prompt
def test_contains_user_text(self):
ctx = PersonContext("p1")
prompt = build_llama_prompt(ctx, "What time is it?", DEFAULT_SYSTEM_PROMPT)
assert "What time is it?" in prompt
def test_ends_with_assistant_tag(self):
ctx = PersonContext("p1")
prompt = build_llama_prompt(ctx, "hello", DEFAULT_SYSTEM_PROMPT)
assert prompt.strip().endswith("<|assistant|>")
def test_history_included(self):
ctx = PersonContext("p1")
ctx.add_user("Previous question")
ctx.add_assistant("Previous answer")
prompt = build_llama_prompt(ctx, "New question", DEFAULT_SYSTEM_PROMPT)
assert "Previous question" in prompt
assert "Previous answer" in prompt
def test_summary_included(self):
ctx = PersonContext("p1", "Alice")
ctx.summary = "Alice visited last week and likes coffee"
prompt = build_llama_prompt(ctx, "hello", DEFAULT_SYSTEM_PROMPT)
assert "likes coffee" in prompt
def test_group_persons_included(self):
ctx = PersonContext("p1", "Alice")
prompt = build_llama_prompt(ctx, "hello", DEFAULT_SYSTEM_PROMPT,
group_persons=["p1", "p2_Bob", "p3_Charlie"])
assert "p2_Bob" in prompt or "p3_Charlie" in prompt
def test_no_group_persons_not_in_prompt(self):
ctx = PersonContext("p1")
prompt = build_llama_prompt(ctx, "hello", DEFAULT_SYSTEM_PROMPT)
assert "Also present" not in prompt
# ── TTS utils (imported here for coverage) ────────────────────────────────────
class TestTtsUtils:
def setup_method(self):
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from saltybot_social.tts_utils import (
strip_ssml, split_sentences, split_first_sentence,
pcm16_to_wav_bytes, chunk_pcm, estimate_duration_ms, apply_volume
)
self.strip_ssml = strip_ssml
self.split_sentences = split_sentences
self.split_first_sentence = split_first_sentence
self.pcm16_to_wav_bytes = pcm16_to_wav_bytes
self.chunk_pcm = chunk_pcm
self.estimate_duration_ms = estimate_duration_ms
self.apply_volume = apply_volume
def test_strip_ssml(self):
assert self.strip_ssml("<speak>Hello <break time='1s'/> world</speak>") == "Hello world"
assert self.strip_ssml("plain text") == "plain text"
def test_split_sentences_basic(self):
text = "Hello world. How are you? I am fine!"
sentences = self.split_sentences(text)
assert len(sentences) == 3
def test_split_sentences_single(self):
sentences = self.split_sentences("Hello world.")
assert len(sentences) == 1
assert sentences[0] == "Hello world."
def test_split_sentences_empty(self):
assert self.split_sentences("") == []
assert self.split_sentences(" ") == []
def test_split_first_sentence(self):
text = "Hello world. How are you? I am fine."
first, rest = self.split_first_sentence(text)
assert first == "Hello world."
assert "How are you?" in rest
def test_wav_header(self):
pcm = b"\x00\x00" * 1000
wav = self.pcm16_to_wav_bytes(pcm, 22050)
assert wav[:4] == b"RIFF"
assert wav[8:12] == b"WAVE"
assert len(wav) == 44 + len(pcm)
def test_chunk_pcm_sizes(self):
pcm = b"\x00\x00" * 2205 # ~100ms at 22050Hz
chunks = list(self.chunk_pcm(pcm, 50, 22050)) # 50ms chunks
for chunk in chunks:
assert len(chunk) % 2 == 0 # int16 aligned
def test_estimate_duration(self):
pcm = b"\x00\x00" * 22050 # 1s at 22050Hz
ms = self.estimate_duration_ms(pcm, 22050)
assert abs(ms - 1000.0) < 1.0
def test_apply_volume_unity(self):
pcm = b"\x00\x40" * 100 # some non-zero samples
result = self.apply_volume(pcm, 1.0)
assert result == pcm
def test_apply_volume_half(self):
import struct
pcm = struct.pack("<h", 1000)
result = self.apply_volume(pcm, 0.5)
val = struct.unpack("<h", result)[0]
assert abs(val - 500) <= 1
def test_apply_volume_clips(self):
import struct
pcm = struct.pack("<h", 30000)
result = self.apply_volume(pcm, 2.0)
val = struct.unpack("<h", result)[0]
assert val == 32767

View File

@ -0,0 +1,237 @@
"""test_speech_utils.py — Unit tests for speech_utils.py (no ROS2 / GPU needed)."""
import math
import struct
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from saltybot_social.speech_utils import (
pcm16_to_float32, float32_to_pcm16,
rms_db, chunk_audio,
EnergyVad, UtteranceSegmenter,
cosine_similarity, identify_speaker,
CHUNK_SAMPLES, SAMPLE_RATE,
)
import pytest
# ── PCM conversion ────────────────────────────────────────────────────────────
class TestPcmConversion:
def test_pcm16_to_float32_silence(self):
data = b"\x00\x00" * 100
samples = pcm16_to_float32(data)
assert len(samples) == 100
assert all(s == 0.0 for s in samples)
def test_pcm16_to_float32_max_positive(self):
data = struct.pack("<h", 32767)
samples = pcm16_to_float32(data)
assert abs(samples[0] - 32767 / 32768.0) < 1e-5
def test_float32_to_pcm16_roundtrip(self):
original = [0.0, 0.5, -0.5, 0.99, -0.99]
pcm = float32_to_pcm16(original)
recovered = pcm16_to_float32(pcm)
for orig, rec in zip(original, recovered):
assert abs(orig - rec) < 1e-3
def test_float32_to_pcm16_clips(self):
# Values > 1.0 should be clipped
pcm = float32_to_pcm16([2.0, -2.0])
samples = pcm16_to_float32(pcm)
assert samples[0] <= 1.0
assert samples[1] >= -1.0
# ── RMS dB ────────────────────────────────────────────────────────────────────
class TestRmsDb:
def test_silence_returns_minus96(self):
assert rms_db([]) == -96.0
assert rms_db([0.0] * 100) < -60.0
def test_full_scale_is_near_zero(self):
samples = [1.0] * 100
db = rms_db(samples)
assert abs(db) < 0.1 # 0 dBFS
def test_half_scale_is_minus6(self):
samples = [0.5] * 100
db = rms_db(samples)
assert abs(db - (-6.02)) < 0.1
def test_low_level_signal(self):
samples = [0.001] * 100
db = rms_db(samples)
assert db < -40.0
# ── Chunk audio ───────────────────────────────────────────────────────────────
class TestChunkAudio:
def test_even_chunks(self):
samples = list(range(480 * 4))
chunks = list(chunk_audio(samples, 480))
assert len(chunks) == 4
assert all(len(c) == 480 for c in chunks)
def test_remainder_dropped(self):
samples = list(range(500))
chunks = list(chunk_audio(samples, 480))
assert len(chunks) == 1
def test_empty_input(self):
assert list(chunk_audio([], 480)) == []
# ── Energy VAD ────────────────────────────────────────────────────────────────
class TestEnergyVad:
def test_silence_not_active(self):
vad = EnergyVad(threshold_db=-35.0, onset_frames=2, offset_frames=3)
silence = [0.0001] * CHUNK_SAMPLES
for _ in range(5):
active = vad.process(silence)
assert not active
def test_loud_signal_activates(self):
vad = EnergyVad(threshold_db=-35.0, onset_frames=2, offset_frames=8)
loud = [0.5] * CHUNK_SAMPLES
results = [vad.process(loud) for _ in range(3)]
assert results[-1] is True
def test_onset_requires_n_frames(self):
vad = EnergyVad(threshold_db=-35.0, onset_frames=3, offset_frames=8)
loud = [0.5] * CHUNK_SAMPLES
assert not vad.process(loud) # frame 1
assert not vad.process(loud) # frame 2
assert vad.process(loud) # frame 3 → activates
def test_offset_deactivates(self):
vad = EnergyVad(threshold_db=-35.0, onset_frames=1, offset_frames=2)
loud = [0.5] * CHUNK_SAMPLES
silence = [0.0001] * CHUNK_SAMPLES
vad.process(loud)
assert vad.is_active
vad.process(silence)
assert vad.is_active # offset_frames=2, need 2
vad.process(silence)
assert not vad.is_active
def test_reset(self):
vad = EnergyVad(threshold_db=-35.0, onset_frames=1, offset_frames=8)
loud = [0.5] * CHUNK_SAMPLES
vad.process(loud)
assert vad.is_active
vad.reset()
assert not vad.is_active
# ── Utterance segmenter ───────────────────────────────────────────────────────
class TestUtteranceSegmenter:
def _make_segmenter(self):
vad = EnergyVad(threshold_db=-35.0, onset_frames=1, offset_frames=2)
return UtteranceSegmenter(vad, pre_roll_frames=2, max_duration_s=5.0)
def test_silence_yields_nothing(self):
seg = self._make_segmenter()
silence = [0.0001] * CHUNK_SAMPLES
for _ in range(10):
results = seg.push(silence)
assert results == []
def test_speech_then_silence_yields_utterance(self):
seg = self._make_segmenter()
loud = [0.5] * CHUNK_SAMPLES
silence = [0.0001] * CHUNK_SAMPLES
# 3 speech frames
for _ in range(3):
assert seg.push(loud) == []
# 2 silence frames to trigger offset
all_results = []
for _ in range(3):
all_results.extend(seg.push(silence))
assert len(all_results) == 1
utt_samples, duration = all_results[0]
assert duration > 0
assert len(utt_samples) > 0
def test_preroll_included(self):
vad = EnergyVad(threshold_db=-35.0, onset_frames=1, offset_frames=2)
seg = UtteranceSegmenter(vad, pre_roll_frames=3, max_duration_s=5.0)
silence = [0.0001] * CHUNK_SAMPLES
loud = [0.5] * CHUNK_SAMPLES
# Push 3 silence frames (become pre-roll buffer)
for _ in range(3):
seg.push(silence)
# Push 1 speech frame (triggers onset, pre-roll included)
seg.push(loud)
# End utterance
all_results = []
for _ in range(3):
all_results.extend(seg.push(silence))
assert len(all_results) == 1
# Utterance should include pre-roll + speech + trailing silence
utt_samples, _ = all_results[0]
assert len(utt_samples) >= 4 * CHUNK_SAMPLES
# ── Speaker identification ────────────────────────────────────────────────────
class TestSpeakerIdentification:
def test_cosine_identical(self):
v = [1.0, 0.0, 0.0]
assert abs(cosine_similarity(v, v) - 1.0) < 1e-6
def test_cosine_orthogonal(self):
a = [1.0, 0.0]
b = [0.0, 1.0]
assert abs(cosine_similarity(a, b)) < 1e-6
def test_cosine_opposite(self):
a = [1.0, 0.0]
b = [-1.0, 0.0]
assert abs(cosine_similarity(a, b) + 1.0) < 1e-6
def test_cosine_zero_vector(self):
assert cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
def test_cosine_length_mismatch(self):
assert cosine_similarity([1.0], [1.0, 0.0]) == 0.0
def test_identify_speaker_match(self):
known = {
"alice": [1.0, 0.0, 0.0],
"bob": [0.0, 1.0, 0.0],
}
result = identify_speaker([1.0, 0.0, 0.0], known, threshold=0.5)
assert result == "alice"
def test_identify_speaker_unknown(self):
known = {"alice": [1.0, 0.0, 0.0]}
result = identify_speaker([0.0, 1.0, 0.0], known, threshold=0.5)
assert result == "unknown"
def test_identify_speaker_empty_db(self):
result = identify_speaker([1.0, 0.0], {}, threshold=0.5)
assert result == "unknown"
def test_identify_best_match(self):
known = {
"alice": [1.0, 0.0, 0.0],
"bob": [0.8, 0.6, 0.0], # closer to [1,0,0] than bob is
}
# Query close to alice
result = identify_speaker([0.99, 0.01, 0.0], known, threshold=0.5)
assert result == "alice"

View File

@ -28,6 +28,11 @@ rosidl_generate_interfaces(${PROJECT_NAME}
"srv/QueryMood.srv" "srv/QueryMood.srv"
# Issue #92 multi-modal tracking fusion # Issue #92 multi-modal tracking fusion
"msg/FusedTarget.msg" "msg/FusedTarget.msg"
# Issue #81 speech pipeline
"msg/SpeechTranscript.msg"
"msg/VadState.msg"
# Issue #83 conversation engine
"msg/ConversationResponse.msg"
DEPENDENCIES std_msgs geometry_msgs builtin_interfaces DEPENDENCIES std_msgs geometry_msgs builtin_interfaces
) )

View File

@ -0,0 +1,9 @@
# ConversationResponse.msg — LLM response, supports streaming token output.
# Published by conversation_node on /social/conversation/response
std_msgs/Header header
string text # Full or partial response text
string speaker_id # Who the response is addressed to
bool is_partial # true = streaming token chunk, false = final response
int32 turn_id # Conversation turn counter (for deduplication)

View File

@ -0,0 +1,10 @@
# SpeechTranscript.msg — Result of STT with speaker identification.
# Published by speech_pipeline_node on /social/speech/transcript
std_msgs/Header header
string text # Transcribed text (UTF-8)
string speaker_id # e.g. "person_42" or "unknown"
float32 confidence # ASR confidence 0..1
float32 audio_duration # Duration of the utterance in seconds
bool is_partial # true = intermediate streaming result, false = final

View File

@ -0,0 +1,8 @@
# VadState.msg — Voice Activity Detection state.
# Published by speech_pipeline_node on /social/speech/vad_state
std_msgs/Header header
bool speech_active # true = speech detected, false = silence
float32 energy_db # RMS energy in dBFS
string state # "silence" | "speech" | "wake_word"