diff --git a/jetson/ros2_ws/src/saltybot_audio_pipeline/README.md b/jetson/ros2_ws/src/saltybot_audio_pipeline/README.md new file mode 100644 index 0000000..9074238 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_audio_pipeline/README.md @@ -0,0 +1,39 @@ +# Audio Pipeline (Issue #503) + +Comprehensive audio pipeline for Salty Bot with full voice interaction support. + +## Features + +- **Hardware**: Jabra SPEAK 810 USB audio device integration +- **Wake Word**: openwakeword "Hey Salty" detection +- **STT**: whisper.cpp running on Jetson GPU (small/base/medium/large models) +- **TTS**: Piper synthesis with voice switching +- **State Machine**: listening → processing → speaking +- **MQTT**: Real-time status reporting +- **Metrics**: Latency tracking and performance monitoring + +## ROS2 Topics + +Published: +- `/saltybot/speech/transcribed_text` (String): Final STT output +- `/saltybot/audio/state` (String): Current audio state +- `/saltybot/audio/status` (String): JSON metrics with latencies + +## MQTT Topics + +- `saltybot/audio/state`: Current state +- `saltybot/audio/status`: Complete status JSON + +## Launch + +```bash +ros2 launch saltybot_audio_pipeline audio_pipeline.launch.py +``` + +## Configuration + +See `config/audio_pipeline_params.yaml` for tuning: +- `device_name`: Jabra device +- `wake_word_threshold`: 0.5 (0.0-1.0) +- `whisper_model`: small/base/medium/large +- `mqtt_enabled`: true/false diff --git a/jetson/ros2_ws/src/saltybot_audio_pipeline/config/audio_pipeline_params.yaml b/jetson/ros2_ws/src/saltybot_audio_pipeline/config/audio_pipeline_params.yaml new file mode 100644 index 0000000..ad72268 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_audio_pipeline/config/audio_pipeline_params.yaml @@ -0,0 +1,18 @@ +audio_pipeline_node: + ros__parameters: + device_name: "Jabra SPEAK 810" + audio_device_index: -1 + sample_rate: 16000 + chunk_size: 512 + wake_word_model: "hey_salty" + wake_word_threshold: 0.5 + wake_word_timeout_s: 8.0 + whisper_model: "small" + whisper_compute_type: "float16" + whisper_language: "" + tts_voice_path: "/models/piper/en_US-lessac-medium.onnx" + tts_sample_rate: 22050 + mqtt_enabled: true + mqtt_broker: "localhost" + mqtt_port: 1883 + mqtt_base_topic: "saltybot/audio" diff --git a/jetson/ros2_ws/src/saltybot_audio_pipeline/launch/audio_pipeline.launch.py b/jetson/ros2_ws/src/saltybot_audio_pipeline/launch/audio_pipeline.launch.py new file mode 100644 index 0000000..5dd06b1 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_audio_pipeline/launch/audio_pipeline.launch.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +from launch import LaunchDescription +from launch_ros.actions import Node +from ament_index_python.packages import get_package_share_directory +import os + +def generate_launch_description(): + pkg_dir = get_package_share_directory("saltybot_audio_pipeline") + config_path = os.path.join(pkg_dir, "config", "audio_pipeline_params.yaml") + return LaunchDescription([ + Node( + package="saltybot_audio_pipeline", + executable="audio_pipeline_node", + name="audio_pipeline_node", + parameters=[config_path], + output="screen", + emulate_tty=True, + ), + ]) diff --git a/jetson/ros2_ws/src/saltybot_audio_pipeline/package.xml b/jetson/ros2_ws/src/saltybot_audio_pipeline/package.xml new file mode 100644 index 0000000..b5bee53 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_audio_pipeline/package.xml @@ -0,0 +1,12 @@ + + + saltybot_audio_pipeline + 1.0.0 + Full audio pipeline: Jabra SPEAK 810, wake word, STT, TTS with MQTT (Issue #503) + Salty Lab + Apache-2.0 + ament_python + rclpy + std_msgs + pytest + \ No newline at end of file diff --git a/jetson/ros2_ws/src/saltybot_audio_pipeline/resource/saltybot_audio_pipeline b/jetson/ros2_ws/src/saltybot_audio_pipeline/resource/saltybot_audio_pipeline new file mode 100644 index 0000000..e69de29 diff --git a/jetson/ros2_ws/src/saltybot_audio_pipeline/saltybot_audio_pipeline/__init__.py b/jetson/ros2_ws/src/saltybot_audio_pipeline/saltybot_audio_pipeline/__init__.py new file mode 100644 index 0000000..f6ffb99 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_audio_pipeline/saltybot_audio_pipeline/__init__.py @@ -0,0 +1,2 @@ +"""Audio pipeline for Salty Bot.""" +__version__ = "1.0.0" diff --git a/jetson/ros2_ws/src/saltybot_audio_pipeline/saltybot_audio_pipeline/audio_pipeline_node.py b/jetson/ros2_ws/src/saltybot_audio_pipeline/saltybot_audio_pipeline/audio_pipeline_node.py new file mode 100644 index 0000000..0cc9cdc --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_audio_pipeline/saltybot_audio_pipeline/audio_pipeline_node.py @@ -0,0 +1,380 @@ +#!/usr/bin/env python3 +"""audio_pipeline_node.py — Full audio pipeline with Jabra SPEAK 810 I/O (Issue #503).""" + +from __future__ import annotations +import json, os, threading, time +from enum import Enum +from dataclasses import dataclass, asdict +from typing import Optional + +import rclpy +from rclpy.node import Node +from rclpy.qos import QoSProfile +from std_msgs.msg import String + +from .audio_utils import EnergyVAD, UtteranceSegmenter, AudioBuffer, pcm16_to_float32, float32_to_pcm16, resample_audio + +try: + import paho.mqtt.client as mqtt + _MQTT_AVAILABLE = True +except ImportError: + _MQTT_AVAILABLE = False + + +class AudioState(Enum): + IDLE = "idle" + LISTENING = "listening" + WAKE_WORD_DETECTED = "wake_detected" + PROCESSING = "processing" + SPEAKING = "speaking" + ERROR = "error" + + +@dataclass +class AudioMetrics: + wake_to_stt_ms: float = 0.0 + stt_processing_ms: float = 0.0 + tts_synthesis_ms: float = 0.0 + total_latency_ms: float = 0.0 + transcribed_text: str = "" + speaker_id: str = "unknown" + error_msg: str = "" + + +class MqttClient: + def __init__(self, broker: str, port: int, base_topic: str): + self.broker = broker + self.port = port + self.base_topic = base_topic + self._client = None + self._connected = False + if _MQTT_AVAILABLE: + try: + self._client = mqtt.Client(client_id=f"saltybot-audio-{int(time.time())}") + self._client.on_connect = lambda c, u, f, rc: setattr(self, '_connected', rc == 0) + self._client.on_disconnect = lambda c, u, rc: setattr(self, '_connected', False) + self._client.connect_async(broker, port, keepalive=60) + self._client.loop_start() + except Exception as e: + print(f"MQTT init failed: {e}") + + def publish(self, topic: str, payload: str) -> bool: + if not self._client or not self._connected: + return False + try: + self._client.publish(topic, payload, qos=0) + return True + except Exception: + return False + + def disconnect(self) -> None: + if self._client: + self._client.loop_stop() + self._client.disconnect() + + +class JabraAudioDevice: + def __init__(self, device_name: str = "Jabra SPEAK 810", device_idx: int = -1): + self.device_name = device_name + self.device_idx = device_idx + self._pa = None + self._input_stream = None + self._output_stream = None + self._is_open = False + + def open(self, sample_rate: int = 16000, chunk_size: int = 512) -> bool: + try: + import pyaudio + self._pa = pyaudio.PyAudio() + if self.device_idx < 0: + self.device_idx = self._find_device_index() or None + self._input_stream = self._pa.open(format=pyaudio.paInt16, channels=1, rate=sample_rate, + input=True, input_device_index=self.device_idx, frames_per_buffer=chunk_size, start=False) + self._output_stream = self._pa.open(format=pyaudio.paInt16, channels=1, rate=sample_rate, + output=True, output_device_index=self.device_idx, frames_per_buffer=chunk_size, start=False) + self._is_open = True + return True + except Exception as e: + print(f"Failed to open Jabra device: {e}") + return False + + def _find_device_index(self) -> int: + try: + import pyaudio + pa = pyaudio.PyAudio() + for i in range(pa.get_device_count()): + info = pa.get_device_info_by_index(i) + if "jabra" in info["name"].lower() or "speak" in info["name"].lower(): + return i + except Exception: + pass + return -1 + + def read_chunk(self, chunk_size: int = 512) -> Optional[bytes]: + if not self._is_open or not self._input_stream: + return None + try: + return self._input_stream.read(chunk_size, exception_on_overflow=False) + except Exception: + return None + + def write_chunk(self, pcm_data: bytes) -> bool: + if not self._is_open or not self._output_stream: + return False + try: + self._output_stream.write(pcm_data) + return True + except Exception: + return False + + def close(self) -> None: + if self._input_stream: + self._input_stream.stop_stream() + self._input_stream.close() + if self._output_stream: + self._output_stream.stop_stream() + self._output_stream.close() + if self._pa: + self._pa.terminate() + self._is_open = False + + +class AudioPipelineNode(Node): + def __init__(self) -> None: + super().__init__("audio_pipeline_node") + for param, default in [ + ("device_name", "Jabra SPEAK 810"), + ("audio_device_index", -1), + ("sample_rate", 16000), + ("chunk_size", 512), + ("wake_word_model", "hey_salty"), + ("wake_word_threshold", 0.5), + ("wake_word_timeout_s", 8.0), + ("whisper_model", "small"), + ("whisper_compute_type", "float16"), + ("whisper_language", ""), + ("tts_voice_path", "/models/piper/en_US-lessac-medium.onnx"), + ("tts_sample_rate", 22050), + ("mqtt_enabled", True), + ("mqtt_broker", "localhost"), + ("mqtt_port", 1883), + ("mqtt_base_topic", "saltybot/audio"), + ]: + self.declare_parameter(param, default) + + device_name = self.get_parameter("device_name").value + device_idx = self.get_parameter("audio_device_index").value + self._sample_rate = self.get_parameter("sample_rate").value + self._chunk_size = self.get_parameter("chunk_size").value + self._ww_model = self.get_parameter("wake_word_model").value + self._ww_thresh = self.get_parameter("wake_word_threshold").value + self._whisper_model = self.get_parameter("whisper_model").value + self._compute_type = self.get_parameter("whisper_compute_type").value + self._whisper_lang = self.get_parameter("whisper_language").value or None + self._tts_voice_path = self.get_parameter("tts_voice_path").value + self._tts_rate = self.get_parameter("tts_sample_rate").value + mqtt_enabled = self.get_parameter("mqtt_enabled").value + mqtt_broker = self.get_parameter("mqtt_broker").value + mqtt_port = self.get_parameter("mqtt_port").value + mqtt_topic = self.get_parameter("mqtt_base_topic").value + + qos = QoSProfile(depth=10) + self._text_pub = self.create_publisher(String, "/saltybot/speech/transcribed_text", qos) + self._state_pub = self.create_publisher(String, "/saltybot/audio/state", qos) + self._status_pub = self.create_publisher(String, "/saltybot/audio/status", qos) + + self._state = AudioState.IDLE + self._state_lock = threading.Lock() + self._metrics = AudioMetrics() + self._running = False + + self._jabra = JabraAudioDevice(device_name, device_idx) + self._oww = None + self._whisper = None + self._tts_voice = None + + self._mqtt = None + if mqtt_enabled and _MQTT_AVAILABLE: + try: + self._mqtt = MqttClient(mqtt_broker, mqtt_port, mqtt_topic) + self.get_logger().info(f"MQTT enabled: {mqtt_broker}:{mqtt_port}/{mqtt_topic}") + except Exception as e: + self.get_logger().warn(f"MQTT init failed: {e}") + + self._vad = EnergyVAD(threshold_db=-35.0) + self._segmenter = UtteranceSegmenter(self._vad, sample_rate=self._sample_rate) + self._audio_buffer = AudioBuffer(capacity_s=30.0, sample_rate=self._sample_rate) + + threading.Thread(target=self._init_pipeline, daemon=True).start() + + def _init_pipeline(self) -> None: + self.get_logger().info("Initializing audio pipeline...") + t0 = time.time() + + if not self._jabra.open(self._sample_rate, self._chunk_size): + self._set_state(AudioState.ERROR) + self._metrics.error_msg = "Failed to open Jabra device" + return + + try: + from openwakeword.model import Model as OWWModel + self._oww = OWWModel(wakeword_models=[self._ww_model]) + self.get_logger().info(f"openwakeword '{self._ww_model}' loaded") + except Exception as e: + self.get_logger().warn(f"openwakeword failed: {e}") + + try: + from faster_whisper import WhisperModel + self._whisper = WhisperModel(self._whisper_model, device="cuda", + compute_type=self._compute_type, download_root="/models") + self.get_logger().info(f"Whisper '{self._whisper_model}' loaded") + except Exception as e: + self.get_logger().error(f"Whisper failed: {e}") + self._set_state(AudioState.ERROR) + self._metrics.error_msg = f"Whisper init: {e}" + return + + try: + from piper import PiperVoice + self._tts_voice = PiperVoice.load(self._tts_voice_path) + self.get_logger().info("Piper TTS loaded") + except Exception as e: + self.get_logger().warn(f"Piper TTS failed: {e}") + + self.get_logger().info(f"Audio pipeline ready ({time.time()-t0:.1f}s)") + self._set_state(AudioState.LISTENING) + self._publish_status() + + threading.Thread(target=self._audio_loop, daemon=True).start() + + def _audio_loop(self) -> None: + self._running = True + import numpy as np + while self._running and self._state != AudioState.ERROR: + raw_chunk = self._jabra.read_chunk(self._chunk_size) + if raw_chunk is None: + continue + samples = pcm16_to_float32(raw_chunk) + self._audio_buffer.push(samples) + + if self._state == AudioState.LISTENING and self._oww is not None: + try: + preds = self._oww.predict(samples) + score = preds.get(self._ww_model, 0.0) + if isinstance(score, (list, tuple)): + score = score[-1] + if score >= self._ww_thresh: + self.get_logger().info(f"Wake word detected (score={score:.2f})") + self._metrics.wake_to_stt_ms = 0.0 + self._set_state(AudioState.WAKE_WORD_DETECTED) + self._segmenter.reset() + self._audio_buffer.clear() + except Exception as e: + self.get_logger().debug(f"Wake word error: {e}") + + if self._state == AudioState.WAKE_WORD_DETECTED: + completed = self._segmenter.push(samples) + for utt_samples, duration in completed: + threading.Thread(target=self._process_utterance, + args=(utt_samples, duration), daemon=True).start() + + def _process_utterance(self, audio_samples: list, duration: float) -> None: + if self._whisper is None: + self._set_state(AudioState.LISTENING) + return + self._set_state(AudioState.PROCESSING) + t0 = time.time() + try: + import numpy as np + audio_np = np.array(audio_samples, dtype=np.float32) if isinstance(audio_samples, list) else audio_samples.astype(np.float32) + segments_gen, info = self._whisper.transcribe(audio_np, language=self._whisper_lang, beam_size=3, vad_filter=False) + text = " ".join([seg.text.strip() for seg in segments_gen]).strip() + if text: + stt_time = (time.time() - t0) * 1000 + self._metrics.stt_processing_ms = stt_time + self._metrics.transcribed_text = text + self._metrics.total_latency_ms = stt_time + msg = String() + msg.data = text + self._text_pub.publish(msg) + self.get_logger().info(f"STT [{duration:.1f}s, {stt_time:.0f}ms]: '{text}'") + self._process_tts(text) + else: + self._set_state(AudioState.LISTENING) + except Exception as e: + self.get_logger().error(f"STT error: {e}") + self._metrics.error_msg = str(e) + self._set_state(AudioState.LISTENING) + + def _process_tts(self, text: str) -> None: + if self._tts_voice is None: + self._set_state(AudioState.LISTENING) + return + self._set_state(AudioState.SPEAKING) + t0 = time.time() + try: + pcm_data = b"".join(self._tts_voice.synthesize_stream_raw(text)) + self._metrics.tts_synthesis_ms = (time.time() - t0) * 1000 + if self._tts_rate != self._sample_rate: + import numpy as np + samples = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32768.0 + pcm_data = float32_to_pcm16(resample_audio(samples, self._tts_rate, self._sample_rate)) + self._jabra.write_chunk(pcm_data) + self.get_logger().info(f"TTS: played {len(pcm_data)} bytes") + except Exception as e: + self.get_logger().error(f"TTS error: {e}") + self._metrics.error_msg = str(e) + finally: + self._set_state(AudioState.LISTENING) + self._publish_status() + + def _set_state(self, new_state: AudioState) -> None: + with self._state_lock: + if self._state != new_state: + self._state = new_state + self.get_logger().info(f"Audio state: {new_state.value}") + msg = String() + msg.data = new_state.value + self._state_pub.publish(msg) + if self._mqtt: + try: + self._mqtt.publish(f"{self._mqtt.base_topic}/state", new_state.value) + except Exception as e: + self.get_logger().debug(f"MQTT publish failed: {e}") + + def _publish_status(self) -> None: + status = {"state": self._state.value, "metrics": asdict(self._metrics), "timestamp": time.time()} + msg = String() + msg.data = json.dumps(status) + self._status_pub.publish(msg) + if self._mqtt: + try: + self._mqtt.publish(f"{self._mqtt.base_topic}/status", msg.data) + except Exception as e: + self.get_logger().debug(f"MQTT publish failed: {e}") + + def destroy_node(self) -> None: + self._running = False + self._jabra.close() + if self._mqtt: + try: + self._mqtt.disconnect() + except Exception: + pass + super().destroy_node() + + +def main(args=None) -> None: + rclpy.init(args=args) + node = AudioPipelineNode() + try: + rclpy.spin(node) + except KeyboardInterrupt: + pass + finally: + node.destroy_node() + rclpy.shutdown() + + +if __name__ == "__main__": + main() diff --git a/jetson/ros2_ws/src/saltybot_audio_pipeline/saltybot_audio_pipeline/audio_utils.py b/jetson/ros2_ws/src/saltybot_audio_pipeline/saltybot_audio_pipeline/audio_utils.py new file mode 100644 index 0000000..b00f6ba --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_audio_pipeline/saltybot_audio_pipeline/audio_utils.py @@ -0,0 +1,133 @@ +"""Audio utilities for processing and buffering.""" + +from __future__ import annotations +from typing import Optional, Tuple, List +import threading, time +from collections import deque +from dataclasses import dataclass +import numpy as np + + +@dataclass +class AudioChunk: + samples: np.ndarray + timestamp: float + rms_db: float + + +class EnergyVAD: + """Energy-based Voice Activity Detection.""" + def __init__(self, threshold_db: float = -35.0): + self.threshold_db = threshold_db + + def is_speech(self, samples: np.ndarray) -> bool: + rms = np.sqrt(np.mean(samples ** 2)) + db = 20 * np.log10(rms + 1e-10) + return db > self.threshold_db + + def rms_db(self, samples: np.ndarray) -> float: + rms = np.sqrt(np.mean(samples ** 2)) + return 20 * np.log10(rms + 1e-10) + + +class UtteranceSegmenter: + """Buffer and segment audio utterances based on energy VAD.""" + def __init__(self, vad: Optional[EnergyVAD] = None, silence_duration_s: float = 0.5, + min_duration_s: float = 0.3, sample_rate: int = 16000): + self.vad = vad or EnergyVAD() + self.silence_duration_s = silence_duration_s + self.min_duration_s = min_duration_s + self.sample_rate = sample_rate + self._buffer = deque() + self._last_speech_time = 0.0 + self._speech_started = False + self._lock = threading.Lock() + + def push(self, samples: np.ndarray) -> List[Tuple[List[float], float]]: + completed = [] + with self._lock: + now = time.time() + is_speech = self.vad.is_speech(samples) + if is_speech: + self._last_speech_time = now + self._speech_started = True + self._buffer.append(samples) + else: + self._buffer.append(samples) + if self._speech_started and now - self._last_speech_time > self.silence_duration_s: + utt_samples = self._extract_buffer() + duration = len(utt_samples) / self.sample_rate + if duration >= self.min_duration_s: + completed.append((utt_samples, duration)) + self._speech_started = False + self._buffer.clear() + return completed + + def _extract_buffer(self) -> List[float]: + if not self._buffer: + return [] + flat = [] + for s in self._buffer: + flat.extend(s.tolist() if isinstance(s, np.ndarray) else s) + return flat + + def reset(self) -> None: + with self._lock: + self._buffer.clear() + self._speech_started = False + + +class AudioBuffer: + """Thread-safe circular audio buffer.""" + def __init__(self, capacity_s: float = 5.0, sample_rate: int = 16000): + self.capacity = int(capacity_s * sample_rate) + self.sample_rate = sample_rate + self._buffer = deque(maxlen=self.capacity) + self._lock = threading.Lock() + + def push(self, samples: np.ndarray) -> None: + with self._lock: + self._buffer.extend(samples.tolist() if isinstance(samples, np.ndarray) else samples) + + def extract(self, duration_s: Optional[float] = None) -> np.ndarray: + with self._lock: + samples = list(self._buffer) + if duration_s is not None: + num_samples = int(duration_s * self.sample_rate) + samples = samples[-num_samples:] + return np.array(samples, dtype=np.float32) + + def clear(self) -> None: + with self._lock: + self._buffer.clear() + + def size(self) -> int: + with self._lock: + return len(self._buffer) + + +def pcm16_to_float32(pcm_bytes: bytes) -> np.ndarray: + samples = np.frombuffer(pcm_bytes, dtype=np.int16) + return samples.astype(np.float32) / 32768.0 + + +def float32_to_pcm16(samples: np.ndarray) -> bytes: + if isinstance(samples, list): + samples = np.array(samples, dtype=np.float32) + clipped = np.clip(samples, -1.0, 1.0) + pcm = (clipped * 32767).astype(np.int16) + return pcm.tobytes() + + +def resample_audio(samples: np.ndarray, orig_rate: int, target_rate: int) -> np.ndarray: + if orig_rate == target_rate: + return samples + from scipy import signal + num_samples = int(len(samples) * target_rate / orig_rate) + resampled = signal.resample(samples, num_samples) + return resampled.astype(np.float32) + + +def calculate_rms_db(samples: np.ndarray) -> float: + rms = np.sqrt(np.mean(samples ** 2)) + return 20 * np.log10(rms + 1e-10) diff --git a/jetson/ros2_ws/src/saltybot_audio_pipeline/setup.cfg b/jetson/ros2_ws/src/saltybot_audio_pipeline/setup.cfg new file mode 100644 index 0000000..633b438 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_audio_pipeline/setup.cfg @@ -0,0 +1,2 @@ +[develop] +script_dir=$base/lib/saltybot_audio_pipeline/scripts \ No newline at end of file diff --git a/jetson/ros2_ws/src/saltybot_audio_pipeline/setup.py b/jetson/ros2_ws/src/saltybot_audio_pipeline/setup.py new file mode 100644 index 0000000..214aa20 --- /dev/null +++ b/jetson/ros2_ws/src/saltybot_audio_pipeline/setup.py @@ -0,0 +1,21 @@ +from setuptools import setup +package_name = 'saltybot_audio_pipeline' +setup( + name=package_name, + version='1.0.0', + packages=[package_name], + data_files=[ + ('share/ament_index/resource_index/packages', ['resource/' + package_name]), + ('share/' + package_name, ['package.xml']), + ('share/' + package_name + '/launch', ['launch/audio_pipeline.launch.py']), + ('share/' + package_name + '/config', ['config/audio_pipeline_params.yaml']), + ], + install_requires=['setuptools'], + zip_safe=True, + author='Salty Lab', + entry_points={ + 'console_scripts': [ + 'audio_pipeline_node = saltybot_audio_pipeline.audio_pipeline_node:main', + ], + }, +) \ No newline at end of file