Merge pull request 'feat(social): speech + LLM + TTS + orchestrator (#81 #83 #85 #89)' (#102) from sl-jetson/social-speech-llm-tts into main
This commit is contained in:
commit
0f2ea7931b
@ -0,0 +1,12 @@
|
|||||||
|
conversation_node:
|
||||||
|
ros__parameters:
|
||||||
|
model_path: "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf"
|
||||||
|
n_ctx: 4096
|
||||||
|
n_gpu_layers: 20 # Increase for more VRAM (Orin has 8GB shared)
|
||||||
|
max_tokens: 200
|
||||||
|
temperature: 0.7
|
||||||
|
top_p: 0.9
|
||||||
|
soul_path: "/soul/SOUL.md"
|
||||||
|
context_db_path: "/social_db/conversation_context.json"
|
||||||
|
save_interval_s: 30.0
|
||||||
|
stream: true
|
||||||
@ -0,0 +1,8 @@
|
|||||||
|
orchestrator_node:
|
||||||
|
ros__parameters:
|
||||||
|
gpu_mem_warn_mb: 4000.0 # Warn when GPU free < 4GB
|
||||||
|
gpu_mem_throttle_mb: 2000.0 # Throttle new inferences at < 2GB
|
||||||
|
watchdog_timeout_s: 30.0
|
||||||
|
latency_window: 20
|
||||||
|
profile_enabled: true
|
||||||
|
state_publish_rate: 2.0 # Hz
|
||||||
13
jetson/ros2_ws/src/saltybot_social/config/speech_params.yaml
Normal file
13
jetson/ros2_ws/src/saltybot_social/config/speech_params.yaml
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
speech_pipeline_node:
|
||||||
|
ros__parameters:
|
||||||
|
mic_device_index: -1 # -1 = system default; use `arecord -l` to list
|
||||||
|
sample_rate: 16000
|
||||||
|
wake_word_model: "hey_salty"
|
||||||
|
wake_word_threshold: 0.5
|
||||||
|
vad_threshold_db: -35.0
|
||||||
|
use_silero_vad: true
|
||||||
|
whisper_model: "small" # small (~500ms), medium (better quality, ~900ms)
|
||||||
|
whisper_compute_type: "float16"
|
||||||
|
speaker_threshold: 0.65
|
||||||
|
speaker_db_path: "/social_db/speaker_embeddings.json"
|
||||||
|
publish_partial: true
|
||||||
@ -0,0 +1,9 @@
|
|||||||
|
tts_node:
|
||||||
|
ros__parameters:
|
||||||
|
voice_path: "/models/piper/en_US-lessac-medium.onnx"
|
||||||
|
sample_rate: 22050
|
||||||
|
volume: 1.0
|
||||||
|
audio_device: "" # "" = system default; set to device name if needed
|
||||||
|
playback_enabled: true
|
||||||
|
publish_audio: false
|
||||||
|
sentence_streaming: true
|
||||||
129
jetson/ros2_ws/src/saltybot_social/launch/social_bot.launch.py
Normal file
129
jetson/ros2_ws/src/saltybot_social/launch/social_bot.launch.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
"""social_bot.launch.py — Launch the full social-bot pipeline.
|
||||||
|
|
||||||
|
Issue #89: End-to-end pipeline orchestrator launch file.
|
||||||
|
|
||||||
|
Brings up all social-bot nodes in dependency order:
|
||||||
|
1. Orchestrator (state machine + watchdog) — t=0s
|
||||||
|
2. Speech pipeline (wake word + VAD + STT) — t=2s
|
||||||
|
3. Conversation engine (LLM) — t=4s
|
||||||
|
4. TTS node (Piper streaming) — t=4s
|
||||||
|
5. Person state tracker (already in social.launch.py — optional)
|
||||||
|
|
||||||
|
Launch args:
|
||||||
|
enable_speech (bool, true)
|
||||||
|
enable_llm (bool, true)
|
||||||
|
enable_tts (bool, true)
|
||||||
|
enable_orchestrator (bool, true)
|
||||||
|
voice_path (str, /models/piper/en_US-lessac-medium.onnx)
|
||||||
|
whisper_model (str, small)
|
||||||
|
llm_model_path (str, /models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf)
|
||||||
|
n_gpu_layers (int, 20)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from ament_index_python.packages import get_package_share_directory
|
||||||
|
from launch import LaunchDescription
|
||||||
|
from launch.actions import (
|
||||||
|
DeclareLaunchArgument, GroupAction, LogInfo, TimerAction
|
||||||
|
)
|
||||||
|
from launch.conditions import IfCondition
|
||||||
|
from launch.substitutions import LaunchConfiguration, PythonExpression
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launch_description() -> LaunchDescription:
|
||||||
|
pkg = get_package_share_directory("saltybot_social")
|
||||||
|
cfg = os.path.join(pkg, "config")
|
||||||
|
|
||||||
|
# ── Launch arguments ─────────────────────────────────────────────────────
|
||||||
|
args = [
|
||||||
|
DeclareLaunchArgument("enable_speech", default_value="true"),
|
||||||
|
DeclareLaunchArgument("enable_llm", default_value="true"),
|
||||||
|
DeclareLaunchArgument("enable_tts", default_value="true"),
|
||||||
|
DeclareLaunchArgument("enable_orchestrator", default_value="true"),
|
||||||
|
DeclareLaunchArgument("voice_path",
|
||||||
|
default_value="/models/piper/en_US-lessac-medium.onnx"),
|
||||||
|
DeclareLaunchArgument("whisper_model", default_value="small"),
|
||||||
|
DeclareLaunchArgument("llm_model_path",
|
||||||
|
default_value="/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf"),
|
||||||
|
DeclareLaunchArgument("n_gpu_layers", default_value="20"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# ── Orchestrator (t=0s) ──────────────────────────────────────────────────
|
||||||
|
orchestrator = Node(
|
||||||
|
package="saltybot_social",
|
||||||
|
executable="orchestrator_node",
|
||||||
|
name="orchestrator_node",
|
||||||
|
parameters=[os.path.join(cfg, "orchestrator_params.yaml")],
|
||||||
|
condition=IfCondition(LaunchConfiguration("enable_orchestrator")),
|
||||||
|
output="screen",
|
||||||
|
emulate_tty=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Speech pipeline (t=2s) ───────────────────────────────────────────────
|
||||||
|
speech = TimerAction(period=2.0, actions=[
|
||||||
|
GroupAction(
|
||||||
|
condition=IfCondition(LaunchConfiguration("enable_speech")),
|
||||||
|
actions=[
|
||||||
|
LogInfo(msg="[social_bot] Starting speech pipeline..."),
|
||||||
|
Node(
|
||||||
|
package="saltybot_social",
|
||||||
|
executable="speech_pipeline_node",
|
||||||
|
name="speech_pipeline_node",
|
||||||
|
parameters=[
|
||||||
|
os.path.join(cfg, "speech_params.yaml"),
|
||||||
|
{"whisper_model": LaunchConfiguration("whisper_model")},
|
||||||
|
],
|
||||||
|
output="screen",
|
||||||
|
emulate_tty=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
# ── LLM conversation engine (t=4s) ───────────────────────────────────────
|
||||||
|
llm = TimerAction(period=4.0, actions=[
|
||||||
|
GroupAction(
|
||||||
|
condition=IfCondition(LaunchConfiguration("enable_llm")),
|
||||||
|
actions=[
|
||||||
|
LogInfo(msg="[social_bot] Starting LLM conversation engine..."),
|
||||||
|
Node(
|
||||||
|
package="saltybot_social",
|
||||||
|
executable="conversation_node",
|
||||||
|
name="conversation_node",
|
||||||
|
parameters=[
|
||||||
|
os.path.join(cfg, "conversation_params.yaml"),
|
||||||
|
{
|
||||||
|
"model_path": LaunchConfiguration("llm_model_path"),
|
||||||
|
"n_gpu_layers": LaunchConfiguration("n_gpu_layers"),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
output="screen",
|
||||||
|
emulate_tty=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
# ── TTS node (t=4s, parallel with LLM) ──────────────────────────────────
|
||||||
|
tts = TimerAction(period=4.0, actions=[
|
||||||
|
GroupAction(
|
||||||
|
condition=IfCondition(LaunchConfiguration("enable_tts")),
|
||||||
|
actions=[
|
||||||
|
LogInfo(msg="[social_bot] Starting TTS node..."),
|
||||||
|
Node(
|
||||||
|
package="saltybot_social",
|
||||||
|
executable="tts_node",
|
||||||
|
name="tts_node",
|
||||||
|
parameters=[
|
||||||
|
os.path.join(cfg, "tts_params.yaml"),
|
||||||
|
{"voice_path": LaunchConfiguration("voice_path")},
|
||||||
|
],
|
||||||
|
output="screen",
|
||||||
|
emulate_tty=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
return LaunchDescription(args + [orchestrator, speech, llm, tts])
|
||||||
@ -5,9 +5,12 @@
|
|||||||
<version>0.1.0</version>
|
<version>0.1.0</version>
|
||||||
<description>
|
<description>
|
||||||
Social interaction layer for saltybot.
|
Social interaction layer for saltybot.
|
||||||
|
speech_pipeline_node: wake word + VAD + Whisper STT + diarization (Issue #81).
|
||||||
|
conversation_node: local LLM with per-person context (Issue #83).
|
||||||
|
tts_node: streaming TTS with Piper first-chunk (Issue #85).
|
||||||
|
orchestrator_node: pipeline state machine + GPU watchdog + latency profiler (Issue #89).
|
||||||
person_state_tracker: multi-modal person identity fusion (Issue #82).
|
person_state_tracker: multi-modal person identity fusion (Issue #82).
|
||||||
expression_node: bridges /social/mood to ESP32-C3 NeoPixel ring over serial (Issue #86).
|
expression_node: LED expression + motor attention (Issue #86).
|
||||||
attention_node: rotates robot toward active speaker via /social/persons bearing (Issue #86).
|
|
||||||
</description>
|
</description>
|
||||||
<maintainer email="seb@vayrette.com">seb</maintainer>
|
<maintainer email="seb@vayrette.com">seb</maintainer>
|
||||||
<license>MIT</license>
|
<license>MIT</license>
|
||||||
|
|||||||
@ -0,0 +1,267 @@
|
|||||||
|
"""conversation_node.py — Local LLM conversation engine with per-person context.
|
||||||
|
|
||||||
|
Issue #83: Conversation engine for social-bot.
|
||||||
|
|
||||||
|
Stack: Phi-3-mini or Llama-3.2-3B GGUF Q4_K_M via llama-cpp-python (CUDA).
|
||||||
|
Subscribes /social/speech/transcript → builds per-person prompt → streams
|
||||||
|
token output → publishes /social/conversation/response.
|
||||||
|
|
||||||
|
Streaming: publishes partial=true tokens as they arrive, then final=false
|
||||||
|
at end of generation. TTS node can begin synthesis on first sentence boundary.
|
||||||
|
|
||||||
|
ROS2 topics:
|
||||||
|
Subscribe: /social/speech/transcript (saltybot_social_msgs/SpeechTranscript)
|
||||||
|
Publish: /social/conversation/response (saltybot_social_msgs/ConversationResponse)
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model_path (str, "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf")
|
||||||
|
n_ctx (int, 4096)
|
||||||
|
n_gpu_layers (int, 20) — GPU offload layers (increase for more VRAM usage)
|
||||||
|
max_tokens (int, 200)
|
||||||
|
temperature (float, 0.7)
|
||||||
|
top_p (float, 0.9)
|
||||||
|
soul_path (str, "/soul/SOUL.md")
|
||||||
|
context_db_path (str, "/social_db/conversation_context.json")
|
||||||
|
save_interval_s (float, 30.0) — how often to persist context to disk
|
||||||
|
stream (bool, true)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile
|
||||||
|
|
||||||
|
from saltybot_social_msgs.msg import SpeechTranscript, ConversationResponse
|
||||||
|
from .llm_context import ContextStore, build_llama_prompt, load_system_prompt, needs_summary_prompt
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationNode(Node):
|
||||||
|
"""Local LLM inference node with per-person conversation memory."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__("conversation_node")
|
||||||
|
|
||||||
|
# ── Parameters ──────────────────────────────────────────────────────
|
||||||
|
self.declare_parameter("model_path",
|
||||||
|
"/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf")
|
||||||
|
self.declare_parameter("n_ctx", 4096)
|
||||||
|
self.declare_parameter("n_gpu_layers", 20)
|
||||||
|
self.declare_parameter("max_tokens", 200)
|
||||||
|
self.declare_parameter("temperature", 0.7)
|
||||||
|
self.declare_parameter("top_p", 0.9)
|
||||||
|
self.declare_parameter("soul_path", "/soul/SOUL.md")
|
||||||
|
self.declare_parameter("context_db_path", "/social_db/conversation_context.json")
|
||||||
|
self.declare_parameter("save_interval_s", 30.0)
|
||||||
|
self.declare_parameter("stream", True)
|
||||||
|
|
||||||
|
self._model_path = self.get_parameter("model_path").value
|
||||||
|
self._n_ctx = self.get_parameter("n_ctx").value
|
||||||
|
self._n_gpu = self.get_parameter("n_gpu_layers").value
|
||||||
|
self._max_tokens = self.get_parameter("max_tokens").value
|
||||||
|
self._temperature = self.get_parameter("temperature").value
|
||||||
|
self._top_p = self.get_parameter("top_p").value
|
||||||
|
self._soul_path = self.get_parameter("soul_path").value
|
||||||
|
self._db_path = self.get_parameter("context_db_path").value
|
||||||
|
self._save_interval = self.get_parameter("save_interval_s").value
|
||||||
|
self._stream = self.get_parameter("stream").value
|
||||||
|
|
||||||
|
# ── Publishers / Subscribers ─────────────────────────────────────────
|
||||||
|
qos = QoSProfile(depth=10)
|
||||||
|
self._resp_pub = self.create_publisher(
|
||||||
|
ConversationResponse, "/social/conversation/response", qos
|
||||||
|
)
|
||||||
|
self._transcript_sub = self.create_subscription(
|
||||||
|
SpeechTranscript, "/social/speech/transcript",
|
||||||
|
self._on_transcript, qos
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── State ────────────────────────────────────────────────────────────
|
||||||
|
self._llm = None
|
||||||
|
self._system_prompt = load_system_prompt(self._soul_path)
|
||||||
|
self._ctx_store = ContextStore(self._db_path)
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._turn_counter = 0
|
||||||
|
self._generating = False
|
||||||
|
self._last_save = time.time()
|
||||||
|
|
||||||
|
# ── Load LLM in background ────────────────────────────────────────────
|
||||||
|
threading.Thread(target=self._load_llm, daemon=True).start()
|
||||||
|
|
||||||
|
# ── Periodic context save ────────────────────────────────────────────
|
||||||
|
self._save_timer = self.create_timer(self._save_interval, self._save_context)
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f"ConversationNode init (model={self._model_path}, "
|
||||||
|
f"gpu_layers={self._n_gpu}, ctx={self._n_ctx})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Model loading ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _load_llm(self) -> None:
|
||||||
|
t0 = time.time()
|
||||||
|
self.get_logger().info(f"Loading LLM: {self._model_path}")
|
||||||
|
try:
|
||||||
|
from llama_cpp import Llama
|
||||||
|
self._llm = Llama(
|
||||||
|
model_path=self._model_path,
|
||||||
|
n_ctx=self._n_ctx,
|
||||||
|
n_gpu_layers=self._n_gpu,
|
||||||
|
n_threads=4,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
self.get_logger().info(
|
||||||
|
f"LLM ready ({time.time()-t0:.1f}s). "
|
||||||
|
f"Context: {self._n_ctx} tokens, GPU layers: {self._n_gpu}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f"LLM load failed: {e}")
|
||||||
|
|
||||||
|
# ── Transcript callback ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _on_transcript(self, msg: SpeechTranscript) -> None:
|
||||||
|
"""Handle final transcripts only (skip streaming partials)."""
|
||||||
|
if msg.is_partial:
|
||||||
|
return
|
||||||
|
if not msg.text.strip():
|
||||||
|
return
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f"Transcript [{msg.speaker_id}]: '{msg.text}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
threading.Thread(
|
||||||
|
target=self._generate_response,
|
||||||
|
args=(msg.text.strip(), msg.speaker_id),
|
||||||
|
daemon=True,
|
||||||
|
).start()
|
||||||
|
|
||||||
|
# ── LLM inference ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _generate_response(self, user_text: str, speaker_id: str) -> None:
|
||||||
|
"""Generate LLM response with streaming. Runs in thread."""
|
||||||
|
if self._llm is None:
|
||||||
|
self.get_logger().warn("LLM not loaded yet, dropping utterance")
|
||||||
|
return
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if self._generating:
|
||||||
|
self.get_logger().warn("LLM busy, dropping utterance")
|
||||||
|
return
|
||||||
|
self._generating = True
|
||||||
|
self._turn_counter += 1
|
||||||
|
turn_id = self._turn_counter
|
||||||
|
|
||||||
|
try:
|
||||||
|
ctx = self._ctx_store.get(speaker_id)
|
||||||
|
|
||||||
|
# Summary compression if context is long
|
||||||
|
if ctx.needs_compression():
|
||||||
|
self._compress_context(ctx)
|
||||||
|
|
||||||
|
ctx.add_user(user_text)
|
||||||
|
|
||||||
|
prompt = build_llama_prompt(
|
||||||
|
ctx, user_text, self._system_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
full_response = ""
|
||||||
|
|
||||||
|
if self._stream:
|
||||||
|
output = self._llm(
|
||||||
|
prompt,
|
||||||
|
max_tokens=self._max_tokens,
|
||||||
|
temperature=self._temperature,
|
||||||
|
top_p=self._top_p,
|
||||||
|
stream=True,
|
||||||
|
stop=["<|user|>", "<|system|>", "\n\n\n"],
|
||||||
|
)
|
||||||
|
for chunk in output:
|
||||||
|
token = chunk["choices"][0]["text"]
|
||||||
|
full_response += token
|
||||||
|
# Publish partial after each sentence boundary for low TTS latency
|
||||||
|
if token.endswith((".", "!", "?", "\n")):
|
||||||
|
self._publish_response(
|
||||||
|
full_response.strip(), speaker_id, turn_id, is_partial=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output = self._llm(
|
||||||
|
prompt,
|
||||||
|
max_tokens=self._max_tokens,
|
||||||
|
temperature=self._temperature,
|
||||||
|
top_p=self._top_p,
|
||||||
|
stream=False,
|
||||||
|
stop=["<|user|>", "<|system|>"],
|
||||||
|
)
|
||||||
|
full_response = output["choices"][0]["text"]
|
||||||
|
|
||||||
|
full_response = full_response.strip()
|
||||||
|
latency_ms = (time.perf_counter() - t0) * 1000
|
||||||
|
self.get_logger().info(
|
||||||
|
f"LLM [{speaker_id}] ({latency_ms:.0f}ms): '{full_response[:80]}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx.add_assistant(full_response)
|
||||||
|
self._publish_response(full_response, speaker_id, turn_id, is_partial=False)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f"LLM inference error: {e}")
|
||||||
|
finally:
|
||||||
|
with self._lock:
|
||||||
|
self._generating = False
|
||||||
|
|
||||||
|
def _compress_context(self, ctx) -> None:
|
||||||
|
"""Ask LLM to summarize old turns for context compression."""
|
||||||
|
if self._llm is None:
|
||||||
|
ctx.compress("(history omitted)")
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
summary_prompt = needs_summary_prompt(ctx)
|
||||||
|
result = self._llm(summary_prompt, max_tokens=80, temperature=0.3, stream=False)
|
||||||
|
summary = result["choices"][0]["text"].strip()
|
||||||
|
ctx.compress(summary)
|
||||||
|
self.get_logger().debug(
|
||||||
|
f"Context compressed for {ctx.person_id}: '{summary[:60]}'"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
ctx.compress("(history omitted)")
|
||||||
|
|
||||||
|
# ── Publish ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _publish_response(
|
||||||
|
self, text: str, speaker_id: str, turn_id: int, is_partial: bool
|
||||||
|
) -> None:
|
||||||
|
msg = ConversationResponse()
|
||||||
|
msg.header.stamp = self.get_clock().now().to_msg()
|
||||||
|
msg.text = text
|
||||||
|
msg.speaker_id = speaker_id
|
||||||
|
msg.is_partial = is_partial
|
||||||
|
msg.turn_id = turn_id
|
||||||
|
self._resp_pub.publish(msg)
|
||||||
|
|
||||||
|
def _save_context(self) -> None:
|
||||||
|
try:
|
||||||
|
self._ctx_store.save()
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f"Context save error: {e}")
|
||||||
|
|
||||||
|
def destroy_node(self) -> None:
|
||||||
|
self._save_context()
|
||||||
|
super().destroy_node()
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = ConversationNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
@ -0,0 +1,176 @@
|
|||||||
|
"""llm_context.py — Per-person conversation context management.
|
||||||
|
|
||||||
|
No ROS2 dependencies. Manages sliding-window conversation history per person_id,
|
||||||
|
builds llama.cpp prompt, handles summary compression when context is full.
|
||||||
|
|
||||||
|
Used by conversation_node.py.
|
||||||
|
Tested by test/test_llm_context.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
# ── Constants ──────────────────────────────────────────────────────────────────
|
||||||
|
MAX_TURNS = 20 # Max turns to keep before summary compression
|
||||||
|
SUMMARY_KEEP = 4 # Keep last N turns after summarizing older ones
|
||||||
|
SYSTEM_PROMPT_PATH = "/soul/SOUL.md"
|
||||||
|
DEFAULT_SYSTEM_PROMPT = (
|
||||||
|
"You are Salty, a friendly social robot. "
|
||||||
|
"You are helpful, warm, and concise. "
|
||||||
|
"Keep responses under 2 sentences unless asked to elaborate."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Conversation history ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class PersonContext:
|
||||||
|
"""Conversation history + relationship memory for one person."""
|
||||||
|
|
||||||
|
def __init__(self, person_id: str, person_name: str = "") -> None:
|
||||||
|
self.person_id = person_id
|
||||||
|
self.person_name = person_name or person_id
|
||||||
|
self.turns: List[dict] = [] # {"role": "user"|"assistant", "content": str, "ts": float}
|
||||||
|
self.summary: str = ""
|
||||||
|
self.last_seen: float = time.time()
|
||||||
|
self.interaction_count: int = 0
|
||||||
|
|
||||||
|
def add_user(self, text: str) -> None:
|
||||||
|
self.turns.append({"role": "user", "content": text, "ts": time.time()})
|
||||||
|
self.last_seen = time.time()
|
||||||
|
self.interaction_count += 1
|
||||||
|
|
||||||
|
def add_assistant(self, text: str) -> None:
|
||||||
|
self.turns.append({"role": "assistant", "content": text, "ts": time.time()})
|
||||||
|
|
||||||
|
def needs_compression(self) -> bool:
|
||||||
|
return len(self.turns) > MAX_TURNS
|
||||||
|
|
||||||
|
def compress(self, summary_text: str) -> None:
|
||||||
|
"""Replace old turns with a summary, keeping last SUMMARY_KEEP turns."""
|
||||||
|
old_summary = self.summary
|
||||||
|
if old_summary:
|
||||||
|
self.summary = old_summary + " " + summary_text
|
||||||
|
else:
|
||||||
|
self.summary = summary_text
|
||||||
|
# Keep most recent turns
|
||||||
|
self.turns = self.turns[-SUMMARY_KEEP:]
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"person_id": self.person_id,
|
||||||
|
"person_name": self.person_name,
|
||||||
|
"turns": self.turns,
|
||||||
|
"summary": self.summary,
|
||||||
|
"last_seen": self.last_seen,
|
||||||
|
"interaction_count": self.interaction_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: dict) -> "PersonContext":
|
||||||
|
ctx = cls(d["person_id"], d.get("person_name", ""))
|
||||||
|
ctx.turns = d.get("turns", [])
|
||||||
|
ctx.summary = d.get("summary", "")
|
||||||
|
ctx.last_seen = d.get("last_seen", 0.0)
|
||||||
|
ctx.interaction_count = d.get("interaction_count", 0)
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
|
||||||
|
# ── Context store ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class ContextStore:
|
||||||
|
"""Persistent per-person conversation context with optional JSON backing."""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str = "/social_db/conversation_context.json") -> None:
|
||||||
|
self._db_path = db_path
|
||||||
|
self._contexts: dict = {} # person_id → PersonContext
|
||||||
|
self._load()
|
||||||
|
|
||||||
|
def get(self, person_id: str, person_name: str = "") -> PersonContext:
|
||||||
|
if person_id not in self._contexts:
|
||||||
|
self._contexts[person_id] = PersonContext(person_id, person_name)
|
||||||
|
elif person_name and not self._contexts[person_id].person_name:
|
||||||
|
self._contexts[person_id].person_name = person_name
|
||||||
|
return self._contexts[person_id]
|
||||||
|
|
||||||
|
def save(self) -> None:
|
||||||
|
os.makedirs(os.path.dirname(self._db_path), exist_ok=True)
|
||||||
|
with open(self._db_path, "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{pid: ctx.to_dict() for pid, ctx in self._contexts.items()},
|
||||||
|
f, indent=2
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load(self) -> None:
|
||||||
|
if os.path.exists(self._db_path):
|
||||||
|
try:
|
||||||
|
with open(self._db_path) as f:
|
||||||
|
raw = json.load(f)
|
||||||
|
self._contexts = {
|
||||||
|
pid: PersonContext.from_dict(d) for pid, d in raw.items()
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
self._contexts = {}
|
||||||
|
|
||||||
|
def all_persons(self) -> List[str]:
|
||||||
|
return list(self._contexts.keys())
|
||||||
|
|
||||||
|
|
||||||
|
# ── Prompt builder ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def load_system_prompt(path: str = SYSTEM_PROMPT_PATH) -> str:
|
||||||
|
"""Load SOUL.md as system prompt. Falls back to default."""
|
||||||
|
if os.path.exists(path):
|
||||||
|
with open(path) as f:
|
||||||
|
return f.read().strip()
|
||||||
|
return DEFAULT_SYSTEM_PROMPT
|
||||||
|
|
||||||
|
|
||||||
|
def build_llama_prompt(
|
||||||
|
ctx: PersonContext,
|
||||||
|
user_text: str,
|
||||||
|
system_prompt: str,
|
||||||
|
group_persons: Optional[List[str]] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Build Phi-3 / Llama-3 style chat prompt string.
|
||||||
|
|
||||||
|
Uses the ChatML format supported by llama-cpp-python.
|
||||||
|
"""
|
||||||
|
lines = [f"<|system|>\n{system_prompt}"]
|
||||||
|
|
||||||
|
if ctx.summary:
|
||||||
|
lines.append(
|
||||||
|
f"<|system|>\n[Memory of {ctx.person_name}]: {ctx.summary}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if group_persons:
|
||||||
|
others = [p for p in group_persons if p != ctx.person_id]
|
||||||
|
if others:
|
||||||
|
lines.append(
|
||||||
|
f"<|system|>\n[Also present]: {', '.join(others)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for turn in ctx.turns:
|
||||||
|
role = "user" if turn["role"] == "user" else "assistant"
|
||||||
|
lines.append(f"<|{role}|>\n{turn['content']}")
|
||||||
|
|
||||||
|
lines.append(f"<|user|>\n{user_text}")
|
||||||
|
lines.append("<|assistant|>")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def needs_summary_prompt(ctx: PersonContext) -> str:
|
||||||
|
"""Build a prompt asking the LLM to summarize old conversation turns."""
|
||||||
|
turns_text = "\n".join(
|
||||||
|
f"{t['role'].upper()}: {t['content']}"
|
||||||
|
for t in ctx.turns[:-SUMMARY_KEEP]
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
f"<|system|>\nSummarize this conversation between {ctx.person_name} and "
|
||||||
|
"Salty in one sentence, focusing on what was discussed and their relationship.\n"
|
||||||
|
f"<|user|>\n{turns_text}\n<|assistant|>"
|
||||||
|
)
|
||||||
@ -0,0 +1,298 @@
|
|||||||
|
"""orchestrator_node.py — End-to-end social-bot pipeline orchestrator.
|
||||||
|
|
||||||
|
Issue #89: Main loop, GPU memory scheduling, latency profiling, watchdog.
|
||||||
|
|
||||||
|
State machine:
|
||||||
|
IDLE → LISTENING (wake word) → THINKING (LLM generating) → SPEAKING (TTS)
|
||||||
|
↑___________________________________________|
|
||||||
|
|
||||||
|
Responsibilities:
|
||||||
|
- Owns the top-level pipeline state machine
|
||||||
|
- Monitors GPU memory and throttles inference when approaching limit
|
||||||
|
- Profiles end-to-end latency per interaction turn
|
||||||
|
- Health watchdog: restart speech pipeline on hang
|
||||||
|
- Exposes /social/orchestrator/state for UI / other nodes
|
||||||
|
|
||||||
|
ROS2 topics:
|
||||||
|
Subscribe: /social/speech/vad_state (VadState)
|
||||||
|
/social/speech/transcript (SpeechTranscript)
|
||||||
|
/social/conversation/response (ConversationResponse)
|
||||||
|
Publish: /social/orchestrator/state (std_msgs/String) — JSON status blob
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
gpu_mem_warn_mb (float, 4000.0) — warn at this free GPU memory (MB)
|
||||||
|
gpu_mem_throttle_mb (float, 2000.0) — suspend new inference requests
|
||||||
|
watchdog_timeout_s (float, 30.0) — restart speech node if no VAD in N seconds
|
||||||
|
latency_window (int, 20) — rolling window for latency stats
|
||||||
|
profile_enabled (bool, true)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from collections import deque
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile
|
||||||
|
from std_msgs.msg import String
|
||||||
|
|
||||||
|
from saltybot_social_msgs.msg import VadState, SpeechTranscript, ConversationResponse
|
||||||
|
|
||||||
|
|
||||||
|
# ── State machine ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class PipelineState(Enum):
|
||||||
|
IDLE = "idle"
|
||||||
|
LISTENING = "listening"
|
||||||
|
THINKING = "thinking"
|
||||||
|
SPEAKING = "speaking"
|
||||||
|
THROTTLED = "throttled" # GPU OOM risk — waiting for memory
|
||||||
|
|
||||||
|
|
||||||
|
# ── Latency tracker ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class LatencyTracker:
|
||||||
|
"""Rolling window latency statistics per pipeline stage."""
|
||||||
|
|
||||||
|
def __init__(self, window: int = 20) -> None:
|
||||||
|
self._window = window
|
||||||
|
self._wakeword_to_transcript: deque = deque(maxlen=window)
|
||||||
|
self._transcript_to_llm_first: deque = deque(maxlen=window)
|
||||||
|
self._llm_first_to_tts_first: deque = deque(maxlen=window)
|
||||||
|
self._end_to_end: deque = deque(maxlen=window)
|
||||||
|
|
||||||
|
def record(self, stage: str, latency_ms: float) -> None:
|
||||||
|
mapping = {
|
||||||
|
"wakeword_to_transcript": self._wakeword_to_transcript,
|
||||||
|
"transcript_to_llm": self._transcript_to_llm_first,
|
||||||
|
"llm_to_tts": self._llm_first_to_tts_first,
|
||||||
|
"end_to_end": self._end_to_end,
|
||||||
|
}
|
||||||
|
q = mapping.get(stage)
|
||||||
|
if q is not None:
|
||||||
|
q.append(latency_ms)
|
||||||
|
|
||||||
|
def stats(self) -> dict:
|
||||||
|
def _stat(q: deque) -> dict:
|
||||||
|
if not q:
|
||||||
|
return {}
|
||||||
|
s = sorted(q)
|
||||||
|
return {
|
||||||
|
"mean_ms": round(sum(s) / len(s), 1),
|
||||||
|
"p50_ms": round(s[len(s) // 2], 1),
|
||||||
|
"p95_ms": round(s[int(len(s) * 0.95)], 1),
|
||||||
|
"n": len(s),
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"wakeword_to_transcript": _stat(self._wakeword_to_transcript),
|
||||||
|
"transcript_to_llm": _stat(self._transcript_to_llm_first),
|
||||||
|
"llm_to_tts": _stat(self._llm_first_to_tts_first),
|
||||||
|
"end_to_end": _stat(self._end_to_end),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── GPU memory helpers ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def get_gpu_free_mb() -> Optional[float]:
|
||||||
|
"""Return free GPU memory in MB, or None if unavailable."""
|
||||||
|
try:
|
||||||
|
import pycuda.driver as drv
|
||||||
|
drv.init()
|
||||||
|
ctx = drv.Device(0).make_context()
|
||||||
|
free, _ = drv.mem_get_info()
|
||||||
|
ctx.pop()
|
||||||
|
return free / 1024 / 1024
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
import subprocess
|
||||||
|
out = subprocess.check_output(
|
||||||
|
["nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits"],
|
||||||
|
timeout=2
|
||||||
|
).decode().strip()
|
||||||
|
return float(out.split("\n")[0])
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Orchestrator ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class OrchestratorNode(Node):
|
||||||
|
"""Pipeline orchestrator — state machine, GPU watchdog, latency profiler."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__("orchestrator_node")
|
||||||
|
|
||||||
|
# ── Parameters ──────────────────────────────────────────────────────
|
||||||
|
self.declare_parameter("gpu_mem_warn_mb", 4000.0)
|
||||||
|
self.declare_parameter("gpu_mem_throttle_mb", 2000.0)
|
||||||
|
self.declare_parameter("watchdog_timeout_s", 30.0)
|
||||||
|
self.declare_parameter("latency_window", 20)
|
||||||
|
self.declare_parameter("profile_enabled", True)
|
||||||
|
self.declare_parameter("state_publish_rate", 2.0) # Hz
|
||||||
|
|
||||||
|
self._gpu_warn = self.get_parameter("gpu_mem_warn_mb").value
|
||||||
|
self._gpu_throttle = self.get_parameter("gpu_mem_throttle_mb").value
|
||||||
|
self._watchdog_timeout = self.get_parameter("watchdog_timeout_s").value
|
||||||
|
self._profile = self.get_parameter("profile_enabled").value
|
||||||
|
self._state_rate = self.get_parameter("state_publish_rate").value
|
||||||
|
|
||||||
|
# ── Publishers / Subscribers ─────────────────────────────────────────
|
||||||
|
qos = QoSProfile(depth=10)
|
||||||
|
self._state_pub = self.create_publisher(String, "/social/orchestrator/state", qos)
|
||||||
|
|
||||||
|
self._vad_sub = self.create_subscription(
|
||||||
|
VadState, "/social/speech/vad_state", self._on_vad, qos
|
||||||
|
)
|
||||||
|
self._transcript_sub = self.create_subscription(
|
||||||
|
SpeechTranscript, "/social/speech/transcript", self._on_transcript, qos
|
||||||
|
)
|
||||||
|
self._response_sub = self.create_subscription(
|
||||||
|
ConversationResponse, "/social/conversation/response", self._on_response, qos
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── State ────────────────────────────────────────────────────────────
|
||||||
|
self._state = PipelineState.IDLE
|
||||||
|
self._state_lock = threading.Lock()
|
||||||
|
self._tracker = LatencyTracker(self.get_parameter("latency_window").value)
|
||||||
|
|
||||||
|
# Timestamps for latency tracking
|
||||||
|
self._t_wake: float = 0.0
|
||||||
|
self._t_transcript: float = 0.0
|
||||||
|
self._t_llm_first: float = 0.0
|
||||||
|
self._t_tts_first: float = 0.0
|
||||||
|
|
||||||
|
self._last_vad_t: float = time.time()
|
||||||
|
self._last_transcript_t: float = 0.0
|
||||||
|
self._last_response_t: float = 0.0
|
||||||
|
|
||||||
|
# ── Timers ────────────────────────────────────────────────────────────
|
||||||
|
self._state_timer = self.create_timer(1.0 / self._state_rate, self._publish_state)
|
||||||
|
self._gpu_timer = self.create_timer(5.0, self._check_gpu)
|
||||||
|
self._watchdog_timer = self.create_timer(10.0, self._watchdog)
|
||||||
|
|
||||||
|
self.get_logger().info("OrchestratorNode ready")
|
||||||
|
|
||||||
|
# ── State transitions ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _set_state(self, new_state: PipelineState) -> None:
|
||||||
|
with self._state_lock:
|
||||||
|
if self._state != new_state:
|
||||||
|
self.get_logger().info(
|
||||||
|
f"Pipeline: {self._state.value} → {new_state.value}"
|
||||||
|
)
|
||||||
|
self._state = new_state
|
||||||
|
|
||||||
|
# ── Subscribers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _on_vad(self, msg: VadState) -> None:
|
||||||
|
self._last_vad_t = time.time()
|
||||||
|
if msg.state == "wake_word":
|
||||||
|
self._t_wake = time.time()
|
||||||
|
self._set_state(PipelineState.LISTENING)
|
||||||
|
elif msg.speech_active and self._state == PipelineState.IDLE:
|
||||||
|
self._set_state(PipelineState.LISTENING)
|
||||||
|
|
||||||
|
def _on_transcript(self, msg: SpeechTranscript) -> None:
|
||||||
|
if msg.is_partial:
|
||||||
|
return
|
||||||
|
self._last_transcript_t = time.time()
|
||||||
|
self._t_transcript = time.time()
|
||||||
|
self._set_state(PipelineState.THINKING)
|
||||||
|
|
||||||
|
if self._profile and self._t_wake > 0:
|
||||||
|
latency_ms = (self._t_transcript - self._t_wake) * 1000
|
||||||
|
self._tracker.record("wakeword_to_transcript", latency_ms)
|
||||||
|
if latency_ms > 500:
|
||||||
|
self.get_logger().warn(
|
||||||
|
f"Wake→transcript latency high: {latency_ms:.0f}ms (target <500ms)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_response(self, msg: ConversationResponse) -> None:
|
||||||
|
self._last_response_t = time.time()
|
||||||
|
|
||||||
|
if msg.is_partial and self._t_llm_first == 0.0:
|
||||||
|
# First token from LLM
|
||||||
|
self._t_llm_first = time.time()
|
||||||
|
if self._profile and self._t_transcript > 0:
|
||||||
|
latency_ms = (self._t_llm_first - self._t_transcript) * 1000
|
||||||
|
self._tracker.record("transcript_to_llm", latency_ms)
|
||||||
|
|
||||||
|
self._set_state(PipelineState.SPEAKING)
|
||||||
|
|
||||||
|
if not msg.is_partial:
|
||||||
|
# Full response complete
|
||||||
|
self._t_llm_first = 0.0
|
||||||
|
self._t_wake = 0.0
|
||||||
|
self._t_transcript = 0.0
|
||||||
|
|
||||||
|
if self._profile and self._t_wake > 0:
|
||||||
|
e2e_ms = (time.time() - self._t_wake) * 1000
|
||||||
|
self._tracker.record("end_to_end", e2e_ms)
|
||||||
|
|
||||||
|
# Return to IDLE after a short delay (TTS still playing)
|
||||||
|
threading.Timer(2.0, lambda: self._set_state(PipelineState.IDLE)).start()
|
||||||
|
|
||||||
|
# ── GPU memory check ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _check_gpu(self) -> None:
|
||||||
|
free_mb = get_gpu_free_mb()
|
||||||
|
if free_mb is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if free_mb < self._gpu_throttle:
|
||||||
|
self.get_logger().error(
|
||||||
|
f"GPU memory critical: {free_mb:.0f}MB free < {self._gpu_throttle:.0f}MB threshold"
|
||||||
|
)
|
||||||
|
self._set_state(PipelineState.THROTTLED)
|
||||||
|
elif free_mb < self._gpu_warn:
|
||||||
|
self.get_logger().warn(f"GPU memory low: {free_mb:.0f}MB free")
|
||||||
|
else:
|
||||||
|
# Recover from throttled state
|
||||||
|
with self._state_lock:
|
||||||
|
if self._state == PipelineState.THROTTLED:
|
||||||
|
self._set_state(PipelineState.IDLE)
|
||||||
|
|
||||||
|
# ── Watchdog ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _watchdog(self) -> None:
|
||||||
|
"""Alert if speech pipeline has gone silent for too long."""
|
||||||
|
since_vad = time.time() - self._last_vad_t
|
||||||
|
if since_vad > self._watchdog_timeout:
|
||||||
|
self.get_logger().error(
|
||||||
|
f"Watchdog: No VAD signal for {since_vad:.0f}s. "
|
||||||
|
"Speech pipeline may be hung. Check /social/speech/vad_state."
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── State publisher ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _publish_state(self) -> None:
|
||||||
|
with self._state_lock:
|
||||||
|
current = self._state.value
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"state": current,
|
||||||
|
"ts": time.time(),
|
||||||
|
"latency": self._tracker.stats() if self._profile else {},
|
||||||
|
"gpu_free_mb": get_gpu_free_mb(),
|
||||||
|
}
|
||||||
|
msg = String()
|
||||||
|
msg.data = json.dumps(payload)
|
||||||
|
self._state_pub.publish(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = OrchestratorNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
@ -0,0 +1,391 @@
|
|||||||
|
"""speech_pipeline_node.py — Wake word + VAD + Whisper STT + speaker diarization.
|
||||||
|
|
||||||
|
Issue #81: Speech pipeline for social-bot.
|
||||||
|
|
||||||
|
Pipeline:
|
||||||
|
USB mic (PyAudio) → VAD (Silero / energy fallback)
|
||||||
|
→ wake word gate (OpenWakeWord "hey_salty")
|
||||||
|
→ utterance buffer → faster-whisper STT (Orin GPU)
|
||||||
|
→ ECAPA-TDNN speaker embedding → identify_speaker()
|
||||||
|
→ publish /social/speech/transcript + /social/speech/vad_state
|
||||||
|
|
||||||
|
Latency target: <500ms wake-to-first-token (streaming partial transcripts).
|
||||||
|
|
||||||
|
ROS2 topics published:
|
||||||
|
/social/speech/transcript (saltybot_social_msgs/SpeechTranscript)
|
||||||
|
/social/speech/vad_state (saltybot_social_msgs/VadState)
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
mic_device_index (int, -1 = system default)
|
||||||
|
sample_rate (int, 16000)
|
||||||
|
wake_word_model (str, "hey_salty" — OpenWakeWord model name or path)
|
||||||
|
wake_word_threshold (float, 0.5)
|
||||||
|
vad_threshold_db (float, -35.0) — fallback energy VAD
|
||||||
|
use_silero_vad (bool, true)
|
||||||
|
whisper_model (str, "small")
|
||||||
|
whisper_compute_type (str, "float16")
|
||||||
|
speaker_threshold (float, 0.65) — cosine similarity threshold
|
||||||
|
speaker_db_path (str, "/social_db/speaker_embeddings.json")
|
||||||
|
publish_partial (bool, true)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile, QoSReliabilityPolicy, QoSDurabilityPolicy
|
||||||
|
from builtin_interfaces.msg import Time
|
||||||
|
from std_msgs.msg import Header
|
||||||
|
|
||||||
|
from saltybot_social_msgs.msg import SpeechTranscript, VadState
|
||||||
|
from .speech_utils import (
|
||||||
|
EnergyVad, UtteranceSegmenter,
|
||||||
|
pcm16_to_float32, rms_db, identify_speaker,
|
||||||
|
SAMPLE_RATE, CHUNK_SAMPLES,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechPipelineNode(Node):
|
||||||
|
"""Wake word → VAD → STT → diarization → ROS2 publisher."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__("speech_pipeline_node")
|
||||||
|
|
||||||
|
# ── Parameters ──────────────────────────────────────────────────────
|
||||||
|
self.declare_parameter("mic_device_index", -1)
|
||||||
|
self.declare_parameter("sample_rate", SAMPLE_RATE)
|
||||||
|
self.declare_parameter("wake_word_model", "hey_salty")
|
||||||
|
self.declare_parameter("wake_word_threshold", 0.5)
|
||||||
|
self.declare_parameter("vad_threshold_db", -35.0)
|
||||||
|
self.declare_parameter("use_silero_vad", True)
|
||||||
|
self.declare_parameter("whisper_model", "small")
|
||||||
|
self.declare_parameter("whisper_compute_type", "float16")
|
||||||
|
self.declare_parameter("speaker_threshold", 0.65)
|
||||||
|
self.declare_parameter("speaker_db_path", "/social_db/speaker_embeddings.json")
|
||||||
|
self.declare_parameter("publish_partial", True)
|
||||||
|
|
||||||
|
self._mic_idx = self.get_parameter("mic_device_index").value
|
||||||
|
self._rate = self.get_parameter("sample_rate").value
|
||||||
|
self._ww_model = self.get_parameter("wake_word_model").value
|
||||||
|
self._ww_thresh = self.get_parameter("wake_word_threshold").value
|
||||||
|
self._vad_thresh_db = self.get_parameter("vad_threshold_db").value
|
||||||
|
self._use_silero = self.get_parameter("use_silero_vad").value
|
||||||
|
self._whisper_model_name = self.get_parameter("whisper_model").value
|
||||||
|
self._compute_type = self.get_parameter("whisper_compute_type").value
|
||||||
|
self._speaker_thresh = self.get_parameter("speaker_threshold").value
|
||||||
|
self._speaker_db = self.get_parameter("speaker_db_path").value
|
||||||
|
self._publish_partial = self.get_parameter("publish_partial").value
|
||||||
|
|
||||||
|
# ── Publishers ───────────────────────────────────────────────────────
|
||||||
|
qos = QoSProfile(depth=10)
|
||||||
|
self._transcript_pub = self.create_publisher(SpeechTranscript,
|
||||||
|
"/social/speech/transcript", qos)
|
||||||
|
self._vad_pub = self.create_publisher(VadState, "/social/speech/vad_state", qos)
|
||||||
|
|
||||||
|
# ── Internal state ───────────────────────────────────────────────────
|
||||||
|
self._whisper = None
|
||||||
|
self._ecapa = None
|
||||||
|
self._oww = None
|
||||||
|
self._vad = None
|
||||||
|
self._segmenter = None
|
||||||
|
self._known_speakers: dict = {}
|
||||||
|
self._audio_stream = None
|
||||||
|
self._pa = None
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._running = False
|
||||||
|
self._wake_word_active = False
|
||||||
|
self._wake_expiry = 0.0
|
||||||
|
self._wake_window_s = 8.0 # seconds to listen after wake word
|
||||||
|
self._turn_counter = 0
|
||||||
|
|
||||||
|
# ── Lazy model loading in background thread ──────────────────────────
|
||||||
|
self._model_thread = threading.Thread(target=self._load_models, daemon=True)
|
||||||
|
self._model_thread.start()
|
||||||
|
|
||||||
|
# ── Model loading ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _load_models(self) -> None:
|
||||||
|
"""Load all AI models (runs in background to not block ROS spin)."""
|
||||||
|
self.get_logger().info("Loading speech models (Whisper, ECAPA-TDNN, VAD)...")
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
# faster-whisper
|
||||||
|
try:
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
self._whisper = WhisperModel(
|
||||||
|
self._whisper_model_name,
|
||||||
|
device="cuda",
|
||||||
|
compute_type=self._compute_type,
|
||||||
|
download_root="/models",
|
||||||
|
)
|
||||||
|
self.get_logger().info(
|
||||||
|
f"Whisper '{self._whisper_model_name}' loaded "
|
||||||
|
f"({time.time()-t0:.1f}s)"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f"Whisper load failed: {e}")
|
||||||
|
|
||||||
|
# Silero VAD
|
||||||
|
if self._use_silero:
|
||||||
|
try:
|
||||||
|
from silero_vad import load_silero_vad
|
||||||
|
self._vad = load_silero_vad()
|
||||||
|
self.get_logger().info("Silero VAD loaded")
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().warn(f"Silero VAD unavailable ({e}), using energy VAD")
|
||||||
|
self._vad = None
|
||||||
|
|
||||||
|
energy_vad = EnergyVad(threshold_db=self._vad_thresh_db)
|
||||||
|
self._segmenter = UtteranceSegmenter(energy_vad)
|
||||||
|
|
||||||
|
# ECAPA-TDNN speaker embeddings
|
||||||
|
try:
|
||||||
|
from speechbrain.pretrained import EncoderClassifier
|
||||||
|
self._ecapa = EncoderClassifier.from_hparams(
|
||||||
|
source="speechbrain/spkrec-ecapa-voxceleb",
|
||||||
|
savedir="/models/speechbrain_ecapa",
|
||||||
|
)
|
||||||
|
self.get_logger().info("ECAPA-TDNN loaded")
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().warn(f"ECAPA-TDNN unavailable ({e}), using 'unknown' speaker")
|
||||||
|
|
||||||
|
# OpenWakeWord
|
||||||
|
try:
|
||||||
|
import openwakeword
|
||||||
|
from openwakeword.model import Model as OWWModel
|
||||||
|
self._oww = OWWModel(
|
||||||
|
wakeword_models=[self._ww_model],
|
||||||
|
inference_framework="onnx",
|
||||||
|
)
|
||||||
|
self.get_logger().info(f"OpenWakeWord '{self._ww_model}' loaded")
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().warn(
|
||||||
|
f"OpenWakeWord unavailable ({e}). Robot will listen continuously."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load speaker DB
|
||||||
|
self._load_speaker_db()
|
||||||
|
|
||||||
|
# Start audio capture
|
||||||
|
self._start_audio()
|
||||||
|
self.get_logger().info(f"Speech pipeline ready ({time.time()-t0:.1f}s total)")
|
||||||
|
|
||||||
|
def _load_speaker_db(self) -> None:
|
||||||
|
"""Load speaker embedding database from JSON."""
|
||||||
|
if os.path.exists(self._speaker_db):
|
||||||
|
try:
|
||||||
|
with open(self._speaker_db) as f:
|
||||||
|
self._known_speakers = json.load(f)
|
||||||
|
self.get_logger().info(
|
||||||
|
f"Speaker DB loaded: {len(self._known_speakers)} persons"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f"Speaker DB load error: {e}")
|
||||||
|
|
||||||
|
# ── Audio capture ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _start_audio(self) -> None:
|
||||||
|
"""Open PyAudio stream and start capture thread."""
|
||||||
|
try:
|
||||||
|
import pyaudio
|
||||||
|
self._pa = pyaudio.PyAudio()
|
||||||
|
dev_idx = None if self._mic_idx < 0 else self._mic_idx
|
||||||
|
self._audio_stream = self._pa.open(
|
||||||
|
format=pyaudio.paInt16,
|
||||||
|
channels=1,
|
||||||
|
rate=self._rate,
|
||||||
|
input=True,
|
||||||
|
input_device_index=dev_idx,
|
||||||
|
frames_per_buffer=CHUNK_SAMPLES,
|
||||||
|
)
|
||||||
|
self._running = True
|
||||||
|
t = threading.Thread(target=self._audio_loop, daemon=True)
|
||||||
|
t.start()
|
||||||
|
self.get_logger().info(
|
||||||
|
f"Audio capture started (device={self._mic_idx}, {self._rate}Hz)"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f"Audio stream error: {e}")
|
||||||
|
|
||||||
|
def _audio_loop(self) -> None:
|
||||||
|
"""Continuous audio capture loop — runs in dedicated thread."""
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
raw = self._audio_stream.read(CHUNK_SAMPLES, exception_on_overflow=False)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
samples = pcm16_to_float32(raw)
|
||||||
|
db = rms_db(samples)
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
# VAD state publish (10Hz — every 3 chunks at 30ms each)
|
||||||
|
vad_active = self._check_vad(samples)
|
||||||
|
state_str = "silence"
|
||||||
|
|
||||||
|
# Wake word check
|
||||||
|
if self._oww is not None:
|
||||||
|
preds = self._oww.predict(samples)
|
||||||
|
score = preds.get(self._ww_model, 0.0)
|
||||||
|
if isinstance(score, (list, tuple)):
|
||||||
|
score = score[-1]
|
||||||
|
if score >= self._ww_thresh:
|
||||||
|
self._wake_word_active = True
|
||||||
|
self._wake_expiry = now + self._wake_window_s
|
||||||
|
state_str = "wake_word"
|
||||||
|
self.get_logger().info(
|
||||||
|
f"Wake word detected (score={score:.2f})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No wake word detector — always listen
|
||||||
|
self._wake_word_active = True
|
||||||
|
self._wake_expiry = now + self._wake_window_s
|
||||||
|
|
||||||
|
if self._wake_word_active and now > self._wake_expiry:
|
||||||
|
self._wake_word_active = False
|
||||||
|
|
||||||
|
# Publish VAD state
|
||||||
|
if vad_active:
|
||||||
|
state_str = "speech"
|
||||||
|
self._publish_vad(vad_active, db, state_str)
|
||||||
|
|
||||||
|
# Only feed segmenter when wake word is active
|
||||||
|
if not self._wake_word_active:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self._segmenter is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
completed = self._segmenter.push(samples)
|
||||||
|
for utt_samples, duration in completed:
|
||||||
|
threading.Thread(
|
||||||
|
target=self._process_utterance,
|
||||||
|
args=(utt_samples, duration),
|
||||||
|
daemon=True,
|
||||||
|
).start()
|
||||||
|
|
||||||
|
def _check_vad(self, samples: list) -> bool:
|
||||||
|
"""Run Silero VAD or fall back to segmenter's energy VAD."""
|
||||||
|
if self._vad is not None:
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
t = torch.tensor(samples, dtype=torch.float32)
|
||||||
|
prob = self._vad(t, self._rate).item()
|
||||||
|
return prob > 0.5
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if self._segmenter is not None:
|
||||||
|
return self._segmenter._vad.is_active
|
||||||
|
return False
|
||||||
|
|
||||||
|
# ── STT + diarization ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _process_utterance(self, samples: list, duration: float) -> None:
|
||||||
|
"""Transcribe utterance and identify speaker. Runs in thread."""
|
||||||
|
if self._whisper is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
|
||||||
|
# Convert to numpy for faster-whisper
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
audio_np = np.array(samples, dtype=np.float32)
|
||||||
|
except ImportError:
|
||||||
|
audio_np = samples
|
||||||
|
|
||||||
|
# Speaker embedding first (can be concurrent with transcription)
|
||||||
|
speaker_id = "unknown"
|
||||||
|
if self._ecapa is not None:
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
audio_tensor = torch.tensor([samples], dtype=torch.float32)
|
||||||
|
with torch.no_grad():
|
||||||
|
emb = self._ecapa.encode_batch(audio_tensor)
|
||||||
|
emb_list = emb[0].cpu().numpy().tolist()
|
||||||
|
speaker_id = identify_speaker(
|
||||||
|
emb_list, self._known_speakers, self._speaker_thresh
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().debug(f"Speaker ID error: {e}")
|
||||||
|
|
||||||
|
# Streaming Whisper transcription
|
||||||
|
partial_text = ""
|
||||||
|
try:
|
||||||
|
segments_gen, _info = self._whisper.transcribe(
|
||||||
|
audio_np,
|
||||||
|
language="en",
|
||||||
|
beam_size=3,
|
||||||
|
vad_filter=False,
|
||||||
|
)
|
||||||
|
for seg in segments_gen:
|
||||||
|
partial_text += seg.text.strip() + " "
|
||||||
|
if self._publish_partial:
|
||||||
|
self._publish_transcript(
|
||||||
|
partial_text.strip(), speaker_id, 0.0, duration, is_partial=True
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f"Whisper error: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
final_text = partial_text.strip()
|
||||||
|
if not final_text:
|
||||||
|
return
|
||||||
|
|
||||||
|
latency_ms = (time.perf_counter() - t0) * 1000
|
||||||
|
self.get_logger().info(
|
||||||
|
f"STT [{speaker_id}] ({duration:.1f}s, {latency_ms:.0f}ms): '{final_text}'"
|
||||||
|
)
|
||||||
|
self._publish_transcript(final_text, speaker_id, 0.9, duration, is_partial=False)
|
||||||
|
|
||||||
|
# ── Publishers ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _publish_transcript(
|
||||||
|
self, text: str, speaker_id: str, confidence: float,
|
||||||
|
duration: float, is_partial: bool
|
||||||
|
) -> None:
|
||||||
|
msg = SpeechTranscript()
|
||||||
|
msg.header.stamp = self.get_clock().now().to_msg()
|
||||||
|
msg.text = text
|
||||||
|
msg.speaker_id = speaker_id
|
||||||
|
msg.confidence = confidence
|
||||||
|
msg.audio_duration = duration
|
||||||
|
msg.is_partial = is_partial
|
||||||
|
self._transcript_pub.publish(msg)
|
||||||
|
|
||||||
|
def _publish_vad(self, speech_active: bool, db: float, state: str) -> None:
|
||||||
|
msg = VadState()
|
||||||
|
msg.header.stamp = self.get_clock().now().to_msg()
|
||||||
|
msg.speech_active = speech_active
|
||||||
|
msg.energy_db = db
|
||||||
|
msg.state = state
|
||||||
|
self._vad_pub.publish(msg)
|
||||||
|
|
||||||
|
# ── Cleanup ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def destroy_node(self) -> None:
|
||||||
|
self._running = False
|
||||||
|
if self._audio_stream:
|
||||||
|
self._audio_stream.stop_stream()
|
||||||
|
self._audio_stream.close()
|
||||||
|
if self._pa:
|
||||||
|
self._pa.terminate()
|
||||||
|
super().destroy_node()
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = SpeechPipelineNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
@ -0,0 +1,188 @@
|
|||||||
|
"""speech_utils.py — Pure helpers for wake word, VAD, and STT.
|
||||||
|
|
||||||
|
No ROS2 dependencies. All functions operate on raw numpy float32 arrays
|
||||||
|
at 16 kHz mono (the standard for Whisper / Silero VAD).
|
||||||
|
|
||||||
|
Tested by test/test_speech_utils.py (no GPU required).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import struct
|
||||||
|
from typing import Generator, List, Tuple
|
||||||
|
|
||||||
|
# ── Constants ──────────────────────────────────────────────────────────────────
|
||||||
|
SAMPLE_RATE = 16000 # Hz — standard for Whisper / Silero
|
||||||
|
CHUNK_MS = 30 # VAD frame size
|
||||||
|
CHUNK_SAMPLES = int(SAMPLE_RATE * CHUNK_MS / 1000) # 480 samples
|
||||||
|
INT16_MAX = 32768.0
|
||||||
|
|
||||||
|
|
||||||
|
# ── Audio helpers ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def pcm16_to_float32(data: bytes) -> list:
|
||||||
|
"""Convert raw PCM16 LE bytes → float32 list in [-1.0, 1.0]."""
|
||||||
|
n = len(data) // 2
|
||||||
|
samples = struct.unpack(f"<{n}h", data[:n * 2])
|
||||||
|
return [s / INT16_MAX for s in samples]
|
||||||
|
|
||||||
|
|
||||||
|
def float32_to_pcm16(samples: list) -> bytes:
|
||||||
|
"""Convert float32 list → PCM16 LE bytes (clipped to ±1.0)."""
|
||||||
|
clipped = [max(-1.0, min(1.0, s)) for s in samples]
|
||||||
|
ints = [max(-32768, min(32767, int(s * INT16_MAX))) for s in clipped]
|
||||||
|
return struct.pack(f"<{len(ints)}h", *ints)
|
||||||
|
|
||||||
|
|
||||||
|
def rms_db(samples: list) -> float:
|
||||||
|
"""Compute RMS energy in dBFS. Returns -96.0 for silence."""
|
||||||
|
if not samples:
|
||||||
|
return -96.0
|
||||||
|
mean_sq = sum(s * s for s in samples) / len(samples)
|
||||||
|
rms = math.sqrt(mean_sq) if mean_sq > 0.0 else 1e-10
|
||||||
|
return 20.0 * math.log10(max(rms, 1e-10))
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_audio(samples: list, chunk_size: int = CHUNK_SAMPLES) -> Generator[list, None, None]:
|
||||||
|
"""Yield fixed-size chunks from a flat sample list."""
|
||||||
|
for i in range(0, len(samples) - chunk_size + 1, chunk_size):
|
||||||
|
yield samples[i:i + chunk_size]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Simple energy-based VAD (fallback when Silero unavailable) ────────────────
|
||||||
|
|
||||||
|
class EnergyVad:
|
||||||
|
"""Threshold-based energy VAD.
|
||||||
|
|
||||||
|
Useful as a lightweight fallback / unit-testable reference.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
threshold_db: float = -35.0,
|
||||||
|
onset_frames: int = 2,
|
||||||
|
offset_frames: int = 8,
|
||||||
|
) -> None:
|
||||||
|
self.threshold_db = threshold_db
|
||||||
|
self.onset_frames = onset_frames
|
||||||
|
self.offset_frames = offset_frames
|
||||||
|
self._above_count = 0
|
||||||
|
self._below_count = 0
|
||||||
|
self._active = False
|
||||||
|
|
||||||
|
def process(self, chunk: list) -> bool:
|
||||||
|
"""Process one chunk (CHUNK_SAMPLES long). Returns True if speech active."""
|
||||||
|
db = rms_db(chunk)
|
||||||
|
if db >= self.threshold_db:
|
||||||
|
self._above_count += 1
|
||||||
|
self._below_count = 0
|
||||||
|
if self._above_count >= self.onset_frames:
|
||||||
|
self._active = True
|
||||||
|
else:
|
||||||
|
self._below_count += 1
|
||||||
|
self._above_count = 0
|
||||||
|
if self._below_count >= self.offset_frames:
|
||||||
|
self._active = False
|
||||||
|
return self._active
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self._above_count = 0
|
||||||
|
self._below_count = 0
|
||||||
|
self._active = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_active(self) -> bool:
|
||||||
|
return self._active
|
||||||
|
|
||||||
|
|
||||||
|
# ── Utterance segmenter ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class UtteranceSegmenter:
|
||||||
|
"""Accumulates audio chunks into complete utterances using VAD.
|
||||||
|
|
||||||
|
Yields (samples, duration_s) when an utterance ends.
|
||||||
|
Pre-roll keeps a buffer before speech onset (catches first phoneme).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vad: EnergyVad,
|
||||||
|
pre_roll_frames: int = 5,
|
||||||
|
max_duration_s: float = 15.0,
|
||||||
|
sample_rate: int = SAMPLE_RATE,
|
||||||
|
chunk_samples: int = CHUNK_SAMPLES,
|
||||||
|
) -> None:
|
||||||
|
self._vad = vad
|
||||||
|
self._pre_roll = pre_roll_frames
|
||||||
|
self._max_frames = int(max_duration_s * sample_rate / chunk_samples)
|
||||||
|
self._sample_rate = sample_rate
|
||||||
|
self._chunk_samples = chunk_samples
|
||||||
|
self._pre_buf: list = [] # ring buffer for pre-roll
|
||||||
|
self._utt_buf: list = [] # growing utterance buffer
|
||||||
|
self._in_utt = False
|
||||||
|
self._silence_after = 0
|
||||||
|
|
||||||
|
def push(self, chunk: list) -> list:
|
||||||
|
"""Push one chunk, return list of completed (samples, duration_s) tuples."""
|
||||||
|
speech = self._vad.process(chunk)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
if not self._in_utt:
|
||||||
|
self._pre_buf.append(chunk)
|
||||||
|
if len(self._pre_buf) > self._pre_roll:
|
||||||
|
self._pre_buf.pop(0)
|
||||||
|
|
||||||
|
if speech:
|
||||||
|
self._in_utt = True
|
||||||
|
self._utt_buf = [s for c in self._pre_buf for s in c]
|
||||||
|
self._utt_buf.extend(chunk)
|
||||||
|
self._silence_after = 0
|
||||||
|
else:
|
||||||
|
self._utt_buf.extend(chunk)
|
||||||
|
if not speech:
|
||||||
|
self._silence_after += 1
|
||||||
|
if (self._silence_after >= self._vad.offset_frames or
|
||||||
|
len(self._utt_buf) >= self._max_frames * self._chunk_samples):
|
||||||
|
dur = len(self._utt_buf) / self._sample_rate
|
||||||
|
results.append((list(self._utt_buf), dur))
|
||||||
|
self._utt_buf = []
|
||||||
|
self._in_utt = False
|
||||||
|
self._silence_after = 0
|
||||||
|
else:
|
||||||
|
self._silence_after = 0
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# ── Speaker embedding distance ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def cosine_similarity(a: list, b: list) -> float:
|
||||||
|
"""Cosine similarity between two embedding vectors."""
|
||||||
|
if len(a) != len(b) or not a:
|
||||||
|
return 0.0
|
||||||
|
dot = sum(x * y for x, y in zip(a, b))
|
||||||
|
norm_a = math.sqrt(sum(x * x for x in a))
|
||||||
|
norm_b = math.sqrt(sum(x * x for x in b))
|
||||||
|
if norm_a < 1e-10 or norm_b < 1e-10:
|
||||||
|
return 0.0
|
||||||
|
return dot / (norm_a * norm_b)
|
||||||
|
|
||||||
|
|
||||||
|
def identify_speaker(
|
||||||
|
embedding: list,
|
||||||
|
known_speakers: dict,
|
||||||
|
threshold: float = 0.65,
|
||||||
|
) -> str:
|
||||||
|
"""Return speaker_id from known_speakers dict or 'unknown'.
|
||||||
|
|
||||||
|
known_speakers: {speaker_id: embedding_list}
|
||||||
|
"""
|
||||||
|
best_id = "unknown"
|
||||||
|
best_sim = threshold
|
||||||
|
for spk_id, spk_emb in known_speakers.items():
|
||||||
|
sim = cosine_similarity(embedding, spk_emb)
|
||||||
|
if sim > best_sim:
|
||||||
|
best_sim = sim
|
||||||
|
best_id = spk_id
|
||||||
|
return best_id
|
||||||
228
jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_node.py
Normal file
228
jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_node.py
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
"""tts_node.py — Streaming TTS with Piper / first-chunk streaming.
|
||||||
|
|
||||||
|
Issue #85: Streaming TTS — Piper/XTTS integration with first-chunk streaming.
|
||||||
|
|
||||||
|
Pipeline:
|
||||||
|
/social/conversation/response (ConversationResponse)
|
||||||
|
→ sentence split → Piper ONNX synthesis (sentence by sentence)
|
||||||
|
→ PCM16 chunks → USB speaker (sounddevice) + publish /social/tts/audio
|
||||||
|
|
||||||
|
First-chunk strategy:
|
||||||
|
- On partial=true ConversationResponse, extract first sentence and synthesize
|
||||||
|
immediately → audio starts before LLM finishes generating
|
||||||
|
- On final=false, synthesize remaining sentences
|
||||||
|
|
||||||
|
Latency target: <200ms to first audio chunk.
|
||||||
|
|
||||||
|
ROS2 topics:
|
||||||
|
Subscribe: /social/conversation/response (saltybot_social_msgs/ConversationResponse)
|
||||||
|
Publish: /social/tts/audio (audio_msgs/Audio or std_msgs/UInt8MultiArray fallback)
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
voice_path (str, "/models/piper/en_US-lessac-medium.onnx")
|
||||||
|
sample_rate (int, 22050)
|
||||||
|
volume (float, 1.0)
|
||||||
|
audio_device (str, "") — sounddevice device name; "" = system default
|
||||||
|
playback_enabled (bool, true)
|
||||||
|
publish_audio (bool, false) — publish PCM to ROS2 topic
|
||||||
|
sentence_streaming (bool, true) — synthesize sentence-by-sentence
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile
|
||||||
|
from std_msgs.msg import UInt8MultiArray
|
||||||
|
|
||||||
|
from saltybot_social_msgs.msg import ConversationResponse
|
||||||
|
from .tts_utils import split_sentences, strip_ssml, apply_volume, chunk_pcm, estimate_duration_ms
|
||||||
|
|
||||||
|
|
||||||
|
class TtsNode(Node):
|
||||||
|
"""Streaming TTS node using Piper ONNX."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__("tts_node")
|
||||||
|
|
||||||
|
# ── Parameters ──────────────────────────────────────────────────────
|
||||||
|
self.declare_parameter("voice_path", "/models/piper/en_US-lessac-medium.onnx")
|
||||||
|
self.declare_parameter("sample_rate", 22050)
|
||||||
|
self.declare_parameter("volume", 1.0)
|
||||||
|
self.declare_parameter("audio_device", "")
|
||||||
|
self.declare_parameter("playback_enabled", True)
|
||||||
|
self.declare_parameter("publish_audio", False)
|
||||||
|
self.declare_parameter("sentence_streaming", True)
|
||||||
|
|
||||||
|
self._voice_path = self.get_parameter("voice_path").value
|
||||||
|
self._sample_rate = self.get_parameter("sample_rate").value
|
||||||
|
self._volume = self.get_parameter("volume").value
|
||||||
|
self._audio_device = self.get_parameter("audio_device").value or None
|
||||||
|
self._playback = self.get_parameter("playback_enabled").value
|
||||||
|
self._publish_audio = self.get_parameter("publish_audio").value
|
||||||
|
self._sentence_streaming = self.get_parameter("sentence_streaming").value
|
||||||
|
|
||||||
|
# ── Publishers / Subscribers ─────────────────────────────────────────
|
||||||
|
qos = QoSProfile(depth=10)
|
||||||
|
self._resp_sub = self.create_subscription(
|
||||||
|
ConversationResponse, "/social/conversation/response",
|
||||||
|
self._on_response, qos
|
||||||
|
)
|
||||||
|
if self._publish_audio:
|
||||||
|
self._audio_pub = self.create_publisher(
|
||||||
|
UInt8MultiArray, "/social/tts/audio", qos
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── TTS engine ────────────────────────────────────────────────────────
|
||||||
|
self._voice = None
|
||||||
|
self._playback_queue: queue.Queue = queue.Queue(maxsize=16)
|
||||||
|
self._current_turn = -1
|
||||||
|
self._synthesized_turns: set = set() # turn_ids already synthesized
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
threading.Thread(target=self._load_voice, daemon=True).start()
|
||||||
|
threading.Thread(target=self._playback_worker, daemon=True).start()
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f"TtsNode init (voice={self._voice_path}, "
|
||||||
|
f"streaming={self._sentence_streaming})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Voice loading ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _load_voice(self) -> None:
|
||||||
|
t0 = time.time()
|
||||||
|
self.get_logger().info(f"Loading Piper voice: {self._voice_path}")
|
||||||
|
try:
|
||||||
|
from piper import PiperVoice
|
||||||
|
self._voice = PiperVoice.load(self._voice_path)
|
||||||
|
# Warmup synthesis to pre-JIT ONNX graph
|
||||||
|
warmup_text = "Hello."
|
||||||
|
list(self._voice.synthesize_stream_raw(warmup_text))
|
||||||
|
self.get_logger().info(f"Piper voice ready ({time.time()-t0:.1f}s)")
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f"Piper voice load failed: {e}")
|
||||||
|
|
||||||
|
# ── Response handler ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _on_response(self, msg: ConversationResponse) -> None:
|
||||||
|
"""Handle streaming LLM response — synthesize sentence by sentence."""
|
||||||
|
if not msg.text.strip():
|
||||||
|
return
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
is_new_turn = msg.turn_id != self._current_turn
|
||||||
|
if is_new_turn:
|
||||||
|
self._current_turn = msg.turn_id
|
||||||
|
# Clear old synthesized sentence cache for this new turn
|
||||||
|
self._synthesized_turns = set()
|
||||||
|
|
||||||
|
text = strip_ssml(msg.text)
|
||||||
|
|
||||||
|
if self._sentence_streaming:
|
||||||
|
sentences = split_sentences(text)
|
||||||
|
for sentence in sentences:
|
||||||
|
# Track which sentences we've already queued by content hash
|
||||||
|
key = (msg.turn_id, hash(sentence))
|
||||||
|
with self._lock:
|
||||||
|
if key in self._synthesized_turns:
|
||||||
|
continue
|
||||||
|
self._synthesized_turns.add(key)
|
||||||
|
self._queue_synthesis(sentence)
|
||||||
|
elif not msg.is_partial:
|
||||||
|
# Non-streaming: synthesize full response at end
|
||||||
|
self._queue_synthesis(text)
|
||||||
|
|
||||||
|
def _queue_synthesis(self, text: str) -> None:
|
||||||
|
"""Queue a text segment for synthesis in the playback worker."""
|
||||||
|
if not text.strip():
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self._playback_queue.put_nowait(text.strip())
|
||||||
|
except queue.Full:
|
||||||
|
self.get_logger().warn("TTS playback queue full, dropping segment")
|
||||||
|
|
||||||
|
# ── Playback worker ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _playback_worker(self) -> None:
|
||||||
|
"""Consume synthesis queue: synthesize → play → publish."""
|
||||||
|
while rclpy.ok():
|
||||||
|
try:
|
||||||
|
text = self._playback_queue.get(timeout=0.5)
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self._voice is None:
|
||||||
|
self.get_logger().warn("TTS voice not loaded yet")
|
||||||
|
self._playback_queue.task_done()
|
||||||
|
continue
|
||||||
|
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
pcm_data = self._synthesize(text)
|
||||||
|
if pcm_data is None:
|
||||||
|
self._playback_queue.task_done()
|
||||||
|
continue
|
||||||
|
|
||||||
|
synth_ms = (time.perf_counter() - t0) * 1000
|
||||||
|
dur_ms = estimate_duration_ms(pcm_data, self._sample_rate)
|
||||||
|
self.get_logger().debug(
|
||||||
|
f"TTS '{text[:40]}' synth={synth_ms:.0f}ms, dur={dur_ms:.0f}ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._volume != 1.0:
|
||||||
|
pcm_data = apply_volume(pcm_data, self._volume)
|
||||||
|
|
||||||
|
if self._playback:
|
||||||
|
self._play_audio(pcm_data)
|
||||||
|
|
||||||
|
if self._publish_audio:
|
||||||
|
self._publish_pcm(pcm_data)
|
||||||
|
|
||||||
|
self._playback_queue.task_done()
|
||||||
|
|
||||||
|
def _synthesize(self, text: str) -> Optional[bytes]:
|
||||||
|
"""Synthesize text to PCM16 bytes using Piper streaming."""
|
||||||
|
if self._voice is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
chunks = list(self._voice.synthesize_stream_raw(text))
|
||||||
|
return b"".join(chunks)
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f"TTS synthesis error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _play_audio(self, pcm_data: bytes) -> None:
|
||||||
|
"""Play PCM16 data on USB speaker via sounddevice."""
|
||||||
|
try:
|
||||||
|
import sounddevice as sd
|
||||||
|
import numpy as np
|
||||||
|
samples = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
sd.play(samples, samplerate=self._sample_rate, device=self._audio_device,
|
||||||
|
blocking=True)
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().error(f"Audio playback error: {e}")
|
||||||
|
|
||||||
|
def _publish_pcm(self, pcm_data: bytes) -> None:
|
||||||
|
"""Publish PCM data as UInt8MultiArray."""
|
||||||
|
if not hasattr(self, "_audio_pub"):
|
||||||
|
return
|
||||||
|
msg = UInt8MultiArray()
|
||||||
|
msg.data = list(pcm_data)
|
||||||
|
self._audio_pub.publish(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = TtsNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
100
jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_utils.py
Normal file
100
jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_utils.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
"""tts_utils.py — Pure TTS helpers for streaming synthesis.
|
||||||
|
|
||||||
|
No ROS2 dependencies. Provides streaming Piper synthesis and audio chunking.
|
||||||
|
|
||||||
|
Tested by test/test_tts_utils.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
import struct
|
||||||
|
from typing import Generator, List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
# ── SSML helpers ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_SSML_STRIP_RE = re.compile(r"<[^>]+>")
|
||||||
|
|
||||||
|
|
||||||
|
def strip_ssml(text: str) -> str:
|
||||||
|
"""Remove SSML tags, returning plain text."""
|
||||||
|
return _SSML_STRIP_RE.sub("", text).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def split_sentences(text: str) -> List[str]:
|
||||||
|
"""Split text into sentences for streaming TTS synthesis.
|
||||||
|
|
||||||
|
Returns list of non-empty sentence strings.
|
||||||
|
"""
|
||||||
|
# Split on sentence-ending punctuation followed by whitespace or end
|
||||||
|
parts = re.split(r"(?<=[.!?])\s+", text.strip())
|
||||||
|
return [p.strip() for p in parts if p.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def split_first_sentence(text: str) -> Tuple[str, str]:
|
||||||
|
"""Return (first_sentence, remainder).
|
||||||
|
|
||||||
|
Used by TTS node to begin synthesis of first sentence immediately
|
||||||
|
while the LLM continues generating the rest.
|
||||||
|
"""
|
||||||
|
sentences = split_sentences(text)
|
||||||
|
if not sentences:
|
||||||
|
return "", ""
|
||||||
|
return sentences[0], " ".join(sentences[1:])
|
||||||
|
|
||||||
|
|
||||||
|
# ── PCM audio helpers ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def pcm16_to_wav_bytes(pcm_data: bytes, sample_rate: int = 22050) -> bytes:
|
||||||
|
"""Wrap raw PCM16 LE mono data with a WAV header."""
|
||||||
|
num_samples = len(pcm_data) // 2
|
||||||
|
num_channels = 1
|
||||||
|
bits_per_sample = 16
|
||||||
|
byte_rate = sample_rate * num_channels * bits_per_sample // 8
|
||||||
|
block_align = num_channels * bits_per_sample // 8
|
||||||
|
data_size = len(pcm_data)
|
||||||
|
header_size = 44
|
||||||
|
|
||||||
|
header = struct.pack(
|
||||||
|
"<4sI4s4sIHHIIHH4sI",
|
||||||
|
b"RIFF",
|
||||||
|
header_size - 8 + data_size,
|
||||||
|
b"WAVE",
|
||||||
|
b"fmt ",
|
||||||
|
16, # chunk size
|
||||||
|
1, # PCM format
|
||||||
|
num_channels,
|
||||||
|
sample_rate,
|
||||||
|
byte_rate,
|
||||||
|
block_align,
|
||||||
|
bits_per_sample,
|
||||||
|
b"data",
|
||||||
|
data_size,
|
||||||
|
)
|
||||||
|
return header + pcm_data
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_pcm(pcm_data: bytes, chunk_ms: int = 200, sample_rate: int = 22050) -> Generator[bytes, None, None]:
|
||||||
|
"""Yield PCM data in fixed-size chunks for streaming playback."""
|
||||||
|
chunk_bytes = (sample_rate * 2 * chunk_ms) // 1000 # int16 = 2 bytes/sample
|
||||||
|
chunk_bytes = max(chunk_bytes, 2) & ~1 # ensure even (int16 aligned)
|
||||||
|
for i in range(0, len(pcm_data), chunk_bytes):
|
||||||
|
yield pcm_data[i:i + chunk_bytes]
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_duration_ms(pcm_data: bytes, sample_rate: int = 22050) -> float:
|
||||||
|
"""Estimate playback duration in ms from PCM byte length."""
|
||||||
|
return len(pcm_data) / (sample_rate * 2) * 1000.0
|
||||||
|
|
||||||
|
|
||||||
|
# ── Volume control ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def apply_volume(pcm_data: bytes, gain: float) -> bytes:
|
||||||
|
"""Scale PCM16 LE samples by gain factor (0.0–2.0). Clips to ±32767."""
|
||||||
|
if abs(gain - 1.0) < 0.01:
|
||||||
|
return pcm_data
|
||||||
|
n = len(pcm_data) // 2
|
||||||
|
samples = struct.unpack(f"<{n}h", pcm_data[:n * 2])
|
||||||
|
scaled = [max(-32767, min(32767, int(s * gain))) for s in samples]
|
||||||
|
return struct.pack(f"<{n}h", *scaled)
|
||||||
@ -21,7 +21,7 @@ setup(
|
|||||||
zip_safe=True,
|
zip_safe=True,
|
||||||
maintainer='seb',
|
maintainer='seb',
|
||||||
maintainer_email='seb@vayrette.com',
|
maintainer_email='seb@vayrette.com',
|
||||||
description='Social interaction layer — person state tracking, LED expression + attention',
|
description='Social interaction layer — person tracking, speech, LLM, TTS, orchestrator',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
tests_require=['pytest'],
|
tests_require=['pytest'],
|
||||||
entry_points={
|
entry_points={
|
||||||
@ -29,6 +29,10 @@ setup(
|
|||||||
'person_state_tracker = saltybot_social.person_state_tracker_node:main',
|
'person_state_tracker = saltybot_social.person_state_tracker_node:main',
|
||||||
'expression_node = saltybot_social.expression_node:main',
|
'expression_node = saltybot_social.expression_node:main',
|
||||||
'attention_node = saltybot_social.attention_node:main',
|
'attention_node = saltybot_social.attention_node:main',
|
||||||
|
'speech_pipeline_node = saltybot_social.speech_pipeline_node:main',
|
||||||
|
'conversation_node = saltybot_social.conversation_node:main',
|
||||||
|
'tts_node = saltybot_social.tts_node:main',
|
||||||
|
'orchestrator_node = saltybot_social.orchestrator_node:main',
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
244
jetson/ros2_ws/src/saltybot_social/test/test_llm_context.py
Normal file
244
jetson/ros2_ws/src/saltybot_social/test/test_llm_context.py
Normal file
@ -0,0 +1,244 @@
|
|||||||
|
"""test_llm_context.py — Unit tests for llm_context.py (no GPU / LLM needed)."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
|
||||||
|
from saltybot_social.llm_context import (
|
||||||
|
PersonContext, ContextStore,
|
||||||
|
build_llama_prompt, MAX_TURNS, SUMMARY_KEEP,
|
||||||
|
DEFAULT_SYSTEM_PROMPT,
|
||||||
|
)
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ── PersonContext ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestPersonContext:
|
||||||
|
def test_create(self):
|
||||||
|
ctx = PersonContext("person_1", "Alice")
|
||||||
|
assert ctx.person_id == "person_1"
|
||||||
|
assert ctx.person_name == "Alice"
|
||||||
|
assert ctx.turns == []
|
||||||
|
assert ctx.interaction_count == 0
|
||||||
|
|
||||||
|
def test_add_user_increments_count(self):
|
||||||
|
ctx = PersonContext("person_1")
|
||||||
|
ctx.add_user("Hello")
|
||||||
|
assert ctx.interaction_count == 1
|
||||||
|
assert ctx.turns[-1]["role"] == "user"
|
||||||
|
assert ctx.turns[-1]["content"] == "Hello"
|
||||||
|
|
||||||
|
def test_add_assistant(self):
|
||||||
|
ctx = PersonContext("person_1")
|
||||||
|
ctx.add_assistant("Hi there!")
|
||||||
|
assert ctx.turns[-1]["role"] == "assistant"
|
||||||
|
|
||||||
|
def test_needs_compression_false(self):
|
||||||
|
ctx = PersonContext("p")
|
||||||
|
for i in range(MAX_TURNS - 1):
|
||||||
|
ctx.add_user(f"msg {i}")
|
||||||
|
assert not ctx.needs_compression()
|
||||||
|
|
||||||
|
def test_needs_compression_true(self):
|
||||||
|
ctx = PersonContext("p")
|
||||||
|
for i in range(MAX_TURNS + 1):
|
||||||
|
ctx.add_user(f"msg {i}")
|
||||||
|
assert ctx.needs_compression()
|
||||||
|
|
||||||
|
def test_compress_keeps_recent_turns(self):
|
||||||
|
ctx = PersonContext("p")
|
||||||
|
for i in range(MAX_TURNS + 5):
|
||||||
|
ctx.add_user(f"msg {i}")
|
||||||
|
ctx.compress("Summary of conversation")
|
||||||
|
assert len(ctx.turns) == SUMMARY_KEEP
|
||||||
|
assert ctx.summary == "Summary of conversation"
|
||||||
|
|
||||||
|
def test_compress_appends_summary(self):
|
||||||
|
ctx = PersonContext("p")
|
||||||
|
ctx.add_user("msg")
|
||||||
|
ctx.compress("First summary")
|
||||||
|
ctx.add_user("msg2")
|
||||||
|
ctx.compress("Second summary")
|
||||||
|
assert "First summary" in ctx.summary
|
||||||
|
assert "Second summary" in ctx.summary
|
||||||
|
|
||||||
|
def test_roundtrip_serialization(self):
|
||||||
|
ctx = PersonContext("person_42", "Bob")
|
||||||
|
ctx.add_user("test")
|
||||||
|
ctx.add_assistant("response")
|
||||||
|
ctx.summary = "Earlier they discussed robots"
|
||||||
|
d = ctx.to_dict()
|
||||||
|
ctx2 = PersonContext.from_dict(d)
|
||||||
|
assert ctx2.person_id == "person_42"
|
||||||
|
assert ctx2.person_name == "Bob"
|
||||||
|
assert len(ctx2.turns) == 2
|
||||||
|
assert ctx2.summary == "Earlier they discussed robots"
|
||||||
|
|
||||||
|
|
||||||
|
# ── ContextStore ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestContextStore:
|
||||||
|
def test_get_creates_new(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db = os.path.join(tmpdir, "ctx.json")
|
||||||
|
store = ContextStore(db)
|
||||||
|
ctx = store.get("p1", "Alice")
|
||||||
|
assert ctx.person_id == "p1"
|
||||||
|
assert ctx.person_name == "Alice"
|
||||||
|
|
||||||
|
def test_get_same_instance(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db = os.path.join(tmpdir, "ctx.json")
|
||||||
|
store = ContextStore(db)
|
||||||
|
ctx1 = store.get("p1")
|
||||||
|
ctx2 = store.get("p1")
|
||||||
|
assert ctx1 is ctx2
|
||||||
|
|
||||||
|
def test_save_and_load(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db = os.path.join(tmpdir, "ctx.json")
|
||||||
|
store = ContextStore(db)
|
||||||
|
ctx = store.get("p1", "Alice")
|
||||||
|
ctx.add_user("Hello")
|
||||||
|
store.save()
|
||||||
|
|
||||||
|
store2 = ContextStore(db)
|
||||||
|
ctx2 = store2.get("p1")
|
||||||
|
assert ctx2.person_name == "Alice"
|
||||||
|
assert len(ctx2.turns) == 1
|
||||||
|
assert ctx2.turns[0]["content"] == "Hello"
|
||||||
|
|
||||||
|
def test_all_persons(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db = os.path.join(tmpdir, "ctx.json")
|
||||||
|
store = ContextStore(db)
|
||||||
|
store.get("p1")
|
||||||
|
store.get("p2")
|
||||||
|
store.get("p3")
|
||||||
|
assert set(store.all_persons()) == {"p1", "p2", "p3"}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Prompt builder ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestBuildLlamaPrompt:
|
||||||
|
def test_contains_system_prompt(self):
|
||||||
|
ctx = PersonContext("p1")
|
||||||
|
prompt = build_llama_prompt(ctx, "hello", DEFAULT_SYSTEM_PROMPT)
|
||||||
|
assert DEFAULT_SYSTEM_PROMPT in prompt
|
||||||
|
|
||||||
|
def test_contains_user_text(self):
|
||||||
|
ctx = PersonContext("p1")
|
||||||
|
prompt = build_llama_prompt(ctx, "What time is it?", DEFAULT_SYSTEM_PROMPT)
|
||||||
|
assert "What time is it?" in prompt
|
||||||
|
|
||||||
|
def test_ends_with_assistant_tag(self):
|
||||||
|
ctx = PersonContext("p1")
|
||||||
|
prompt = build_llama_prompt(ctx, "hello", DEFAULT_SYSTEM_PROMPT)
|
||||||
|
assert prompt.strip().endswith("<|assistant|>")
|
||||||
|
|
||||||
|
def test_history_included(self):
|
||||||
|
ctx = PersonContext("p1")
|
||||||
|
ctx.add_user("Previous question")
|
||||||
|
ctx.add_assistant("Previous answer")
|
||||||
|
prompt = build_llama_prompt(ctx, "New question", DEFAULT_SYSTEM_PROMPT)
|
||||||
|
assert "Previous question" in prompt
|
||||||
|
assert "Previous answer" in prompt
|
||||||
|
|
||||||
|
def test_summary_included(self):
|
||||||
|
ctx = PersonContext("p1", "Alice")
|
||||||
|
ctx.summary = "Alice visited last week and likes coffee"
|
||||||
|
prompt = build_llama_prompt(ctx, "hello", DEFAULT_SYSTEM_PROMPT)
|
||||||
|
assert "likes coffee" in prompt
|
||||||
|
|
||||||
|
def test_group_persons_included(self):
|
||||||
|
ctx = PersonContext("p1", "Alice")
|
||||||
|
prompt = build_llama_prompt(ctx, "hello", DEFAULT_SYSTEM_PROMPT,
|
||||||
|
group_persons=["p1", "p2_Bob", "p3_Charlie"])
|
||||||
|
assert "p2_Bob" in prompt or "p3_Charlie" in prompt
|
||||||
|
|
||||||
|
def test_no_group_persons_not_in_prompt(self):
|
||||||
|
ctx = PersonContext("p1")
|
||||||
|
prompt = build_llama_prompt(ctx, "hello", DEFAULT_SYSTEM_PROMPT)
|
||||||
|
assert "Also present" not in prompt
|
||||||
|
|
||||||
|
|
||||||
|
# ── TTS utils (imported here for coverage) ────────────────────────────────────
|
||||||
|
|
||||||
|
class TestTtsUtils:
|
||||||
|
def setup_method(self):
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
from saltybot_social.tts_utils import (
|
||||||
|
strip_ssml, split_sentences, split_first_sentence,
|
||||||
|
pcm16_to_wav_bytes, chunk_pcm, estimate_duration_ms, apply_volume
|
||||||
|
)
|
||||||
|
self.strip_ssml = strip_ssml
|
||||||
|
self.split_sentences = split_sentences
|
||||||
|
self.split_first_sentence = split_first_sentence
|
||||||
|
self.pcm16_to_wav_bytes = pcm16_to_wav_bytes
|
||||||
|
self.chunk_pcm = chunk_pcm
|
||||||
|
self.estimate_duration_ms = estimate_duration_ms
|
||||||
|
self.apply_volume = apply_volume
|
||||||
|
|
||||||
|
def test_strip_ssml(self):
|
||||||
|
assert self.strip_ssml("<speak>Hello <break time='1s'/> world</speak>") == "Hello world"
|
||||||
|
assert self.strip_ssml("plain text") == "plain text"
|
||||||
|
|
||||||
|
def test_split_sentences_basic(self):
|
||||||
|
text = "Hello world. How are you? I am fine!"
|
||||||
|
sentences = self.split_sentences(text)
|
||||||
|
assert len(sentences) == 3
|
||||||
|
|
||||||
|
def test_split_sentences_single(self):
|
||||||
|
sentences = self.split_sentences("Hello world.")
|
||||||
|
assert len(sentences) == 1
|
||||||
|
assert sentences[0] == "Hello world."
|
||||||
|
|
||||||
|
def test_split_sentences_empty(self):
|
||||||
|
assert self.split_sentences("") == []
|
||||||
|
assert self.split_sentences(" ") == []
|
||||||
|
|
||||||
|
def test_split_first_sentence(self):
|
||||||
|
text = "Hello world. How are you? I am fine."
|
||||||
|
first, rest = self.split_first_sentence(text)
|
||||||
|
assert first == "Hello world."
|
||||||
|
assert "How are you?" in rest
|
||||||
|
|
||||||
|
def test_wav_header(self):
|
||||||
|
pcm = b"\x00\x00" * 1000
|
||||||
|
wav = self.pcm16_to_wav_bytes(pcm, 22050)
|
||||||
|
assert wav[:4] == b"RIFF"
|
||||||
|
assert wav[8:12] == b"WAVE"
|
||||||
|
assert len(wav) == 44 + len(pcm)
|
||||||
|
|
||||||
|
def test_chunk_pcm_sizes(self):
|
||||||
|
pcm = b"\x00\x00" * 2205 # ~100ms at 22050Hz
|
||||||
|
chunks = list(self.chunk_pcm(pcm, 50, 22050)) # 50ms chunks
|
||||||
|
for chunk in chunks:
|
||||||
|
assert len(chunk) % 2 == 0 # int16 aligned
|
||||||
|
|
||||||
|
def test_estimate_duration(self):
|
||||||
|
pcm = b"\x00\x00" * 22050 # 1s at 22050Hz
|
||||||
|
ms = self.estimate_duration_ms(pcm, 22050)
|
||||||
|
assert abs(ms - 1000.0) < 1.0
|
||||||
|
|
||||||
|
def test_apply_volume_unity(self):
|
||||||
|
pcm = b"\x00\x40" * 100 # some non-zero samples
|
||||||
|
result = self.apply_volume(pcm, 1.0)
|
||||||
|
assert result == pcm
|
||||||
|
|
||||||
|
def test_apply_volume_half(self):
|
||||||
|
import struct
|
||||||
|
pcm = struct.pack("<h", 1000)
|
||||||
|
result = self.apply_volume(pcm, 0.5)
|
||||||
|
val = struct.unpack("<h", result)[0]
|
||||||
|
assert abs(val - 500) <= 1
|
||||||
|
|
||||||
|
def test_apply_volume_clips(self):
|
||||||
|
import struct
|
||||||
|
pcm = struct.pack("<h", 30000)
|
||||||
|
result = self.apply_volume(pcm, 2.0)
|
||||||
|
val = struct.unpack("<h", result)[0]
|
||||||
|
assert val == 32767
|
||||||
237
jetson/ros2_ws/src/saltybot_social/test/test_speech_utils.py
Normal file
237
jetson/ros2_ws/src/saltybot_social/test/test_speech_utils.py
Normal file
@ -0,0 +1,237 @@
|
|||||||
|
"""test_speech_utils.py — Unit tests for speech_utils.py (no ROS2 / GPU needed)."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
|
||||||
|
from saltybot_social.speech_utils import (
|
||||||
|
pcm16_to_float32, float32_to_pcm16,
|
||||||
|
rms_db, chunk_audio,
|
||||||
|
EnergyVad, UtteranceSegmenter,
|
||||||
|
cosine_similarity, identify_speaker,
|
||||||
|
CHUNK_SAMPLES, SAMPLE_RATE,
|
||||||
|
)
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ── PCM conversion ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestPcmConversion:
|
||||||
|
def test_pcm16_to_float32_silence(self):
|
||||||
|
data = b"\x00\x00" * 100
|
||||||
|
samples = pcm16_to_float32(data)
|
||||||
|
assert len(samples) == 100
|
||||||
|
assert all(s == 0.0 for s in samples)
|
||||||
|
|
||||||
|
def test_pcm16_to_float32_max_positive(self):
|
||||||
|
data = struct.pack("<h", 32767)
|
||||||
|
samples = pcm16_to_float32(data)
|
||||||
|
assert abs(samples[0] - 32767 / 32768.0) < 1e-5
|
||||||
|
|
||||||
|
def test_float32_to_pcm16_roundtrip(self):
|
||||||
|
original = [0.0, 0.5, -0.5, 0.99, -0.99]
|
||||||
|
pcm = float32_to_pcm16(original)
|
||||||
|
recovered = pcm16_to_float32(pcm)
|
||||||
|
for orig, rec in zip(original, recovered):
|
||||||
|
assert abs(orig - rec) < 1e-3
|
||||||
|
|
||||||
|
def test_float32_to_pcm16_clips(self):
|
||||||
|
# Values > 1.0 should be clipped
|
||||||
|
pcm = float32_to_pcm16([2.0, -2.0])
|
||||||
|
samples = pcm16_to_float32(pcm)
|
||||||
|
assert samples[0] <= 1.0
|
||||||
|
assert samples[1] >= -1.0
|
||||||
|
|
||||||
|
|
||||||
|
# ── RMS dB ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestRmsDb:
|
||||||
|
def test_silence_returns_minus96(self):
|
||||||
|
assert rms_db([]) == -96.0
|
||||||
|
assert rms_db([0.0] * 100) < -60.0
|
||||||
|
|
||||||
|
def test_full_scale_is_near_zero(self):
|
||||||
|
samples = [1.0] * 100
|
||||||
|
db = rms_db(samples)
|
||||||
|
assert abs(db) < 0.1 # 0 dBFS
|
||||||
|
|
||||||
|
def test_half_scale_is_minus6(self):
|
||||||
|
samples = [0.5] * 100
|
||||||
|
db = rms_db(samples)
|
||||||
|
assert abs(db - (-6.02)) < 0.1
|
||||||
|
|
||||||
|
def test_low_level_signal(self):
|
||||||
|
samples = [0.001] * 100
|
||||||
|
db = rms_db(samples)
|
||||||
|
assert db < -40.0
|
||||||
|
|
||||||
|
|
||||||
|
# ── Chunk audio ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestChunkAudio:
|
||||||
|
def test_even_chunks(self):
|
||||||
|
samples = list(range(480 * 4))
|
||||||
|
chunks = list(chunk_audio(samples, 480))
|
||||||
|
assert len(chunks) == 4
|
||||||
|
assert all(len(c) == 480 for c in chunks)
|
||||||
|
|
||||||
|
def test_remainder_dropped(self):
|
||||||
|
samples = list(range(500))
|
||||||
|
chunks = list(chunk_audio(samples, 480))
|
||||||
|
assert len(chunks) == 1
|
||||||
|
|
||||||
|
def test_empty_input(self):
|
||||||
|
assert list(chunk_audio([], 480)) == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── Energy VAD ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestEnergyVad:
|
||||||
|
def test_silence_not_active(self):
|
||||||
|
vad = EnergyVad(threshold_db=-35.0, onset_frames=2, offset_frames=3)
|
||||||
|
silence = [0.0001] * CHUNK_SAMPLES
|
||||||
|
for _ in range(5):
|
||||||
|
active = vad.process(silence)
|
||||||
|
assert not active
|
||||||
|
|
||||||
|
def test_loud_signal_activates(self):
|
||||||
|
vad = EnergyVad(threshold_db=-35.0, onset_frames=2, offset_frames=8)
|
||||||
|
loud = [0.5] * CHUNK_SAMPLES
|
||||||
|
results = [vad.process(loud) for _ in range(3)]
|
||||||
|
assert results[-1] is True
|
||||||
|
|
||||||
|
def test_onset_requires_n_frames(self):
|
||||||
|
vad = EnergyVad(threshold_db=-35.0, onset_frames=3, offset_frames=8)
|
||||||
|
loud = [0.5] * CHUNK_SAMPLES
|
||||||
|
assert not vad.process(loud) # frame 1
|
||||||
|
assert not vad.process(loud) # frame 2
|
||||||
|
assert vad.process(loud) # frame 3 → activates
|
||||||
|
|
||||||
|
def test_offset_deactivates(self):
|
||||||
|
vad = EnergyVad(threshold_db=-35.0, onset_frames=1, offset_frames=2)
|
||||||
|
loud = [0.5] * CHUNK_SAMPLES
|
||||||
|
silence = [0.0001] * CHUNK_SAMPLES
|
||||||
|
vad.process(loud)
|
||||||
|
assert vad.is_active
|
||||||
|
vad.process(silence)
|
||||||
|
assert vad.is_active # offset_frames=2, need 2
|
||||||
|
vad.process(silence)
|
||||||
|
assert not vad.is_active
|
||||||
|
|
||||||
|
def test_reset(self):
|
||||||
|
vad = EnergyVad(threshold_db=-35.0, onset_frames=1, offset_frames=8)
|
||||||
|
loud = [0.5] * CHUNK_SAMPLES
|
||||||
|
vad.process(loud)
|
||||||
|
assert vad.is_active
|
||||||
|
vad.reset()
|
||||||
|
assert not vad.is_active
|
||||||
|
|
||||||
|
|
||||||
|
# ── Utterance segmenter ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestUtteranceSegmenter:
|
||||||
|
def _make_segmenter(self):
|
||||||
|
vad = EnergyVad(threshold_db=-35.0, onset_frames=1, offset_frames=2)
|
||||||
|
return UtteranceSegmenter(vad, pre_roll_frames=2, max_duration_s=5.0)
|
||||||
|
|
||||||
|
def test_silence_yields_nothing(self):
|
||||||
|
seg = self._make_segmenter()
|
||||||
|
silence = [0.0001] * CHUNK_SAMPLES
|
||||||
|
for _ in range(10):
|
||||||
|
results = seg.push(silence)
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
def test_speech_then_silence_yields_utterance(self):
|
||||||
|
seg = self._make_segmenter()
|
||||||
|
loud = [0.5] * CHUNK_SAMPLES
|
||||||
|
silence = [0.0001] * CHUNK_SAMPLES
|
||||||
|
|
||||||
|
# 3 speech frames
|
||||||
|
for _ in range(3):
|
||||||
|
assert seg.push(loud) == []
|
||||||
|
|
||||||
|
# 2 silence frames to trigger offset
|
||||||
|
all_results = []
|
||||||
|
for _ in range(3):
|
||||||
|
all_results.extend(seg.push(silence))
|
||||||
|
|
||||||
|
assert len(all_results) == 1
|
||||||
|
utt_samples, duration = all_results[0]
|
||||||
|
assert duration > 0
|
||||||
|
assert len(utt_samples) > 0
|
||||||
|
|
||||||
|
def test_preroll_included(self):
|
||||||
|
vad = EnergyVad(threshold_db=-35.0, onset_frames=1, offset_frames=2)
|
||||||
|
seg = UtteranceSegmenter(vad, pre_roll_frames=3, max_duration_s=5.0)
|
||||||
|
silence = [0.0001] * CHUNK_SAMPLES
|
||||||
|
loud = [0.5] * CHUNK_SAMPLES
|
||||||
|
|
||||||
|
# Push 3 silence frames (become pre-roll buffer)
|
||||||
|
for _ in range(3):
|
||||||
|
seg.push(silence)
|
||||||
|
|
||||||
|
# Push 1 speech frame (triggers onset, pre-roll included)
|
||||||
|
seg.push(loud)
|
||||||
|
|
||||||
|
# End utterance
|
||||||
|
all_results = []
|
||||||
|
for _ in range(3):
|
||||||
|
all_results.extend(seg.push(silence))
|
||||||
|
|
||||||
|
assert len(all_results) == 1
|
||||||
|
# Utterance should include pre-roll + speech + trailing silence
|
||||||
|
utt_samples, _ = all_results[0]
|
||||||
|
assert len(utt_samples) >= 4 * CHUNK_SAMPLES
|
||||||
|
|
||||||
|
|
||||||
|
# ── Speaker identification ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestSpeakerIdentification:
|
||||||
|
def test_cosine_identical(self):
|
||||||
|
v = [1.0, 0.0, 0.0]
|
||||||
|
assert abs(cosine_similarity(v, v) - 1.0) < 1e-6
|
||||||
|
|
||||||
|
def test_cosine_orthogonal(self):
|
||||||
|
a = [1.0, 0.0]
|
||||||
|
b = [0.0, 1.0]
|
||||||
|
assert abs(cosine_similarity(a, b)) < 1e-6
|
||||||
|
|
||||||
|
def test_cosine_opposite(self):
|
||||||
|
a = [1.0, 0.0]
|
||||||
|
b = [-1.0, 0.0]
|
||||||
|
assert abs(cosine_similarity(a, b) + 1.0) < 1e-6
|
||||||
|
|
||||||
|
def test_cosine_zero_vector(self):
|
||||||
|
assert cosine_similarity([0.0, 0.0], [1.0, 0.0]) == 0.0
|
||||||
|
|
||||||
|
def test_cosine_length_mismatch(self):
|
||||||
|
assert cosine_similarity([1.0], [1.0, 0.0]) == 0.0
|
||||||
|
|
||||||
|
def test_identify_speaker_match(self):
|
||||||
|
known = {
|
||||||
|
"alice": [1.0, 0.0, 0.0],
|
||||||
|
"bob": [0.0, 1.0, 0.0],
|
||||||
|
}
|
||||||
|
result = identify_speaker([1.0, 0.0, 0.0], known, threshold=0.5)
|
||||||
|
assert result == "alice"
|
||||||
|
|
||||||
|
def test_identify_speaker_unknown(self):
|
||||||
|
known = {"alice": [1.0, 0.0, 0.0]}
|
||||||
|
result = identify_speaker([0.0, 1.0, 0.0], known, threshold=0.5)
|
||||||
|
assert result == "unknown"
|
||||||
|
|
||||||
|
def test_identify_speaker_empty_db(self):
|
||||||
|
result = identify_speaker([1.0, 0.0], {}, threshold=0.5)
|
||||||
|
assert result == "unknown"
|
||||||
|
|
||||||
|
def test_identify_best_match(self):
|
||||||
|
known = {
|
||||||
|
"alice": [1.0, 0.0, 0.0],
|
||||||
|
"bob": [0.8, 0.6, 0.0], # closer to [1,0,0] than bob is
|
||||||
|
}
|
||||||
|
# Query close to alice
|
||||||
|
result = identify_speaker([0.99, 0.01, 0.0], known, threshold=0.5)
|
||||||
|
assert result == "alice"
|
||||||
@ -28,6 +28,11 @@ rosidl_generate_interfaces(${PROJECT_NAME}
|
|||||||
"srv/QueryMood.srv"
|
"srv/QueryMood.srv"
|
||||||
# Issue #92 — multi-modal tracking fusion
|
# Issue #92 — multi-modal tracking fusion
|
||||||
"msg/FusedTarget.msg"
|
"msg/FusedTarget.msg"
|
||||||
|
# Issue #81 — speech pipeline
|
||||||
|
"msg/SpeechTranscript.msg"
|
||||||
|
"msg/VadState.msg"
|
||||||
|
# Issue #83 — conversation engine
|
||||||
|
"msg/ConversationResponse.msg"
|
||||||
DEPENDENCIES std_msgs geometry_msgs builtin_interfaces
|
DEPENDENCIES std_msgs geometry_msgs builtin_interfaces
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,9 @@
|
|||||||
|
# ConversationResponse.msg — LLM response, supports streaming token output.
|
||||||
|
# Published by conversation_node on /social/conversation/response
|
||||||
|
|
||||||
|
std_msgs/Header header
|
||||||
|
|
||||||
|
string text # Full or partial response text
|
||||||
|
string speaker_id # Who the response is addressed to
|
||||||
|
bool is_partial # true = streaming token chunk, false = final response
|
||||||
|
int32 turn_id # Conversation turn counter (for deduplication)
|
||||||
@ -0,0 +1,10 @@
|
|||||||
|
# SpeechTranscript.msg — Result of STT with speaker identification.
|
||||||
|
# Published by speech_pipeline_node on /social/speech/transcript
|
||||||
|
|
||||||
|
std_msgs/Header header
|
||||||
|
|
||||||
|
string text # Transcribed text (UTF-8)
|
||||||
|
string speaker_id # e.g. "person_42" or "unknown"
|
||||||
|
float32 confidence # ASR confidence 0..1
|
||||||
|
float32 audio_duration # Duration of the utterance in seconds
|
||||||
|
bool is_partial # true = intermediate streaming result, false = final
|
||||||
8
jetson/ros2_ws/src/saltybot_social_msgs/msg/VadState.msg
Normal file
8
jetson/ros2_ws/src/saltybot_social_msgs/msg/VadState.msg
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# VadState.msg — Voice Activity Detection state.
|
||||||
|
# Published by speech_pipeline_node on /social/speech/vad_state
|
||||||
|
|
||||||
|
std_msgs/Header header
|
||||||
|
|
||||||
|
bool speech_active # true = speech detected, false = silence
|
||||||
|
float32 energy_db # RMS energy in dBFS
|
||||||
|
string state # "silence" | "speech" | "wake_word"
|
||||||
Loading…
x
Reference in New Issue
Block a user