Compare commits

..

No commits in common. "e3e4bd70a4782519780cac6e514ad5259bea0ee3" and "077f26d9d6d6cdf0329281bcbbacb730076771f0" have entirely different histories.

8 changed files with 349 additions and 288 deletions

View File

@ -8,7 +8,6 @@ speech_pipeline_node:
use_silero_vad: true use_silero_vad: true
whisper_model: "small" # small (~500ms), medium (better quality, ~900ms) whisper_model: "small" # small (~500ms), medium (better quality, ~900ms)
whisper_compute_type: "float16" whisper_compute_type: "float16"
whisper_language: "" # "" = auto-detect; set e.g. "fr" to force
speaker_threshold: 0.65 speaker_threshold: 0.65
speaker_db_path: "/social_db/speaker_embeddings.json" speaker_db_path: "/social_db/speaker_embeddings.json"
publish_partial: true publish_partial: true

View File

@ -1,8 +1,6 @@
tts_node: tts_node:
ros__parameters: ros__parameters:
voice_path: "/models/piper/en_US-lessac-medium.onnx" voice_path: "/models/piper/en_US-lessac-medium.onnx"
voice_map_json: "{}"
default_language: "en"
sample_rate: 22050 sample_rate: 22050
volume: 1.0 volume: 1.0
audio_device: "" # "" = system default; set to device name if needed audio_device: "" # "" = system default; set to device name if needed

View File

@ -1,30 +1,54 @@
"""conversation_node.py — Local LLM conversation engine with per-person context. """conversation_node.py — Local LLM conversation engine with per-person context.
Issue #83/#161/#167
Issue #83: Conversation engine for social-bot.
Stack: Phi-3-mini or Llama-3.2-3B GGUF Q4_K_M via llama-cpp-python (CUDA).
Subscribes /social/speech/transcript builds per-person prompt streams
token output publishes /social/conversation/response.
Streaming: publishes partial=true tokens as they arrive, then final=false
at end of generation. TTS node can begin synthesis on first sentence boundary.
ROS2 topics:
Subscribe: /social/speech/transcript (saltybot_social_msgs/SpeechTranscript)
Publish: /social/conversation/response (saltybot_social_msgs/ConversationResponse)
Parameters:
model_path (str, "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf")
n_ctx (int, 4096)
n_gpu_layers (int, 20) GPU offload layers (increase for more VRAM usage)
max_tokens (int, 200)
temperature (float, 0.7)
top_p (float, 0.9)
soul_path (str, "/soul/SOUL.md")
context_db_path (str, "/social_db/conversation_context.json")
save_interval_s (float, 30.0) how often to persist context to disk
stream (bool, true)
""" """
from __future__ import annotations from __future__ import annotations
import json, threading, time
from typing import Dict, Optional import threading
import time
from typing import Optional
import rclpy import rclpy
from rclpy.node import Node from rclpy.node import Node
from rclpy.qos import QoSProfile from rclpy.qos import QoSProfile
from std_msgs.msg import String
from saltybot_social_msgs.msg import SpeechTranscript, ConversationResponse from saltybot_social_msgs.msg import SpeechTranscript, ConversationResponse
from .llm_context import ContextStore, build_llama_prompt, load_system_prompt, needs_summary_prompt from .llm_context import ContextStore, build_llama_prompt, load_system_prompt, needs_summary_prompt
_LANG_NAMES: Dict[str, str] = {
"en": "English", "fr": "French", "es": "Spanish", "de": "German",
"it": "Italian", "pt": "Portuguese", "ja": "Japanese", "zh": "Chinese",
"ko": "Korean", "ar": "Arabic", "ru": "Russian", "nl": "Dutch",
"pl": "Polish", "sv": "Swedish", "da": "Danish", "fi": "Finnish",
"no": "Norwegian", "tr": "Turkish", "hi": "Hindi", "uk": "Ukrainian",
"cs": "Czech", "ro": "Romanian", "hu": "Hungarian", "el": "Greek",
}
class ConversationNode(Node): class ConversationNode(Node):
"""Local LLM inference node with per-person conversation memory.""" """Local LLM inference node with per-person conversation memory."""
def __init__(self) -> None: def __init__(self) -> None:
super().__init__("conversation_node") super().__init__("conversation_node")
self.declare_parameter("model_path", "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf")
# ── Parameters ──────────────────────────────────────────────────────
self.declare_parameter("model_path",
"/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf")
self.declare_parameter("n_ctx", 4096) self.declare_parameter("n_ctx", 4096)
self.declare_parameter("n_gpu_layers", 20) self.declare_parameter("n_gpu_layers", 20)
self.declare_parameter("max_tokens", 200) self.declare_parameter("max_tokens", 200)
@ -34,6 +58,7 @@ class ConversationNode(Node):
self.declare_parameter("context_db_path", "/social_db/conversation_context.json") self.declare_parameter("context_db_path", "/social_db/conversation_context.json")
self.declare_parameter("save_interval_s", 30.0) self.declare_parameter("save_interval_s", 30.0)
self.declare_parameter("stream", True) self.declare_parameter("stream", True)
self._model_path = self.get_parameter("model_path").value self._model_path = self.get_parameter("model_path").value
self._n_ctx = self.get_parameter("n_ctx").value self._n_ctx = self.get_parameter("n_ctx").value
self._n_gpu = self.get_parameter("n_gpu_layers").value self._n_gpu = self.get_parameter("n_gpu_layers").value
@ -44,9 +69,18 @@ class ConversationNode(Node):
self._db_path = self.get_parameter("context_db_path").value self._db_path = self.get_parameter("context_db_path").value
self._save_interval = self.get_parameter("save_interval_s").value self._save_interval = self.get_parameter("save_interval_s").value
self._stream = self.get_parameter("stream").value self._stream = self.get_parameter("stream").value
# ── Publishers / Subscribers ─────────────────────────────────────────
qos = QoSProfile(depth=10) qos = QoSProfile(depth=10)
self._resp_pub = self.create_publisher(ConversationResponse, "/social/conversation/response", qos) self._resp_pub = self.create_publisher(
self._transcript_sub = self.create_subscription(SpeechTranscript, "/social/speech/transcript", self._on_transcript, qos) ConversationResponse, "/social/conversation/response", qos
)
self._transcript_sub = self.create_subscription(
SpeechTranscript, "/social/speech/transcript",
self._on_transcript, qos
)
# ── State ────────────────────────────────────────────────────────────
self._llm = None self._llm = None
self._system_prompt = load_system_prompt(self._soul_path) self._system_prompt = load_system_prompt(self._soul_path)
self._ctx_store = ContextStore(self._db_path) self._ctx_store = ContextStore(self._db_path)
@ -54,114 +88,187 @@ class ConversationNode(Node):
self._turn_counter = 0 self._turn_counter = 0
self._generating = False self._generating = False
self._last_save = time.time() self._last_save = time.time()
self._speaker_lang: Dict[str, str] = {}
self._emotions: Dict[str, str] = {} # ── Load LLM in background ────────────────────────────────────────────
self.create_subscription(String, "/social/emotion/context", self._on_emotion_context, 10)
threading.Thread(target=self._load_llm, daemon=True).start() threading.Thread(target=self._load_llm, daemon=True).start()
# ── Periodic context save ────────────────────────────────────────────
self._save_timer = self.create_timer(self._save_interval, self._save_context) self._save_timer = self.create_timer(self._save_interval, self._save_context)
# ── Emotion context (Issue #161) ──────────────────────────────────────
self._emotions: Dict[int, str] = {}
self.create_subscription(
String, "/social/emotion/context",
self._on_emotion_context, 10
)
self.get_logger().info( self.get_logger().info(
f"ConversationNode init (model={self._model_path}, " f"ConversationNode init (model={self._model_path}, "
f"gpu_layers={self._n_gpu}, ctx={self._n_ctx})" f"gpu_layers={self._n_gpu}, ctx={self._n_ctx})"
) )
# ── Model loading ─────────────────────────────────────────────────────────
def _load_llm(self) -> None: def _load_llm(self) -> None:
t0 = time.time() t0 = time.time()
self.get_logger().info(f"Loading LLM: {self._model_path}")
try: try:
from llama_cpp import Llama from llama_cpp import Llama
self._llm = Llama(model_path=self._model_path, n_ctx=self._n_ctx, n_gpu_layers=self._n_gpu, n_threads=4, verbose=False) self._llm = Llama(
self.get_logger().info(f"LLM ready ({time.time()-t0:.1f}s)") model_path=self._model_path,
n_ctx=self._n_ctx,
n_gpu_layers=self._n_gpu,
n_threads=4,
verbose=False,
)
self.get_logger().info(
f"LLM ready ({time.time()-t0:.1f}s). "
f"Context: {self._n_ctx} tokens, GPU layers: {self._n_gpu}"
)
except Exception as e: except Exception as e:
self.get_logger().error(f"LLM load failed: {e}") self.get_logger().error(f"LLM load failed: {e}")
# ── Transcript callback ───────────────────────────────────────────────────
def _on_transcript(self, msg: SpeechTranscript) -> None: def _on_transcript(self, msg: SpeechTranscript) -> None:
if msg.is_partial or not msg.text.strip(): """Handle final transcripts only (skip streaming partials)."""
if msg.is_partial:
return return
if msg.language: if not msg.text.strip():
self._speaker_lang[msg.speaker_id] = msg.language return
self.get_logger().info(f"Transcript [{msg.speaker_id}/{msg.language or '?'}]: '{msg.text}'")
threading.Thread(target=self._generate_response, args=(msg.text.strip(), msg.speaker_id), daemon=True).start() self.get_logger().info(
f"Transcript [{msg.speaker_id}]: '{msg.text}'"
)
threading.Thread(
target=self._generate_response,
args=(msg.text.strip(), msg.speaker_id),
daemon=True,
).start()
# ── LLM inference ─────────────────────────────────────────────────────────
def _generate_response(self, user_text: str, speaker_id: str) -> None: def _generate_response(self, user_text: str, speaker_id: str) -> None:
"""Generate LLM response with streaming. Runs in thread."""
if self._llm is None: if self._llm is None:
self.get_logger().warn("LLM not loaded yet, dropping utterance"); return self.get_logger().warn("LLM not loaded yet, dropping utterance")
return
with self._lock: with self._lock:
if self._generating: if self._generating:
self.get_logger().warn("LLM busy, dropping utterance"); return self.get_logger().warn("LLM busy, dropping utterance")
return
self._generating = True self._generating = True
self._turn_counter += 1 self._turn_counter += 1
turn_id = self._turn_counter turn_id = self._turn_counter
lang = self._speaker_lang.get(speaker_id, "en")
try: try:
ctx = self._ctx_store.get(speaker_id) ctx = self._ctx_store.get(speaker_id)
# Summary compression if context is long
if ctx.needs_compression(): if ctx.needs_compression():
self._compress_context(ctx) self._compress_context(ctx)
emotion_hint = self._emotion_hint(speaker_id)
lang_hint = self._language_hint(speaker_id) ctx.add_user(user_text)
hints = " ".join(h for h in (emotion_hint, lang_hint) if h)
annotated = f"{user_text} {hints}".rstrip() if hints else user_text prompt = build_llama_prompt(
ctx.add_user(annotated) ctx, user_text, self._system_prompt
prompt = build_llama_prompt(ctx, annotated, self._system_prompt) )
t0 = time.perf_counter() t0 = time.perf_counter()
full_response = "" full_response = ""
if self._stream: if self._stream:
output = self._llm(prompt, max_tokens=self._max_tokens, temperature=self._temperature, top_p=self._top_p, stream=True, stop=["<|user|>", "<|system|>", "\n\n\n"]) output = self._llm(
prompt,
max_tokens=self._max_tokens,
temperature=self._temperature,
top_p=self._top_p,
stream=True,
stop=["<|user|>", "<|system|>", "\n\n\n"],
)
for chunk in output: for chunk in output:
token = chunk["choices"][0]["text"] token = chunk["choices"][0]["text"]
full_response += token full_response += token
# Publish partial after each sentence boundary for low TTS latency
if token.endswith((".", "!", "?", "\n")): if token.endswith((".", "!", "?", "\n")):
self._publish_response(full_response.strip(), speaker_id, turn_id, language=lang, is_partial=True) self._publish_response(
full_response.strip(), speaker_id, turn_id, is_partial=True
)
else: else:
output = self._llm(prompt, max_tokens=self._max_tokens, temperature=self._temperature, top_p=self._top_p, stream=False, stop=["<|user|>", "<|system|>"]) output = self._llm(
prompt,
max_tokens=self._max_tokens,
temperature=self._temperature,
top_p=self._top_p,
stream=False,
stop=["<|user|>", "<|system|>"],
)
full_response = output["choices"][0]["text"] full_response = output["choices"][0]["text"]
full_response = full_response.strip() full_response = full_response.strip()
self.get_logger().info(f"LLM [{speaker_id}/{lang}] ({(time.perf_counter()-t0)*1000:.0f}ms): '{full_response[:80]}'") latency_ms = (time.perf_counter() - t0) * 1000
self.get_logger().info(
f"LLM [{speaker_id}] ({latency_ms:.0f}ms): '{full_response[:80]}'"
)
ctx.add_assistant(full_response) ctx.add_assistant(full_response)
self._publish_response(full_response, speaker_id, turn_id, language=lang, is_partial=False) self._publish_response(full_response, speaker_id, turn_id, is_partial=False)
except Exception as e: except Exception as e:
self.get_logger().error(f"LLM inference error: {e}") self.get_logger().error(f"LLM inference error: {e}")
finally: finally:
with self._lock: self._generating = False with self._lock:
self._generating = False
def _compress_context(self, ctx) -> None: def _compress_context(self, ctx) -> None:
if self._llm is None: ctx.compress("(history omitted)"); return """Ask LLM to summarize old turns for context compression."""
if self._llm is None:
ctx.compress("(history omitted)")
return
try: try:
result = self._llm(needs_summary_prompt(ctx), max_tokens=80, temperature=0.3, stream=False) summary_prompt = needs_summary_prompt(ctx)
ctx.compress(result["choices"][0]["text"].strip()) result = self._llm(summary_prompt, max_tokens=80, temperature=0.3, stream=False)
except Exception: ctx.compress("(history omitted)") summary = result["choices"][0]["text"].strip()
ctx.compress(summary)
self.get_logger().debug(
f"Context compressed for {ctx.person_id}: '{summary[:60]}'"
)
except Exception:
ctx.compress("(history omitted)")
def _language_hint(self, speaker_id: str) -> str: # ── Publish ───────────────────────────────────────────────────────────────
lang = self._speaker_lang.get(speaker_id, "en")
if lang and lang != "en":
return f"[Please respond in {_LANG_NAMES.get(lang, lang)}.]"
return ""
def _on_emotion_context(self, msg: String) -> None: def _publish_response(
try: self, text: str, speaker_id: str, turn_id: int, is_partial: bool
for k, v in json.loads(msg.data).get("emotions", {}).items(): ) -> None:
self._emotions[k] = v
except Exception: pass
def _emotion_hint(self, speaker_id: str) -> str:
emo = self._emotions.get(speaker_id, "")
return f"[The person seems {emo} right now.]" if emo and emo != "neutral" else ""
def _publish_response(self, text: str, speaker_id: str, turn_id: int, language: str = "en", is_partial: bool = False) -> None:
msg = ConversationResponse() msg = ConversationResponse()
msg.header.stamp = self.get_clock().now().to_msg() msg.header.stamp = self.get_clock().now().to_msg()
msg.text = text; msg.speaker_id = speaker_id; msg.is_partial = is_partial msg.text = text
msg.turn_id = turn_id; msg.language = language msg.speaker_id = speaker_id
msg.is_partial = is_partial
msg.turn_id = turn_id
self._resp_pub.publish(msg) self._resp_pub.publish(msg)
def _save_context(self) -> None: def _save_context(self) -> None:
try: self._ctx_store.save() try:
except Exception as e: self.get_logger().error(f"Context save error: {e}") self._ctx_store.save()
except Exception as e:
self.get_logger().error(f"Context save error: {e}")
def destroy_node(self) -> None: def destroy_node(self) -> None:
self._save_context(); super().destroy_node() self._save_context()
super().destroy_node()
def main(args=None) -> None: def main(args=None) -> None:
rclpy.init(args=args) rclpy.init(args=args)
node = ConversationNode() node = ConversationNode()
try: rclpy.spin(node) try:
except KeyboardInterrupt: pass rclpy.spin(node)
finally: node.destroy_node(); rclpy.shutdown() except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@ -66,7 +66,6 @@ class SpeechPipelineNode(Node):
self.declare_parameter("use_silero_vad", True) self.declare_parameter("use_silero_vad", True)
self.declare_parameter("whisper_model", "small") self.declare_parameter("whisper_model", "small")
self.declare_parameter("whisper_compute_type", "float16") self.declare_parameter("whisper_compute_type", "float16")
self.declare_parameter("whisper_language", "")
self.declare_parameter("speaker_threshold", 0.65) self.declare_parameter("speaker_threshold", 0.65)
self.declare_parameter("speaker_db_path", "/social_db/speaker_embeddings.json") self.declare_parameter("speaker_db_path", "/social_db/speaker_embeddings.json")
self.declare_parameter("publish_partial", True) self.declare_parameter("publish_partial", True)
@ -79,7 +78,6 @@ class SpeechPipelineNode(Node):
self._use_silero = self.get_parameter("use_silero_vad").value self._use_silero = self.get_parameter("use_silero_vad").value
self._whisper_model_name = self.get_parameter("whisper_model").value self._whisper_model_name = self.get_parameter("whisper_model").value
self._compute_type = self.get_parameter("whisper_compute_type").value self._compute_type = self.get_parameter("whisper_compute_type").value
self._whisper_language = self.get_parameter("whisper_language").value or None
self._speaker_thresh = self.get_parameter("speaker_threshold").value self._speaker_thresh = self.get_parameter("speaker_threshold").value
self._speaker_db = self.get_parameter("speaker_db_path").value self._speaker_db = self.get_parameter("speaker_db_path").value
self._publish_partial = self.get_parameter("publish_partial").value self._publish_partial = self.get_parameter("publish_partial").value
@ -317,24 +315,20 @@ class SpeechPipelineNode(Node):
except Exception as e: except Exception as e:
self.get_logger().debug(f"Speaker ID error: {e}") self.get_logger().debug(f"Speaker ID error: {e}")
# Streaming Whisper transcription with language detection # Streaming Whisper transcription
partial_text = "" partial_text = ""
detected_lang = self._whisper_language or "en"
try: try:
segments_gen, info = self._whisper.transcribe( segments_gen, _info = self._whisper.transcribe(
audio_np, audio_np,
language=self._whisper_language, # None = auto-detect language="en",
beam_size=3, beam_size=3,
vad_filter=False, vad_filter=False,
) )
if hasattr(info, "language") and info.language:
detected_lang = info.language
for seg in segments_gen: for seg in segments_gen:
partial_text += seg.text.strip() + " " partial_text += seg.text.strip() + " "
if self._publish_partial: if self._publish_partial:
self._publish_transcript( self._publish_transcript(
partial_text.strip(), speaker_id, 0.0, duration, partial_text.strip(), speaker_id, 0.0, duration, is_partial=True
language=detected_lang, is_partial=True,
) )
except Exception as e: except Exception as e:
self.get_logger().error(f"Whisper error: {e}") self.get_logger().error(f"Whisper error: {e}")
@ -346,19 +340,15 @@ class SpeechPipelineNode(Node):
latency_ms = (time.perf_counter() - t0) * 1000 latency_ms = (time.perf_counter() - t0) * 1000
self.get_logger().info( self.get_logger().info(
f"STT [{speaker_id}/{detected_lang}] ({duration:.1f}s, {latency_ms:.0f}ms): " f"STT [{speaker_id}] ({duration:.1f}s, {latency_ms:.0f}ms): '{final_text}'"
f"'{final_text}'"
)
self._publish_transcript(
final_text, speaker_id, 0.9, duration,
language=detected_lang, is_partial=False,
) )
self._publish_transcript(final_text, speaker_id, 0.9, duration, is_partial=False)
# ── Publishers ──────────────────────────────────────────────────────────── # ── Publishers ────────────────────────────────────────────────────────────
def _publish_transcript( def _publish_transcript(
self, text: str, speaker_id: str, confidence: float, self, text: str, speaker_id: str, confidence: float,
duration: float, language: str = "en", is_partial: bool = False, duration: float, is_partial: bool
) -> None: ) -> None:
msg = SpeechTranscript() msg = SpeechTranscript()
msg.header.stamp = self.get_clock().now().to_msg() msg.header.stamp = self.get_clock().now().to_msg()
@ -366,7 +356,6 @@ class SpeechPipelineNode(Node):
msg.speaker_id = speaker_id msg.speaker_id = speaker_id
msg.confidence = confidence msg.confidence = confidence
msg.audio_duration = duration msg.audio_duration = duration
msg.language = language
msg.is_partial = is_partial msg.is_partial = is_partial
self._transcript_pub.publish(msg) self._transcript_pub.publish(msg)

View File

@ -1,136 +1,228 @@
"""tts_node.py -- Streaming TTS with Piper / first-chunk streaming. """tts_node.py — Streaming TTS with Piper / first-chunk streaming.
Issue #85/#167
Issue #85: Streaming TTS — Piper/XTTS integration with first-chunk streaming.
Pipeline:
/social/conversation/response (ConversationResponse)
sentence split Piper ONNX synthesis (sentence by sentence)
PCM16 chunks USB speaker (sounddevice) + publish /social/tts/audio
First-chunk strategy:
- On partial=true ConversationResponse, extract first sentence and synthesize
immediately audio starts before LLM finishes generating
- On final=false, synthesize remaining sentences
Latency target: <200ms to first audio chunk.
ROS2 topics:
Subscribe: /social/conversation/response (saltybot_social_msgs/ConversationResponse)
Publish: /social/tts/audio (audio_msgs/Audio or std_msgs/UInt8MultiArray fallback)
Parameters:
voice_path (str, "/models/piper/en_US-lessac-medium.onnx")
sample_rate (int, 22050)
volume (float, 1.0)
audio_device (str, "") sounddevice device name; "" = system default
playback_enabled (bool, true)
publish_audio (bool, false) publish PCM to ROS2 topic
sentence_streaming (bool, true) synthesize sentence-by-sentence
""" """
from __future__ import annotations from __future__ import annotations
import json, queue, threading, time
from typing import Any, Dict, Optional import queue
import threading
import time
from typing import Optional
import rclpy import rclpy
from rclpy.node import Node from rclpy.node import Node
from rclpy.qos import QoSProfile from rclpy.qos import QoSProfile
from std_msgs.msg import UInt8MultiArray from std_msgs.msg import UInt8MultiArray
from saltybot_social_msgs.msg import ConversationResponse from saltybot_social_msgs.msg import ConversationResponse
from .tts_utils import split_sentences, strip_ssml, apply_volume, chunk_pcm, estimate_duration_ms from .tts_utils import split_sentences, strip_ssml, apply_volume, chunk_pcm, estimate_duration_ms
class TtsNode(Node): class TtsNode(Node):
"""Streaming TTS node using Piper ONNX with per-language voice switching.""" """Streaming TTS node using Piper ONNX."""
def __init__(self) -> None: def __init__(self) -> None:
super().__init__("tts_node") super().__init__("tts_node")
# ── Parameters ──────────────────────────────────────────────────────
self.declare_parameter("voice_path", "/models/piper/en_US-lessac-medium.onnx") self.declare_parameter("voice_path", "/models/piper/en_US-lessac-medium.onnx")
self.declare_parameter("voice_map_json", "{}")
self.declare_parameter("default_language", "en")
self.declare_parameter("sample_rate", 22050) self.declare_parameter("sample_rate", 22050)
self.declare_parameter("volume", 1.0) self.declare_parameter("volume", 1.0)
self.declare_parameter("audio_device", "") self.declare_parameter("audio_device", "")
self.declare_parameter("playback_enabled", True) self.declare_parameter("playback_enabled", True)
self.declare_parameter("publish_audio", False) self.declare_parameter("publish_audio", False)
self.declare_parameter("sentence_streaming", True) self.declare_parameter("sentence_streaming", True)
self._voice_path = self.get_parameter("voice_path").value self._voice_path = self.get_parameter("voice_path").value
self._voice_map_json = self.get_parameter("voice_map_json").value
self._default_language = self.get_parameter("default_language").value or "en"
self._sample_rate = self.get_parameter("sample_rate").value self._sample_rate = self.get_parameter("sample_rate").value
self._volume = self.get_parameter("volume").value self._volume = self.get_parameter("volume").value
self._audio_device = self.get_parameter("audio_device").value or None self._audio_device = self.get_parameter("audio_device").value or None
self._playback = self.get_parameter("playback_enabled").value self._playback = self.get_parameter("playback_enabled").value
self._publish_audio = self.get_parameter("publish_audio").value self._publish_audio = self.get_parameter("publish_audio").value
self._sentence_streaming = self.get_parameter("sentence_streaming").value self._sentence_streaming = self.get_parameter("sentence_streaming").value
try:
extra: Dict[str, str] = json.loads(self._voice_map_json) if self._voice_map_json.strip() not in ("{}","") else {} # ── Publishers / Subscribers ─────────────────────────────────────────
except Exception as e:
self.get_logger().warn(f"voice_map_json parse error: {e}"); extra = {}
self._voice_paths: Dict[str, str] = {self._default_language: self._voice_path}
self._voice_paths.update(extra)
qos = QoSProfile(depth=10) qos = QoSProfile(depth=10)
self._resp_sub = self.create_subscription(ConversationResponse, "/social/conversation/response", self._on_response, qos) self._resp_sub = self.create_subscription(
ConversationResponse, "/social/conversation/response",
self._on_response, qos
)
if self._publish_audio: if self._publish_audio:
self._audio_pub = self.create_publisher(UInt8MultiArray, "/social/tts/audio", qos) self._audio_pub = self.create_publisher(
self._voices: Dict[str, Any] = {} UInt8MultiArray, "/social/tts/audio", qos
self._voices_lock = threading.Lock() )
# ── TTS engine ────────────────────────────────────────────────────────
self._voice = None
self._playback_queue: queue.Queue = queue.Queue(maxsize=16) self._playback_queue: queue.Queue = queue.Queue(maxsize=16)
self._current_turn = -1 self._current_turn = -1
self._synthesized_turns: set = set() self._synthesized_turns: set = set() # turn_ids already synthesized
self._lock = threading.Lock() self._lock = threading.Lock()
threading.Thread(target=self._load_voice_for_lang, args=(self._default_language,), daemon=True).start()
threading.Thread(target=self._playback_worker, daemon=True).start()
self.get_logger().info(f"TtsNode init (langs={list(self._voice_paths.keys())})")
def _load_voice_for_lang(self, lang: str) -> None: threading.Thread(target=self._load_voice, daemon=True).start()
path = self._voice_paths.get(lang) threading.Thread(target=self._playback_worker, daemon=True).start()
if not path:
self.get_logger().warn(f"No voice for '{lang}', fallback to '{self._default_language}'"); return self.get_logger().info(
with self._voices_lock: f"TtsNode init (voice={self._voice_path}, "
if lang in self._voices: return f"streaming={self._sentence_streaming})"
)
# ── Voice loading ─────────────────────────────────────────────────────────
def _load_voice(self) -> None:
t0 = time.time()
self.get_logger().info(f"Loading Piper voice: {self._voice_path}")
try: try:
from piper import PiperVoice from piper import PiperVoice
voice = PiperVoice.load(path) self._voice = PiperVoice.load(self._voice_path)
list(voice.synthesize_stream_raw("Hello.")) # Warmup synthesis to pre-JIT ONNX graph
with self._voices_lock: self._voices[lang] = voice warmup_text = "Hello."
self.get_logger().info(f"Piper [{lang}] ready") list(self._voice.synthesize_stream_raw(warmup_text))
self.get_logger().info(f"Piper voice ready ({time.time()-t0:.1f}s)")
except Exception as e: except Exception as e:
self.get_logger().error(f"Piper voice load failed [{lang}]: {e}") self.get_logger().error(f"Piper voice load failed: {e}")
def _get_voice(self, lang: str): # ── Response handler ──────────────────────────────────────────────────────
with self._voices_lock:
v = self._voices.get(lang)
if v is not None: return v
if lang in self._voice_paths:
threading.Thread(target=self._load_voice_for_lang, args=(lang,), daemon=True).start()
return self._voices.get(self._default_language)
def _on_response(self, msg: ConversationResponse) -> None: def _on_response(self, msg: ConversationResponse) -> None:
if not msg.text.strip(): return """Handle streaming LLM response — synthesize sentence by sentence."""
lang = msg.language if msg.language else self._default_language if not msg.text.strip():
return
with self._lock: with self._lock:
if msg.turn_id != self._current_turn: is_new_turn = msg.turn_id != self._current_turn
self._current_turn = msg.turn_id; self._synthesized_turns = set() if is_new_turn:
self._current_turn = msg.turn_id
# Clear old synthesized sentence cache for this new turn
self._synthesized_turns = set()
text = strip_ssml(msg.text) text = strip_ssml(msg.text)
if self._sentence_streaming: if self._sentence_streaming:
for sentence in split_sentences(text): sentences = split_sentences(text)
for sentence in sentences:
# Track which sentences we've already queued by content hash
key = (msg.turn_id, hash(sentence)) key = (msg.turn_id, hash(sentence))
with self._lock: with self._lock:
if key in self._synthesized_turns: continue if key in self._synthesized_turns:
continue
self._synthesized_turns.add(key) self._synthesized_turns.add(key)
self._queue_synthesis(sentence, lang) self._queue_synthesis(sentence)
elif not msg.is_partial: elif not msg.is_partial:
self._queue_synthesis(text, lang) # Non-streaming: synthesize full response at end
self._queue_synthesis(text)
def _queue_synthesis(self, text: str, lang: str) -> None: def _queue_synthesis(self, text: str) -> None:
if not text.strip(): return """Queue a text segment for synthesis in the playback worker."""
try: self._playback_queue.put_nowait((text.strip(), lang)) if not text.strip():
except queue.Full: self.get_logger().warn("TTS queue full") return
try:
self._playback_queue.put_nowait(text.strip())
except queue.Full:
self.get_logger().warn("TTS playback queue full, dropping segment")
# ── Playback worker ───────────────────────────────────────────────────────
def _playback_worker(self) -> None: def _playback_worker(self) -> None:
"""Consume synthesis queue: synthesize → play → publish."""
while rclpy.ok(): while rclpy.ok():
try: item = self._playback_queue.get(timeout=0.5) try:
except queue.Empty: continue text = self._playback_queue.get(timeout=0.5)
text, lang = item except queue.Empty:
voice = self._get_voice(lang) continue
if voice is None:
self.get_logger().warn(f"No voice for '{lang}'"); self._playback_queue.task_done(); continue if self._voice is None:
self.get_logger().warn("TTS voice not loaded yet")
self._playback_queue.task_done()
continue
t0 = time.perf_counter() t0 = time.perf_counter()
pcm_data = self._synthesize(text, voice) pcm_data = self._synthesize(text)
if pcm_data is None: self._playback_queue.task_done(); continue if pcm_data is None:
if self._volume != 1.0: pcm_data = apply_volume(pcm_data, self._volume) self._playback_queue.task_done()
if self._playback: self._play_audio(pcm_data) continue
if self._publish_audio: self._publish_pcm(pcm_data)
synth_ms = (time.perf_counter() - t0) * 1000
dur_ms = estimate_duration_ms(pcm_data, self._sample_rate)
self.get_logger().debug(
f"TTS '{text[:40]}' synth={synth_ms:.0f}ms, dur={dur_ms:.0f}ms"
)
if self._volume != 1.0:
pcm_data = apply_volume(pcm_data, self._volume)
if self._playback:
self._play_audio(pcm_data)
if self._publish_audio:
self._publish_pcm(pcm_data)
self._playback_queue.task_done() self._playback_queue.task_done()
def _synthesize(self, text: str, voice) -> Optional[bytes]: def _synthesize(self, text: str) -> Optional[bytes]:
try: return b"".join(voice.synthesize_stream_raw(text)) """Synthesize text to PCM16 bytes using Piper streaming."""
except Exception as e: self.get_logger().error(f"TTS error: {e}"); return None if self._voice is None:
return None
try:
chunks = list(self._voice.synthesize_stream_raw(text))
return b"".join(chunks)
except Exception as e:
self.get_logger().error(f"TTS synthesis error: {e}")
return None
def _play_audio(self, pcm_data: bytes) -> None: def _play_audio(self, pcm_data: bytes) -> None:
"""Play PCM16 data on USB speaker via sounddevice."""
try: try:
import sounddevice as sd, numpy as np import sounddevice as sd
sd.play(np.frombuffer(pcm_data,dtype=np.int16).astype(np.float32)/32768.0, samplerate=self._sample_rate, device=self._audio_device, blocking=True) import numpy as np
except Exception as e: self.get_logger().error(f"Playback error: {e}") samples = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0
sd.play(samples, samplerate=self._sample_rate, device=self._audio_device,
blocking=True)
except Exception as e:
self.get_logger().error(f"Audio playback error: {e}")
def _publish_pcm(self, pcm_data: bytes) -> None: def _publish_pcm(self, pcm_data: bytes) -> None:
if not hasattr(self,"_audio_pub"): return """Publish PCM data as UInt8MultiArray."""
msg = UInt8MultiArray(); msg.data = list(pcm_data); self._audio_pub.publish(msg) if not hasattr(self, "_audio_pub"):
return
msg = UInt8MultiArray()
msg.data = list(pcm_data)
self._audio_pub.publish(msg)
def main(args=None) -> None: def main(args=None) -> None:
rclpy.init(args=args) rclpy.init(args=args)
node = TtsNode() node = TtsNode()
try: rclpy.spin(node) try:
except KeyboardInterrupt: pass rclpy.spin(node)
finally: node.destroy_node(); rclpy.shutdown() except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@ -1,122 +0,0 @@
"""test_multilang.py -- Unit tests for Issue #167 multi-language support."""
from __future__ import annotations
import json, os
from typing import Any, Dict, Optional
import pytest
def _pkg_root():
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def _read_src(rel_path):
with open(os.path.join(_pkg_root(), rel_path)) as f:
return f.read()
def _extract_lang_names():
import ast
src = _read_src("saltybot_social/conversation_node.py")
start = src.index("_LANG_NAMES: Dict[str, str] = {")
end = src.index("\n}", start) + 2
return ast.literal_eval(src[start:end].split("=",1)[1].strip())
class TestLangNames:
@pytest.fixture(scope="class")
def ln(self): return _extract_lang_names()
def test_english(self, ln): assert ln["en"] == "English"
def test_french(self, ln): assert ln["fr"] == "French"
def test_spanish(self, ln): assert ln["es"] == "Spanish"
def test_german(self, ln): assert ln["de"] == "German"
def test_japanese(self, ln):assert ln["ja"] == "Japanese"
def test_chinese(self, ln): assert ln["zh"] == "Chinese"
def test_arabic(self, ln): assert ln["ar"] == "Arabic"
def test_at_least_15(self, ln): assert len(ln) >= 15
def test_lowercase_keys(self, ln):
for k in ln: assert k == k.lower() and 2 <= len(k) <= 3
def test_nonempty_values(self, ln):
for k, v in ln.items(): assert v
class TestLanguageHint:
def _h(self, sl, sid, ln):
lang = sl.get(sid, "en")
return f"[Please respond in {ln.get(lang, lang)}.]" if lang and lang != "en" else ""
@pytest.fixture(scope="class")
def ln(self): return _extract_lang_names()
def test_english_no_hint(self, ln): assert self._h({"p": "en"}, "p", ln) == ""
def test_unknown_no_hint(self, ln): assert self._h({}, "p", ln) == ""
def test_french(self, ln): assert self._h({"p":"fr"},"p",ln) == "[Please respond in French.]"
def test_spanish(self, ln): assert self._h({"p":"es"},"p",ln) == "[Please respond in Spanish.]"
def test_unknown_code(self, ln): assert "xx" in self._h({"p":"xx"},"p",ln)
def test_brackets(self, ln):
h = self._h({"p":"de"},"p",ln)
assert h.startswith("[") and h.endswith("]")
class TestVoiceMap:
def _parse(self, jstr, dl, dp):
try: extra = json.loads(jstr) if jstr.strip() not in ("{}","") else {}
except: extra = {}
r = {dl: dp}; r.update(extra); return r
def test_empty(self): assert self._parse("{}","en","/e") == {"en":"/e"}
def test_extra(self):
vm = self._parse('{"fr":"/f"}', "en", "/e")
assert vm["fr"] == "/f"
def test_invalid(self): assert self._parse("BAD","en","/e") == {"en":"/e"}
def test_multi(self):
assert len(self._parse(json.dumps({"fr":"/f","es":"/s"}),"en","/e")) == 3
class TestVoiceSelect:
def _s(self, voices, lang, default):
return voices.get(lang) or voices.get(default)
def test_exact(self): assert self._s({"en":"E","fr":"F"},"fr","en") == "F"
def test_fallback(self): assert self._s({"en":"E"},"fr","en") == "E"
def test_none(self): assert self._s({},"fr","en") is None
class TestSttFields:
@pytest.fixture(scope="class")
def src(self): return _read_src("saltybot_social/speech_pipeline_node.py")
def test_param(self, src): assert "whisper_language" in src
def test_detected_lang(self, src): assert "detected_lang" in src
def test_msg_language(self, src): assert "msg.language = language" in src
def test_auto_detect(self, src): assert "language=self._whisper_language" in src
class TestConvFields:
@pytest.fixture(scope="class")
def src(self): return _read_src("saltybot_social/conversation_node.py")
def test_speaker_lang(self, src): assert "_speaker_lang" in src
def test_lang_hint_method(self, src): assert "_language_hint" in src
def test_msg_language(self, src): assert "msg.language = language" in src
def test_lang_names(self, src): assert "_LANG_NAMES" in src
def test_please_respond(self, src): assert "Please respond in" in src
def test_emotion_coexists(self, src): assert "_emotion_hint" in src
class TestTtsFields:
@pytest.fixture(scope="class")
def src(self): return _read_src("saltybot_social/tts_node.py")
def test_voice_map_json(self, src): assert "voice_map_json" in src
def test_default_lang(self, src): assert "default_language" in src
def test_voices_dict(self, src): assert "_voices" in src
def test_get_voice(self, src): assert "_get_voice" in src
def test_load_voice_for_lang(self, src): assert "_load_voice_for_lang" in src
def test_queue_tuple(self, src): assert "(text.strip(), lang)" in src
def test_synthesize_voice_arg(self, src): assert "_synthesize(text, voice)" in src
class TestMsgDefs:
@pytest.fixture(scope="class")
def tr(self): return _read_src("../saltybot_social_msgs/msg/SpeechTranscript.msg")
@pytest.fixture(scope="class")
def re(self): return _read_src("../saltybot_social_msgs/msg/ConversationResponse.msg")
def test_transcript_lang(self, tr): assert "string language" in tr
def test_transcript_bcp47(self, tr): assert "BCP-47" in tr
def test_response_lang(self, re): assert "string language" in re
class TestEdgeCases:
def test_empty_lang_no_hint(self):
lang = "" or "en"; assert lang == "en"
def test_lang_flows(self):
d: Dict[str,str] = {}; d["p1"] = "fr"
assert d.get("p1","en") == "fr"
def test_multi_speakers(self):
d = {"p1":"fr","p2":"es"}
assert d["p1"] == "fr" and d["p2"] == "es"
def test_voice_map_code_in_tts(self):
src = _read_src("saltybot_social/tts_node.py")
assert "voice_map_json" in src and "json.loads" in src

View File

@ -7,4 +7,3 @@ string text # Full or partial response text
string speaker_id # Who the response is addressed to string speaker_id # Who the response is addressed to
bool is_partial # true = streaming token chunk, false = final response bool is_partial # true = streaming token chunk, false = final response
int32 turn_id # Conversation turn counter (for deduplication) int32 turn_id # Conversation turn counter (for deduplication)
string language # BCP-47 language code for TTS voice selection e.g. "en" "fr" "es"

View File

@ -8,4 +8,3 @@ string speaker_id # e.g. "person_42" or "unknown"
float32 confidence # ASR confidence 0..1 float32 confidence # ASR confidence 0..1
float32 audio_duration # Duration of the utterance in seconds float32 audio_duration # Duration of the utterance in seconds
bool is_partial # true = intermediate streaming result, false = final bool is_partial # true = intermediate streaming result, false = final
string language # BCP-47 detected language code e.g. "en" "fr" "es" (empty = unknown)