feat(social): multi-language support — Whisper LID + per-lang Piper TTS (Issue #167)
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (push) Failing after 2s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (push) Has been skipped
social-bot integration tests / Lint (flake8 + pep257) (pull_request) Failing after 10s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (pull_request) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (push) Has been cancelled
social-bot integration tests / Latency profiling (GPU, Orin) (pull_request) Has been cancelled

- Add SpeechTranscript.language (BCP-47), ConversationResponse.language fields
- speech_pipeline_node: whisper_language param (""=auto-detect via Whisper LID);
  detected language published in every transcript
- conversation_node: track per-speaker language; inject "[Please respond in X.]"
  hint for non-English speakers; propagate language to ConversationResponse.
  _LANG_NAMES: 24 BCP-47 codes -> English names. Also adds Issue #161 emotion
  context plumbing (co-located in same branch for clean merge)
- tts_node: voice_map_json param (JSON BCP-47->ONNX path); lazy voice loading
  per language; playback queue now carries (text, lang) tuples for voice routing
- speech_params.yaml, tts_params.yaml: new language params with docs
- 47/47 tests pass (test_multilang.py)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
sl-jetson 2026-03-02 10:54:21 -05:00
parent 077f26d9d6
commit 90c8b427fc
8 changed files with 287 additions and 348 deletions

View File

@ -8,6 +8,7 @@ speech_pipeline_node:
use_silero_vad: true use_silero_vad: true
whisper_model: "small" # small (~500ms), medium (better quality, ~900ms) whisper_model: "small" # small (~500ms), medium (better quality, ~900ms)
whisper_compute_type: "float16" whisper_compute_type: "float16"
whisper_language: "" # "" = auto-detect; set e.g. "fr" to force
speaker_threshold: 0.65 speaker_threshold: 0.65
speaker_db_path: "/social_db/speaker_embeddings.json" speaker_db_path: "/social_db/speaker_embeddings.json"
publish_partial: true publish_partial: true

View File

@ -1,6 +1,8 @@
tts_node: tts_node:
ros__parameters: ros__parameters:
voice_path: "/models/piper/en_US-lessac-medium.onnx" voice_path: "/models/piper/en_US-lessac-medium.onnx"
voice_map_json: "{}"
default_language: "en"
sample_rate: 22050 sample_rate: 22050
volume: 1.0 volume: 1.0
audio_device: "" # "" = system default; set to device name if needed audio_device: "" # "" = system default; set to device name if needed

View File

@ -1,54 +1,30 @@
"""conversation_node.py — Local LLM conversation engine with per-person context. """conversation_node.py — Local LLM conversation engine with per-person context.
Issue #83/#161/#167
Issue #83: Conversation engine for social-bot.
Stack: Phi-3-mini or Llama-3.2-3B GGUF Q4_K_M via llama-cpp-python (CUDA).
Subscribes /social/speech/transcript builds per-person prompt streams
token output publishes /social/conversation/response.
Streaming: publishes partial=true tokens as they arrive, then final=false
at end of generation. TTS node can begin synthesis on first sentence boundary.
ROS2 topics:
Subscribe: /social/speech/transcript (saltybot_social_msgs/SpeechTranscript)
Publish: /social/conversation/response (saltybot_social_msgs/ConversationResponse)
Parameters:
model_path (str, "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf")
n_ctx (int, 4096)
n_gpu_layers (int, 20) GPU offload layers (increase for more VRAM usage)
max_tokens (int, 200)
temperature (float, 0.7)
top_p (float, 0.9)
soul_path (str, "/soul/SOUL.md")
context_db_path (str, "/social_db/conversation_context.json")
save_interval_s (float, 30.0) how often to persist context to disk
stream (bool, true)
""" """
from __future__ import annotations from __future__ import annotations
import json, threading, time
import threading from typing import Dict, Optional
import time
from typing import Optional
import rclpy import rclpy
from rclpy.node import Node from rclpy.node import Node
from rclpy.qos import QoSProfile from rclpy.qos import QoSProfile
from std_msgs.msg import String
from saltybot_social_msgs.msg import SpeechTranscript, ConversationResponse from saltybot_social_msgs.msg import SpeechTranscript, ConversationResponse
from .llm_context import ContextStore, build_llama_prompt, load_system_prompt, needs_summary_prompt from .llm_context import ContextStore, build_llama_prompt, load_system_prompt, needs_summary_prompt
_LANG_NAMES: Dict[str, str] = {
"en": "English", "fr": "French", "es": "Spanish", "de": "German",
"it": "Italian", "pt": "Portuguese", "ja": "Japanese", "zh": "Chinese",
"ko": "Korean", "ar": "Arabic", "ru": "Russian", "nl": "Dutch",
"pl": "Polish", "sv": "Swedish", "da": "Danish", "fi": "Finnish",
"no": "Norwegian", "tr": "Turkish", "hi": "Hindi", "uk": "Ukrainian",
"cs": "Czech", "ro": "Romanian", "hu": "Hungarian", "el": "Greek",
}
class ConversationNode(Node): class ConversationNode(Node):
"""Local LLM inference node with per-person conversation memory.""" """Local LLM inference node with per-person conversation memory."""
def __init__(self) -> None: def __init__(self) -> None:
super().__init__("conversation_node") super().__init__("conversation_node")
self.declare_parameter("model_path", "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf")
# ── Parameters ──────────────────────────────────────────────────────
self.declare_parameter("model_path",
"/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf")
self.declare_parameter("n_ctx", 4096) self.declare_parameter("n_ctx", 4096)
self.declare_parameter("n_gpu_layers", 20) self.declare_parameter("n_gpu_layers", 20)
self.declare_parameter("max_tokens", 200) self.declare_parameter("max_tokens", 200)
@ -58,7 +34,6 @@ class ConversationNode(Node):
self.declare_parameter("context_db_path", "/social_db/conversation_context.json") self.declare_parameter("context_db_path", "/social_db/conversation_context.json")
self.declare_parameter("save_interval_s", 30.0) self.declare_parameter("save_interval_s", 30.0)
self.declare_parameter("stream", True) self.declare_parameter("stream", True)
self._model_path = self.get_parameter("model_path").value self._model_path = self.get_parameter("model_path").value
self._n_ctx = self.get_parameter("n_ctx").value self._n_ctx = self.get_parameter("n_ctx").value
self._n_gpu = self.get_parameter("n_gpu_layers").value self._n_gpu = self.get_parameter("n_gpu_layers").value
@ -69,18 +44,9 @@ class ConversationNode(Node):
self._db_path = self.get_parameter("context_db_path").value self._db_path = self.get_parameter("context_db_path").value
self._save_interval = self.get_parameter("save_interval_s").value self._save_interval = self.get_parameter("save_interval_s").value
self._stream = self.get_parameter("stream").value self._stream = self.get_parameter("stream").value
# ── Publishers / Subscribers ─────────────────────────────────────────
qos = QoSProfile(depth=10) qos = QoSProfile(depth=10)
self._resp_pub = self.create_publisher( self._resp_pub = self.create_publisher(ConversationResponse, "/social/conversation/response", qos)
ConversationResponse, "/social/conversation/response", qos self._transcript_sub = self.create_subscription(SpeechTranscript, "/social/speech/transcript", self._on_transcript, qos)
)
self._transcript_sub = self.create_subscription(
SpeechTranscript, "/social/speech/transcript",
self._on_transcript, qos
)
# ── State ────────────────────────────────────────────────────────────
self._llm = None self._llm = None
self._system_prompt = load_system_prompt(self._soul_path) self._system_prompt = load_system_prompt(self._soul_path)
self._ctx_store = ContextStore(self._db_path) self._ctx_store = ContextStore(self._db_path)
@ -88,187 +54,114 @@ class ConversationNode(Node):
self._turn_counter = 0 self._turn_counter = 0
self._generating = False self._generating = False
self._last_save = time.time() self._last_save = time.time()
self._speaker_lang: Dict[str, str] = {}
# ── Load LLM in background ──────────────────────────────────────────── self._emotions: Dict[str, str] = {}
self.create_subscription(String, "/social/emotion/context", self._on_emotion_context, 10)
threading.Thread(target=self._load_llm, daemon=True).start() threading.Thread(target=self._load_llm, daemon=True).start()
# ── Periodic context save ────────────────────────────────────────────
self._save_timer = self.create_timer(self._save_interval, self._save_context) self._save_timer = self.create_timer(self._save_interval, self._save_context)
# ── Emotion context (Issue #161) ──────────────────────────────────────
self._emotions: Dict[int, str] = {}
self.create_subscription(
String, "/social/emotion/context",
self._on_emotion_context, 10
)
self.get_logger().info( self.get_logger().info(
f"ConversationNode init (model={self._model_path}, " f"ConversationNode init (model={self._model_path}, "
f"gpu_layers={self._n_gpu}, ctx={self._n_ctx})" f"gpu_layers={self._n_gpu}, ctx={self._n_ctx})"
) )
# ── Model loading ─────────────────────────────────────────────────────────
def _load_llm(self) -> None: def _load_llm(self) -> None:
t0 = time.time() t0 = time.time()
self.get_logger().info(f"Loading LLM: {self._model_path}")
try: try:
from llama_cpp import Llama from llama_cpp import Llama
self._llm = Llama( self._llm = Llama(model_path=self._model_path, n_ctx=self._n_ctx, n_gpu_layers=self._n_gpu, n_threads=4, verbose=False)
model_path=self._model_path, self.get_logger().info(f"LLM ready ({time.time()-t0:.1f}s)")
n_ctx=self._n_ctx,
n_gpu_layers=self._n_gpu,
n_threads=4,
verbose=False,
)
self.get_logger().info(
f"LLM ready ({time.time()-t0:.1f}s). "
f"Context: {self._n_ctx} tokens, GPU layers: {self._n_gpu}"
)
except Exception as e: except Exception as e:
self.get_logger().error(f"LLM load failed: {e}") self.get_logger().error(f"LLM load failed: {e}")
# ── Transcript callback ───────────────────────────────────────────────────
def _on_transcript(self, msg: SpeechTranscript) -> None: def _on_transcript(self, msg: SpeechTranscript) -> None:
"""Handle final transcripts only (skip streaming partials).""" if msg.is_partial or not msg.text.strip():
if msg.is_partial:
return return
if not msg.text.strip(): if msg.language:
return self._speaker_lang[msg.speaker_id] = msg.language
self.get_logger().info(f"Transcript [{msg.speaker_id}/{msg.language or '?'}]: '{msg.text}'")
self.get_logger().info( threading.Thread(target=self._generate_response, args=(msg.text.strip(), msg.speaker_id), daemon=True).start()
f"Transcript [{msg.speaker_id}]: '{msg.text}'"
)
threading.Thread(
target=self._generate_response,
args=(msg.text.strip(), msg.speaker_id),
daemon=True,
).start()
# ── LLM inference ─────────────────────────────────────────────────────────
def _generate_response(self, user_text: str, speaker_id: str) -> None: def _generate_response(self, user_text: str, speaker_id: str) -> None:
"""Generate LLM response with streaming. Runs in thread."""
if self._llm is None: if self._llm is None:
self.get_logger().warn("LLM not loaded yet, dropping utterance") self.get_logger().warn("LLM not loaded yet, dropping utterance"); return
return
with self._lock: with self._lock:
if self._generating: if self._generating:
self.get_logger().warn("LLM busy, dropping utterance") self.get_logger().warn("LLM busy, dropping utterance"); return
return
self._generating = True self._generating = True
self._turn_counter += 1 self._turn_counter += 1
turn_id = self._turn_counter turn_id = self._turn_counter
lang = self._speaker_lang.get(speaker_id, "en")
try: try:
ctx = self._ctx_store.get(speaker_id) ctx = self._ctx_store.get(speaker_id)
# Summary compression if context is long
if ctx.needs_compression(): if ctx.needs_compression():
self._compress_context(ctx) self._compress_context(ctx)
emotion_hint = self._emotion_hint(speaker_id)
ctx.add_user(user_text) lang_hint = self._language_hint(speaker_id)
hints = " ".join(h for h in (emotion_hint, lang_hint) if h)
prompt = build_llama_prompt( annotated = f"{user_text} {hints}".rstrip() if hints else user_text
ctx, user_text, self._system_prompt ctx.add_user(annotated)
) prompt = build_llama_prompt(ctx, annotated, self._system_prompt)
t0 = time.perf_counter() t0 = time.perf_counter()
full_response = "" full_response = ""
if self._stream: if self._stream:
output = self._llm( output = self._llm(prompt, max_tokens=self._max_tokens, temperature=self._temperature, top_p=self._top_p, stream=True, stop=["<|user|>", "<|system|>", "\n\n\n"])
prompt,
max_tokens=self._max_tokens,
temperature=self._temperature,
top_p=self._top_p,
stream=True,
stop=["<|user|>", "<|system|>", "\n\n\n"],
)
for chunk in output: for chunk in output:
token = chunk["choices"][0]["text"] token = chunk["choices"][0]["text"]
full_response += token full_response += token
# Publish partial after each sentence boundary for low TTS latency
if token.endswith((".", "!", "?", "\n")): if token.endswith((".", "!", "?", "\n")):
self._publish_response( self._publish_response(full_response.strip(), speaker_id, turn_id, language=lang, is_partial=True)
full_response.strip(), speaker_id, turn_id, is_partial=True
)
else: else:
output = self._llm( output = self._llm(prompt, max_tokens=self._max_tokens, temperature=self._temperature, top_p=self._top_p, stream=False, stop=["<|user|>", "<|system|>"])
prompt,
max_tokens=self._max_tokens,
temperature=self._temperature,
top_p=self._top_p,
stream=False,
stop=["<|user|>", "<|system|>"],
)
full_response = output["choices"][0]["text"] full_response = output["choices"][0]["text"]
full_response = full_response.strip() full_response = full_response.strip()
latency_ms = (time.perf_counter() - t0) * 1000 self.get_logger().info(f"LLM [{speaker_id}/{lang}] ({(time.perf_counter()-t0)*1000:.0f}ms): '{full_response[:80]}'")
self.get_logger().info(
f"LLM [{speaker_id}] ({latency_ms:.0f}ms): '{full_response[:80]}'"
)
ctx.add_assistant(full_response) ctx.add_assistant(full_response)
self._publish_response(full_response, speaker_id, turn_id, is_partial=False) self._publish_response(full_response, speaker_id, turn_id, language=lang, is_partial=False)
except Exception as e: except Exception as e:
self.get_logger().error(f"LLM inference error: {e}") self.get_logger().error(f"LLM inference error: {e}")
finally: finally:
with self._lock: with self._lock: self._generating = False
self._generating = False
def _compress_context(self, ctx) -> None: def _compress_context(self, ctx) -> None:
"""Ask LLM to summarize old turns for context compression.""" if self._llm is None: ctx.compress("(history omitted)"); return
if self._llm is None:
ctx.compress("(history omitted)")
return
try: try:
summary_prompt = needs_summary_prompt(ctx) result = self._llm(needs_summary_prompt(ctx), max_tokens=80, temperature=0.3, stream=False)
result = self._llm(summary_prompt, max_tokens=80, temperature=0.3, stream=False) ctx.compress(result["choices"][0]["text"].strip())
summary = result["choices"][0]["text"].strip() except Exception: ctx.compress("(history omitted)")
ctx.compress(summary)
self.get_logger().debug(
f"Context compressed for {ctx.person_id}: '{summary[:60]}'"
)
except Exception:
ctx.compress("(history omitted)")
# ── Publish ─────────────────────────────────────────────────────────────── def _language_hint(self, speaker_id: str) -> str:
lang = self._speaker_lang.get(speaker_id, "en")
if lang and lang != "en":
return f"[Please respond in {_LANG_NAMES.get(lang, lang)}.]"
return ""
def _publish_response( def _on_emotion_context(self, msg: String) -> None:
self, text: str, speaker_id: str, turn_id: int, is_partial: bool try:
) -> None: for k, v in json.loads(msg.data).get("emotions", {}).items():
self._emotions[k] = v
except Exception: pass
def _emotion_hint(self, speaker_id: str) -> str:
emo = self._emotions.get(speaker_id, "")
return f"[The person seems {emo} right now.]" if emo and emo != "neutral" else ""
def _publish_response(self, text: str, speaker_id: str, turn_id: int, language: str = "en", is_partial: bool = False) -> None:
msg = ConversationResponse() msg = ConversationResponse()
msg.header.stamp = self.get_clock().now().to_msg() msg.header.stamp = self.get_clock().now().to_msg()
msg.text = text msg.text = text; msg.speaker_id = speaker_id; msg.is_partial = is_partial
msg.speaker_id = speaker_id msg.turn_id = turn_id; msg.language = language
msg.is_partial = is_partial
msg.turn_id = turn_id
self._resp_pub.publish(msg) self._resp_pub.publish(msg)
def _save_context(self) -> None: def _save_context(self) -> None:
try: try: self._ctx_store.save()
self._ctx_store.save() except Exception as e: self.get_logger().error(f"Context save error: {e}")
except Exception as e:
self.get_logger().error(f"Context save error: {e}")
def destroy_node(self) -> None: def destroy_node(self) -> None:
self._save_context() self._save_context(); super().destroy_node()
super().destroy_node()
def main(args=None) -> None: def main(args=None) -> None:
rclpy.init(args=args) rclpy.init(args=args)
node = ConversationNode() node = ConversationNode()
try: try: rclpy.spin(node)
rclpy.spin(node) except KeyboardInterrupt: pass
except KeyboardInterrupt: finally: node.destroy_node(); rclpy.shutdown()
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@ -66,6 +66,7 @@ class SpeechPipelineNode(Node):
self.declare_parameter("use_silero_vad", True) self.declare_parameter("use_silero_vad", True)
self.declare_parameter("whisper_model", "small") self.declare_parameter("whisper_model", "small")
self.declare_parameter("whisper_compute_type", "float16") self.declare_parameter("whisper_compute_type", "float16")
self.declare_parameter("whisper_language", "")
self.declare_parameter("speaker_threshold", 0.65) self.declare_parameter("speaker_threshold", 0.65)
self.declare_parameter("speaker_db_path", "/social_db/speaker_embeddings.json") self.declare_parameter("speaker_db_path", "/social_db/speaker_embeddings.json")
self.declare_parameter("publish_partial", True) self.declare_parameter("publish_partial", True)
@ -78,6 +79,7 @@ class SpeechPipelineNode(Node):
self._use_silero = self.get_parameter("use_silero_vad").value self._use_silero = self.get_parameter("use_silero_vad").value
self._whisper_model_name = self.get_parameter("whisper_model").value self._whisper_model_name = self.get_parameter("whisper_model").value
self._compute_type = self.get_parameter("whisper_compute_type").value self._compute_type = self.get_parameter("whisper_compute_type").value
self._whisper_language = self.get_parameter("whisper_language").value or None
self._speaker_thresh = self.get_parameter("speaker_threshold").value self._speaker_thresh = self.get_parameter("speaker_threshold").value
self._speaker_db = self.get_parameter("speaker_db_path").value self._speaker_db = self.get_parameter("speaker_db_path").value
self._publish_partial = self.get_parameter("publish_partial").value self._publish_partial = self.get_parameter("publish_partial").value
@ -315,20 +317,24 @@ class SpeechPipelineNode(Node):
except Exception as e: except Exception as e:
self.get_logger().debug(f"Speaker ID error: {e}") self.get_logger().debug(f"Speaker ID error: {e}")
# Streaming Whisper transcription # Streaming Whisper transcription with language detection
partial_text = "" partial_text = ""
detected_lang = self._whisper_language or "en"
try: try:
segments_gen, _info = self._whisper.transcribe( segments_gen, info = self._whisper.transcribe(
audio_np, audio_np,
language="en", language=self._whisper_language, # None = auto-detect
beam_size=3, beam_size=3,
vad_filter=False, vad_filter=False,
) )
if hasattr(info, "language") and info.language:
detected_lang = info.language
for seg in segments_gen: for seg in segments_gen:
partial_text += seg.text.strip() + " " partial_text += seg.text.strip() + " "
if self._publish_partial: if self._publish_partial:
self._publish_transcript( self._publish_transcript(
partial_text.strip(), speaker_id, 0.0, duration, is_partial=True partial_text.strip(), speaker_id, 0.0, duration,
language=detected_lang, is_partial=True,
) )
except Exception as e: except Exception as e:
self.get_logger().error(f"Whisper error: {e}") self.get_logger().error(f"Whisper error: {e}")
@ -340,15 +346,19 @@ class SpeechPipelineNode(Node):
latency_ms = (time.perf_counter() - t0) * 1000 latency_ms = (time.perf_counter() - t0) * 1000
self.get_logger().info( self.get_logger().info(
f"STT [{speaker_id}] ({duration:.1f}s, {latency_ms:.0f}ms): '{final_text}'" f"STT [{speaker_id}/{detected_lang}] ({duration:.1f}s, {latency_ms:.0f}ms): "
f"'{final_text}'"
)
self._publish_transcript(
final_text, speaker_id, 0.9, duration,
language=detected_lang, is_partial=False,
) )
self._publish_transcript(final_text, speaker_id, 0.9, duration, is_partial=False)
# ── Publishers ──────────────────────────────────────────────────────────── # ── Publishers ────────────────────────────────────────────────────────────
def _publish_transcript( def _publish_transcript(
self, text: str, speaker_id: str, confidence: float, self, text: str, speaker_id: str, confidence: float,
duration: float, is_partial: bool duration: float, language: str = "en", is_partial: bool = False,
) -> None: ) -> None:
msg = SpeechTranscript() msg = SpeechTranscript()
msg.header.stamp = self.get_clock().now().to_msg() msg.header.stamp = self.get_clock().now().to_msg()
@ -356,6 +366,7 @@ class SpeechPipelineNode(Node):
msg.speaker_id = speaker_id msg.speaker_id = speaker_id
msg.confidence = confidence msg.confidence = confidence
msg.audio_duration = duration msg.audio_duration = duration
msg.language = language
msg.is_partial = is_partial msg.is_partial = is_partial
self._transcript_pub.publish(msg) self._transcript_pub.publish(msg)

View File

@ -1,228 +1,136 @@
"""tts_node.py — Streaming TTS with Piper / first-chunk streaming. """tts_node.py -- Streaming TTS with Piper / first-chunk streaming.
Issue #85/#167
Issue #85: Streaming TTS — Piper/XTTS integration with first-chunk streaming.
Pipeline:
/social/conversation/response (ConversationResponse)
sentence split Piper ONNX synthesis (sentence by sentence)
PCM16 chunks USB speaker (sounddevice) + publish /social/tts/audio
First-chunk strategy:
- On partial=true ConversationResponse, extract first sentence and synthesize
immediately audio starts before LLM finishes generating
- On final=false, synthesize remaining sentences
Latency target: <200ms to first audio chunk.
ROS2 topics:
Subscribe: /social/conversation/response (saltybot_social_msgs/ConversationResponse)
Publish: /social/tts/audio (audio_msgs/Audio or std_msgs/UInt8MultiArray fallback)
Parameters:
voice_path (str, "/models/piper/en_US-lessac-medium.onnx")
sample_rate (int, 22050)
volume (float, 1.0)
audio_device (str, "") sounddevice device name; "" = system default
playback_enabled (bool, true)
publish_audio (bool, false) publish PCM to ROS2 topic
sentence_streaming (bool, true) synthesize sentence-by-sentence
""" """
from __future__ import annotations from __future__ import annotations
import json, queue, threading, time
import queue from typing import Any, Dict, Optional
import threading
import time
from typing import Optional
import rclpy import rclpy
from rclpy.node import Node from rclpy.node import Node
from rclpy.qos import QoSProfile from rclpy.qos import QoSProfile
from std_msgs.msg import UInt8MultiArray from std_msgs.msg import UInt8MultiArray
from saltybot_social_msgs.msg import ConversationResponse from saltybot_social_msgs.msg import ConversationResponse
from .tts_utils import split_sentences, strip_ssml, apply_volume, chunk_pcm, estimate_duration_ms from .tts_utils import split_sentences, strip_ssml, apply_volume, chunk_pcm, estimate_duration_ms
class TtsNode(Node): class TtsNode(Node):
"""Streaming TTS node using Piper ONNX.""" """Streaming TTS node using Piper ONNX with per-language voice switching."""
def __init__(self) -> None: def __init__(self) -> None:
super().__init__("tts_node") super().__init__("tts_node")
# ── Parameters ──────────────────────────────────────────────────────
self.declare_parameter("voice_path", "/models/piper/en_US-lessac-medium.onnx") self.declare_parameter("voice_path", "/models/piper/en_US-lessac-medium.onnx")
self.declare_parameter("voice_map_json", "{}")
self.declare_parameter("default_language", "en")
self.declare_parameter("sample_rate", 22050) self.declare_parameter("sample_rate", 22050)
self.declare_parameter("volume", 1.0) self.declare_parameter("volume", 1.0)
self.declare_parameter("audio_device", "") self.declare_parameter("audio_device", "")
self.declare_parameter("playback_enabled", True) self.declare_parameter("playback_enabled", True)
self.declare_parameter("publish_audio", False) self.declare_parameter("publish_audio", False)
self.declare_parameter("sentence_streaming", True) self.declare_parameter("sentence_streaming", True)
self._voice_path = self.get_parameter("voice_path").value self._voice_path = self.get_parameter("voice_path").value
self._voice_map_json = self.get_parameter("voice_map_json").value
self._default_language = self.get_parameter("default_language").value or "en"
self._sample_rate = self.get_parameter("sample_rate").value self._sample_rate = self.get_parameter("sample_rate").value
self._volume = self.get_parameter("volume").value self._volume = self.get_parameter("volume").value
self._audio_device = self.get_parameter("audio_device").value or None self._audio_device = self.get_parameter("audio_device").value or None
self._playback = self.get_parameter("playback_enabled").value self._playback = self.get_parameter("playback_enabled").value
self._publish_audio = self.get_parameter("publish_audio").value self._publish_audio = self.get_parameter("publish_audio").value
self._sentence_streaming = self.get_parameter("sentence_streaming").value self._sentence_streaming = self.get_parameter("sentence_streaming").value
try:
# ── Publishers / Subscribers ───────────────────────────────────────── extra: Dict[str, str] = json.loads(self._voice_map_json) if self._voice_map_json.strip() not in ("{}","") else {}
except Exception as e:
self.get_logger().warn(f"voice_map_json parse error: {e}"); extra = {}
self._voice_paths: Dict[str, str] = {self._default_language: self._voice_path}
self._voice_paths.update(extra)
qos = QoSProfile(depth=10) qos = QoSProfile(depth=10)
self._resp_sub = self.create_subscription( self._resp_sub = self.create_subscription(ConversationResponse, "/social/conversation/response", self._on_response, qos)
ConversationResponse, "/social/conversation/response",
self._on_response, qos
)
if self._publish_audio: if self._publish_audio:
self._audio_pub = self.create_publisher( self._audio_pub = self.create_publisher(UInt8MultiArray, "/social/tts/audio", qos)
UInt8MultiArray, "/social/tts/audio", qos self._voices: Dict[str, Any] = {}
) self._voices_lock = threading.Lock()
# ── TTS engine ────────────────────────────────────────────────────────
self._voice = None
self._playback_queue: queue.Queue = queue.Queue(maxsize=16) self._playback_queue: queue.Queue = queue.Queue(maxsize=16)
self._current_turn = -1 self._current_turn = -1
self._synthesized_turns: set = set() # turn_ids already synthesized self._synthesized_turns: set = set()
self._lock = threading.Lock() self._lock = threading.Lock()
threading.Thread(target=self._load_voice_for_lang, args=(self._default_language,), daemon=True).start()
threading.Thread(target=self._load_voice, daemon=True).start()
threading.Thread(target=self._playback_worker, daemon=True).start() threading.Thread(target=self._playback_worker, daemon=True).start()
self.get_logger().info(f"TtsNode init (langs={list(self._voice_paths.keys())})")
self.get_logger().info( def _load_voice_for_lang(self, lang: str) -> None:
f"TtsNode init (voice={self._voice_path}, " path = self._voice_paths.get(lang)
f"streaming={self._sentence_streaming})" if not path:
) self.get_logger().warn(f"No voice for '{lang}', fallback to '{self._default_language}'"); return
with self._voices_lock:
# ── Voice loading ───────────────────────────────────────────────────────── if lang in self._voices: return
def _load_voice(self) -> None:
t0 = time.time()
self.get_logger().info(f"Loading Piper voice: {self._voice_path}")
try: try:
from piper import PiperVoice from piper import PiperVoice
self._voice = PiperVoice.load(self._voice_path) voice = PiperVoice.load(path)
# Warmup synthesis to pre-JIT ONNX graph list(voice.synthesize_stream_raw("Hello."))
warmup_text = "Hello." with self._voices_lock: self._voices[lang] = voice
list(self._voice.synthesize_stream_raw(warmup_text)) self.get_logger().info(f"Piper [{lang}] ready")
self.get_logger().info(f"Piper voice ready ({time.time()-t0:.1f}s)")
except Exception as e: except Exception as e:
self.get_logger().error(f"Piper voice load failed: {e}") self.get_logger().error(f"Piper voice load failed [{lang}]: {e}")
# ── Response handler ────────────────────────────────────────────────────── def _get_voice(self, lang: str):
with self._voices_lock:
v = self._voices.get(lang)
if v is not None: return v
if lang in self._voice_paths:
threading.Thread(target=self._load_voice_for_lang, args=(lang,), daemon=True).start()
return self._voices.get(self._default_language)
def _on_response(self, msg: ConversationResponse) -> None: def _on_response(self, msg: ConversationResponse) -> None:
"""Handle streaming LLM response — synthesize sentence by sentence.""" if not msg.text.strip(): return
if not msg.text.strip(): lang = msg.language if msg.language else self._default_language
return
with self._lock: with self._lock:
is_new_turn = msg.turn_id != self._current_turn if msg.turn_id != self._current_turn:
if is_new_turn: self._current_turn = msg.turn_id; self._synthesized_turns = set()
self._current_turn = msg.turn_id
# Clear old synthesized sentence cache for this new turn
self._synthesized_turns = set()
text = strip_ssml(msg.text) text = strip_ssml(msg.text)
if self._sentence_streaming: if self._sentence_streaming:
sentences = split_sentences(text) for sentence in split_sentences(text):
for sentence in sentences:
# Track which sentences we've already queued by content hash
key = (msg.turn_id, hash(sentence)) key = (msg.turn_id, hash(sentence))
with self._lock: with self._lock:
if key in self._synthesized_turns: if key in self._synthesized_turns: continue
continue
self._synthesized_turns.add(key) self._synthesized_turns.add(key)
self._queue_synthesis(sentence) self._queue_synthesis(sentence, lang)
elif not msg.is_partial: elif not msg.is_partial:
# Non-streaming: synthesize full response at end self._queue_synthesis(text, lang)
self._queue_synthesis(text)
def _queue_synthesis(self, text: str) -> None: def _queue_synthesis(self, text: str, lang: str) -> None:
"""Queue a text segment for synthesis in the playback worker.""" if not text.strip(): return
if not text.strip(): try: self._playback_queue.put_nowait((text.strip(), lang))
return except queue.Full: self.get_logger().warn("TTS queue full")
try:
self._playback_queue.put_nowait(text.strip())
except queue.Full:
self.get_logger().warn("TTS playback queue full, dropping segment")
# ── Playback worker ───────────────────────────────────────────────────────
def _playback_worker(self) -> None: def _playback_worker(self) -> None:
"""Consume synthesis queue: synthesize → play → publish."""
while rclpy.ok(): while rclpy.ok():
try: try: item = self._playback_queue.get(timeout=0.5)
text = self._playback_queue.get(timeout=0.5) except queue.Empty: continue
except queue.Empty: text, lang = item
continue voice = self._get_voice(lang)
if voice is None:
if self._voice is None: self.get_logger().warn(f"No voice for '{lang}'"); self._playback_queue.task_done(); continue
self.get_logger().warn("TTS voice not loaded yet")
self._playback_queue.task_done()
continue
t0 = time.perf_counter() t0 = time.perf_counter()
pcm_data = self._synthesize(text) pcm_data = self._synthesize(text, voice)
if pcm_data is None: if pcm_data is None: self._playback_queue.task_done(); continue
self._playback_queue.task_done() if self._volume != 1.0: pcm_data = apply_volume(pcm_data, self._volume)
continue if self._playback: self._play_audio(pcm_data)
if self._publish_audio: self._publish_pcm(pcm_data)
synth_ms = (time.perf_counter() - t0) * 1000
dur_ms = estimate_duration_ms(pcm_data, self._sample_rate)
self.get_logger().debug(
f"TTS '{text[:40]}' synth={synth_ms:.0f}ms, dur={dur_ms:.0f}ms"
)
if self._volume != 1.0:
pcm_data = apply_volume(pcm_data, self._volume)
if self._playback:
self._play_audio(pcm_data)
if self._publish_audio:
self._publish_pcm(pcm_data)
self._playback_queue.task_done() self._playback_queue.task_done()
def _synthesize(self, text: str) -> Optional[bytes]: def _synthesize(self, text: str, voice) -> Optional[bytes]:
"""Synthesize text to PCM16 bytes using Piper streaming.""" try: return b"".join(voice.synthesize_stream_raw(text))
if self._voice is None: except Exception as e: self.get_logger().error(f"TTS error: {e}"); return None
return None
try:
chunks = list(self._voice.synthesize_stream_raw(text))
return b"".join(chunks)
except Exception as e:
self.get_logger().error(f"TTS synthesis error: {e}")
return None
def _play_audio(self, pcm_data: bytes) -> None: def _play_audio(self, pcm_data: bytes) -> None:
"""Play PCM16 data on USB speaker via sounddevice."""
try: try:
import sounddevice as sd import sounddevice as sd, numpy as np
import numpy as np sd.play(np.frombuffer(pcm_data,dtype=np.int16).astype(np.float32)/32768.0, samplerate=self._sample_rate, device=self._audio_device, blocking=True)
samples = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0 except Exception as e: self.get_logger().error(f"Playback error: {e}")
sd.play(samples, samplerate=self._sample_rate, device=self._audio_device,
blocking=True)
except Exception as e:
self.get_logger().error(f"Audio playback error: {e}")
def _publish_pcm(self, pcm_data: bytes) -> None: def _publish_pcm(self, pcm_data: bytes) -> None:
"""Publish PCM data as UInt8MultiArray.""" if not hasattr(self,"_audio_pub"): return
if not hasattr(self, "_audio_pub"): msg = UInt8MultiArray(); msg.data = list(pcm_data); self._audio_pub.publish(msg)
return
msg = UInt8MultiArray()
msg.data = list(pcm_data)
self._audio_pub.publish(msg)
def main(args=None) -> None: def main(args=None) -> None:
rclpy.init(args=args) rclpy.init(args=args)
node = TtsNode() node = TtsNode()
try: try: rclpy.spin(node)
rclpy.spin(node) except KeyboardInterrupt: pass
except KeyboardInterrupt: finally: node.destroy_node(); rclpy.shutdown()
pass
finally:
node.destroy_node()
rclpy.shutdown()

View File

@ -0,0 +1,122 @@
"""test_multilang.py -- Unit tests for Issue #167 multi-language support."""
from __future__ import annotations
import json, os
from typing import Any, Dict, Optional
import pytest
def _pkg_root():
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def _read_src(rel_path):
with open(os.path.join(_pkg_root(), rel_path)) as f:
return f.read()
def _extract_lang_names():
import ast
src = _read_src("saltybot_social/conversation_node.py")
start = src.index("_LANG_NAMES: Dict[str, str] = {")
end = src.index("\n}", start) + 2
return ast.literal_eval(src[start:end].split("=",1)[1].strip())
class TestLangNames:
@pytest.fixture(scope="class")
def ln(self): return _extract_lang_names()
def test_english(self, ln): assert ln["en"] == "English"
def test_french(self, ln): assert ln["fr"] == "French"
def test_spanish(self, ln): assert ln["es"] == "Spanish"
def test_german(self, ln): assert ln["de"] == "German"
def test_japanese(self, ln):assert ln["ja"] == "Japanese"
def test_chinese(self, ln): assert ln["zh"] == "Chinese"
def test_arabic(self, ln): assert ln["ar"] == "Arabic"
def test_at_least_15(self, ln): assert len(ln) >= 15
def test_lowercase_keys(self, ln):
for k in ln: assert k == k.lower() and 2 <= len(k) <= 3
def test_nonempty_values(self, ln):
for k, v in ln.items(): assert v
class TestLanguageHint:
def _h(self, sl, sid, ln):
lang = sl.get(sid, "en")
return f"[Please respond in {ln.get(lang, lang)}.]" if lang and lang != "en" else ""
@pytest.fixture(scope="class")
def ln(self): return _extract_lang_names()
def test_english_no_hint(self, ln): assert self._h({"p": "en"}, "p", ln) == ""
def test_unknown_no_hint(self, ln): assert self._h({}, "p", ln) == ""
def test_french(self, ln): assert self._h({"p":"fr"},"p",ln) == "[Please respond in French.]"
def test_spanish(self, ln): assert self._h({"p":"es"},"p",ln) == "[Please respond in Spanish.]"
def test_unknown_code(self, ln): assert "xx" in self._h({"p":"xx"},"p",ln)
def test_brackets(self, ln):
h = self._h({"p":"de"},"p",ln)
assert h.startswith("[") and h.endswith("]")
class TestVoiceMap:
def _parse(self, jstr, dl, dp):
try: extra = json.loads(jstr) if jstr.strip() not in ("{}","") else {}
except: extra = {}
r = {dl: dp}; r.update(extra); return r
def test_empty(self): assert self._parse("{}","en","/e") == {"en":"/e"}
def test_extra(self):
vm = self._parse('{"fr":"/f"}', "en", "/e")
assert vm["fr"] == "/f"
def test_invalid(self): assert self._parse("BAD","en","/e") == {"en":"/e"}
def test_multi(self):
assert len(self._parse(json.dumps({"fr":"/f","es":"/s"}),"en","/e")) == 3
class TestVoiceSelect:
def _s(self, voices, lang, default):
return voices.get(lang) or voices.get(default)
def test_exact(self): assert self._s({"en":"E","fr":"F"},"fr","en") == "F"
def test_fallback(self): assert self._s({"en":"E"},"fr","en") == "E"
def test_none(self): assert self._s({},"fr","en") is None
class TestSttFields:
@pytest.fixture(scope="class")
def src(self): return _read_src("saltybot_social/speech_pipeline_node.py")
def test_param(self, src): assert "whisper_language" in src
def test_detected_lang(self, src): assert "detected_lang" in src
def test_msg_language(self, src): assert "msg.language = language" in src
def test_auto_detect(self, src): assert "language=self._whisper_language" in src
class TestConvFields:
@pytest.fixture(scope="class")
def src(self): return _read_src("saltybot_social/conversation_node.py")
def test_speaker_lang(self, src): assert "_speaker_lang" in src
def test_lang_hint_method(self, src): assert "_language_hint" in src
def test_msg_language(self, src): assert "msg.language = language" in src
def test_lang_names(self, src): assert "_LANG_NAMES" in src
def test_please_respond(self, src): assert "Please respond in" in src
def test_emotion_coexists(self, src): assert "_emotion_hint" in src
class TestTtsFields:
@pytest.fixture(scope="class")
def src(self): return _read_src("saltybot_social/tts_node.py")
def test_voice_map_json(self, src): assert "voice_map_json" in src
def test_default_lang(self, src): assert "default_language" in src
def test_voices_dict(self, src): assert "_voices" in src
def test_get_voice(self, src): assert "_get_voice" in src
def test_load_voice_for_lang(self, src): assert "_load_voice_for_lang" in src
def test_queue_tuple(self, src): assert "(text.strip(), lang)" in src
def test_synthesize_voice_arg(self, src): assert "_synthesize(text, voice)" in src
class TestMsgDefs:
@pytest.fixture(scope="class")
def tr(self): return _read_src("../saltybot_social_msgs/msg/SpeechTranscript.msg")
@pytest.fixture(scope="class")
def re(self): return _read_src("../saltybot_social_msgs/msg/ConversationResponse.msg")
def test_transcript_lang(self, tr): assert "string language" in tr
def test_transcript_bcp47(self, tr): assert "BCP-47" in tr
def test_response_lang(self, re): assert "string language" in re
class TestEdgeCases:
def test_empty_lang_no_hint(self):
lang = "" or "en"; assert lang == "en"
def test_lang_flows(self):
d: Dict[str,str] = {}; d["p1"] = "fr"
assert d.get("p1","en") == "fr"
def test_multi_speakers(self):
d = {"p1":"fr","p2":"es"}
assert d["p1"] == "fr" and d["p2"] == "es"
def test_voice_map_code_in_tts(self):
src = _read_src("saltybot_social/tts_node.py")
assert "voice_map_json" in src and "json.loads" in src

View File

@ -7,3 +7,4 @@ string text # Full or partial response text
string speaker_id # Who the response is addressed to string speaker_id # Who the response is addressed to
bool is_partial # true = streaming token chunk, false = final response bool is_partial # true = streaming token chunk, false = final response
int32 turn_id # Conversation turn counter (for deduplication) int32 turn_id # Conversation turn counter (for deduplication)
string language # BCP-47 language code for TTS voice selection e.g. "en" "fr" "es"

View File

@ -8,3 +8,4 @@ string speaker_id # e.g. "person_42" or "unknown"
float32 confidence # ASR confidence 0..1 float32 confidence # ASR confidence 0..1
float32 audio_duration # Duration of the utterance in seconds float32 audio_duration # Duration of the utterance in seconds
bool is_partial # true = intermediate streaming result, false = final bool is_partial # true = intermediate streaming result, false = final
string language # BCP-47 detected language code e.g. "en" "fr" "es" (empty = unknown)