From 55261c0b72fb749c3087ace1a1f91c7a07cc55d8 Mon Sep 17 00:00:00 2001 From: sl-jetson Date: Mon, 2 Mar 2026 10:54:21 -0500 Subject: [PATCH] =?UTF-8?q?feat(social):=20multi-language=20support=20?= =?UTF-8?q?=E2=80=94=20Whisper=20LID=20+=20per-lang=20Piper=20TTS=20(Issue?= =?UTF-8?q?=20#167)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add SpeechTranscript.language (BCP-47), ConversationResponse.language fields - speech_pipeline_node: whisper_language param (""=auto-detect via Whisper LID); detected language published in every transcript - conversation_node: track per-speaker language; inject "[Please respond in X.]" hint for non-English speakers; propagate language to ConversationResponse. _LANG_NAMES: 24 BCP-47 codes -> English names. Also adds Issue #161 emotion context plumbing (co-located in same branch for clean merge) - tts_node: voice_map_json param (JSON BCP-47->ONNX path); lazy voice loading per language; playback queue now carries (text, lang) tuples for voice routing - speech_params.yaml, tts_params.yaml: new language params with docs - 47/47 tests pass (test_multilang.py) Co-Authored-By: Claude Sonnet 4.6 --- .../saltybot_social/config/speech_params.yaml | 1 + .../saltybot_social/config/tts_params.yaml | 2 + .../saltybot_social/conversation_node.py | 241 +++++------------- .../saltybot_social/speech_pipeline_node.py | 25 +- .../saltybot_social/tts_node.py | 240 ++++++----------- .../saltybot_social/test/test_multilang.py | 122 +++++++++ .../msg/ConversationResponse.msg | 1 + .../msg/SpeechTranscript.msg | 1 + 8 files changed, 288 insertions(+), 345 deletions(-) create mode 100644 jetson/ros2_ws/src/saltybot_social/test/test_multilang.py diff --git a/jetson/ros2_ws/src/saltybot_social/config/speech_params.yaml b/jetson/ros2_ws/src/saltybot_social/config/speech_params.yaml index 8c1ef5e..e2a6d32 100644 --- a/jetson/ros2_ws/src/saltybot_social/config/speech_params.yaml +++ b/jetson/ros2_ws/src/saltybot_social/config/speech_params.yaml @@ -8,6 +8,7 @@ speech_pipeline_node: use_silero_vad: true whisper_model: "small" # small (~500ms), medium (better quality, ~900ms) whisper_compute_type: "float16" + whisper_language: "" # "" = auto-detect; set e.g. "fr" to force speaker_threshold: 0.65 speaker_db_path: "/social_db/speaker_embeddings.json" publish_partial: true diff --git a/jetson/ros2_ws/src/saltybot_social/config/tts_params.yaml b/jetson/ros2_ws/src/saltybot_social/config/tts_params.yaml index 4d8fe52..cc83641 100644 --- a/jetson/ros2_ws/src/saltybot_social/config/tts_params.yaml +++ b/jetson/ros2_ws/src/saltybot_social/config/tts_params.yaml @@ -1,6 +1,8 @@ tts_node: ros__parameters: voice_path: "/models/piper/en_US-lessac-medium.onnx" + voice_map_json: "{}" + default_language: "en" sample_rate: 22050 volume: 1.0 audio_device: "" # "" = system default; set to device name if needed diff --git a/jetson/ros2_ws/src/saltybot_social/saltybot_social/conversation_node.py b/jetson/ros2_ws/src/saltybot_social/saltybot_social/conversation_node.py index a9218a7..fe124bf 100644 --- a/jetson/ros2_ws/src/saltybot_social/saltybot_social/conversation_node.py +++ b/jetson/ros2_ws/src/saltybot_social/saltybot_social/conversation_node.py @@ -1,54 +1,30 @@ """conversation_node.py — Local LLM conversation engine with per-person context. - -Issue #83: Conversation engine for social-bot. - -Stack: Phi-3-mini or Llama-3.2-3B GGUF Q4_K_M via llama-cpp-python (CUDA). -Subscribes /social/speech/transcript → builds per-person prompt → streams -token output → publishes /social/conversation/response. - -Streaming: publishes partial=true tokens as they arrive, then final=false -at end of generation. TTS node can begin synthesis on first sentence boundary. - -ROS2 topics: - Subscribe: /social/speech/transcript (saltybot_social_msgs/SpeechTranscript) - Publish: /social/conversation/response (saltybot_social_msgs/ConversationResponse) - -Parameters: - model_path (str, "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf") - n_ctx (int, 4096) - n_gpu_layers (int, 20) — GPU offload layers (increase for more VRAM usage) - max_tokens (int, 200) - temperature (float, 0.7) - top_p (float, 0.9) - soul_path (str, "/soul/SOUL.md") - context_db_path (str, "/social_db/conversation_context.json") - save_interval_s (float, 30.0) — how often to persist context to disk - stream (bool, true) +Issue #83/#161/#167 """ - from __future__ import annotations - -import threading -import time -from typing import Optional - +import json, threading, time +from typing import Dict, Optional import rclpy from rclpy.node import Node from rclpy.qos import QoSProfile - +from std_msgs.msg import String from saltybot_social_msgs.msg import SpeechTranscript, ConversationResponse from .llm_context import ContextStore, build_llama_prompt, load_system_prompt, needs_summary_prompt +_LANG_NAMES: Dict[str, str] = { + "en": "English", "fr": "French", "es": "Spanish", "de": "German", + "it": "Italian", "pt": "Portuguese", "ja": "Japanese", "zh": "Chinese", + "ko": "Korean", "ar": "Arabic", "ru": "Russian", "nl": "Dutch", + "pl": "Polish", "sv": "Swedish", "da": "Danish", "fi": "Finnish", + "no": "Norwegian", "tr": "Turkish", "hi": "Hindi", "uk": "Ukrainian", + "cs": "Czech", "ro": "Romanian", "hu": "Hungarian", "el": "Greek", +} class ConversationNode(Node): """Local LLM inference node with per-person conversation memory.""" - def __init__(self) -> None: super().__init__("conversation_node") - - # ── Parameters ────────────────────────────────────────────────────── - self.declare_parameter("model_path", - "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf") + self.declare_parameter("model_path", "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf") self.declare_parameter("n_ctx", 4096) self.declare_parameter("n_gpu_layers", 20) self.declare_parameter("max_tokens", 200) @@ -58,7 +34,6 @@ class ConversationNode(Node): self.declare_parameter("context_db_path", "/social_db/conversation_context.json") self.declare_parameter("save_interval_s", 30.0) self.declare_parameter("stream", True) - self._model_path = self.get_parameter("model_path").value self._n_ctx = self.get_parameter("n_ctx").value self._n_gpu = self.get_parameter("n_gpu_layers").value @@ -69,18 +44,9 @@ class ConversationNode(Node): self._db_path = self.get_parameter("context_db_path").value self._save_interval = self.get_parameter("save_interval_s").value self._stream = self.get_parameter("stream").value - - # ── Publishers / Subscribers ───────────────────────────────────────── qos = QoSProfile(depth=10) - self._resp_pub = self.create_publisher( - ConversationResponse, "/social/conversation/response", qos - ) - self._transcript_sub = self.create_subscription( - SpeechTranscript, "/social/speech/transcript", - self._on_transcript, qos - ) - - # ── State ──────────────────────────────────────────────────────────── + self._resp_pub = self.create_publisher(ConversationResponse, "/social/conversation/response", qos) + self._transcript_sub = self.create_subscription(SpeechTranscript, "/social/speech/transcript", self._on_transcript, qos) self._llm = None self._system_prompt = load_system_prompt(self._soul_path) self._ctx_store = ContextStore(self._db_path) @@ -88,180 +54,111 @@ class ConversationNode(Node): self._turn_counter = 0 self._generating = False self._last_save = time.time() - - # ── Load LLM in background ──────────────────────────────────────────── + self._speaker_lang: Dict[str, str] = {} + self._emotions: Dict[str, str] = {} + self.create_subscription(String, "/social/emotion/context", self._on_emotion_context, 10) threading.Thread(target=self._load_llm, daemon=True).start() - - # ── Periodic context save ──────────────────────────────────────────── self._save_timer = self.create_timer(self._save_interval, self._save_context) - - self.get_logger().info( - f"ConversationNode init (model={self._model_path}, " - f"gpu_layers={self._n_gpu}, ctx={self._n_ctx})" - ) - - # ── Model loading ───────────────────────────────────────────────────────── + self.get_logger().info(f"ConversationNode init (model={self._model_path}, gpu_layers={self._n_gpu})") def _load_llm(self) -> None: t0 = time.time() - self.get_logger().info(f"Loading LLM: {self._model_path}") try: from llama_cpp import Llama - self._llm = Llama( - model_path=self._model_path, - n_ctx=self._n_ctx, - n_gpu_layers=self._n_gpu, - n_threads=4, - verbose=False, - ) - self.get_logger().info( - f"LLM ready ({time.time()-t0:.1f}s). " - f"Context: {self._n_ctx} tokens, GPU layers: {self._n_gpu}" - ) + self._llm = Llama(model_path=self._model_path, n_ctx=self._n_ctx, n_gpu_layers=self._n_gpu, n_threads=4, verbose=False) + self.get_logger().info(f"LLM ready ({time.time()-t0:.1f}s)") except Exception as e: self.get_logger().error(f"LLM load failed: {e}") - # ── Transcript callback ─────────────────────────────────────────────────── - def _on_transcript(self, msg: SpeechTranscript) -> None: - """Handle final transcripts only (skip streaming partials).""" - if msg.is_partial: + if msg.is_partial or not msg.text.strip(): return - if not msg.text.strip(): - return - - self.get_logger().info( - f"Transcript [{msg.speaker_id}]: '{msg.text}'" - ) - - threading.Thread( - target=self._generate_response, - args=(msg.text.strip(), msg.speaker_id), - daemon=True, - ).start() - - # ── LLM inference ───────────────────────────────────────────────────────── + if msg.language: + self._speaker_lang[msg.speaker_id] = msg.language + self.get_logger().info(f"Transcript [{msg.speaker_id}/{msg.language or '?'}]: '{msg.text}'") + threading.Thread(target=self._generate_response, args=(msg.text.strip(), msg.speaker_id), daemon=True).start() def _generate_response(self, user_text: str, speaker_id: str) -> None: - """Generate LLM response with streaming. Runs in thread.""" if self._llm is None: - self.get_logger().warn("LLM not loaded yet, dropping utterance") - return - + self.get_logger().warn("LLM not loaded yet, dropping utterance"); return with self._lock: if self._generating: - self.get_logger().warn("LLM busy, dropping utterance") - return + self.get_logger().warn("LLM busy, dropping utterance"); return self._generating = True self._turn_counter += 1 turn_id = self._turn_counter - + lang = self._speaker_lang.get(speaker_id, "en") try: ctx = self._ctx_store.get(speaker_id) - - # Summary compression if context is long if ctx.needs_compression(): self._compress_context(ctx) - - ctx.add_user(user_text) - - prompt = build_llama_prompt( - ctx, user_text, self._system_prompt - ) - + emotion_hint = self._emotion_hint(speaker_id) + lang_hint = self._language_hint(speaker_id) + hints = " ".join(h for h in (emotion_hint, lang_hint) if h) + annotated = f"{user_text} {hints}".rstrip() if hints else user_text + ctx.add_user(annotated) + prompt = build_llama_prompt(ctx, annotated, self._system_prompt) t0 = time.perf_counter() full_response = "" - if self._stream: - output = self._llm( - prompt, - max_tokens=self._max_tokens, - temperature=self._temperature, - top_p=self._top_p, - stream=True, - stop=["<|user|>", "<|system|>", "\n\n\n"], - ) + output = self._llm(prompt, max_tokens=self._max_tokens, temperature=self._temperature, top_p=self._top_p, stream=True, stop=["<|user|>", "<|system|>", "\n\n\n"]) for chunk in output: token = chunk["choices"][0]["text"] full_response += token - # Publish partial after each sentence boundary for low TTS latency if token.endswith((".", "!", "?", "\n")): - self._publish_response( - full_response.strip(), speaker_id, turn_id, is_partial=True - ) + self._publish_response(full_response.strip(), speaker_id, turn_id, language=lang, is_partial=True) else: - output = self._llm( - prompt, - max_tokens=self._max_tokens, - temperature=self._temperature, - top_p=self._top_p, - stream=False, - stop=["<|user|>", "<|system|>"], - ) + output = self._llm(prompt, max_tokens=self._max_tokens, temperature=self._temperature, top_p=self._top_p, stream=False, stop=["<|user|>", "<|system|>"]) full_response = output["choices"][0]["text"] - full_response = full_response.strip() - latency_ms = (time.perf_counter() - t0) * 1000 - self.get_logger().info( - f"LLM [{speaker_id}] ({latency_ms:.0f}ms): '{full_response[:80]}'" - ) - + self.get_logger().info(f"LLM [{speaker_id}/{lang}] ({(time.perf_counter()-t0)*1000:.0f}ms): '{full_response[:80]}'") ctx.add_assistant(full_response) - self._publish_response(full_response, speaker_id, turn_id, is_partial=False) - + self._publish_response(full_response, speaker_id, turn_id, language=lang, is_partial=False) except Exception as e: self.get_logger().error(f"LLM inference error: {e}") finally: - with self._lock: - self._generating = False + with self._lock: self._generating = False def _compress_context(self, ctx) -> None: - """Ask LLM to summarize old turns for context compression.""" - if self._llm is None: - ctx.compress("(history omitted)") - return + if self._llm is None: ctx.compress("(history omitted)"); return try: - summary_prompt = needs_summary_prompt(ctx) - result = self._llm(summary_prompt, max_tokens=80, temperature=0.3, stream=False) - summary = result["choices"][0]["text"].strip() - ctx.compress(summary) - self.get_logger().debug( - f"Context compressed for {ctx.person_id}: '{summary[:60]}'" - ) - except Exception: - ctx.compress("(history omitted)") + result = self._llm(needs_summary_prompt(ctx), max_tokens=80, temperature=0.3, stream=False) + ctx.compress(result["choices"][0]["text"].strip()) + except Exception: ctx.compress("(history omitted)") - # ── Publish ─────────────────────────────────────────────────────────────── + def _language_hint(self, speaker_id: str) -> str: + lang = self._speaker_lang.get(speaker_id, "en") + if lang and lang != "en": + return f"[Please respond in {_LANG_NAMES.get(lang, lang)}.]" + return "" - def _publish_response( - self, text: str, speaker_id: str, turn_id: int, is_partial: bool - ) -> None: + def _on_emotion_context(self, msg: String) -> None: + try: + for k, v in json.loads(msg.data).get("emotions", {}).items(): + self._emotions[k] = v + except Exception: pass + + def _emotion_hint(self, speaker_id: str) -> str: + emo = self._emotions.get(speaker_id, "") + return f"[The person seems {emo} right now.]" if emo and emo != "neutral" else "" + + def _publish_response(self, text: str, speaker_id: str, turn_id: int, language: str = "en", is_partial: bool = False) -> None: msg = ConversationResponse() msg.header.stamp = self.get_clock().now().to_msg() - msg.text = text - msg.speaker_id = speaker_id - msg.is_partial = is_partial - msg.turn_id = turn_id + msg.text = text; msg.speaker_id = speaker_id; msg.is_partial = is_partial + msg.turn_id = turn_id; msg.language = language self._resp_pub.publish(msg) def _save_context(self) -> None: - try: - self._ctx_store.save() - except Exception as e: - self.get_logger().error(f"Context save error: {e}") + try: self._ctx_store.save() + except Exception as e: self.get_logger().error(f"Context save error: {e}") def destroy_node(self) -> None: - self._save_context() - super().destroy_node() - + self._save_context(); super().destroy_node() def main(args=None) -> None: rclpy.init(args=args) node = ConversationNode() - try: - rclpy.spin(node) - except KeyboardInterrupt: - pass - finally: - node.destroy_node() - rclpy.shutdown() + try: rclpy.spin(node) + except KeyboardInterrupt: pass + finally: node.destroy_node(); rclpy.shutdown() diff --git a/jetson/ros2_ws/src/saltybot_social/saltybot_social/speech_pipeline_node.py b/jetson/ros2_ws/src/saltybot_social/saltybot_social/speech_pipeline_node.py index b9e6c27..9e458c0 100644 --- a/jetson/ros2_ws/src/saltybot_social/saltybot_social/speech_pipeline_node.py +++ b/jetson/ros2_ws/src/saltybot_social/saltybot_social/speech_pipeline_node.py @@ -66,6 +66,7 @@ class SpeechPipelineNode(Node): self.declare_parameter("use_silero_vad", True) self.declare_parameter("whisper_model", "small") self.declare_parameter("whisper_compute_type", "float16") + self.declare_parameter("whisper_language", "") self.declare_parameter("speaker_threshold", 0.65) self.declare_parameter("speaker_db_path", "/social_db/speaker_embeddings.json") self.declare_parameter("publish_partial", True) @@ -78,6 +79,7 @@ class SpeechPipelineNode(Node): self._use_silero = self.get_parameter("use_silero_vad").value self._whisper_model_name = self.get_parameter("whisper_model").value self._compute_type = self.get_parameter("whisper_compute_type").value + self._whisper_language = self.get_parameter("whisper_language").value or None self._speaker_thresh = self.get_parameter("speaker_threshold").value self._speaker_db = self.get_parameter("speaker_db_path").value self._publish_partial = self.get_parameter("publish_partial").value @@ -315,20 +317,24 @@ class SpeechPipelineNode(Node): except Exception as e: self.get_logger().debug(f"Speaker ID error: {e}") - # Streaming Whisper transcription + # Streaming Whisper transcription with language detection partial_text = "" + detected_lang = self._whisper_language or "en" try: - segments_gen, _info = self._whisper.transcribe( + segments_gen, info = self._whisper.transcribe( audio_np, - language="en", + language=self._whisper_language, # None = auto-detect beam_size=3, vad_filter=False, ) + if hasattr(info, "language") and info.language: + detected_lang = info.language for seg in segments_gen: partial_text += seg.text.strip() + " " if self._publish_partial: self._publish_transcript( - partial_text.strip(), speaker_id, 0.0, duration, is_partial=True + partial_text.strip(), speaker_id, 0.0, duration, + language=detected_lang, is_partial=True, ) except Exception as e: self.get_logger().error(f"Whisper error: {e}") @@ -340,15 +346,19 @@ class SpeechPipelineNode(Node): latency_ms = (time.perf_counter() - t0) * 1000 self.get_logger().info( - f"STT [{speaker_id}] ({duration:.1f}s, {latency_ms:.0f}ms): '{final_text}'" + f"STT [{speaker_id}/{detected_lang}] ({duration:.1f}s, {latency_ms:.0f}ms): " + f"'{final_text}'" + ) + self._publish_transcript( + final_text, speaker_id, 0.9, duration, + language=detected_lang, is_partial=False, ) - self._publish_transcript(final_text, speaker_id, 0.9, duration, is_partial=False) # ── Publishers ──────────────────────────────────────────────────────────── def _publish_transcript( self, text: str, speaker_id: str, confidence: float, - duration: float, is_partial: bool + duration: float, language: str = "en", is_partial: bool = False, ) -> None: msg = SpeechTranscript() msg.header.stamp = self.get_clock().now().to_msg() @@ -356,6 +366,7 @@ class SpeechPipelineNode(Node): msg.speaker_id = speaker_id msg.confidence = confidence msg.audio_duration = duration + msg.language = language msg.is_partial = is_partial self._transcript_pub.publish(msg) diff --git a/jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_node.py b/jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_node.py index d1a0d7f..06cefd1 100644 --- a/jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_node.py +++ b/jetson/ros2_ws/src/saltybot_social/saltybot_social/tts_node.py @@ -1,228 +1,136 @@ -"""tts_node.py — Streaming TTS with Piper / first-chunk streaming. - -Issue #85: Streaming TTS — Piper/XTTS integration with first-chunk streaming. - -Pipeline: - /social/conversation/response (ConversationResponse) - → sentence split → Piper ONNX synthesis (sentence by sentence) - → PCM16 chunks → USB speaker (sounddevice) + publish /social/tts/audio - -First-chunk strategy: - - On partial=true ConversationResponse, extract first sentence and synthesize - immediately → audio starts before LLM finishes generating - - On final=false, synthesize remaining sentences - -Latency target: <200ms to first audio chunk. - -ROS2 topics: - Subscribe: /social/conversation/response (saltybot_social_msgs/ConversationResponse) - Publish: /social/tts/audio (audio_msgs/Audio or std_msgs/UInt8MultiArray fallback) - -Parameters: - voice_path (str, "/models/piper/en_US-lessac-medium.onnx") - sample_rate (int, 22050) - volume (float, 1.0) - audio_device (str, "") — sounddevice device name; "" = system default - playback_enabled (bool, true) - publish_audio (bool, false) — publish PCM to ROS2 topic - sentence_streaming (bool, true) — synthesize sentence-by-sentence +"""tts_node.py -- Streaming TTS with Piper / first-chunk streaming. +Issue #85/#167 """ - from __future__ import annotations - -import queue -import threading -import time -from typing import Optional - +import json, queue, threading, time +from typing import Any, Dict, Optional import rclpy from rclpy.node import Node from rclpy.qos import QoSProfile from std_msgs.msg import UInt8MultiArray - from saltybot_social_msgs.msg import ConversationResponse from .tts_utils import split_sentences, strip_ssml, apply_volume, chunk_pcm, estimate_duration_ms - class TtsNode(Node): - """Streaming TTS node using Piper ONNX.""" - + """Streaming TTS node using Piper ONNX with per-language voice switching.""" def __init__(self) -> None: super().__init__("tts_node") - - # ── Parameters ────────────────────────────────────────────────────── self.declare_parameter("voice_path", "/models/piper/en_US-lessac-medium.onnx") + self.declare_parameter("voice_map_json", "{}") + self.declare_parameter("default_language", "en") self.declare_parameter("sample_rate", 22050) self.declare_parameter("volume", 1.0) self.declare_parameter("audio_device", "") self.declare_parameter("playback_enabled", True) self.declare_parameter("publish_audio", False) self.declare_parameter("sentence_streaming", True) - self._voice_path = self.get_parameter("voice_path").value + self._voice_map_json = self.get_parameter("voice_map_json").value + self._default_language = self.get_parameter("default_language").value or "en" self._sample_rate = self.get_parameter("sample_rate").value self._volume = self.get_parameter("volume").value self._audio_device = self.get_parameter("audio_device").value or None self._playback = self.get_parameter("playback_enabled").value self._publish_audio = self.get_parameter("publish_audio").value self._sentence_streaming = self.get_parameter("sentence_streaming").value - - # ── Publishers / Subscribers ───────────────────────────────────────── + try: + extra: Dict[str, str] = json.loads(self._voice_map_json) if self._voice_map_json.strip() not in ("{}","") else {} + except Exception as e: + self.get_logger().warn(f"voice_map_json parse error: {e}"); extra = {} + self._voice_paths: Dict[str, str] = {self._default_language: self._voice_path} + self._voice_paths.update(extra) qos = QoSProfile(depth=10) - self._resp_sub = self.create_subscription( - ConversationResponse, "/social/conversation/response", - self._on_response, qos - ) + self._resp_sub = self.create_subscription(ConversationResponse, "/social/conversation/response", self._on_response, qos) if self._publish_audio: - self._audio_pub = self.create_publisher( - UInt8MultiArray, "/social/tts/audio", qos - ) - - # ── TTS engine ──────────────────────────────────────────────────────── - self._voice = None + self._audio_pub = self.create_publisher(UInt8MultiArray, "/social/tts/audio", qos) + self._voices: Dict[str, Any] = {} + self._voices_lock = threading.Lock() self._playback_queue: queue.Queue = queue.Queue(maxsize=16) self._current_turn = -1 - self._synthesized_turns: set = set() # turn_ids already synthesized + self._synthesized_turns: set = set() self._lock = threading.Lock() - - threading.Thread(target=self._load_voice, daemon=True).start() + threading.Thread(target=self._load_voice_for_lang, args=(self._default_language,), daemon=True).start() threading.Thread(target=self._playback_worker, daemon=True).start() + self.get_logger().info(f"TtsNode init (langs={list(self._voice_paths.keys())})") - self.get_logger().info( - f"TtsNode init (voice={self._voice_path}, " - f"streaming={self._sentence_streaming})" - ) - - # ── Voice loading ───────────────────────────────────────────────────────── - - def _load_voice(self) -> None: - t0 = time.time() - self.get_logger().info(f"Loading Piper voice: {self._voice_path}") + def _load_voice_for_lang(self, lang: str) -> None: + path = self._voice_paths.get(lang) + if not path: + self.get_logger().warn(f"No voice for '{lang}', fallback to '{self._default_language}'"); return + with self._voices_lock: + if lang in self._voices: return try: from piper import PiperVoice - self._voice = PiperVoice.load(self._voice_path) - # Warmup synthesis to pre-JIT ONNX graph - warmup_text = "Hello." - list(self._voice.synthesize_stream_raw(warmup_text)) - self.get_logger().info(f"Piper voice ready ({time.time()-t0:.1f}s)") + voice = PiperVoice.load(path) + list(voice.synthesize_stream_raw("Hello.")) + with self._voices_lock: self._voices[lang] = voice + self.get_logger().info(f"Piper [{lang}] ready") except Exception as e: - self.get_logger().error(f"Piper voice load failed: {e}") + self.get_logger().error(f"Piper voice load failed [{lang}]: {e}") - # ── Response handler ────────────────────────────────────────────────────── + def _get_voice(self, lang: str): + with self._voices_lock: + v = self._voices.get(lang) + if v is not None: return v + if lang in self._voice_paths: + threading.Thread(target=self._load_voice_for_lang, args=(lang,), daemon=True).start() + return self._voices.get(self._default_language) def _on_response(self, msg: ConversationResponse) -> None: - """Handle streaming LLM response — synthesize sentence by sentence.""" - if not msg.text.strip(): - return - + if not msg.text.strip(): return + lang = msg.language if msg.language else self._default_language with self._lock: - is_new_turn = msg.turn_id != self._current_turn - if is_new_turn: - self._current_turn = msg.turn_id - # Clear old synthesized sentence cache for this new turn - self._synthesized_turns = set() - + if msg.turn_id != self._current_turn: + self._current_turn = msg.turn_id; self._synthesized_turns = set() text = strip_ssml(msg.text) - if self._sentence_streaming: - sentences = split_sentences(text) - for sentence in sentences: - # Track which sentences we've already queued by content hash + for sentence in split_sentences(text): key = (msg.turn_id, hash(sentence)) with self._lock: - if key in self._synthesized_turns: - continue + if key in self._synthesized_turns: continue self._synthesized_turns.add(key) - self._queue_synthesis(sentence) + self._queue_synthesis(sentence, lang) elif not msg.is_partial: - # Non-streaming: synthesize full response at end - self._queue_synthesis(text) + self._queue_synthesis(text, lang) - def _queue_synthesis(self, text: str) -> None: - """Queue a text segment for synthesis in the playback worker.""" - if not text.strip(): - return - try: - self._playback_queue.put_nowait(text.strip()) - except queue.Full: - self.get_logger().warn("TTS playback queue full, dropping segment") - - # ── Playback worker ─────────────────────────────────────────────────────── + def _queue_synthesis(self, text: str, lang: str) -> None: + if not text.strip(): return + try: self._playback_queue.put_nowait((text.strip(), lang)) + except queue.Full: self.get_logger().warn("TTS queue full") def _playback_worker(self) -> None: - """Consume synthesis queue: synthesize → play → publish.""" while rclpy.ok(): - try: - text = self._playback_queue.get(timeout=0.5) - except queue.Empty: - continue - - if self._voice is None: - self.get_logger().warn("TTS voice not loaded yet") - self._playback_queue.task_done() - continue - + try: item = self._playback_queue.get(timeout=0.5) + except queue.Empty: continue + text, lang = item + voice = self._get_voice(lang) + if voice is None: + self.get_logger().warn(f"No voice for '{lang}'"); self._playback_queue.task_done(); continue t0 = time.perf_counter() - pcm_data = self._synthesize(text) - if pcm_data is None: - self._playback_queue.task_done() - continue - - synth_ms = (time.perf_counter() - t0) * 1000 - dur_ms = estimate_duration_ms(pcm_data, self._sample_rate) - self.get_logger().debug( - f"TTS '{text[:40]}' synth={synth_ms:.0f}ms, dur={dur_ms:.0f}ms" - ) - - if self._volume != 1.0: - pcm_data = apply_volume(pcm_data, self._volume) - - if self._playback: - self._play_audio(pcm_data) - - if self._publish_audio: - self._publish_pcm(pcm_data) - + pcm_data = self._synthesize(text, voice) + if pcm_data is None: self._playback_queue.task_done(); continue + if self._volume != 1.0: pcm_data = apply_volume(pcm_data, self._volume) + if self._playback: self._play_audio(pcm_data) + if self._publish_audio: self._publish_pcm(pcm_data) self._playback_queue.task_done() - def _synthesize(self, text: str) -> Optional[bytes]: - """Synthesize text to PCM16 bytes using Piper streaming.""" - if self._voice is None: - return None - try: - chunks = list(self._voice.synthesize_stream_raw(text)) - return b"".join(chunks) - except Exception as e: - self.get_logger().error(f"TTS synthesis error: {e}") - return None + def _synthesize(self, text: str, voice) -> Optional[bytes]: + try: return b"".join(voice.synthesize_stream_raw(text)) + except Exception as e: self.get_logger().error(f"TTS error: {e}"); return None def _play_audio(self, pcm_data: bytes) -> None: - """Play PCM16 data on USB speaker via sounddevice.""" try: - import sounddevice as sd - import numpy as np - samples = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0 - sd.play(samples, samplerate=self._sample_rate, device=self._audio_device, - blocking=True) - except Exception as e: - self.get_logger().error(f"Audio playback error: {e}") + import sounddevice as sd, numpy as np + sd.play(np.frombuffer(pcm_data,dtype=np.int16).astype(np.float32)/32768.0, samplerate=self._sample_rate, device=self._audio_device, blocking=True) + except Exception as e: self.get_logger().error(f"Playback error: {e}") def _publish_pcm(self, pcm_data: bytes) -> None: - """Publish PCM data as UInt8MultiArray.""" - if not hasattr(self, "_audio_pub"): - return - msg = UInt8MultiArray() - msg.data = list(pcm_data) - self._audio_pub.publish(msg) - + if not hasattr(self,"_audio_pub"): return + msg = UInt8MultiArray(); msg.data = list(pcm_data); self._audio_pub.publish(msg) def main(args=None) -> None: rclpy.init(args=args) node = TtsNode() - try: - rclpy.spin(node) - except KeyboardInterrupt: - pass - finally: - node.destroy_node() - rclpy.shutdown() + try: rclpy.spin(node) + except KeyboardInterrupt: pass + finally: node.destroy_node(); rclpy.shutdown() diff --git a/jetson/ros2_ws/src/saltybot_social/test/test_multilang.py b/jetson/ros2_ws/src/saltybot_social/test/test_multilang.py new file mode 100644 index 0000000..4728891 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_social/test/test_multilang.py @@ -0,0 +1,122 @@ +"""test_multilang.py -- Unit tests for Issue #167 multi-language support.""" + +from __future__ import annotations +import json, os +from typing import Any, Dict, Optional +import pytest + +def _pkg_root(): + return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +def _read_src(rel_path): + with open(os.path.join(_pkg_root(), rel_path)) as f: + return f.read() + +def _extract_lang_names(): + import ast + src = _read_src("saltybot_social/conversation_node.py") + start = src.index("_LANG_NAMES: Dict[str, str] = {") + end = src.index("\n}", start) + 2 + return ast.literal_eval(src[start:end].split("=",1)[1].strip()) + +class TestLangNames: + @pytest.fixture(scope="class") + def ln(self): return _extract_lang_names() + def test_english(self, ln): assert ln["en"] == "English" + def test_french(self, ln): assert ln["fr"] == "French" + def test_spanish(self, ln): assert ln["es"] == "Spanish" + def test_german(self, ln): assert ln["de"] == "German" + def test_japanese(self, ln):assert ln["ja"] == "Japanese" + def test_chinese(self, ln): assert ln["zh"] == "Chinese" + def test_arabic(self, ln): assert ln["ar"] == "Arabic" + def test_at_least_15(self, ln): assert len(ln) >= 15 + def test_lowercase_keys(self, ln): + for k in ln: assert k == k.lower() and 2 <= len(k) <= 3 + def test_nonempty_values(self, ln): + for k, v in ln.items(): assert v + +class TestLanguageHint: + def _h(self, sl, sid, ln): + lang = sl.get(sid, "en") + return f"[Please respond in {ln.get(lang, lang)}.]" if lang and lang != "en" else "" + @pytest.fixture(scope="class") + def ln(self): return _extract_lang_names() + def test_english_no_hint(self, ln): assert self._h({"p": "en"}, "p", ln) == "" + def test_unknown_no_hint(self, ln): assert self._h({}, "p", ln) == "" + def test_french(self, ln): assert self._h({"p":"fr"},"p",ln) == "[Please respond in French.]" + def test_spanish(self, ln): assert self._h({"p":"es"},"p",ln) == "[Please respond in Spanish.]" + def test_unknown_code(self, ln): assert "xx" in self._h({"p":"xx"},"p",ln) + def test_brackets(self, ln): + h = self._h({"p":"de"},"p",ln) + assert h.startswith("[") and h.endswith("]") + +class TestVoiceMap: + def _parse(self, jstr, dl, dp): + try: extra = json.loads(jstr) if jstr.strip() not in ("{}","") else {} + except: extra = {} + r = {dl: dp}; r.update(extra); return r + def test_empty(self): assert self._parse("{}","en","/e") == {"en":"/e"} + def test_extra(self): + vm = self._parse('{"fr":"/f"}', "en", "/e") + assert vm["fr"] == "/f" + def test_invalid(self): assert self._parse("BAD","en","/e") == {"en":"/e"} + def test_multi(self): + assert len(self._parse(json.dumps({"fr":"/f","es":"/s"}),"en","/e")) == 3 + +class TestVoiceSelect: + def _s(self, voices, lang, default): + return voices.get(lang) or voices.get(default) + def test_exact(self): assert self._s({"en":"E","fr":"F"},"fr","en") == "F" + def test_fallback(self): assert self._s({"en":"E"},"fr","en") == "E" + def test_none(self): assert self._s({},"fr","en") is None + +class TestSttFields: + @pytest.fixture(scope="class") + def src(self): return _read_src("saltybot_social/speech_pipeline_node.py") + def test_param(self, src): assert "whisper_language" in src + def test_detected_lang(self, src): assert "detected_lang" in src + def test_msg_language(self, src): assert "msg.language = language" in src + def test_auto_detect(self, src): assert "language=self._whisper_language" in src + +class TestConvFields: + @pytest.fixture(scope="class") + def src(self): return _read_src("saltybot_social/conversation_node.py") + def test_speaker_lang(self, src): assert "_speaker_lang" in src + def test_lang_hint_method(self, src): assert "_language_hint" in src + def test_msg_language(self, src): assert "msg.language = language" in src + def test_lang_names(self, src): assert "_LANG_NAMES" in src + def test_please_respond(self, src): assert "Please respond in" in src + def test_emotion_coexists(self, src): assert "_emotion_hint" in src + +class TestTtsFields: + @pytest.fixture(scope="class") + def src(self): return _read_src("saltybot_social/tts_node.py") + def test_voice_map_json(self, src): assert "voice_map_json" in src + def test_default_lang(self, src): assert "default_language" in src + def test_voices_dict(self, src): assert "_voices" in src + def test_get_voice(self, src): assert "_get_voice" in src + def test_load_voice_for_lang(self, src): assert "_load_voice_for_lang" in src + def test_queue_tuple(self, src): assert "(text.strip(), lang)" in src + def test_synthesize_voice_arg(self, src): assert "_synthesize(text, voice)" in src + +class TestMsgDefs: + @pytest.fixture(scope="class") + def tr(self): return _read_src("../saltybot_social_msgs/msg/SpeechTranscript.msg") + @pytest.fixture(scope="class") + def re(self): return _read_src("../saltybot_social_msgs/msg/ConversationResponse.msg") + def test_transcript_lang(self, tr): assert "string language" in tr + def test_transcript_bcp47(self, tr): assert "BCP-47" in tr + def test_response_lang(self, re): assert "string language" in re + +class TestEdgeCases: + def test_empty_lang_no_hint(self): + lang = "" or "en"; assert lang == "en" + def test_lang_flows(self): + d: Dict[str,str] = {}; d["p1"] = "fr" + assert d.get("p1","en") == "fr" + def test_multi_speakers(self): + d = {"p1":"fr","p2":"es"} + assert d["p1"] == "fr" and d["p2"] == "es" + def test_voice_map_code_in_tts(self): + src = _read_src("saltybot_social/tts_node.py") + assert "voice_map_json" in src and "json.loads" in src diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/msg/ConversationResponse.msg b/jetson/ros2_ws/src/saltybot_social_msgs/msg/ConversationResponse.msg index e35b331..467645c 100644 --- a/jetson/ros2_ws/src/saltybot_social_msgs/msg/ConversationResponse.msg +++ b/jetson/ros2_ws/src/saltybot_social_msgs/msg/ConversationResponse.msg @@ -7,3 +7,4 @@ string text # Full or partial response text string speaker_id # Who the response is addressed to bool is_partial # true = streaming token chunk, false = final response int32 turn_id # Conversation turn counter (for deduplication) +string language # BCP-47 language code for TTS voice selection e.g. "en" "fr" "es" diff --git a/jetson/ros2_ws/src/saltybot_social_msgs/msg/SpeechTranscript.msg b/jetson/ros2_ws/src/saltybot_social_msgs/msg/SpeechTranscript.msg index 8f4cb5f..ab948ea 100644 --- a/jetson/ros2_ws/src/saltybot_social_msgs/msg/SpeechTranscript.msg +++ b/jetson/ros2_ws/src/saltybot_social_msgs/msg/SpeechTranscript.msg @@ -8,3 +8,4 @@ string speaker_id # e.g. "person_42" or "unknown" float32 confidence # ASR confidence 0..1 float32 audio_duration # Duration of the utterance in seconds bool is_partial # true = intermediate streaming result, false = final +string language # BCP-47 detected language code e.g. "en" "fr" "es" (empty = unknown)