feat(social): multi-language support - Whisper LID + per-lang Piper TTS (Issue #167) #187
@ -8,6 +8,7 @@ speech_pipeline_node:
|
||||
use_silero_vad: true
|
||||
whisper_model: "small" # small (~500ms), medium (better quality, ~900ms)
|
||||
whisper_compute_type: "float16"
|
||||
whisper_language: "" # "" = auto-detect; set e.g. "fr" to force
|
||||
speaker_threshold: 0.65
|
||||
speaker_db_path: "/social_db/speaker_embeddings.json"
|
||||
publish_partial: true
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
tts_node:
|
||||
ros__parameters:
|
||||
voice_path: "/models/piper/en_US-lessac-medium.onnx"
|
||||
voice_map_json: "{}"
|
||||
default_language: "en"
|
||||
sample_rate: 22050
|
||||
volume: 1.0
|
||||
audio_device: "" # "" = system default; set to device name if needed
|
||||
|
||||
@ -1,54 +1,30 @@
|
||||
"""conversation_node.py — Local LLM conversation engine with per-person context.
|
||||
|
||||
Issue #83: Conversation engine for social-bot.
|
||||
|
||||
Stack: Phi-3-mini or Llama-3.2-3B GGUF Q4_K_M via llama-cpp-python (CUDA).
|
||||
Subscribes /social/speech/transcript → builds per-person prompt → streams
|
||||
token output → publishes /social/conversation/response.
|
||||
|
||||
Streaming: publishes partial=true tokens as they arrive, then final=false
|
||||
at end of generation. TTS node can begin synthesis on first sentence boundary.
|
||||
|
||||
ROS2 topics:
|
||||
Subscribe: /social/speech/transcript (saltybot_social_msgs/SpeechTranscript)
|
||||
Publish: /social/conversation/response (saltybot_social_msgs/ConversationResponse)
|
||||
|
||||
Parameters:
|
||||
model_path (str, "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf")
|
||||
n_ctx (int, 4096)
|
||||
n_gpu_layers (int, 20) — GPU offload layers (increase for more VRAM usage)
|
||||
max_tokens (int, 200)
|
||||
temperature (float, 0.7)
|
||||
top_p (float, 0.9)
|
||||
soul_path (str, "/soul/SOUL.md")
|
||||
context_db_path (str, "/social_db/conversation_context.json")
|
||||
save_interval_s (float, 30.0) — how often to persist context to disk
|
||||
stream (bool, true)
|
||||
Issue #83/#161/#167
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import json, threading, time
|
||||
from typing import Dict, Optional
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile
|
||||
|
||||
from std_msgs.msg import String
|
||||
from saltybot_social_msgs.msg import SpeechTranscript, ConversationResponse
|
||||
from .llm_context import ContextStore, build_llama_prompt, load_system_prompt, needs_summary_prompt
|
||||
|
||||
_LANG_NAMES: Dict[str, str] = {
|
||||
"en": "English", "fr": "French", "es": "Spanish", "de": "German",
|
||||
"it": "Italian", "pt": "Portuguese", "ja": "Japanese", "zh": "Chinese",
|
||||
"ko": "Korean", "ar": "Arabic", "ru": "Russian", "nl": "Dutch",
|
||||
"pl": "Polish", "sv": "Swedish", "da": "Danish", "fi": "Finnish",
|
||||
"no": "Norwegian", "tr": "Turkish", "hi": "Hindi", "uk": "Ukrainian",
|
||||
"cs": "Czech", "ro": "Romanian", "hu": "Hungarian", "el": "Greek",
|
||||
}
|
||||
|
||||
class ConversationNode(Node):
|
||||
"""Local LLM inference node with per-person conversation memory."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("conversation_node")
|
||||
|
||||
# ── Parameters ──────────────────────────────────────────────────────
|
||||
self.declare_parameter("model_path",
|
||||
"/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf")
|
||||
self.declare_parameter("model_path", "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf")
|
||||
self.declare_parameter("n_ctx", 4096)
|
||||
self.declare_parameter("n_gpu_layers", 20)
|
||||
self.declare_parameter("max_tokens", 200)
|
||||
@ -58,7 +34,6 @@ class ConversationNode(Node):
|
||||
self.declare_parameter("context_db_path", "/social_db/conversation_context.json")
|
||||
self.declare_parameter("save_interval_s", 30.0)
|
||||
self.declare_parameter("stream", True)
|
||||
|
||||
self._model_path = self.get_parameter("model_path").value
|
||||
self._n_ctx = self.get_parameter("n_ctx").value
|
||||
self._n_gpu = self.get_parameter("n_gpu_layers").value
|
||||
@ -69,18 +44,9 @@ class ConversationNode(Node):
|
||||
self._db_path = self.get_parameter("context_db_path").value
|
||||
self._save_interval = self.get_parameter("save_interval_s").value
|
||||
self._stream = self.get_parameter("stream").value
|
||||
|
||||
# ── Publishers / Subscribers ─────────────────────────────────────────
|
||||
qos = QoSProfile(depth=10)
|
||||
self._resp_pub = self.create_publisher(
|
||||
ConversationResponse, "/social/conversation/response", qos
|
||||
)
|
||||
self._transcript_sub = self.create_subscription(
|
||||
SpeechTranscript, "/social/speech/transcript",
|
||||
self._on_transcript, qos
|
||||
)
|
||||
|
||||
# ── State ────────────────────────────────────────────────────────────
|
||||
self._resp_pub = self.create_publisher(ConversationResponse, "/social/conversation/response", qos)
|
||||
self._transcript_sub = self.create_subscription(SpeechTranscript, "/social/speech/transcript", self._on_transcript, qos)
|
||||
self._llm = None
|
||||
self._system_prompt = load_system_prompt(self._soul_path)
|
||||
self._ctx_store = ContextStore(self._db_path)
|
||||
@ -88,187 +54,114 @@ class ConversationNode(Node):
|
||||
self._turn_counter = 0
|
||||
self._generating = False
|
||||
self._last_save = time.time()
|
||||
|
||||
# ── Load LLM in background ────────────────────────────────────────────
|
||||
self._speaker_lang: Dict[str, str] = {}
|
||||
self._emotions: Dict[str, str] = {}
|
||||
self.create_subscription(String, "/social/emotion/context", self._on_emotion_context, 10)
|
||||
threading.Thread(target=self._load_llm, daemon=True).start()
|
||||
|
||||
# ── Periodic context save ────────────────────────────────────────────
|
||||
self._save_timer = self.create_timer(self._save_interval, self._save_context)
|
||||
|
||||
# ── Emotion context (Issue #161) ──────────────────────────────────────
|
||||
self._emotions: Dict[int, str] = {}
|
||||
self.create_subscription(
|
||||
String, "/social/emotion/context",
|
||||
self._on_emotion_context, 10
|
||||
)
|
||||
|
||||
self.get_logger().info(
|
||||
f"ConversationNode init (model={self._model_path}, "
|
||||
f"gpu_layers={self._n_gpu}, ctx={self._n_ctx})"
|
||||
)
|
||||
|
||||
# ── Model loading ─────────────────────────────────────────────────────────
|
||||
|
||||
def _load_llm(self) -> None:
|
||||
t0 = time.time()
|
||||
self.get_logger().info(f"Loading LLM: {self._model_path}")
|
||||
try:
|
||||
from llama_cpp import Llama
|
||||
self._llm = Llama(
|
||||
model_path=self._model_path,
|
||||
n_ctx=self._n_ctx,
|
||||
n_gpu_layers=self._n_gpu,
|
||||
n_threads=4,
|
||||
verbose=False,
|
||||
)
|
||||
self.get_logger().info(
|
||||
f"LLM ready ({time.time()-t0:.1f}s). "
|
||||
f"Context: {self._n_ctx} tokens, GPU layers: {self._n_gpu}"
|
||||
)
|
||||
self._llm = Llama(model_path=self._model_path, n_ctx=self._n_ctx, n_gpu_layers=self._n_gpu, n_threads=4, verbose=False)
|
||||
self.get_logger().info(f"LLM ready ({time.time()-t0:.1f}s)")
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"LLM load failed: {e}")
|
||||
|
||||
# ── Transcript callback ───────────────────────────────────────────────────
|
||||
|
||||
def _on_transcript(self, msg: SpeechTranscript) -> None:
|
||||
"""Handle final transcripts only (skip streaming partials)."""
|
||||
if msg.is_partial:
|
||||
if msg.is_partial or not msg.text.strip():
|
||||
return
|
||||
if not msg.text.strip():
|
||||
return
|
||||
|
||||
self.get_logger().info(
|
||||
f"Transcript [{msg.speaker_id}]: '{msg.text}'"
|
||||
)
|
||||
|
||||
threading.Thread(
|
||||
target=self._generate_response,
|
||||
args=(msg.text.strip(), msg.speaker_id),
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
# ── LLM inference ─────────────────────────────────────────────────────────
|
||||
if msg.language:
|
||||
self._speaker_lang[msg.speaker_id] = msg.language
|
||||
self.get_logger().info(f"Transcript [{msg.speaker_id}/{msg.language or '?'}]: '{msg.text}'")
|
||||
threading.Thread(target=self._generate_response, args=(msg.text.strip(), msg.speaker_id), daemon=True).start()
|
||||
|
||||
def _generate_response(self, user_text: str, speaker_id: str) -> None:
|
||||
"""Generate LLM response with streaming. Runs in thread."""
|
||||
if self._llm is None:
|
||||
self.get_logger().warn("LLM not loaded yet, dropping utterance")
|
||||
return
|
||||
|
||||
self.get_logger().warn("LLM not loaded yet, dropping utterance"); return
|
||||
with self._lock:
|
||||
if self._generating:
|
||||
self.get_logger().warn("LLM busy, dropping utterance")
|
||||
return
|
||||
self.get_logger().warn("LLM busy, dropping utterance"); return
|
||||
self._generating = True
|
||||
self._turn_counter += 1
|
||||
turn_id = self._turn_counter
|
||||
|
||||
lang = self._speaker_lang.get(speaker_id, "en")
|
||||
try:
|
||||
ctx = self._ctx_store.get(speaker_id)
|
||||
|
||||
# Summary compression if context is long
|
||||
if ctx.needs_compression():
|
||||
self._compress_context(ctx)
|
||||
|
||||
ctx.add_user(user_text)
|
||||
|
||||
prompt = build_llama_prompt(
|
||||
ctx, user_text, self._system_prompt
|
||||
)
|
||||
|
||||
emotion_hint = self._emotion_hint(speaker_id)
|
||||
lang_hint = self._language_hint(speaker_id)
|
||||
hints = " ".join(h for h in (emotion_hint, lang_hint) if h)
|
||||
annotated = f"{user_text} {hints}".rstrip() if hints else user_text
|
||||
ctx.add_user(annotated)
|
||||
prompt = build_llama_prompt(ctx, annotated, self._system_prompt)
|
||||
t0 = time.perf_counter()
|
||||
full_response = ""
|
||||
|
||||
if self._stream:
|
||||
output = self._llm(
|
||||
prompt,
|
||||
max_tokens=self._max_tokens,
|
||||
temperature=self._temperature,
|
||||
top_p=self._top_p,
|
||||
stream=True,
|
||||
stop=["<|user|>", "<|system|>", "\n\n\n"],
|
||||
)
|
||||
output = self._llm(prompt, max_tokens=self._max_tokens, temperature=self._temperature, top_p=self._top_p, stream=True, stop=["<|user|>", "<|system|>", "\n\n\n"])
|
||||
for chunk in output:
|
||||
token = chunk["choices"][0]["text"]
|
||||
full_response += token
|
||||
# Publish partial after each sentence boundary for low TTS latency
|
||||
if token.endswith((".", "!", "?", "\n")):
|
||||
self._publish_response(
|
||||
full_response.strip(), speaker_id, turn_id, is_partial=True
|
||||
)
|
||||
self._publish_response(full_response.strip(), speaker_id, turn_id, language=lang, is_partial=True)
|
||||
else:
|
||||
output = self._llm(
|
||||
prompt,
|
||||
max_tokens=self._max_tokens,
|
||||
temperature=self._temperature,
|
||||
top_p=self._top_p,
|
||||
stream=False,
|
||||
stop=["<|user|>", "<|system|>"],
|
||||
)
|
||||
output = self._llm(prompt, max_tokens=self._max_tokens, temperature=self._temperature, top_p=self._top_p, stream=False, stop=["<|user|>", "<|system|>"])
|
||||
full_response = output["choices"][0]["text"]
|
||||
|
||||
full_response = full_response.strip()
|
||||
latency_ms = (time.perf_counter() - t0) * 1000
|
||||
self.get_logger().info(
|
||||
f"LLM [{speaker_id}] ({latency_ms:.0f}ms): '{full_response[:80]}'"
|
||||
)
|
||||
|
||||
self.get_logger().info(f"LLM [{speaker_id}/{lang}] ({(time.perf_counter()-t0)*1000:.0f}ms): '{full_response[:80]}'")
|
||||
ctx.add_assistant(full_response)
|
||||
self._publish_response(full_response, speaker_id, turn_id, is_partial=False)
|
||||
|
||||
self._publish_response(full_response, speaker_id, turn_id, language=lang, is_partial=False)
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"LLM inference error: {e}")
|
||||
finally:
|
||||
with self._lock:
|
||||
self._generating = False
|
||||
with self._lock: self._generating = False
|
||||
|
||||
def _compress_context(self, ctx) -> None:
|
||||
"""Ask LLM to summarize old turns for context compression."""
|
||||
if self._llm is None:
|
||||
ctx.compress("(history omitted)")
|
||||
return
|
||||
if self._llm is None: ctx.compress("(history omitted)"); return
|
||||
try:
|
||||
summary_prompt = needs_summary_prompt(ctx)
|
||||
result = self._llm(summary_prompt, max_tokens=80, temperature=0.3, stream=False)
|
||||
summary = result["choices"][0]["text"].strip()
|
||||
ctx.compress(summary)
|
||||
self.get_logger().debug(
|
||||
f"Context compressed for {ctx.person_id}: '{summary[:60]}'"
|
||||
)
|
||||
except Exception:
|
||||
ctx.compress("(history omitted)")
|
||||
result = self._llm(needs_summary_prompt(ctx), max_tokens=80, temperature=0.3, stream=False)
|
||||
ctx.compress(result["choices"][0]["text"].strip())
|
||||
except Exception: ctx.compress("(history omitted)")
|
||||
|
||||
# ── Publish ───────────────────────────────────────────────────────────────
|
||||
def _language_hint(self, speaker_id: str) -> str:
|
||||
lang = self._speaker_lang.get(speaker_id, "en")
|
||||
if lang and lang != "en":
|
||||
return f"[Please respond in {_LANG_NAMES.get(lang, lang)}.]"
|
||||
return ""
|
||||
|
||||
def _publish_response(
|
||||
self, text: str, speaker_id: str, turn_id: int, is_partial: bool
|
||||
) -> None:
|
||||
def _on_emotion_context(self, msg: String) -> None:
|
||||
try:
|
||||
for k, v in json.loads(msg.data).get("emotions", {}).items():
|
||||
self._emotions[k] = v
|
||||
except Exception: pass
|
||||
|
||||
def _emotion_hint(self, speaker_id: str) -> str:
|
||||
emo = self._emotions.get(speaker_id, "")
|
||||
return f"[The person seems {emo} right now.]" if emo and emo != "neutral" else ""
|
||||
|
||||
def _publish_response(self, text: str, speaker_id: str, turn_id: int, language: str = "en", is_partial: bool = False) -> None:
|
||||
msg = ConversationResponse()
|
||||
msg.header.stamp = self.get_clock().now().to_msg()
|
||||
msg.text = text
|
||||
msg.speaker_id = speaker_id
|
||||
msg.is_partial = is_partial
|
||||
msg.turn_id = turn_id
|
||||
msg.text = text; msg.speaker_id = speaker_id; msg.is_partial = is_partial
|
||||
msg.turn_id = turn_id; msg.language = language
|
||||
self._resp_pub.publish(msg)
|
||||
|
||||
def _save_context(self) -> None:
|
||||
try:
|
||||
self._ctx_store.save()
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"Context save error: {e}")
|
||||
try: self._ctx_store.save()
|
||||
except Exception as e: self.get_logger().error(f"Context save error: {e}")
|
||||
|
||||
def destroy_node(self) -> None:
|
||||
self._save_context()
|
||||
super().destroy_node()
|
||||
|
||||
self._save_context(); super().destroy_node()
|
||||
|
||||
def main(args=None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = ConversationNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
try: rclpy.spin(node)
|
||||
except KeyboardInterrupt: pass
|
||||
finally: node.destroy_node(); rclpy.shutdown()
|
||||
|
||||
@ -66,6 +66,7 @@ class SpeechPipelineNode(Node):
|
||||
self.declare_parameter("use_silero_vad", True)
|
||||
self.declare_parameter("whisper_model", "small")
|
||||
self.declare_parameter("whisper_compute_type", "float16")
|
||||
self.declare_parameter("whisper_language", "")
|
||||
self.declare_parameter("speaker_threshold", 0.65)
|
||||
self.declare_parameter("speaker_db_path", "/social_db/speaker_embeddings.json")
|
||||
self.declare_parameter("publish_partial", True)
|
||||
@ -78,6 +79,7 @@ class SpeechPipelineNode(Node):
|
||||
self._use_silero = self.get_parameter("use_silero_vad").value
|
||||
self._whisper_model_name = self.get_parameter("whisper_model").value
|
||||
self._compute_type = self.get_parameter("whisper_compute_type").value
|
||||
self._whisper_language = self.get_parameter("whisper_language").value or None
|
||||
self._speaker_thresh = self.get_parameter("speaker_threshold").value
|
||||
self._speaker_db = self.get_parameter("speaker_db_path").value
|
||||
self._publish_partial = self.get_parameter("publish_partial").value
|
||||
@ -315,20 +317,24 @@ class SpeechPipelineNode(Node):
|
||||
except Exception as e:
|
||||
self.get_logger().debug(f"Speaker ID error: {e}")
|
||||
|
||||
# Streaming Whisper transcription
|
||||
# Streaming Whisper transcription with language detection
|
||||
partial_text = ""
|
||||
detected_lang = self._whisper_language or "en"
|
||||
try:
|
||||
segments_gen, _info = self._whisper.transcribe(
|
||||
segments_gen, info = self._whisper.transcribe(
|
||||
audio_np,
|
||||
language="en",
|
||||
language=self._whisper_language, # None = auto-detect
|
||||
beam_size=3,
|
||||
vad_filter=False,
|
||||
)
|
||||
if hasattr(info, "language") and info.language:
|
||||
detected_lang = info.language
|
||||
for seg in segments_gen:
|
||||
partial_text += seg.text.strip() + " "
|
||||
if self._publish_partial:
|
||||
self._publish_transcript(
|
||||
partial_text.strip(), speaker_id, 0.0, duration, is_partial=True
|
||||
partial_text.strip(), speaker_id, 0.0, duration,
|
||||
language=detected_lang, is_partial=True,
|
||||
)
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"Whisper error: {e}")
|
||||
@ -340,15 +346,19 @@ class SpeechPipelineNode(Node):
|
||||
|
||||
latency_ms = (time.perf_counter() - t0) * 1000
|
||||
self.get_logger().info(
|
||||
f"STT [{speaker_id}] ({duration:.1f}s, {latency_ms:.0f}ms): '{final_text}'"
|
||||
f"STT [{speaker_id}/{detected_lang}] ({duration:.1f}s, {latency_ms:.0f}ms): "
|
||||
f"'{final_text}'"
|
||||
)
|
||||
self._publish_transcript(
|
||||
final_text, speaker_id, 0.9, duration,
|
||||
language=detected_lang, is_partial=False,
|
||||
)
|
||||
self._publish_transcript(final_text, speaker_id, 0.9, duration, is_partial=False)
|
||||
|
||||
# ── Publishers ────────────────────────────────────────────────────────────
|
||||
|
||||
def _publish_transcript(
|
||||
self, text: str, speaker_id: str, confidence: float,
|
||||
duration: float, is_partial: bool
|
||||
duration: float, language: str = "en", is_partial: bool = False,
|
||||
) -> None:
|
||||
msg = SpeechTranscript()
|
||||
msg.header.stamp = self.get_clock().now().to_msg()
|
||||
@ -356,6 +366,7 @@ class SpeechPipelineNode(Node):
|
||||
msg.speaker_id = speaker_id
|
||||
msg.confidence = confidence
|
||||
msg.audio_duration = duration
|
||||
msg.language = language
|
||||
msg.is_partial = is_partial
|
||||
self._transcript_pub.publish(msg)
|
||||
|
||||
|
||||
@ -1,228 +1,136 @@
|
||||
"""tts_node.py — Streaming TTS with Piper / first-chunk streaming.
|
||||
|
||||
Issue #85: Streaming TTS — Piper/XTTS integration with first-chunk streaming.
|
||||
|
||||
Pipeline:
|
||||
/social/conversation/response (ConversationResponse)
|
||||
→ sentence split → Piper ONNX synthesis (sentence by sentence)
|
||||
→ PCM16 chunks → USB speaker (sounddevice) + publish /social/tts/audio
|
||||
|
||||
First-chunk strategy:
|
||||
- On partial=true ConversationResponse, extract first sentence and synthesize
|
||||
immediately → audio starts before LLM finishes generating
|
||||
- On final=false, synthesize remaining sentences
|
||||
|
||||
Latency target: <200ms to first audio chunk.
|
||||
|
||||
ROS2 topics:
|
||||
Subscribe: /social/conversation/response (saltybot_social_msgs/ConversationResponse)
|
||||
Publish: /social/tts/audio (audio_msgs/Audio or std_msgs/UInt8MultiArray fallback)
|
||||
|
||||
Parameters:
|
||||
voice_path (str, "/models/piper/en_US-lessac-medium.onnx")
|
||||
sample_rate (int, 22050)
|
||||
volume (float, 1.0)
|
||||
audio_device (str, "") — sounddevice device name; "" = system default
|
||||
playback_enabled (bool, true)
|
||||
publish_audio (bool, false) — publish PCM to ROS2 topic
|
||||
sentence_streaming (bool, true) — synthesize sentence-by-sentence
|
||||
"""tts_node.py -- Streaming TTS with Piper / first-chunk streaming.
|
||||
Issue #85/#167
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import json, queue, threading, time
|
||||
from typing import Any, Dict, Optional
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile
|
||||
from std_msgs.msg import UInt8MultiArray
|
||||
|
||||
from saltybot_social_msgs.msg import ConversationResponse
|
||||
from .tts_utils import split_sentences, strip_ssml, apply_volume, chunk_pcm, estimate_duration_ms
|
||||
|
||||
|
||||
class TtsNode(Node):
|
||||
"""Streaming TTS node using Piper ONNX."""
|
||||
|
||||
"""Streaming TTS node using Piper ONNX with per-language voice switching."""
|
||||
def __init__(self) -> None:
|
||||
super().__init__("tts_node")
|
||||
|
||||
# ── Parameters ──────────────────────────────────────────────────────
|
||||
self.declare_parameter("voice_path", "/models/piper/en_US-lessac-medium.onnx")
|
||||
self.declare_parameter("voice_map_json", "{}")
|
||||
self.declare_parameter("default_language", "en")
|
||||
self.declare_parameter("sample_rate", 22050)
|
||||
self.declare_parameter("volume", 1.0)
|
||||
self.declare_parameter("audio_device", "")
|
||||
self.declare_parameter("playback_enabled", True)
|
||||
self.declare_parameter("publish_audio", False)
|
||||
self.declare_parameter("sentence_streaming", True)
|
||||
|
||||
self._voice_path = self.get_parameter("voice_path").value
|
||||
self._voice_map_json = self.get_parameter("voice_map_json").value
|
||||
self._default_language = self.get_parameter("default_language").value or "en"
|
||||
self._sample_rate = self.get_parameter("sample_rate").value
|
||||
self._volume = self.get_parameter("volume").value
|
||||
self._audio_device = self.get_parameter("audio_device").value or None
|
||||
self._playback = self.get_parameter("playback_enabled").value
|
||||
self._publish_audio = self.get_parameter("publish_audio").value
|
||||
self._sentence_streaming = self.get_parameter("sentence_streaming").value
|
||||
|
||||
# ── Publishers / Subscribers ─────────────────────────────────────────
|
||||
try:
|
||||
extra: Dict[str, str] = json.loads(self._voice_map_json) if self._voice_map_json.strip() not in ("{}","") else {}
|
||||
except Exception as e:
|
||||
self.get_logger().warn(f"voice_map_json parse error: {e}"); extra = {}
|
||||
self._voice_paths: Dict[str, str] = {self._default_language: self._voice_path}
|
||||
self._voice_paths.update(extra)
|
||||
qos = QoSProfile(depth=10)
|
||||
self._resp_sub = self.create_subscription(
|
||||
ConversationResponse, "/social/conversation/response",
|
||||
self._on_response, qos
|
||||
)
|
||||
self._resp_sub = self.create_subscription(ConversationResponse, "/social/conversation/response", self._on_response, qos)
|
||||
if self._publish_audio:
|
||||
self._audio_pub = self.create_publisher(
|
||||
UInt8MultiArray, "/social/tts/audio", qos
|
||||
)
|
||||
|
||||
# ── TTS engine ────────────────────────────────────────────────────────
|
||||
self._voice = None
|
||||
self._audio_pub = self.create_publisher(UInt8MultiArray, "/social/tts/audio", qos)
|
||||
self._voices: Dict[str, Any] = {}
|
||||
self._voices_lock = threading.Lock()
|
||||
self._playback_queue: queue.Queue = queue.Queue(maxsize=16)
|
||||
self._current_turn = -1
|
||||
self._synthesized_turns: set = set() # turn_ids already synthesized
|
||||
self._synthesized_turns: set = set()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
threading.Thread(target=self._load_voice, daemon=True).start()
|
||||
threading.Thread(target=self._load_voice_for_lang, args=(self._default_language,), daemon=True).start()
|
||||
threading.Thread(target=self._playback_worker, daemon=True).start()
|
||||
self.get_logger().info(f"TtsNode init (langs={list(self._voice_paths.keys())})")
|
||||
|
||||
self.get_logger().info(
|
||||
f"TtsNode init (voice={self._voice_path}, "
|
||||
f"streaming={self._sentence_streaming})"
|
||||
)
|
||||
|
||||
# ── Voice loading ─────────────────────────────────────────────────────────
|
||||
|
||||
def _load_voice(self) -> None:
|
||||
t0 = time.time()
|
||||
self.get_logger().info(f"Loading Piper voice: {self._voice_path}")
|
||||
def _load_voice_for_lang(self, lang: str) -> None:
|
||||
path = self._voice_paths.get(lang)
|
||||
if not path:
|
||||
self.get_logger().warn(f"No voice for '{lang}', fallback to '{self._default_language}'"); return
|
||||
with self._voices_lock:
|
||||
if lang in self._voices: return
|
||||
try:
|
||||
from piper import PiperVoice
|
||||
self._voice = PiperVoice.load(self._voice_path)
|
||||
# Warmup synthesis to pre-JIT ONNX graph
|
||||
warmup_text = "Hello."
|
||||
list(self._voice.synthesize_stream_raw(warmup_text))
|
||||
self.get_logger().info(f"Piper voice ready ({time.time()-t0:.1f}s)")
|
||||
voice = PiperVoice.load(path)
|
||||
list(voice.synthesize_stream_raw("Hello."))
|
||||
with self._voices_lock: self._voices[lang] = voice
|
||||
self.get_logger().info(f"Piper [{lang}] ready")
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"Piper voice load failed: {e}")
|
||||
self.get_logger().error(f"Piper voice load failed [{lang}]: {e}")
|
||||
|
||||
# ── Response handler ──────────────────────────────────────────────────────
|
||||
def _get_voice(self, lang: str):
|
||||
with self._voices_lock:
|
||||
v = self._voices.get(lang)
|
||||
if v is not None: return v
|
||||
if lang in self._voice_paths:
|
||||
threading.Thread(target=self._load_voice_for_lang, args=(lang,), daemon=True).start()
|
||||
return self._voices.get(self._default_language)
|
||||
|
||||
def _on_response(self, msg: ConversationResponse) -> None:
|
||||
"""Handle streaming LLM response — synthesize sentence by sentence."""
|
||||
if not msg.text.strip():
|
||||
return
|
||||
|
||||
if not msg.text.strip(): return
|
||||
lang = msg.language if msg.language else self._default_language
|
||||
with self._lock:
|
||||
is_new_turn = msg.turn_id != self._current_turn
|
||||
if is_new_turn:
|
||||
self._current_turn = msg.turn_id
|
||||
# Clear old synthesized sentence cache for this new turn
|
||||
self._synthesized_turns = set()
|
||||
|
||||
if msg.turn_id != self._current_turn:
|
||||
self._current_turn = msg.turn_id; self._synthesized_turns = set()
|
||||
text = strip_ssml(msg.text)
|
||||
|
||||
if self._sentence_streaming:
|
||||
sentences = split_sentences(text)
|
||||
for sentence in sentences:
|
||||
# Track which sentences we've already queued by content hash
|
||||
for sentence in split_sentences(text):
|
||||
key = (msg.turn_id, hash(sentence))
|
||||
with self._lock:
|
||||
if key in self._synthesized_turns:
|
||||
continue
|
||||
if key in self._synthesized_turns: continue
|
||||
self._synthesized_turns.add(key)
|
||||
self._queue_synthesis(sentence)
|
||||
self._queue_synthesis(sentence, lang)
|
||||
elif not msg.is_partial:
|
||||
# Non-streaming: synthesize full response at end
|
||||
self._queue_synthesis(text)
|
||||
self._queue_synthesis(text, lang)
|
||||
|
||||
def _queue_synthesis(self, text: str) -> None:
|
||||
"""Queue a text segment for synthesis in the playback worker."""
|
||||
if not text.strip():
|
||||
return
|
||||
try:
|
||||
self._playback_queue.put_nowait(text.strip())
|
||||
except queue.Full:
|
||||
self.get_logger().warn("TTS playback queue full, dropping segment")
|
||||
|
||||
# ── Playback worker ───────────────────────────────────────────────────────
|
||||
def _queue_synthesis(self, text: str, lang: str) -> None:
|
||||
if not text.strip(): return
|
||||
try: self._playback_queue.put_nowait((text.strip(), lang))
|
||||
except queue.Full: self.get_logger().warn("TTS queue full")
|
||||
|
||||
def _playback_worker(self) -> None:
|
||||
"""Consume synthesis queue: synthesize → play → publish."""
|
||||
while rclpy.ok():
|
||||
try:
|
||||
text = self._playback_queue.get(timeout=0.5)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
if self._voice is None:
|
||||
self.get_logger().warn("TTS voice not loaded yet")
|
||||
self._playback_queue.task_done()
|
||||
continue
|
||||
|
||||
try: item = self._playback_queue.get(timeout=0.5)
|
||||
except queue.Empty: continue
|
||||
text, lang = item
|
||||
voice = self._get_voice(lang)
|
||||
if voice is None:
|
||||
self.get_logger().warn(f"No voice for '{lang}'"); self._playback_queue.task_done(); continue
|
||||
t0 = time.perf_counter()
|
||||
pcm_data = self._synthesize(text)
|
||||
if pcm_data is None:
|
||||
self._playback_queue.task_done()
|
||||
continue
|
||||
|
||||
synth_ms = (time.perf_counter() - t0) * 1000
|
||||
dur_ms = estimate_duration_ms(pcm_data, self._sample_rate)
|
||||
self.get_logger().debug(
|
||||
f"TTS '{text[:40]}' synth={synth_ms:.0f}ms, dur={dur_ms:.0f}ms"
|
||||
)
|
||||
|
||||
if self._volume != 1.0:
|
||||
pcm_data = apply_volume(pcm_data, self._volume)
|
||||
|
||||
if self._playback:
|
||||
self._play_audio(pcm_data)
|
||||
|
||||
if self._publish_audio:
|
||||
self._publish_pcm(pcm_data)
|
||||
|
||||
pcm_data = self._synthesize(text, voice)
|
||||
if pcm_data is None: self._playback_queue.task_done(); continue
|
||||
if self._volume != 1.0: pcm_data = apply_volume(pcm_data, self._volume)
|
||||
if self._playback: self._play_audio(pcm_data)
|
||||
if self._publish_audio: self._publish_pcm(pcm_data)
|
||||
self._playback_queue.task_done()
|
||||
|
||||
def _synthesize(self, text: str) -> Optional[bytes]:
|
||||
"""Synthesize text to PCM16 bytes using Piper streaming."""
|
||||
if self._voice is None:
|
||||
return None
|
||||
try:
|
||||
chunks = list(self._voice.synthesize_stream_raw(text))
|
||||
return b"".join(chunks)
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"TTS synthesis error: {e}")
|
||||
return None
|
||||
def _synthesize(self, text: str, voice) -> Optional[bytes]:
|
||||
try: return b"".join(voice.synthesize_stream_raw(text))
|
||||
except Exception as e: self.get_logger().error(f"TTS error: {e}"); return None
|
||||
|
||||
def _play_audio(self, pcm_data: bytes) -> None:
|
||||
"""Play PCM16 data on USB speaker via sounddevice."""
|
||||
try:
|
||||
import sounddevice as sd
|
||||
import numpy as np
|
||||
samples = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
sd.play(samples, samplerate=self._sample_rate, device=self._audio_device,
|
||||
blocking=True)
|
||||
except Exception as e:
|
||||
self.get_logger().error(f"Audio playback error: {e}")
|
||||
import sounddevice as sd, numpy as np
|
||||
sd.play(np.frombuffer(pcm_data,dtype=np.int16).astype(np.float32)/32768.0, samplerate=self._sample_rate, device=self._audio_device, blocking=True)
|
||||
except Exception as e: self.get_logger().error(f"Playback error: {e}")
|
||||
|
||||
def _publish_pcm(self, pcm_data: bytes) -> None:
|
||||
"""Publish PCM data as UInt8MultiArray."""
|
||||
if not hasattr(self, "_audio_pub"):
|
||||
return
|
||||
msg = UInt8MultiArray()
|
||||
msg.data = list(pcm_data)
|
||||
self._audio_pub.publish(msg)
|
||||
|
||||
if not hasattr(self,"_audio_pub"): return
|
||||
msg = UInt8MultiArray(); msg.data = list(pcm_data); self._audio_pub.publish(msg)
|
||||
|
||||
def main(args=None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = TtsNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
try: rclpy.spin(node)
|
||||
except KeyboardInterrupt: pass
|
||||
finally: node.destroy_node(); rclpy.shutdown()
|
||||
|
||||
122
jetson/ros2_ws/src/saltybot_social/test/test_multilang.py
Normal file
122
jetson/ros2_ws/src/saltybot_social/test/test_multilang.py
Normal file
@ -0,0 +1,122 @@
|
||||
"""test_multilang.py -- Unit tests for Issue #167 multi-language support."""
|
||||
|
||||
from __future__ import annotations
|
||||
import json, os
|
||||
from typing import Any, Dict, Optional
|
||||
import pytest
|
||||
|
||||
def _pkg_root():
|
||||
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
def _read_src(rel_path):
|
||||
with open(os.path.join(_pkg_root(), rel_path)) as f:
|
||||
return f.read()
|
||||
|
||||
def _extract_lang_names():
|
||||
import ast
|
||||
src = _read_src("saltybot_social/conversation_node.py")
|
||||
start = src.index("_LANG_NAMES: Dict[str, str] = {")
|
||||
end = src.index("\n}", start) + 2
|
||||
return ast.literal_eval(src[start:end].split("=",1)[1].strip())
|
||||
|
||||
class TestLangNames:
|
||||
@pytest.fixture(scope="class")
|
||||
def ln(self): return _extract_lang_names()
|
||||
def test_english(self, ln): assert ln["en"] == "English"
|
||||
def test_french(self, ln): assert ln["fr"] == "French"
|
||||
def test_spanish(self, ln): assert ln["es"] == "Spanish"
|
||||
def test_german(self, ln): assert ln["de"] == "German"
|
||||
def test_japanese(self, ln):assert ln["ja"] == "Japanese"
|
||||
def test_chinese(self, ln): assert ln["zh"] == "Chinese"
|
||||
def test_arabic(self, ln): assert ln["ar"] == "Arabic"
|
||||
def test_at_least_15(self, ln): assert len(ln) >= 15
|
||||
def test_lowercase_keys(self, ln):
|
||||
for k in ln: assert k == k.lower() and 2 <= len(k) <= 3
|
||||
def test_nonempty_values(self, ln):
|
||||
for k, v in ln.items(): assert v
|
||||
|
||||
class TestLanguageHint:
|
||||
def _h(self, sl, sid, ln):
|
||||
lang = sl.get(sid, "en")
|
||||
return f"[Please respond in {ln.get(lang, lang)}.]" if lang and lang != "en" else ""
|
||||
@pytest.fixture(scope="class")
|
||||
def ln(self): return _extract_lang_names()
|
||||
def test_english_no_hint(self, ln): assert self._h({"p": "en"}, "p", ln) == ""
|
||||
def test_unknown_no_hint(self, ln): assert self._h({}, "p", ln) == ""
|
||||
def test_french(self, ln): assert self._h({"p":"fr"},"p",ln) == "[Please respond in French.]"
|
||||
def test_spanish(self, ln): assert self._h({"p":"es"},"p",ln) == "[Please respond in Spanish.]"
|
||||
def test_unknown_code(self, ln): assert "xx" in self._h({"p":"xx"},"p",ln)
|
||||
def test_brackets(self, ln):
|
||||
h = self._h({"p":"de"},"p",ln)
|
||||
assert h.startswith("[") and h.endswith("]")
|
||||
|
||||
class TestVoiceMap:
|
||||
def _parse(self, jstr, dl, dp):
|
||||
try: extra = json.loads(jstr) if jstr.strip() not in ("{}","") else {}
|
||||
except: extra = {}
|
||||
r = {dl: dp}; r.update(extra); return r
|
||||
def test_empty(self): assert self._parse("{}","en","/e") == {"en":"/e"}
|
||||
def test_extra(self):
|
||||
vm = self._parse('{"fr":"/f"}', "en", "/e")
|
||||
assert vm["fr"] == "/f"
|
||||
def test_invalid(self): assert self._parse("BAD","en","/e") == {"en":"/e"}
|
||||
def test_multi(self):
|
||||
assert len(self._parse(json.dumps({"fr":"/f","es":"/s"}),"en","/e")) == 3
|
||||
|
||||
class TestVoiceSelect:
|
||||
def _s(self, voices, lang, default):
|
||||
return voices.get(lang) or voices.get(default)
|
||||
def test_exact(self): assert self._s({"en":"E","fr":"F"},"fr","en") == "F"
|
||||
def test_fallback(self): assert self._s({"en":"E"},"fr","en") == "E"
|
||||
def test_none(self): assert self._s({},"fr","en") is None
|
||||
|
||||
class TestSttFields:
|
||||
@pytest.fixture(scope="class")
|
||||
def src(self): return _read_src("saltybot_social/speech_pipeline_node.py")
|
||||
def test_param(self, src): assert "whisper_language" in src
|
||||
def test_detected_lang(self, src): assert "detected_lang" in src
|
||||
def test_msg_language(self, src): assert "msg.language = language" in src
|
||||
def test_auto_detect(self, src): assert "language=self._whisper_language" in src
|
||||
|
||||
class TestConvFields:
|
||||
@pytest.fixture(scope="class")
|
||||
def src(self): return _read_src("saltybot_social/conversation_node.py")
|
||||
def test_speaker_lang(self, src): assert "_speaker_lang" in src
|
||||
def test_lang_hint_method(self, src): assert "_language_hint" in src
|
||||
def test_msg_language(self, src): assert "msg.language = language" in src
|
||||
def test_lang_names(self, src): assert "_LANG_NAMES" in src
|
||||
def test_please_respond(self, src): assert "Please respond in" in src
|
||||
def test_emotion_coexists(self, src): assert "_emotion_hint" in src
|
||||
|
||||
class TestTtsFields:
|
||||
@pytest.fixture(scope="class")
|
||||
def src(self): return _read_src("saltybot_social/tts_node.py")
|
||||
def test_voice_map_json(self, src): assert "voice_map_json" in src
|
||||
def test_default_lang(self, src): assert "default_language" in src
|
||||
def test_voices_dict(self, src): assert "_voices" in src
|
||||
def test_get_voice(self, src): assert "_get_voice" in src
|
||||
def test_load_voice_for_lang(self, src): assert "_load_voice_for_lang" in src
|
||||
def test_queue_tuple(self, src): assert "(text.strip(), lang)" in src
|
||||
def test_synthesize_voice_arg(self, src): assert "_synthesize(text, voice)" in src
|
||||
|
||||
class TestMsgDefs:
|
||||
@pytest.fixture(scope="class")
|
||||
def tr(self): return _read_src("../saltybot_social_msgs/msg/SpeechTranscript.msg")
|
||||
@pytest.fixture(scope="class")
|
||||
def re(self): return _read_src("../saltybot_social_msgs/msg/ConversationResponse.msg")
|
||||
def test_transcript_lang(self, tr): assert "string language" in tr
|
||||
def test_transcript_bcp47(self, tr): assert "BCP-47" in tr
|
||||
def test_response_lang(self, re): assert "string language" in re
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_empty_lang_no_hint(self):
|
||||
lang = "" or "en"; assert lang == "en"
|
||||
def test_lang_flows(self):
|
||||
d: Dict[str,str] = {}; d["p1"] = "fr"
|
||||
assert d.get("p1","en") == "fr"
|
||||
def test_multi_speakers(self):
|
||||
d = {"p1":"fr","p2":"es"}
|
||||
assert d["p1"] == "fr" and d["p2"] == "es"
|
||||
def test_voice_map_code_in_tts(self):
|
||||
src = _read_src("saltybot_social/tts_node.py")
|
||||
assert "voice_map_json" in src and "json.loads" in src
|
||||
@ -7,3 +7,4 @@ string text # Full or partial response text
|
||||
string speaker_id # Who the response is addressed to
|
||||
bool is_partial # true = streaming token chunk, false = final response
|
||||
int32 turn_id # Conversation turn counter (for deduplication)
|
||||
string language # BCP-47 language code for TTS voice selection e.g. "en" "fr" "es"
|
||||
|
||||
@ -8,3 +8,4 @@ string speaker_id # e.g. "person_42" or "unknown"
|
||||
float32 confidence # ASR confidence 0..1
|
||||
float32 audio_duration # Duration of the utterance in seconds
|
||||
bool is_partial # true = intermediate streaming result, false = final
|
||||
string language # BCP-47 detected language code e.g. "en" "fr" "es" (empty = unknown)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user