Merge pull request 'feat(social): multi-language support - Whisper LID + per-lang Piper TTS (Issue #167)' (#187) from sl-jetson/issue-167-multilang into main
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (pull_request) Failing after 2s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (pull_request) Has been skipped
social-bot integration tests / Lint (flake8 + pep257) (push) Failing after 8s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (push) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (push) Has been cancelled
social-bot integration tests / Latency profiling (GPU, Orin) (pull_request) Has been cancelled
Some checks failed
social-bot integration tests / Lint (flake8 + pep257) (pull_request) Failing after 2s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (pull_request) Has been skipped
social-bot integration tests / Lint (flake8 + pep257) (push) Failing after 8s
social-bot integration tests / Core integration tests (mock sensors, no GPU) (push) Has been skipped
social-bot integration tests / Latency profiling (GPU, Orin) (push) Has been cancelled
social-bot integration tests / Latency profiling (GPU, Orin) (pull_request) Has been cancelled
This commit is contained in:
commit
e3e4bd70a4
@ -8,6 +8,7 @@ speech_pipeline_node:
|
|||||||
use_silero_vad: true
|
use_silero_vad: true
|
||||||
whisper_model: "small" # small (~500ms), medium (better quality, ~900ms)
|
whisper_model: "small" # small (~500ms), medium (better quality, ~900ms)
|
||||||
whisper_compute_type: "float16"
|
whisper_compute_type: "float16"
|
||||||
|
whisper_language: "" # "" = auto-detect; set e.g. "fr" to force
|
||||||
speaker_threshold: 0.65
|
speaker_threshold: 0.65
|
||||||
speaker_db_path: "/social_db/speaker_embeddings.json"
|
speaker_db_path: "/social_db/speaker_embeddings.json"
|
||||||
publish_partial: true
|
publish_partial: true
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
tts_node:
|
tts_node:
|
||||||
ros__parameters:
|
ros__parameters:
|
||||||
voice_path: "/models/piper/en_US-lessac-medium.onnx"
|
voice_path: "/models/piper/en_US-lessac-medium.onnx"
|
||||||
|
voice_map_json: "{}"
|
||||||
|
default_language: "en"
|
||||||
sample_rate: 22050
|
sample_rate: 22050
|
||||||
volume: 1.0
|
volume: 1.0
|
||||||
audio_device: "" # "" = system default; set to device name if needed
|
audio_device: "" # "" = system default; set to device name if needed
|
||||||
|
|||||||
@ -1,54 +1,30 @@
|
|||||||
"""conversation_node.py — Local LLM conversation engine with per-person context.
|
"""conversation_node.py — Local LLM conversation engine with per-person context.
|
||||||
|
Issue #83/#161/#167
|
||||||
Issue #83: Conversation engine for social-bot.
|
|
||||||
|
|
||||||
Stack: Phi-3-mini or Llama-3.2-3B GGUF Q4_K_M via llama-cpp-python (CUDA).
|
|
||||||
Subscribes /social/speech/transcript → builds per-person prompt → streams
|
|
||||||
token output → publishes /social/conversation/response.
|
|
||||||
|
|
||||||
Streaming: publishes partial=true tokens as they arrive, then final=false
|
|
||||||
at end of generation. TTS node can begin synthesis on first sentence boundary.
|
|
||||||
|
|
||||||
ROS2 topics:
|
|
||||||
Subscribe: /social/speech/transcript (saltybot_social_msgs/SpeechTranscript)
|
|
||||||
Publish: /social/conversation/response (saltybot_social_msgs/ConversationResponse)
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
model_path (str, "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf")
|
|
||||||
n_ctx (int, 4096)
|
|
||||||
n_gpu_layers (int, 20) — GPU offload layers (increase for more VRAM usage)
|
|
||||||
max_tokens (int, 200)
|
|
||||||
temperature (float, 0.7)
|
|
||||||
top_p (float, 0.9)
|
|
||||||
soul_path (str, "/soul/SOUL.md")
|
|
||||||
context_db_path (str, "/social_db/conversation_context.json")
|
|
||||||
save_interval_s (float, 30.0) — how often to persist context to disk
|
|
||||||
stream (bool, true)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import json, threading, time
|
||||||
import threading
|
from typing import Dict, Optional
|
||||||
import time
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import rclpy
|
import rclpy
|
||||||
from rclpy.node import Node
|
from rclpy.node import Node
|
||||||
from rclpy.qos import QoSProfile
|
from rclpy.qos import QoSProfile
|
||||||
|
from std_msgs.msg import String
|
||||||
from saltybot_social_msgs.msg import SpeechTranscript, ConversationResponse
|
from saltybot_social_msgs.msg import SpeechTranscript, ConversationResponse
|
||||||
from .llm_context import ContextStore, build_llama_prompt, load_system_prompt, needs_summary_prompt
|
from .llm_context import ContextStore, build_llama_prompt, load_system_prompt, needs_summary_prompt
|
||||||
|
|
||||||
|
_LANG_NAMES: Dict[str, str] = {
|
||||||
|
"en": "English", "fr": "French", "es": "Spanish", "de": "German",
|
||||||
|
"it": "Italian", "pt": "Portuguese", "ja": "Japanese", "zh": "Chinese",
|
||||||
|
"ko": "Korean", "ar": "Arabic", "ru": "Russian", "nl": "Dutch",
|
||||||
|
"pl": "Polish", "sv": "Swedish", "da": "Danish", "fi": "Finnish",
|
||||||
|
"no": "Norwegian", "tr": "Turkish", "hi": "Hindi", "uk": "Ukrainian",
|
||||||
|
"cs": "Czech", "ro": "Romanian", "hu": "Hungarian", "el": "Greek",
|
||||||
|
}
|
||||||
|
|
||||||
class ConversationNode(Node):
|
class ConversationNode(Node):
|
||||||
"""Local LLM inference node with per-person conversation memory."""
|
"""Local LLM inference node with per-person conversation memory."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__("conversation_node")
|
super().__init__("conversation_node")
|
||||||
|
self.declare_parameter("model_path", "/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf")
|
||||||
# ── Parameters ──────────────────────────────────────────────────────
|
|
||||||
self.declare_parameter("model_path",
|
|
||||||
"/models/gguf/phi-3-mini-4k-instruct-q4_k_m.gguf")
|
|
||||||
self.declare_parameter("n_ctx", 4096)
|
self.declare_parameter("n_ctx", 4096)
|
||||||
self.declare_parameter("n_gpu_layers", 20)
|
self.declare_parameter("n_gpu_layers", 20)
|
||||||
self.declare_parameter("max_tokens", 200)
|
self.declare_parameter("max_tokens", 200)
|
||||||
@ -58,7 +34,6 @@ class ConversationNode(Node):
|
|||||||
self.declare_parameter("context_db_path", "/social_db/conversation_context.json")
|
self.declare_parameter("context_db_path", "/social_db/conversation_context.json")
|
||||||
self.declare_parameter("save_interval_s", 30.0)
|
self.declare_parameter("save_interval_s", 30.0)
|
||||||
self.declare_parameter("stream", True)
|
self.declare_parameter("stream", True)
|
||||||
|
|
||||||
self._model_path = self.get_parameter("model_path").value
|
self._model_path = self.get_parameter("model_path").value
|
||||||
self._n_ctx = self.get_parameter("n_ctx").value
|
self._n_ctx = self.get_parameter("n_ctx").value
|
||||||
self._n_gpu = self.get_parameter("n_gpu_layers").value
|
self._n_gpu = self.get_parameter("n_gpu_layers").value
|
||||||
@ -69,18 +44,9 @@ class ConversationNode(Node):
|
|||||||
self._db_path = self.get_parameter("context_db_path").value
|
self._db_path = self.get_parameter("context_db_path").value
|
||||||
self._save_interval = self.get_parameter("save_interval_s").value
|
self._save_interval = self.get_parameter("save_interval_s").value
|
||||||
self._stream = self.get_parameter("stream").value
|
self._stream = self.get_parameter("stream").value
|
||||||
|
|
||||||
# ── Publishers / Subscribers ─────────────────────────────────────────
|
|
||||||
qos = QoSProfile(depth=10)
|
qos = QoSProfile(depth=10)
|
||||||
self._resp_pub = self.create_publisher(
|
self._resp_pub = self.create_publisher(ConversationResponse, "/social/conversation/response", qos)
|
||||||
ConversationResponse, "/social/conversation/response", qos
|
self._transcript_sub = self.create_subscription(SpeechTranscript, "/social/speech/transcript", self._on_transcript, qos)
|
||||||
)
|
|
||||||
self._transcript_sub = self.create_subscription(
|
|
||||||
SpeechTranscript, "/social/speech/transcript",
|
|
||||||
self._on_transcript, qos
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── State ────────────────────────────────────────────────────────────
|
|
||||||
self._llm = None
|
self._llm = None
|
||||||
self._system_prompt = load_system_prompt(self._soul_path)
|
self._system_prompt = load_system_prompt(self._soul_path)
|
||||||
self._ctx_store = ContextStore(self._db_path)
|
self._ctx_store = ContextStore(self._db_path)
|
||||||
@ -88,187 +54,114 @@ class ConversationNode(Node):
|
|||||||
self._turn_counter = 0
|
self._turn_counter = 0
|
||||||
self._generating = False
|
self._generating = False
|
||||||
self._last_save = time.time()
|
self._last_save = time.time()
|
||||||
|
self._speaker_lang: Dict[str, str] = {}
|
||||||
# ── Load LLM in background ────────────────────────────────────────────
|
self._emotions: Dict[str, str] = {}
|
||||||
|
self.create_subscription(String, "/social/emotion/context", self._on_emotion_context, 10)
|
||||||
threading.Thread(target=self._load_llm, daemon=True).start()
|
threading.Thread(target=self._load_llm, daemon=True).start()
|
||||||
|
|
||||||
# ── Periodic context save ────────────────────────────────────────────
|
|
||||||
self._save_timer = self.create_timer(self._save_interval, self._save_context)
|
self._save_timer = self.create_timer(self._save_interval, self._save_context)
|
||||||
|
|
||||||
# ── Emotion context (Issue #161) ──────────────────────────────────────
|
|
||||||
self._emotions: Dict[int, str] = {}
|
|
||||||
self.create_subscription(
|
|
||||||
String, "/social/emotion/context",
|
|
||||||
self._on_emotion_context, 10
|
|
||||||
)
|
|
||||||
|
|
||||||
self.get_logger().info(
|
self.get_logger().info(
|
||||||
f"ConversationNode init (model={self._model_path}, "
|
f"ConversationNode init (model={self._model_path}, "
|
||||||
f"gpu_layers={self._n_gpu}, ctx={self._n_ctx})"
|
f"gpu_layers={self._n_gpu}, ctx={self._n_ctx})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── Model loading ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _load_llm(self) -> None:
|
def _load_llm(self) -> None:
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
self.get_logger().info(f"Loading LLM: {self._model_path}")
|
|
||||||
try:
|
try:
|
||||||
from llama_cpp import Llama
|
from llama_cpp import Llama
|
||||||
self._llm = Llama(
|
self._llm = Llama(model_path=self._model_path, n_ctx=self._n_ctx, n_gpu_layers=self._n_gpu, n_threads=4, verbose=False)
|
||||||
model_path=self._model_path,
|
self.get_logger().info(f"LLM ready ({time.time()-t0:.1f}s)")
|
||||||
n_ctx=self._n_ctx,
|
|
||||||
n_gpu_layers=self._n_gpu,
|
|
||||||
n_threads=4,
|
|
||||||
verbose=False,
|
|
||||||
)
|
|
||||||
self.get_logger().info(
|
|
||||||
f"LLM ready ({time.time()-t0:.1f}s). "
|
|
||||||
f"Context: {self._n_ctx} tokens, GPU layers: {self._n_gpu}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.get_logger().error(f"LLM load failed: {e}")
|
self.get_logger().error(f"LLM load failed: {e}")
|
||||||
|
|
||||||
# ── Transcript callback ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _on_transcript(self, msg: SpeechTranscript) -> None:
|
def _on_transcript(self, msg: SpeechTranscript) -> None:
|
||||||
"""Handle final transcripts only (skip streaming partials)."""
|
if msg.is_partial or not msg.text.strip():
|
||||||
if msg.is_partial:
|
|
||||||
return
|
return
|
||||||
if not msg.text.strip():
|
if msg.language:
|
||||||
return
|
self._speaker_lang[msg.speaker_id] = msg.language
|
||||||
|
self.get_logger().info(f"Transcript [{msg.speaker_id}/{msg.language or '?'}]: '{msg.text}'")
|
||||||
self.get_logger().info(
|
threading.Thread(target=self._generate_response, args=(msg.text.strip(), msg.speaker_id), daemon=True).start()
|
||||||
f"Transcript [{msg.speaker_id}]: '{msg.text}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
threading.Thread(
|
|
||||||
target=self._generate_response,
|
|
||||||
args=(msg.text.strip(), msg.speaker_id),
|
|
||||||
daemon=True,
|
|
||||||
).start()
|
|
||||||
|
|
||||||
# ── LLM inference ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _generate_response(self, user_text: str, speaker_id: str) -> None:
|
def _generate_response(self, user_text: str, speaker_id: str) -> None:
|
||||||
"""Generate LLM response with streaming. Runs in thread."""
|
|
||||||
if self._llm is None:
|
if self._llm is None:
|
||||||
self.get_logger().warn("LLM not loaded yet, dropping utterance")
|
self.get_logger().warn("LLM not loaded yet, dropping utterance"); return
|
||||||
return
|
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._generating:
|
if self._generating:
|
||||||
self.get_logger().warn("LLM busy, dropping utterance")
|
self.get_logger().warn("LLM busy, dropping utterance"); return
|
||||||
return
|
|
||||||
self._generating = True
|
self._generating = True
|
||||||
self._turn_counter += 1
|
self._turn_counter += 1
|
||||||
turn_id = self._turn_counter
|
turn_id = self._turn_counter
|
||||||
|
lang = self._speaker_lang.get(speaker_id, "en")
|
||||||
try:
|
try:
|
||||||
ctx = self._ctx_store.get(speaker_id)
|
ctx = self._ctx_store.get(speaker_id)
|
||||||
|
|
||||||
# Summary compression if context is long
|
|
||||||
if ctx.needs_compression():
|
if ctx.needs_compression():
|
||||||
self._compress_context(ctx)
|
self._compress_context(ctx)
|
||||||
|
emotion_hint = self._emotion_hint(speaker_id)
|
||||||
ctx.add_user(user_text)
|
lang_hint = self._language_hint(speaker_id)
|
||||||
|
hints = " ".join(h for h in (emotion_hint, lang_hint) if h)
|
||||||
prompt = build_llama_prompt(
|
annotated = f"{user_text} {hints}".rstrip() if hints else user_text
|
||||||
ctx, user_text, self._system_prompt
|
ctx.add_user(annotated)
|
||||||
)
|
prompt = build_llama_prompt(ctx, annotated, self._system_prompt)
|
||||||
|
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
full_response = ""
|
full_response = ""
|
||||||
|
|
||||||
if self._stream:
|
if self._stream:
|
||||||
output = self._llm(
|
output = self._llm(prompt, max_tokens=self._max_tokens, temperature=self._temperature, top_p=self._top_p, stream=True, stop=["<|user|>", "<|system|>", "\n\n\n"])
|
||||||
prompt,
|
|
||||||
max_tokens=self._max_tokens,
|
|
||||||
temperature=self._temperature,
|
|
||||||
top_p=self._top_p,
|
|
||||||
stream=True,
|
|
||||||
stop=["<|user|>", "<|system|>", "\n\n\n"],
|
|
||||||
)
|
|
||||||
for chunk in output:
|
for chunk in output:
|
||||||
token = chunk["choices"][0]["text"]
|
token = chunk["choices"][0]["text"]
|
||||||
full_response += token
|
full_response += token
|
||||||
# Publish partial after each sentence boundary for low TTS latency
|
|
||||||
if token.endswith((".", "!", "?", "\n")):
|
if token.endswith((".", "!", "?", "\n")):
|
||||||
self._publish_response(
|
self._publish_response(full_response.strip(), speaker_id, turn_id, language=lang, is_partial=True)
|
||||||
full_response.strip(), speaker_id, turn_id, is_partial=True
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
output = self._llm(
|
output = self._llm(prompt, max_tokens=self._max_tokens, temperature=self._temperature, top_p=self._top_p, stream=False, stop=["<|user|>", "<|system|>"])
|
||||||
prompt,
|
|
||||||
max_tokens=self._max_tokens,
|
|
||||||
temperature=self._temperature,
|
|
||||||
top_p=self._top_p,
|
|
||||||
stream=False,
|
|
||||||
stop=["<|user|>", "<|system|>"],
|
|
||||||
)
|
|
||||||
full_response = output["choices"][0]["text"]
|
full_response = output["choices"][0]["text"]
|
||||||
|
|
||||||
full_response = full_response.strip()
|
full_response = full_response.strip()
|
||||||
latency_ms = (time.perf_counter() - t0) * 1000
|
self.get_logger().info(f"LLM [{speaker_id}/{lang}] ({(time.perf_counter()-t0)*1000:.0f}ms): '{full_response[:80]}'")
|
||||||
self.get_logger().info(
|
|
||||||
f"LLM [{speaker_id}] ({latency_ms:.0f}ms): '{full_response[:80]}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx.add_assistant(full_response)
|
ctx.add_assistant(full_response)
|
||||||
self._publish_response(full_response, speaker_id, turn_id, is_partial=False)
|
self._publish_response(full_response, speaker_id, turn_id, language=lang, is_partial=False)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.get_logger().error(f"LLM inference error: {e}")
|
self.get_logger().error(f"LLM inference error: {e}")
|
||||||
finally:
|
finally:
|
||||||
with self._lock:
|
with self._lock: self._generating = False
|
||||||
self._generating = False
|
|
||||||
|
|
||||||
def _compress_context(self, ctx) -> None:
|
def _compress_context(self, ctx) -> None:
|
||||||
"""Ask LLM to summarize old turns for context compression."""
|
if self._llm is None: ctx.compress("(history omitted)"); return
|
||||||
if self._llm is None:
|
|
||||||
ctx.compress("(history omitted)")
|
|
||||||
return
|
|
||||||
try:
|
try:
|
||||||
summary_prompt = needs_summary_prompt(ctx)
|
result = self._llm(needs_summary_prompt(ctx), max_tokens=80, temperature=0.3, stream=False)
|
||||||
result = self._llm(summary_prompt, max_tokens=80, temperature=0.3, stream=False)
|
ctx.compress(result["choices"][0]["text"].strip())
|
||||||
summary = result["choices"][0]["text"].strip()
|
except Exception: ctx.compress("(history omitted)")
|
||||||
ctx.compress(summary)
|
|
||||||
self.get_logger().debug(
|
|
||||||
f"Context compressed for {ctx.person_id}: '{summary[:60]}'"
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
ctx.compress("(history omitted)")
|
|
||||||
|
|
||||||
# ── Publish ───────────────────────────────────────────────────────────────
|
def _language_hint(self, speaker_id: str) -> str:
|
||||||
|
lang = self._speaker_lang.get(speaker_id, "en")
|
||||||
|
if lang and lang != "en":
|
||||||
|
return f"[Please respond in {_LANG_NAMES.get(lang, lang)}.]"
|
||||||
|
return ""
|
||||||
|
|
||||||
def _publish_response(
|
def _on_emotion_context(self, msg: String) -> None:
|
||||||
self, text: str, speaker_id: str, turn_id: int, is_partial: bool
|
try:
|
||||||
) -> None:
|
for k, v in json.loads(msg.data).get("emotions", {}).items():
|
||||||
|
self._emotions[k] = v
|
||||||
|
except Exception: pass
|
||||||
|
|
||||||
|
def _emotion_hint(self, speaker_id: str) -> str:
|
||||||
|
emo = self._emotions.get(speaker_id, "")
|
||||||
|
return f"[The person seems {emo} right now.]" if emo and emo != "neutral" else ""
|
||||||
|
|
||||||
|
def _publish_response(self, text: str, speaker_id: str, turn_id: int, language: str = "en", is_partial: bool = False) -> None:
|
||||||
msg = ConversationResponse()
|
msg = ConversationResponse()
|
||||||
msg.header.stamp = self.get_clock().now().to_msg()
|
msg.header.stamp = self.get_clock().now().to_msg()
|
||||||
msg.text = text
|
msg.text = text; msg.speaker_id = speaker_id; msg.is_partial = is_partial
|
||||||
msg.speaker_id = speaker_id
|
msg.turn_id = turn_id; msg.language = language
|
||||||
msg.is_partial = is_partial
|
|
||||||
msg.turn_id = turn_id
|
|
||||||
self._resp_pub.publish(msg)
|
self._resp_pub.publish(msg)
|
||||||
|
|
||||||
def _save_context(self) -> None:
|
def _save_context(self) -> None:
|
||||||
try:
|
try: self._ctx_store.save()
|
||||||
self._ctx_store.save()
|
except Exception as e: self.get_logger().error(f"Context save error: {e}")
|
||||||
except Exception as e:
|
|
||||||
self.get_logger().error(f"Context save error: {e}")
|
|
||||||
|
|
||||||
def destroy_node(self) -> None:
|
def destroy_node(self) -> None:
|
||||||
self._save_context()
|
self._save_context(); super().destroy_node()
|
||||||
super().destroy_node()
|
|
||||||
|
|
||||||
|
|
||||||
def main(args=None) -> None:
|
def main(args=None) -> None:
|
||||||
rclpy.init(args=args)
|
rclpy.init(args=args)
|
||||||
node = ConversationNode()
|
node = ConversationNode()
|
||||||
try:
|
try: rclpy.spin(node)
|
||||||
rclpy.spin(node)
|
except KeyboardInterrupt: pass
|
||||||
except KeyboardInterrupt:
|
finally: node.destroy_node(); rclpy.shutdown()
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
node.destroy_node()
|
|
||||||
rclpy.shutdown()
|
|
||||||
|
|||||||
@ -66,6 +66,7 @@ class SpeechPipelineNode(Node):
|
|||||||
self.declare_parameter("use_silero_vad", True)
|
self.declare_parameter("use_silero_vad", True)
|
||||||
self.declare_parameter("whisper_model", "small")
|
self.declare_parameter("whisper_model", "small")
|
||||||
self.declare_parameter("whisper_compute_type", "float16")
|
self.declare_parameter("whisper_compute_type", "float16")
|
||||||
|
self.declare_parameter("whisper_language", "")
|
||||||
self.declare_parameter("speaker_threshold", 0.65)
|
self.declare_parameter("speaker_threshold", 0.65)
|
||||||
self.declare_parameter("speaker_db_path", "/social_db/speaker_embeddings.json")
|
self.declare_parameter("speaker_db_path", "/social_db/speaker_embeddings.json")
|
||||||
self.declare_parameter("publish_partial", True)
|
self.declare_parameter("publish_partial", True)
|
||||||
@ -78,6 +79,7 @@ class SpeechPipelineNode(Node):
|
|||||||
self._use_silero = self.get_parameter("use_silero_vad").value
|
self._use_silero = self.get_parameter("use_silero_vad").value
|
||||||
self._whisper_model_name = self.get_parameter("whisper_model").value
|
self._whisper_model_name = self.get_parameter("whisper_model").value
|
||||||
self._compute_type = self.get_parameter("whisper_compute_type").value
|
self._compute_type = self.get_parameter("whisper_compute_type").value
|
||||||
|
self._whisper_language = self.get_parameter("whisper_language").value or None
|
||||||
self._speaker_thresh = self.get_parameter("speaker_threshold").value
|
self._speaker_thresh = self.get_parameter("speaker_threshold").value
|
||||||
self._speaker_db = self.get_parameter("speaker_db_path").value
|
self._speaker_db = self.get_parameter("speaker_db_path").value
|
||||||
self._publish_partial = self.get_parameter("publish_partial").value
|
self._publish_partial = self.get_parameter("publish_partial").value
|
||||||
@ -315,20 +317,24 @@ class SpeechPipelineNode(Node):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.get_logger().debug(f"Speaker ID error: {e}")
|
self.get_logger().debug(f"Speaker ID error: {e}")
|
||||||
|
|
||||||
# Streaming Whisper transcription
|
# Streaming Whisper transcription with language detection
|
||||||
partial_text = ""
|
partial_text = ""
|
||||||
|
detected_lang = self._whisper_language or "en"
|
||||||
try:
|
try:
|
||||||
segments_gen, _info = self._whisper.transcribe(
|
segments_gen, info = self._whisper.transcribe(
|
||||||
audio_np,
|
audio_np,
|
||||||
language="en",
|
language=self._whisper_language, # None = auto-detect
|
||||||
beam_size=3,
|
beam_size=3,
|
||||||
vad_filter=False,
|
vad_filter=False,
|
||||||
)
|
)
|
||||||
|
if hasattr(info, "language") and info.language:
|
||||||
|
detected_lang = info.language
|
||||||
for seg in segments_gen:
|
for seg in segments_gen:
|
||||||
partial_text += seg.text.strip() + " "
|
partial_text += seg.text.strip() + " "
|
||||||
if self._publish_partial:
|
if self._publish_partial:
|
||||||
self._publish_transcript(
|
self._publish_transcript(
|
||||||
partial_text.strip(), speaker_id, 0.0, duration, is_partial=True
|
partial_text.strip(), speaker_id, 0.0, duration,
|
||||||
|
language=detected_lang, is_partial=True,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.get_logger().error(f"Whisper error: {e}")
|
self.get_logger().error(f"Whisper error: {e}")
|
||||||
@ -340,15 +346,19 @@ class SpeechPipelineNode(Node):
|
|||||||
|
|
||||||
latency_ms = (time.perf_counter() - t0) * 1000
|
latency_ms = (time.perf_counter() - t0) * 1000
|
||||||
self.get_logger().info(
|
self.get_logger().info(
|
||||||
f"STT [{speaker_id}] ({duration:.1f}s, {latency_ms:.0f}ms): '{final_text}'"
|
f"STT [{speaker_id}/{detected_lang}] ({duration:.1f}s, {latency_ms:.0f}ms): "
|
||||||
|
f"'{final_text}'"
|
||||||
|
)
|
||||||
|
self._publish_transcript(
|
||||||
|
final_text, speaker_id, 0.9, duration,
|
||||||
|
language=detected_lang, is_partial=False,
|
||||||
)
|
)
|
||||||
self._publish_transcript(final_text, speaker_id, 0.9, duration, is_partial=False)
|
|
||||||
|
|
||||||
# ── Publishers ────────────────────────────────────────────────────────────
|
# ── Publishers ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def _publish_transcript(
|
def _publish_transcript(
|
||||||
self, text: str, speaker_id: str, confidence: float,
|
self, text: str, speaker_id: str, confidence: float,
|
||||||
duration: float, is_partial: bool
|
duration: float, language: str = "en", is_partial: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
msg = SpeechTranscript()
|
msg = SpeechTranscript()
|
||||||
msg.header.stamp = self.get_clock().now().to_msg()
|
msg.header.stamp = self.get_clock().now().to_msg()
|
||||||
@ -356,6 +366,7 @@ class SpeechPipelineNode(Node):
|
|||||||
msg.speaker_id = speaker_id
|
msg.speaker_id = speaker_id
|
||||||
msg.confidence = confidence
|
msg.confidence = confidence
|
||||||
msg.audio_duration = duration
|
msg.audio_duration = duration
|
||||||
|
msg.language = language
|
||||||
msg.is_partial = is_partial
|
msg.is_partial = is_partial
|
||||||
self._transcript_pub.publish(msg)
|
self._transcript_pub.publish(msg)
|
||||||
|
|
||||||
|
|||||||
@ -1,228 +1,136 @@
|
|||||||
"""tts_node.py — Streaming TTS with Piper / first-chunk streaming.
|
"""tts_node.py -- Streaming TTS with Piper / first-chunk streaming.
|
||||||
|
Issue #85/#167
|
||||||
Issue #85: Streaming TTS — Piper/XTTS integration with first-chunk streaming.
|
|
||||||
|
|
||||||
Pipeline:
|
|
||||||
/social/conversation/response (ConversationResponse)
|
|
||||||
→ sentence split → Piper ONNX synthesis (sentence by sentence)
|
|
||||||
→ PCM16 chunks → USB speaker (sounddevice) + publish /social/tts/audio
|
|
||||||
|
|
||||||
First-chunk strategy:
|
|
||||||
- On partial=true ConversationResponse, extract first sentence and synthesize
|
|
||||||
immediately → audio starts before LLM finishes generating
|
|
||||||
- On final=false, synthesize remaining sentences
|
|
||||||
|
|
||||||
Latency target: <200ms to first audio chunk.
|
|
||||||
|
|
||||||
ROS2 topics:
|
|
||||||
Subscribe: /social/conversation/response (saltybot_social_msgs/ConversationResponse)
|
|
||||||
Publish: /social/tts/audio (audio_msgs/Audio or std_msgs/UInt8MultiArray fallback)
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
voice_path (str, "/models/piper/en_US-lessac-medium.onnx")
|
|
||||||
sample_rate (int, 22050)
|
|
||||||
volume (float, 1.0)
|
|
||||||
audio_device (str, "") — sounddevice device name; "" = system default
|
|
||||||
playback_enabled (bool, true)
|
|
||||||
publish_audio (bool, false) — publish PCM to ROS2 topic
|
|
||||||
sentence_streaming (bool, true) — synthesize sentence-by-sentence
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import json, queue, threading, time
|
||||||
import queue
|
from typing import Any, Dict, Optional
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import rclpy
|
import rclpy
|
||||||
from rclpy.node import Node
|
from rclpy.node import Node
|
||||||
from rclpy.qos import QoSProfile
|
from rclpy.qos import QoSProfile
|
||||||
from std_msgs.msg import UInt8MultiArray
|
from std_msgs.msg import UInt8MultiArray
|
||||||
|
|
||||||
from saltybot_social_msgs.msg import ConversationResponse
|
from saltybot_social_msgs.msg import ConversationResponse
|
||||||
from .tts_utils import split_sentences, strip_ssml, apply_volume, chunk_pcm, estimate_duration_ms
|
from .tts_utils import split_sentences, strip_ssml, apply_volume, chunk_pcm, estimate_duration_ms
|
||||||
|
|
||||||
|
|
||||||
class TtsNode(Node):
|
class TtsNode(Node):
|
||||||
"""Streaming TTS node using Piper ONNX."""
|
"""Streaming TTS node using Piper ONNX with per-language voice switching."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__("tts_node")
|
super().__init__("tts_node")
|
||||||
|
|
||||||
# ── Parameters ──────────────────────────────────────────────────────
|
|
||||||
self.declare_parameter("voice_path", "/models/piper/en_US-lessac-medium.onnx")
|
self.declare_parameter("voice_path", "/models/piper/en_US-lessac-medium.onnx")
|
||||||
|
self.declare_parameter("voice_map_json", "{}")
|
||||||
|
self.declare_parameter("default_language", "en")
|
||||||
self.declare_parameter("sample_rate", 22050)
|
self.declare_parameter("sample_rate", 22050)
|
||||||
self.declare_parameter("volume", 1.0)
|
self.declare_parameter("volume", 1.0)
|
||||||
self.declare_parameter("audio_device", "")
|
self.declare_parameter("audio_device", "")
|
||||||
self.declare_parameter("playback_enabled", True)
|
self.declare_parameter("playback_enabled", True)
|
||||||
self.declare_parameter("publish_audio", False)
|
self.declare_parameter("publish_audio", False)
|
||||||
self.declare_parameter("sentence_streaming", True)
|
self.declare_parameter("sentence_streaming", True)
|
||||||
|
|
||||||
self._voice_path = self.get_parameter("voice_path").value
|
self._voice_path = self.get_parameter("voice_path").value
|
||||||
|
self._voice_map_json = self.get_parameter("voice_map_json").value
|
||||||
|
self._default_language = self.get_parameter("default_language").value or "en"
|
||||||
self._sample_rate = self.get_parameter("sample_rate").value
|
self._sample_rate = self.get_parameter("sample_rate").value
|
||||||
self._volume = self.get_parameter("volume").value
|
self._volume = self.get_parameter("volume").value
|
||||||
self._audio_device = self.get_parameter("audio_device").value or None
|
self._audio_device = self.get_parameter("audio_device").value or None
|
||||||
self._playback = self.get_parameter("playback_enabled").value
|
self._playback = self.get_parameter("playback_enabled").value
|
||||||
self._publish_audio = self.get_parameter("publish_audio").value
|
self._publish_audio = self.get_parameter("publish_audio").value
|
||||||
self._sentence_streaming = self.get_parameter("sentence_streaming").value
|
self._sentence_streaming = self.get_parameter("sentence_streaming").value
|
||||||
|
try:
|
||||||
# ── Publishers / Subscribers ─────────────────────────────────────────
|
extra: Dict[str, str] = json.loads(self._voice_map_json) if self._voice_map_json.strip() not in ("{}","") else {}
|
||||||
|
except Exception as e:
|
||||||
|
self.get_logger().warn(f"voice_map_json parse error: {e}"); extra = {}
|
||||||
|
self._voice_paths: Dict[str, str] = {self._default_language: self._voice_path}
|
||||||
|
self._voice_paths.update(extra)
|
||||||
qos = QoSProfile(depth=10)
|
qos = QoSProfile(depth=10)
|
||||||
self._resp_sub = self.create_subscription(
|
self._resp_sub = self.create_subscription(ConversationResponse, "/social/conversation/response", self._on_response, qos)
|
||||||
ConversationResponse, "/social/conversation/response",
|
|
||||||
self._on_response, qos
|
|
||||||
)
|
|
||||||
if self._publish_audio:
|
if self._publish_audio:
|
||||||
self._audio_pub = self.create_publisher(
|
self._audio_pub = self.create_publisher(UInt8MultiArray, "/social/tts/audio", qos)
|
||||||
UInt8MultiArray, "/social/tts/audio", qos
|
self._voices: Dict[str, Any] = {}
|
||||||
)
|
self._voices_lock = threading.Lock()
|
||||||
|
|
||||||
# ── TTS engine ────────────────────────────────────────────────────────
|
|
||||||
self._voice = None
|
|
||||||
self._playback_queue: queue.Queue = queue.Queue(maxsize=16)
|
self._playback_queue: queue.Queue = queue.Queue(maxsize=16)
|
||||||
self._current_turn = -1
|
self._current_turn = -1
|
||||||
self._synthesized_turns: set = set() # turn_ids already synthesized
|
self._synthesized_turns: set = set()
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
threading.Thread(target=self._load_voice_for_lang, args=(self._default_language,), daemon=True).start()
|
||||||
threading.Thread(target=self._load_voice, daemon=True).start()
|
|
||||||
threading.Thread(target=self._playback_worker, daemon=True).start()
|
threading.Thread(target=self._playback_worker, daemon=True).start()
|
||||||
|
self.get_logger().info(f"TtsNode init (langs={list(self._voice_paths.keys())})")
|
||||||
|
|
||||||
self.get_logger().info(
|
def _load_voice_for_lang(self, lang: str) -> None:
|
||||||
f"TtsNode init (voice={self._voice_path}, "
|
path = self._voice_paths.get(lang)
|
||||||
f"streaming={self._sentence_streaming})"
|
if not path:
|
||||||
)
|
self.get_logger().warn(f"No voice for '{lang}', fallback to '{self._default_language}'"); return
|
||||||
|
with self._voices_lock:
|
||||||
# ── Voice loading ─────────────────────────────────────────────────────────
|
if lang in self._voices: return
|
||||||
|
|
||||||
def _load_voice(self) -> None:
|
|
||||||
t0 = time.time()
|
|
||||||
self.get_logger().info(f"Loading Piper voice: {self._voice_path}")
|
|
||||||
try:
|
try:
|
||||||
from piper import PiperVoice
|
from piper import PiperVoice
|
||||||
self._voice = PiperVoice.load(self._voice_path)
|
voice = PiperVoice.load(path)
|
||||||
# Warmup synthesis to pre-JIT ONNX graph
|
list(voice.synthesize_stream_raw("Hello."))
|
||||||
warmup_text = "Hello."
|
with self._voices_lock: self._voices[lang] = voice
|
||||||
list(self._voice.synthesize_stream_raw(warmup_text))
|
self.get_logger().info(f"Piper [{lang}] ready")
|
||||||
self.get_logger().info(f"Piper voice ready ({time.time()-t0:.1f}s)")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.get_logger().error(f"Piper voice load failed: {e}")
|
self.get_logger().error(f"Piper voice load failed [{lang}]: {e}")
|
||||||
|
|
||||||
# ── Response handler ──────────────────────────────────────────────────────
|
def _get_voice(self, lang: str):
|
||||||
|
with self._voices_lock:
|
||||||
|
v = self._voices.get(lang)
|
||||||
|
if v is not None: return v
|
||||||
|
if lang in self._voice_paths:
|
||||||
|
threading.Thread(target=self._load_voice_for_lang, args=(lang,), daemon=True).start()
|
||||||
|
return self._voices.get(self._default_language)
|
||||||
|
|
||||||
def _on_response(self, msg: ConversationResponse) -> None:
|
def _on_response(self, msg: ConversationResponse) -> None:
|
||||||
"""Handle streaming LLM response — synthesize sentence by sentence."""
|
if not msg.text.strip(): return
|
||||||
if not msg.text.strip():
|
lang = msg.language if msg.language else self._default_language
|
||||||
return
|
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
is_new_turn = msg.turn_id != self._current_turn
|
if msg.turn_id != self._current_turn:
|
||||||
if is_new_turn:
|
self._current_turn = msg.turn_id; self._synthesized_turns = set()
|
||||||
self._current_turn = msg.turn_id
|
|
||||||
# Clear old synthesized sentence cache for this new turn
|
|
||||||
self._synthesized_turns = set()
|
|
||||||
|
|
||||||
text = strip_ssml(msg.text)
|
text = strip_ssml(msg.text)
|
||||||
|
|
||||||
if self._sentence_streaming:
|
if self._sentence_streaming:
|
||||||
sentences = split_sentences(text)
|
for sentence in split_sentences(text):
|
||||||
for sentence in sentences:
|
|
||||||
# Track which sentences we've already queued by content hash
|
|
||||||
key = (msg.turn_id, hash(sentence))
|
key = (msg.turn_id, hash(sentence))
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if key in self._synthesized_turns:
|
if key in self._synthesized_turns: continue
|
||||||
continue
|
|
||||||
self._synthesized_turns.add(key)
|
self._synthesized_turns.add(key)
|
||||||
self._queue_synthesis(sentence)
|
self._queue_synthesis(sentence, lang)
|
||||||
elif not msg.is_partial:
|
elif not msg.is_partial:
|
||||||
# Non-streaming: synthesize full response at end
|
self._queue_synthesis(text, lang)
|
||||||
self._queue_synthesis(text)
|
|
||||||
|
|
||||||
def _queue_synthesis(self, text: str) -> None:
|
def _queue_synthesis(self, text: str, lang: str) -> None:
|
||||||
"""Queue a text segment for synthesis in the playback worker."""
|
if not text.strip(): return
|
||||||
if not text.strip():
|
try: self._playback_queue.put_nowait((text.strip(), lang))
|
||||||
return
|
except queue.Full: self.get_logger().warn("TTS queue full")
|
||||||
try:
|
|
||||||
self._playback_queue.put_nowait(text.strip())
|
|
||||||
except queue.Full:
|
|
||||||
self.get_logger().warn("TTS playback queue full, dropping segment")
|
|
||||||
|
|
||||||
# ── Playback worker ───────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _playback_worker(self) -> None:
|
def _playback_worker(self) -> None:
|
||||||
"""Consume synthesis queue: synthesize → play → publish."""
|
|
||||||
while rclpy.ok():
|
while rclpy.ok():
|
||||||
try:
|
try: item = self._playback_queue.get(timeout=0.5)
|
||||||
text = self._playback_queue.get(timeout=0.5)
|
except queue.Empty: continue
|
||||||
except queue.Empty:
|
text, lang = item
|
||||||
continue
|
voice = self._get_voice(lang)
|
||||||
|
if voice is None:
|
||||||
if self._voice is None:
|
self.get_logger().warn(f"No voice for '{lang}'"); self._playback_queue.task_done(); continue
|
||||||
self.get_logger().warn("TTS voice not loaded yet")
|
|
||||||
self._playback_queue.task_done()
|
|
||||||
continue
|
|
||||||
|
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
pcm_data = self._synthesize(text)
|
pcm_data = self._synthesize(text, voice)
|
||||||
if pcm_data is None:
|
if pcm_data is None: self._playback_queue.task_done(); continue
|
||||||
self._playback_queue.task_done()
|
if self._volume != 1.0: pcm_data = apply_volume(pcm_data, self._volume)
|
||||||
continue
|
if self._playback: self._play_audio(pcm_data)
|
||||||
|
if self._publish_audio: self._publish_pcm(pcm_data)
|
||||||
synth_ms = (time.perf_counter() - t0) * 1000
|
|
||||||
dur_ms = estimate_duration_ms(pcm_data, self._sample_rate)
|
|
||||||
self.get_logger().debug(
|
|
||||||
f"TTS '{text[:40]}' synth={synth_ms:.0f}ms, dur={dur_ms:.0f}ms"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._volume != 1.0:
|
|
||||||
pcm_data = apply_volume(pcm_data, self._volume)
|
|
||||||
|
|
||||||
if self._playback:
|
|
||||||
self._play_audio(pcm_data)
|
|
||||||
|
|
||||||
if self._publish_audio:
|
|
||||||
self._publish_pcm(pcm_data)
|
|
||||||
|
|
||||||
self._playback_queue.task_done()
|
self._playback_queue.task_done()
|
||||||
|
|
||||||
def _synthesize(self, text: str) -> Optional[bytes]:
|
def _synthesize(self, text: str, voice) -> Optional[bytes]:
|
||||||
"""Synthesize text to PCM16 bytes using Piper streaming."""
|
try: return b"".join(voice.synthesize_stream_raw(text))
|
||||||
if self._voice is None:
|
except Exception as e: self.get_logger().error(f"TTS error: {e}"); return None
|
||||||
return None
|
|
||||||
try:
|
|
||||||
chunks = list(self._voice.synthesize_stream_raw(text))
|
|
||||||
return b"".join(chunks)
|
|
||||||
except Exception as e:
|
|
||||||
self.get_logger().error(f"TTS synthesis error: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _play_audio(self, pcm_data: bytes) -> None:
|
def _play_audio(self, pcm_data: bytes) -> None:
|
||||||
"""Play PCM16 data on USB speaker via sounddevice."""
|
|
||||||
try:
|
try:
|
||||||
import sounddevice as sd
|
import sounddevice as sd, numpy as np
|
||||||
import numpy as np
|
sd.play(np.frombuffer(pcm_data,dtype=np.int16).astype(np.float32)/32768.0, samplerate=self._sample_rate, device=self._audio_device, blocking=True)
|
||||||
samples = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0
|
except Exception as e: self.get_logger().error(f"Playback error: {e}")
|
||||||
sd.play(samples, samplerate=self._sample_rate, device=self._audio_device,
|
|
||||||
blocking=True)
|
|
||||||
except Exception as e:
|
|
||||||
self.get_logger().error(f"Audio playback error: {e}")
|
|
||||||
|
|
||||||
def _publish_pcm(self, pcm_data: bytes) -> None:
|
def _publish_pcm(self, pcm_data: bytes) -> None:
|
||||||
"""Publish PCM data as UInt8MultiArray."""
|
if not hasattr(self,"_audio_pub"): return
|
||||||
if not hasattr(self, "_audio_pub"):
|
msg = UInt8MultiArray(); msg.data = list(pcm_data); self._audio_pub.publish(msg)
|
||||||
return
|
|
||||||
msg = UInt8MultiArray()
|
|
||||||
msg.data = list(pcm_data)
|
|
||||||
self._audio_pub.publish(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def main(args=None) -> None:
|
def main(args=None) -> None:
|
||||||
rclpy.init(args=args)
|
rclpy.init(args=args)
|
||||||
node = TtsNode()
|
node = TtsNode()
|
||||||
try:
|
try: rclpy.spin(node)
|
||||||
rclpy.spin(node)
|
except KeyboardInterrupt: pass
|
||||||
except KeyboardInterrupt:
|
finally: node.destroy_node(); rclpy.shutdown()
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
node.destroy_node()
|
|
||||||
rclpy.shutdown()
|
|
||||||
|
|||||||
122
jetson/ros2_ws/src/saltybot_social/test/test_multilang.py
Normal file
122
jetson/ros2_ws/src/saltybot_social/test/test_multilang.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
"""test_multilang.py -- Unit tests for Issue #167 multi-language support."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import json, os
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
def _pkg_root():
|
||||||
|
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
def _read_src(rel_path):
|
||||||
|
with open(os.path.join(_pkg_root(), rel_path)) as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
def _extract_lang_names():
|
||||||
|
import ast
|
||||||
|
src = _read_src("saltybot_social/conversation_node.py")
|
||||||
|
start = src.index("_LANG_NAMES: Dict[str, str] = {")
|
||||||
|
end = src.index("\n}", start) + 2
|
||||||
|
return ast.literal_eval(src[start:end].split("=",1)[1].strip())
|
||||||
|
|
||||||
|
class TestLangNames:
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def ln(self): return _extract_lang_names()
|
||||||
|
def test_english(self, ln): assert ln["en"] == "English"
|
||||||
|
def test_french(self, ln): assert ln["fr"] == "French"
|
||||||
|
def test_spanish(self, ln): assert ln["es"] == "Spanish"
|
||||||
|
def test_german(self, ln): assert ln["de"] == "German"
|
||||||
|
def test_japanese(self, ln):assert ln["ja"] == "Japanese"
|
||||||
|
def test_chinese(self, ln): assert ln["zh"] == "Chinese"
|
||||||
|
def test_arabic(self, ln): assert ln["ar"] == "Arabic"
|
||||||
|
def test_at_least_15(self, ln): assert len(ln) >= 15
|
||||||
|
def test_lowercase_keys(self, ln):
|
||||||
|
for k in ln: assert k == k.lower() and 2 <= len(k) <= 3
|
||||||
|
def test_nonempty_values(self, ln):
|
||||||
|
for k, v in ln.items(): assert v
|
||||||
|
|
||||||
|
class TestLanguageHint:
|
||||||
|
def _h(self, sl, sid, ln):
|
||||||
|
lang = sl.get(sid, "en")
|
||||||
|
return f"[Please respond in {ln.get(lang, lang)}.]" if lang and lang != "en" else ""
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def ln(self): return _extract_lang_names()
|
||||||
|
def test_english_no_hint(self, ln): assert self._h({"p": "en"}, "p", ln) == ""
|
||||||
|
def test_unknown_no_hint(self, ln): assert self._h({}, "p", ln) == ""
|
||||||
|
def test_french(self, ln): assert self._h({"p":"fr"},"p",ln) == "[Please respond in French.]"
|
||||||
|
def test_spanish(self, ln): assert self._h({"p":"es"},"p",ln) == "[Please respond in Spanish.]"
|
||||||
|
def test_unknown_code(self, ln): assert "xx" in self._h({"p":"xx"},"p",ln)
|
||||||
|
def test_brackets(self, ln):
|
||||||
|
h = self._h({"p":"de"},"p",ln)
|
||||||
|
assert h.startswith("[") and h.endswith("]")
|
||||||
|
|
||||||
|
class TestVoiceMap:
|
||||||
|
def _parse(self, jstr, dl, dp):
|
||||||
|
try: extra = json.loads(jstr) if jstr.strip() not in ("{}","") else {}
|
||||||
|
except: extra = {}
|
||||||
|
r = {dl: dp}; r.update(extra); return r
|
||||||
|
def test_empty(self): assert self._parse("{}","en","/e") == {"en":"/e"}
|
||||||
|
def test_extra(self):
|
||||||
|
vm = self._parse('{"fr":"/f"}', "en", "/e")
|
||||||
|
assert vm["fr"] == "/f"
|
||||||
|
def test_invalid(self): assert self._parse("BAD","en","/e") == {"en":"/e"}
|
||||||
|
def test_multi(self):
|
||||||
|
assert len(self._parse(json.dumps({"fr":"/f","es":"/s"}),"en","/e")) == 3
|
||||||
|
|
||||||
|
class TestVoiceSelect:
|
||||||
|
def _s(self, voices, lang, default):
|
||||||
|
return voices.get(lang) or voices.get(default)
|
||||||
|
def test_exact(self): assert self._s({"en":"E","fr":"F"},"fr","en") == "F"
|
||||||
|
def test_fallback(self): assert self._s({"en":"E"},"fr","en") == "E"
|
||||||
|
def test_none(self): assert self._s({},"fr","en") is None
|
||||||
|
|
||||||
|
class TestSttFields:
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def src(self): return _read_src("saltybot_social/speech_pipeline_node.py")
|
||||||
|
def test_param(self, src): assert "whisper_language" in src
|
||||||
|
def test_detected_lang(self, src): assert "detected_lang" in src
|
||||||
|
def test_msg_language(self, src): assert "msg.language = language" in src
|
||||||
|
def test_auto_detect(self, src): assert "language=self._whisper_language" in src
|
||||||
|
|
||||||
|
class TestConvFields:
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def src(self): return _read_src("saltybot_social/conversation_node.py")
|
||||||
|
def test_speaker_lang(self, src): assert "_speaker_lang" in src
|
||||||
|
def test_lang_hint_method(self, src): assert "_language_hint" in src
|
||||||
|
def test_msg_language(self, src): assert "msg.language = language" in src
|
||||||
|
def test_lang_names(self, src): assert "_LANG_NAMES" in src
|
||||||
|
def test_please_respond(self, src): assert "Please respond in" in src
|
||||||
|
def test_emotion_coexists(self, src): assert "_emotion_hint" in src
|
||||||
|
|
||||||
|
class TestTtsFields:
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def src(self): return _read_src("saltybot_social/tts_node.py")
|
||||||
|
def test_voice_map_json(self, src): assert "voice_map_json" in src
|
||||||
|
def test_default_lang(self, src): assert "default_language" in src
|
||||||
|
def test_voices_dict(self, src): assert "_voices" in src
|
||||||
|
def test_get_voice(self, src): assert "_get_voice" in src
|
||||||
|
def test_load_voice_for_lang(self, src): assert "_load_voice_for_lang" in src
|
||||||
|
def test_queue_tuple(self, src): assert "(text.strip(), lang)" in src
|
||||||
|
def test_synthesize_voice_arg(self, src): assert "_synthesize(text, voice)" in src
|
||||||
|
|
||||||
|
class TestMsgDefs:
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def tr(self): return _read_src("../saltybot_social_msgs/msg/SpeechTranscript.msg")
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def re(self): return _read_src("../saltybot_social_msgs/msg/ConversationResponse.msg")
|
||||||
|
def test_transcript_lang(self, tr): assert "string language" in tr
|
||||||
|
def test_transcript_bcp47(self, tr): assert "BCP-47" in tr
|
||||||
|
def test_response_lang(self, re): assert "string language" in re
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
def test_empty_lang_no_hint(self):
|
||||||
|
lang = "" or "en"; assert lang == "en"
|
||||||
|
def test_lang_flows(self):
|
||||||
|
d: Dict[str,str] = {}; d["p1"] = "fr"
|
||||||
|
assert d.get("p1","en") == "fr"
|
||||||
|
def test_multi_speakers(self):
|
||||||
|
d = {"p1":"fr","p2":"es"}
|
||||||
|
assert d["p1"] == "fr" and d["p2"] == "es"
|
||||||
|
def test_voice_map_code_in_tts(self):
|
||||||
|
src = _read_src("saltybot_social/tts_node.py")
|
||||||
|
assert "voice_map_json" in src and "json.loads" in src
|
||||||
@ -7,3 +7,4 @@ string text # Full or partial response text
|
|||||||
string speaker_id # Who the response is addressed to
|
string speaker_id # Who the response is addressed to
|
||||||
bool is_partial # true = streaming token chunk, false = final response
|
bool is_partial # true = streaming token chunk, false = final response
|
||||||
int32 turn_id # Conversation turn counter (for deduplication)
|
int32 turn_id # Conversation turn counter (for deduplication)
|
||||||
|
string language # BCP-47 language code for TTS voice selection e.g. "en" "fr" "es"
|
||||||
|
|||||||
@ -8,3 +8,4 @@ string speaker_id # e.g. "person_42" or "unknown"
|
|||||||
float32 confidence # ASR confidence 0..1
|
float32 confidence # ASR confidence 0..1
|
||||||
float32 audio_duration # Duration of the utterance in seconds
|
float32 audio_duration # Duration of the utterance in seconds
|
||||||
bool is_partial # true = intermediate streaming result, false = final
|
bool is_partial # true = intermediate streaming result, false = final
|
||||||
|
string language # BCP-47 detected language code e.g. "en" "fr" "es" (empty = unknown)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user