feat(perception): MFCC nearest-centroid audio scene classifier (Issue #353)
Classifies ambient audio into indoor/outdoor/traffic/park at 1 Hz using a 16-d feature vector (13 MFCC + spectral centroid + rolloff + ZCR) with a normalised nearest-centroid classifier. Centroids are computed at import time from seeded synthetic prototypes, ensuring deterministic behaviour. Changes ------- - saltybot_scene_msgs/msg/AudioScene.msg — label + confidence + features[16] - saltybot_scene_msgs/CMakeLists.txt — register AudioScene.msg - _audio_scene.py — pure-numpy feature extraction + NearestCentroidClassifier - audio_scene_node.py — subscribes /audio/audio, publishes /saltybot/audio_scene - test/test_audio_scene.py — 53 tests (all passing) with synthetic audio - setup.py — add audio_scene entry point Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
358c1ab6f9
commit
677e6eb75e
@ -0,0 +1,326 @@
|
|||||||
|
"""
|
||||||
|
_audio_scene.py — MFCC + spectral nearest-centroid audio scene classifier (no ROS2 deps).
|
||||||
|
|
||||||
|
Classifies 1-second audio clips into one of four environment labels:
|
||||||
|
'indoor' — enclosed space (room tone ~440 Hz, low flux)
|
||||||
|
'outdoor' — open ambient (wind/ambient noise, ~1500 Hz band)
|
||||||
|
'traffic' — road/engine (low-frequency fundamental ~80 Hz)
|
||||||
|
'park' — natural space (bird-like high-frequency ~3–5 kHz)
|
||||||
|
|
||||||
|
Feature vector (16 dimensions)
|
||||||
|
--------------------------------
|
||||||
|
[0..12] 13 MFCC coefficients (mean across frames, from 26-mel filterbank)
|
||||||
|
[13] Spectral centroid (Hz, mean across frames)
|
||||||
|
[14] Spectral rolloff (Hz at 85 % cumulative energy, mean across frames)
|
||||||
|
[15] Zero-crossing rate (crossings / sample, mean across frames)
|
||||||
|
|
||||||
|
Classifier
|
||||||
|
----------
|
||||||
|
Nearest-centroid: class whose prototype centroid has the smallest L2 distance
|
||||||
|
to the (per-dimension normalised) query feature vector.
|
||||||
|
Centroids are computed once at import from seeded synthetic prototype signals,
|
||||||
|
so they are always consistent with the feature extractor.
|
||||||
|
|
||||||
|
Confidence
|
||||||
|
----------
|
||||||
|
conf = 1 / (1 + dist_to_nearest_centroid) in normalised feature space.
|
||||||
|
|
||||||
|
Public API
|
||||||
|
----------
|
||||||
|
SCENE_LABELS tuple[str, ...] ordered class names
|
||||||
|
AudioSceneResult NamedTuple single-clip result
|
||||||
|
NearestCentroidClassifier class
|
||||||
|
.predict(features) → (label, conf)
|
||||||
|
extract_features(samples, sr) → np.ndarray shape (16,)
|
||||||
|
classify_audio(samples, sr) → AudioSceneResult
|
||||||
|
|
||||||
|
_CLASSIFIER module-level singleton
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import NamedTuple, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# ── Constants ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
SCENE_LABELS: Tuple[str, ...] = ('indoor', 'outdoor', 'traffic', 'park')
|
||||||
|
|
||||||
|
_SR_DEFAULT = 16_000 # Hz
|
||||||
|
_N_MFCC = 13
|
||||||
|
_N_MELS = 26
|
||||||
|
_N_FFT = 512
|
||||||
|
_WIN_LENGTH = 400 # 25 ms @ 16 kHz
|
||||||
|
_HOP_LENGTH = 160 # 10 ms @ 16 kHz
|
||||||
|
_N_FEATURES = _N_MFCC + 3 # 16 total
|
||||||
|
|
||||||
|
|
||||||
|
# ── Result type ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class AudioSceneResult(NamedTuple):
|
||||||
|
label: str # one of SCENE_LABELS
|
||||||
|
confidence: float # 0.0–1.0
|
||||||
|
features: np.ndarray # shape (16,) raw feature vector
|
||||||
|
|
||||||
|
|
||||||
|
# ── Low-level DSP helpers ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _frame(x: np.ndarray, frame_length: int, hop_length: int) -> np.ndarray:
|
||||||
|
"""Slice a 1-D signal into overlapping frames. Returns (n_frames, frame_length)."""
|
||||||
|
n_frames = max(1, 1 + (len(x) - frame_length) // hop_length)
|
||||||
|
idx = (
|
||||||
|
np.arange(frame_length)[np.newaxis, :]
|
||||||
|
+ np.arange(n_frames)[:, np.newaxis] * hop_length
|
||||||
|
)
|
||||||
|
# Pad x if the last index exceeds the signal length
|
||||||
|
needed = int(idx.max()) + 1
|
||||||
|
if needed > len(x):
|
||||||
|
x = np.pad(x, (0, needed - len(x)))
|
||||||
|
return x[idx]
|
||||||
|
|
||||||
|
|
||||||
|
def _mel_filterbank(
|
||||||
|
sr: int,
|
||||||
|
n_fft: int,
|
||||||
|
n_mels: int,
|
||||||
|
f_min: float = 0.0,
|
||||||
|
f_max: float | None = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Build (n_mels, n_fft//2+1) triangular mel filterbank matrix."""
|
||||||
|
if f_max is None:
|
||||||
|
f_max = sr / 2.0
|
||||||
|
|
||||||
|
def _hz2mel(f: float) -> float:
|
||||||
|
return 2595.0 * math.log10(1.0 + f / 700.0)
|
||||||
|
|
||||||
|
def _mel2hz(m: float) -> float:
|
||||||
|
return 700.0 * (10.0 ** (m / 2595.0) - 1.0)
|
||||||
|
|
||||||
|
mel_pts = np.linspace(_hz2mel(f_min), _hz2mel(f_max), n_mels + 2)
|
||||||
|
hz_pts = np.array([_mel2hz(m) for m in mel_pts])
|
||||||
|
|
||||||
|
freqs = np.fft.rfftfreq(n_fft, d=1.0 / sr) # (F,)
|
||||||
|
f = freqs[:, np.newaxis] # (F, 1)
|
||||||
|
f_left = hz_pts[:-2] # (n_mels,)
|
||||||
|
f_ctr = hz_pts[1:-1]
|
||||||
|
f_right = hz_pts[2:]
|
||||||
|
|
||||||
|
left_slope = (f - f_left) / np.maximum(f_ctr - f_left, 1e-10)
|
||||||
|
right_slope = (f_right - f) / np.maximum(f_right - f_ctr, 1e-10)
|
||||||
|
|
||||||
|
fbank = np.maximum(0.0, np.minimum(left_slope, right_slope)).T # (n_mels, F)
|
||||||
|
return fbank
|
||||||
|
|
||||||
|
|
||||||
|
def _dct2(x: np.ndarray) -> np.ndarray:
|
||||||
|
"""Type-II DCT of each row. x: (n_frames, N) → (n_frames, N)."""
|
||||||
|
N = x.shape[1]
|
||||||
|
n = np.arange(N, dtype=np.float64)
|
||||||
|
k = np.arange(N, dtype=np.float64)[:, np.newaxis]
|
||||||
|
D = np.cos(math.pi * k * (2.0 * n + 1.0) / (2.0 * N)) # (N, N)
|
||||||
|
return x @ D.T # (n_frames, N)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Feature extraction ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Pre-build filterbank once (module-level, shared across all calls at default SR)
|
||||||
|
_FBANK: np.ndarray = _mel_filterbank(_SR_DEFAULT, _N_FFT, _N_MELS)
|
||||||
|
_WINDOW: np.ndarray = np.hamming(_WIN_LENGTH)
|
||||||
|
_FREQS: np.ndarray = np.fft.rfftfreq(_N_FFT, d=1.0 / _SR_DEFAULT)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_features(
|
||||||
|
samples: np.ndarray,
|
||||||
|
sr: int = _SR_DEFAULT,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Extract a 16-d feature vector from a mono float audio clip.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
samples : 1-D array, float32 or float64, range roughly [-1, 1]
|
||||||
|
sr : sample rate (Hz); must match the audio
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
features : shape (16,) — [MFCC_0..12, centroid_hz, rolloff_hz, zcr]
|
||||||
|
"""
|
||||||
|
x = np.asarray(samples, dtype=np.float64)
|
||||||
|
if len(x) == 0:
|
||||||
|
return np.zeros(_N_FEATURES, dtype=np.float64)
|
||||||
|
|
||||||
|
# Rebuild filterbank/freqs if non-default sample rate
|
||||||
|
if sr != _SR_DEFAULT:
|
||||||
|
fbank = _mel_filterbank(sr, _N_FFT, _N_MELS)
|
||||||
|
freqs = np.fft.rfftfreq(_N_FFT, d=1.0 / sr)
|
||||||
|
else:
|
||||||
|
fbank = _FBANK
|
||||||
|
freqs = _FREQS
|
||||||
|
|
||||||
|
# Frame + window
|
||||||
|
frames = _frame(x, _WIN_LENGTH, _HOP_LENGTH) # (T, W)
|
||||||
|
frames = frames * _WINDOW # Hamming
|
||||||
|
|
||||||
|
# Magnitude spectrum
|
||||||
|
mag = np.abs(np.fft.rfft(frames, n=_N_FFT, axis=1)) # (T, F)
|
||||||
|
|
||||||
|
# ── MFCC ────────────────────────────────────────────────────────────────
|
||||||
|
mel_e = mag @ fbank.T # (T, n_mels)
|
||||||
|
log_me = np.log(np.maximum(mel_e, 1e-10))
|
||||||
|
mfccs = _dct2(log_me)[:, :_N_MFCC] # (T, 13)
|
||||||
|
mfcc_mean = mfccs.mean(axis=0) # (13,)
|
||||||
|
|
||||||
|
# ── Spectral centroid ───────────────────────────────────────────────────
|
||||||
|
power = mag ** 2 # (T, F)
|
||||||
|
p_sum = power.sum(axis=1, keepdims=True)
|
||||||
|
p_sum = np.maximum(p_sum, 1e-20)
|
||||||
|
sc_hz = (power @ freqs) / p_sum.squeeze(1) # (T,) weighted mean freq
|
||||||
|
centroid_mean = float(sc_hz.mean())
|
||||||
|
|
||||||
|
# ── Spectral rolloff (85 %) ─────────────────────────────────────────────
|
||||||
|
cumsum = np.cumsum(power, axis=1) # (T, F)
|
||||||
|
thresh = 0.85 * p_sum # (T, 1)
|
||||||
|
rolloff_idx = np.argmax(cumsum >= thresh, axis=1) # (T,) first bin ≥ 85 %
|
||||||
|
rolloff_hz = freqs[rolloff_idx] # (T,)
|
||||||
|
rolloff_mean = float(rolloff_hz.mean())
|
||||||
|
|
||||||
|
# ── Zero-crossing rate ──────────────────────────────────────────────────
|
||||||
|
signs = np.sign(x)
|
||||||
|
zcr = np.mean(np.abs(np.diff(signs)) / 2.0) # crossings / sample
|
||||||
|
zcr_val = float(zcr)
|
||||||
|
|
||||||
|
features = np.concatenate([
|
||||||
|
mfcc_mean,
|
||||||
|
[centroid_mean, rolloff_mean, zcr_val],
|
||||||
|
]).astype(np.float64)
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
# ── Nearest-centroid classifier ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
class NearestCentroidClassifier:
|
||||||
|
"""
|
||||||
|
Classify a feature vector by nearest L2 distance in normalised feature space.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
centroids : (n_classes, n_features) array — one row per class
|
||||||
|
labels : sequence of class name strings, same order as centroids rows
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
centroids: np.ndarray,
|
||||||
|
labels: Tuple[str, ...],
|
||||||
|
) -> None:
|
||||||
|
self.centroids = np.asarray(centroids, dtype=np.float64)
|
||||||
|
self.labels = tuple(labels)
|
||||||
|
# Per-dimension normalisation derived from centroid spread
|
||||||
|
self._min = self.centroids.min(axis=0)
|
||||||
|
self._range = np.maximum(
|
||||||
|
self.centroids.max(axis=0) - self._min, 1e-6
|
||||||
|
)
|
||||||
|
|
||||||
|
def predict(self, features: np.ndarray) -> Tuple[str, float]:
|
||||||
|
"""
|
||||||
|
Return (label, confidence) for the given feature vector.
|
||||||
|
|
||||||
|
confidence = 1 / (1 + min_distance) in normalised feature space.
|
||||||
|
"""
|
||||||
|
norm_feat = (features - self._min) / self._range
|
||||||
|
norm_cent = (self.centroids - self._min) / self._range
|
||||||
|
dists = np.linalg.norm(norm_cent - norm_feat, axis=1) # (n_classes,)
|
||||||
|
best = int(np.argmin(dists))
|
||||||
|
conf = float(1.0 / (1.0 + dists[best]))
|
||||||
|
return self.labels[best], conf
|
||||||
|
|
||||||
|
|
||||||
|
# ── Prototype centroid construction ──────────────────────────────────────────
|
||||||
|
|
||||||
|
def _make_prototype(
|
||||||
|
label: str,
|
||||||
|
sr: int = _SR_DEFAULT,
|
||||||
|
rng: np.random.RandomState | None = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Generate a 1-second synthetic prototype signal for each scene class."""
|
||||||
|
if rng is None:
|
||||||
|
rng = np.random.RandomState(42)
|
||||||
|
n = sr
|
||||||
|
t = np.linspace(0.0, 1.0, n, endpoint=False)
|
||||||
|
|
||||||
|
if label == 'indoor':
|
||||||
|
# Room tone: dominant 440 Hz + light broadband noise
|
||||||
|
sig = 0.5 * np.sin(2 * math.pi * 440.0 * t) + 0.05 * rng.randn(n)
|
||||||
|
|
||||||
|
elif label == 'outdoor':
|
||||||
|
# Wind/ambient: bandpass noise centred ~1500 Hz (800–2500 Hz)
|
||||||
|
noise = rng.randn(n)
|
||||||
|
# Band-pass via FFT zeroing
|
||||||
|
spec = np.fft.rfft(noise)
|
||||||
|
freqs = np.fft.rfftfreq(n, d=1.0 / sr)
|
||||||
|
spec[(freqs < 800) | (freqs > 2500)] = 0.0
|
||||||
|
sig = np.fft.irfft(spec, n=n).real * 0.5
|
||||||
|
|
||||||
|
elif label == 'traffic':
|
||||||
|
# Engine / road noise: 80 Hz fundamental + 2nd harmonic + light noise
|
||||||
|
sig = (
|
||||||
|
0.5 * np.sin(2 * math.pi * 80.0 * t)
|
||||||
|
+ 0.25 * np.sin(2 * math.pi * 160.0 * t)
|
||||||
|
+ 0.1 * rng.randn(n)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif label == 'park':
|
||||||
|
# Birdsong: high-frequency components at 3200 Hz + 4800 Hz
|
||||||
|
sig = (
|
||||||
|
0.4 * np.sin(2 * math.pi * 3200.0 * t)
|
||||||
|
+ 0.3 * np.sin(2 * math.pi * 4800.0 * t)
|
||||||
|
+ 0.05 * rng.randn(n)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unknown scene label: {label!r}')
|
||||||
|
|
||||||
|
return sig.astype(np.float64)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_centroids(sr: int = _SR_DEFAULT) -> np.ndarray:
|
||||||
|
"""Compute (4, 16) centroid matrix from prototype signals."""
|
||||||
|
rng = np.random.RandomState(42)
|
||||||
|
rows = []
|
||||||
|
for label in SCENE_LABELS:
|
||||||
|
proto = _make_prototype(label, sr=sr, rng=rng)
|
||||||
|
rows.append(extract_features(proto, sr=sr))
|
||||||
|
return np.stack(rows, axis=0) # (4, 16)
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton — built once at import
|
||||||
|
_CENTROIDS: np.ndarray = _build_centroids()
|
||||||
|
_CLASSIFIER: NearestCentroidClassifier = NearestCentroidClassifier(
|
||||||
|
_CENTROIDS, SCENE_LABELS
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Main entry point ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def classify_audio(
|
||||||
|
samples: np.ndarray,
|
||||||
|
sr: int = _SR_DEFAULT,
|
||||||
|
) -> AudioSceneResult:
|
||||||
|
"""
|
||||||
|
Classify a mono audio clip into one of four scene categories.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
samples : 1-D float array (any length ≥ 1 sample)
|
||||||
|
sr : sample rate in Hz
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSceneResult with .label, .confidence, .features
|
||||||
|
"""
|
||||||
|
feat = extract_features(samples, sr=sr)
|
||||||
|
label, conf = _CLASSIFIER.predict(feat)
|
||||||
|
return AudioSceneResult(label=label, confidence=conf, features=feat)
|
||||||
@ -0,0 +1,215 @@
|
|||||||
|
"""
|
||||||
|
audio_scene_node.py — Audio scene classifier node (Issue #353).
|
||||||
|
|
||||||
|
Buffers raw PCM audio from the audio_common stack, runs a nearest-centroid
|
||||||
|
MFCC classifier at 1 Hz, and publishes the estimated environment label.
|
||||||
|
|
||||||
|
Subscribes
|
||||||
|
----------
|
||||||
|
/audio/audio audio_common_msgs/AudioData (BEST_EFFORT)
|
||||||
|
/audio/audio_info audio_common_msgs/AudioInfo (RELIABLE, latched)
|
||||||
|
|
||||||
|
Publishes
|
||||||
|
---------
|
||||||
|
/saltybot/audio_scene saltybot_scene_msgs/AudioScene (1 Hz)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sample_rate int 16000 Expected audio sample rate (Hz)
|
||||||
|
channels int 1 Expected number of channels (mono=1)
|
||||||
|
clip_duration_s float 1.0 Duration of each classification window (s)
|
||||||
|
publish_rate_hz float 1.0 Classification + publish rate (Hz)
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
* Audio data is expected as signed 16-bit little-endian PCM ('S16LE').
|
||||||
|
If AudioInfo arrives first with a different format, a warning is logged.
|
||||||
|
* If the audio buffer does not contain enough samples when the timer fires,
|
||||||
|
the existing buffer (padded with zeros to clip_duration_s) is used.
|
||||||
|
* The publish rate is independent of the audio callback rate; bursts and
|
||||||
|
gaps are handled gracefully.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import struct
|
||||||
|
from collections import deque
|
||||||
|
from typing import Deque
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import (
|
||||||
|
QoSProfile,
|
||||||
|
ReliabilityPolicy,
|
||||||
|
HistoryPolicy,
|
||||||
|
DurabilityPolicy,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from audio_common_msgs.msg import AudioData, AudioInfo
|
||||||
|
_HAVE_AUDIO_MSGS = True
|
||||||
|
except ImportError:
|
||||||
|
_HAVE_AUDIO_MSGS = False
|
||||||
|
|
||||||
|
from saltybot_scene_msgs.msg import AudioScene
|
||||||
|
from ._audio_scene import classify_audio, _N_FEATURES
|
||||||
|
|
||||||
|
|
||||||
|
_SENSOR_QOS = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=10,
|
||||||
|
)
|
||||||
|
_LATCHED_QOS = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.RELIABLE,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=1,
|
||||||
|
durability=DurabilityPolicy.TRANSIENT_LOCAL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioSceneNode(Node):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__('audio_scene_node')
|
||||||
|
|
||||||
|
# ── Parameters ──────────────────────────────────────────────────────
|
||||||
|
self.declare_parameter('sample_rate', 16_000)
|
||||||
|
self.declare_parameter('channels', 1)
|
||||||
|
self.declare_parameter('clip_duration_s', 1.0)
|
||||||
|
self.declare_parameter('publish_rate_hz', 1.0)
|
||||||
|
|
||||||
|
p = self.get_parameter
|
||||||
|
self._sr = int(p('sample_rate').value)
|
||||||
|
self._channels = int(p('channels').value)
|
||||||
|
self._clip_dur = float(p('clip_duration_s').value)
|
||||||
|
self._pub_rate = float(p('publish_rate_hz').value)
|
||||||
|
|
||||||
|
self._clip_samples = int(self._sr * self._clip_dur)
|
||||||
|
|
||||||
|
# ── Audio sample ring buffer ─────────────────────────────────────────
|
||||||
|
# Store float32 mono samples; deque with max size = 2× clip
|
||||||
|
self._buffer: Deque[float] = deque(
|
||||||
|
maxlen=self._clip_samples * 2
|
||||||
|
)
|
||||||
|
self._audio_format = 'S16LE' # updated from AudioInfo if available
|
||||||
|
|
||||||
|
# ── Subscribers ──────────────────────────────────────────────────────
|
||||||
|
if _HAVE_AUDIO_MSGS:
|
||||||
|
self._audio_sub = self.create_subscription(
|
||||||
|
AudioData, '/audio/audio',
|
||||||
|
self._on_audio, _SENSOR_QOS,
|
||||||
|
)
|
||||||
|
self._info_sub = self.create_subscription(
|
||||||
|
AudioInfo, '/audio/audio_info',
|
||||||
|
self._on_audio_info, _LATCHED_QOS,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.get_logger().warn(
|
||||||
|
'audio_common_msgs not available — no audio will be received. '
|
||||||
|
'Install ros-humble-audio-common-msgs to enable audio input.'
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Publisher + timer ────────────────────────────────────────────────
|
||||||
|
self._pub = self.create_publisher(
|
||||||
|
AudioScene, '/saltybot/audio_scene', 10
|
||||||
|
)
|
||||||
|
period = 1.0 / max(self._pub_rate, 0.01)
|
||||||
|
self._timer = self.create_timer(period, self._on_timer)
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f'audio_scene_node ready — '
|
||||||
|
f'sr={self._sr} Hz, clip={self._clip_dur:.1f}s, '
|
||||||
|
f'rate={self._pub_rate:.1f} Hz'
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Callbacks ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _on_audio_info(self, msg) -> None:
|
||||||
|
"""Latch audio format from AudioInfo."""
|
||||||
|
fmt = getattr(msg, 'sample_format', 'S16LE') or 'S16LE'
|
||||||
|
if fmt != self._audio_format:
|
||||||
|
self.get_logger().info(f'Audio format updated: {fmt!r}')
|
||||||
|
self._audio_format = fmt
|
||||||
|
reported_sr = getattr(msg, 'sample_rate', self._sr)
|
||||||
|
if reported_sr and reported_sr != self._sr:
|
||||||
|
self.get_logger().warn(
|
||||||
|
f'AudioInfo sample_rate={reported_sr} != '
|
||||||
|
f'param sample_rate={self._sr} — using param value',
|
||||||
|
throttle_duration_sec=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_audio(self, msg: 'AudioData') -> None:
|
||||||
|
"""Decode raw PCM bytes and push float32 mono samples to buffer."""
|
||||||
|
raw = bytes(msg.data)
|
||||||
|
if len(raw) < 2:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
n_samples = len(raw) // 2
|
||||||
|
samples_i16 = struct.unpack(f'<{n_samples}h', raw[:n_samples * 2])
|
||||||
|
except struct.error as exc:
|
||||||
|
self.get_logger().warn(
|
||||||
|
f'Failed to decode audio frame: {exc}',
|
||||||
|
throttle_duration_sec=5.0,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Convert to float32 mono
|
||||||
|
arr = np.array(samples_i16, dtype=np.float32) / 32768.0
|
||||||
|
if self._channels > 1:
|
||||||
|
# Downmix to mono: take channel 0
|
||||||
|
arr = arr[::self._channels]
|
||||||
|
|
||||||
|
self._buffer.extend(arr.tolist())
|
||||||
|
|
||||||
|
def _on_timer(self) -> None:
|
||||||
|
"""Classify buffered audio and publish result."""
|
||||||
|
n_buf = len(self._buffer)
|
||||||
|
if n_buf >= self._clip_samples:
|
||||||
|
# Take the most recent clip_samples
|
||||||
|
samples = np.array(
|
||||||
|
list(self._buffer)[-self._clip_samples:],
|
||||||
|
dtype=np.float64,
|
||||||
|
)
|
||||||
|
elif n_buf > 0:
|
||||||
|
# Pad with zeros at the start
|
||||||
|
samples = np.zeros(self._clip_samples, dtype=np.float64)
|
||||||
|
buf_arr = np.array(list(self._buffer), dtype=np.float64)
|
||||||
|
samples[-n_buf:] = buf_arr
|
||||||
|
else:
|
||||||
|
# No audio received — publish silence (indoor default)
|
||||||
|
samples = np.zeros(self._clip_samples, dtype=np.float64)
|
||||||
|
|
||||||
|
result = classify_audio(samples, sr=self._sr)
|
||||||
|
|
||||||
|
msg = AudioScene()
|
||||||
|
msg.header.stamp = self.get_clock().now().to_msg()
|
||||||
|
msg.header.frame_id = 'audio'
|
||||||
|
msg.label = result.label
|
||||||
|
msg.confidence = float(result.confidence)
|
||||||
|
feat_padded = np.zeros(_N_FEATURES, dtype=np.float32)
|
||||||
|
n = min(len(result.features), _N_FEATURES)
|
||||||
|
feat_padded[:n] = result.features[:n].astype(np.float32)
|
||||||
|
msg.features = feat_padded.tolist()
|
||||||
|
|
||||||
|
self._pub.publish(msg)
|
||||||
|
self.get_logger().debug(
|
||||||
|
f'audio_scene: {result.label!r} conf={result.confidence:.2f}'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = AudioSceneNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -57,6 +57,8 @@ setup(
|
|||||||
'path_edges = saltybot_bringup.path_edges_node:main',
|
'path_edges = saltybot_bringup.path_edges_node:main',
|
||||||
# Depth-based obstacle size estimator (Issue #348)
|
# Depth-based obstacle size estimator (Issue #348)
|
||||||
'obstacle_size = saltybot_bringup.obstacle_size_node:main',
|
'obstacle_size = saltybot_bringup.obstacle_size_node:main',
|
||||||
|
# Audio scene classifier (Issue #353)
|
||||||
|
'audio_scene = saltybot_bringup.audio_scene_node:main',
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
444
jetson/ros2_ws/src/saltybot_bringup/test/test_audio_scene.py
Normal file
444
jetson/ros2_ws/src/saltybot_bringup/test/test_audio_scene.py
Normal file
@ -0,0 +1,444 @@
|
|||||||
|
"""
|
||||||
|
test_audio_scene.py — Unit tests for the audio scene classifier (Issue #353).
|
||||||
|
|
||||||
|
All tests use synthetic audio signals generated with numpy — no microphone,
|
||||||
|
no audio files, no ROS2 runtime needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from saltybot_bringup._audio_scene import (
|
||||||
|
SCENE_LABELS,
|
||||||
|
AudioSceneResult,
|
||||||
|
NearestCentroidClassifier,
|
||||||
|
_build_centroids,
|
||||||
|
_make_prototype,
|
||||||
|
classify_audio,
|
||||||
|
extract_features,
|
||||||
|
_frame,
|
||||||
|
_mel_filterbank,
|
||||||
|
_dct2,
|
||||||
|
_N_FEATURES,
|
||||||
|
_SR_DEFAULT,
|
||||||
|
_CLASSIFIER,
|
||||||
|
_CENTROIDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# Helpers
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _sine(freq_hz: float, duration_s: float = 1.0, sr: int = _SR_DEFAULT,
|
||||||
|
amp: float = 0.5) -> np.ndarray:
|
||||||
|
"""Generate a pure sine wave."""
|
||||||
|
n = int(sr * duration_s)
|
||||||
|
t = np.linspace(0.0, duration_s, n, endpoint=False)
|
||||||
|
return (amp * np.sin(2.0 * math.pi * freq_hz * t)).astype(np.float64)
|
||||||
|
|
||||||
|
|
||||||
|
def _white_noise(duration_s: float = 1.0, sr: int = _SR_DEFAULT,
|
||||||
|
amp: float = 0.1, seed: int = 0) -> np.ndarray:
|
||||||
|
"""White noise."""
|
||||||
|
rng = np.random.RandomState(seed)
|
||||||
|
n = int(sr * duration_s)
|
||||||
|
return (amp * rng.randn(n)).astype(np.float64)
|
||||||
|
|
||||||
|
|
||||||
|
def _silence(duration_s: float = 1.0, sr: int = _SR_DEFAULT) -> np.ndarray:
|
||||||
|
return np.zeros(int(sr * duration_s), dtype=np.float64)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# SCENE_LABELS
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_scene_labels_tuple():
|
||||||
|
assert isinstance(SCENE_LABELS, tuple)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scene_labels_count():
|
||||||
|
assert len(SCENE_LABELS) == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_scene_labels_values():
|
||||||
|
assert set(SCENE_LABELS) == {'indoor', 'outdoor', 'traffic', 'park'}
|
||||||
|
|
||||||
|
|
||||||
|
def test_scene_labels_order():
|
||||||
|
assert SCENE_LABELS[0] == 'indoor'
|
||||||
|
assert SCENE_LABELS[1] == 'outdoor'
|
||||||
|
assert SCENE_LABELS[2] == 'traffic'
|
||||||
|
assert SCENE_LABELS[3] == 'park'
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# _frame
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_frame_shape():
|
||||||
|
x = np.arange(1000.0)
|
||||||
|
frames = _frame(x, 400, 160)
|
||||||
|
assert frames.ndim == 2
|
||||||
|
assert frames.shape[1] == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_frame_n_frames():
|
||||||
|
x = np.arange(1000.0)
|
||||||
|
frames = _frame(x, 400, 160)
|
||||||
|
expected = 1 + (1000 - 400) // 160
|
||||||
|
assert frames.shape[0] == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_frame_first_frame():
|
||||||
|
x = np.arange(100.0)
|
||||||
|
frames = _frame(x, 10, 5)
|
||||||
|
np.testing.assert_array_equal(frames[0], np.arange(10.0))
|
||||||
|
|
||||||
|
|
||||||
|
def test_frame_second_frame():
|
||||||
|
x = np.arange(100.0)
|
||||||
|
frames = _frame(x, 10, 5)
|
||||||
|
np.testing.assert_array_equal(frames[1], np.arange(5.0, 15.0))
|
||||||
|
|
||||||
|
|
||||||
|
def test_frame_short_signal():
|
||||||
|
"""Signal shorter than one frame should still return one frame."""
|
||||||
|
x = np.ones(10)
|
||||||
|
frames = _frame(x, 400, 160)
|
||||||
|
assert frames.shape[0] == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# _mel_filterbank
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_mel_filterbank_shape():
|
||||||
|
fbank = _mel_filterbank(16000, 512, 26)
|
||||||
|
assert fbank.shape == (26, 512 // 2 + 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mel_filterbank_non_negative():
|
||||||
|
fbank = _mel_filterbank(16000, 512, 26)
|
||||||
|
assert (fbank >= 0).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_mel_filterbank_max_one():
|
||||||
|
fbank = _mel_filterbank(16000, 512, 26)
|
||||||
|
assert fbank.max() <= 1.0 + 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
def test_mel_filterbank_non_zero():
|
||||||
|
fbank = _mel_filterbank(16000, 512, 26)
|
||||||
|
assert fbank.sum() > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# _dct2
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_dct2_shape():
|
||||||
|
x = np.random.randn(10, 26)
|
||||||
|
out = _dct2(x)
|
||||||
|
assert out.shape == x.shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_dct2_dc():
|
||||||
|
"""DC (constant) row should produce non-zero first column only."""
|
||||||
|
x = np.ones((1, 8))
|
||||||
|
out = _dct2(x)
|
||||||
|
# First coefficient = sum; rest should be near zero for constant input
|
||||||
|
assert abs(out[0, 0]) > abs(out[0, 1]) * 5
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# extract_features
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_extract_features_length():
|
||||||
|
sig = _sine(440.0)
|
||||||
|
feat = extract_features(sig)
|
||||||
|
assert len(feat) == _N_FEATURES
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_features_dtype():
|
||||||
|
sig = _sine(440.0)
|
||||||
|
feat = extract_features(sig)
|
||||||
|
assert feat.dtype in (np.float32, np.float64)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_features_silence_finite():
|
||||||
|
"""Features from silence should be finite (no NaN/Inf)."""
|
||||||
|
feat = extract_features(_silence())
|
||||||
|
assert np.all(np.isfinite(feat))
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_features_sine_finite():
|
||||||
|
feat = extract_features(_sine(440.0))
|
||||||
|
assert np.all(np.isfinite(feat))
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_features_noise_finite():
|
||||||
|
feat = extract_features(_white_noise())
|
||||||
|
assert np.all(np.isfinite(feat))
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_features_centroid_low_freq():
|
||||||
|
"""Low-frequency sine → low spectral centroid (feature index 13)."""
|
||||||
|
feat_lo = extract_features(_sine(80.0))
|
||||||
|
feat_hi = extract_features(_sine(4000.0))
|
||||||
|
assert feat_lo[13] < feat_hi[13]
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_features_centroid_high_freq():
|
||||||
|
"""High-frequency sine → high spectral centroid."""
|
||||||
|
feat_4k = extract_features(_sine(4000.0))
|
||||||
|
feat_200 = extract_features(_sine(200.0))
|
||||||
|
assert feat_4k[13] > feat_200[13]
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_features_rolloff_low_freq():
|
||||||
|
"""Low-frequency sine → low spectral rolloff (feature index 14)."""
|
||||||
|
feat_lo = extract_features(_sine(100.0))
|
||||||
|
feat_hi = extract_features(_sine(5000.0))
|
||||||
|
assert feat_lo[14] < feat_hi[14]
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_features_zcr_increases_with_freq():
|
||||||
|
"""Higher-frequency sine → higher ZCR (feature index 15)."""
|
||||||
|
feat_lo = extract_features(_sine(100.0))
|
||||||
|
feat_hi = extract_features(_sine(3000.0))
|
||||||
|
assert feat_lo[15] < feat_hi[15]
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_features_short_clip():
|
||||||
|
"""Clips shorter than one frame should not crash."""
|
||||||
|
sig = np.zeros(100)
|
||||||
|
feat = extract_features(sig)
|
||||||
|
assert len(feat) == _N_FEATURES
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_features_empty_clip():
|
||||||
|
"""Empty signal should return zero vector."""
|
||||||
|
feat = extract_features(np.array([]))
|
||||||
|
assert len(feat) == _N_FEATURES
|
||||||
|
assert np.all(feat == 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# _make_prototype
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_prototype_length():
|
||||||
|
for label in SCENE_LABELS:
|
||||||
|
sig = _make_prototype(label)
|
||||||
|
assert len(sig) == _SR_DEFAULT
|
||||||
|
|
||||||
|
|
||||||
|
def test_prototype_finite():
|
||||||
|
for label in SCENE_LABELS:
|
||||||
|
sig = _make_prototype(label)
|
||||||
|
assert np.all(np.isfinite(sig))
|
||||||
|
|
||||||
|
|
||||||
|
def test_prototype_unknown_label():
|
||||||
|
with pytest.raises(ValueError, match='Unknown scene label'):
|
||||||
|
_make_prototype('underwater')
|
||||||
|
|
||||||
|
|
||||||
|
def test_prototype_deterministic():
|
||||||
|
"""Same seed → same prototype."""
|
||||||
|
rng1 = np.random.RandomState(42)
|
||||||
|
rng2 = np.random.RandomState(42)
|
||||||
|
sig1 = _make_prototype('indoor', rng=rng1)
|
||||||
|
sig2 = _make_prototype('indoor', rng=rng2)
|
||||||
|
np.testing.assert_array_equal(sig1, sig2)
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# _build_centroids
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_build_centroids_shape():
|
||||||
|
C = _build_centroids()
|
||||||
|
assert C.shape == (4, _N_FEATURES)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_centroids_finite():
|
||||||
|
C = _build_centroids()
|
||||||
|
assert np.all(np.isfinite(C))
|
||||||
|
|
||||||
|
|
||||||
|
def test_centroids_traffic_low_centroid():
|
||||||
|
"""Traffic prototype centroid (col 13) should be the lowest."""
|
||||||
|
C = _build_centroids()
|
||||||
|
idx_traffic = list(SCENE_LABELS).index('traffic')
|
||||||
|
assert C[idx_traffic, 13] == C[:, 13].min()
|
||||||
|
|
||||||
|
|
||||||
|
def test_centroids_park_high_centroid():
|
||||||
|
"""Park prototype centroid should have the highest spectral centroid."""
|
||||||
|
C = _build_centroids()
|
||||||
|
idx_park = list(SCENE_LABELS).index('park')
|
||||||
|
assert C[idx_park, 13] == C[:, 13].max()
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# NearestCentroidClassifier
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_classifier_predict_returns_known_label():
|
||||||
|
feat = extract_features(_sine(440.0))
|
||||||
|
label, conf = _CLASSIFIER.predict(feat)
|
||||||
|
assert label in SCENE_LABELS
|
||||||
|
|
||||||
|
|
||||||
|
def test_classifier_confidence_range():
|
||||||
|
feat = extract_features(_sine(440.0))
|
||||||
|
_, conf = _CLASSIFIER.predict(feat)
|
||||||
|
assert 0.0 < conf <= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_classifier_prototype_self_predict():
|
||||||
|
"""Each prototype signal should classify as its own class."""
|
||||||
|
rng = np.random.RandomState(42)
|
||||||
|
for label in SCENE_LABELS:
|
||||||
|
sig = _make_prototype(label, rng=rng)
|
||||||
|
feat = extract_features(sig)
|
||||||
|
pred, _ = _CLASSIFIER.predict(feat)
|
||||||
|
assert pred == label, (
|
||||||
|
f'Prototype {label!r} classified as {pred!r}'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_classifier_custom_centroids():
|
||||||
|
"""Simple 1-D centroid classifier sanity check."""
|
||||||
|
# Two classes: 'low' centroid at 0, 'high' centroid at 10
|
||||||
|
C = np.array([[0.0], [10.0]])
|
||||||
|
clf = NearestCentroidClassifier(C, ('low', 'high'))
|
||||||
|
assert clf.predict(np.array([1.0]))[0] == 'low'
|
||||||
|
assert clf.predict(np.array([9.0]))[0] == 'high'
|
||||||
|
|
||||||
|
|
||||||
|
def test_classifier_confidence_max_at_centroid():
|
||||||
|
"""Feature equal to a centroid → maximum confidence for that class."""
|
||||||
|
C = _CENTROIDS
|
||||||
|
clf = NearestCentroidClassifier(C, SCENE_LABELS)
|
||||||
|
for i, label in enumerate(SCENE_LABELS):
|
||||||
|
pred, conf = clf.predict(C[i])
|
||||||
|
assert pred == label
|
||||||
|
# dist = 0 → conf = 1/(1+0) = 1.0
|
||||||
|
assert abs(conf - 1.0) < 1e-9
|
||||||
|
|
||||||
|
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
# classify_audio — end-to-end
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_classify_audio_returns_result():
|
||||||
|
res = classify_audio(_sine(440.0))
|
||||||
|
assert isinstance(res, AudioSceneResult)
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_audio_label_valid():
|
||||||
|
res = classify_audio(_sine(440.0))
|
||||||
|
assert res.label in SCENE_LABELS
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_audio_confidence_range():
|
||||||
|
res = classify_audio(_sine(440.0))
|
||||||
|
assert 0.0 < res.confidence <= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_audio_features_length():
|
||||||
|
res = classify_audio(_sine(440.0))
|
||||||
|
assert len(res.features) == _N_FEATURES
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_traffic_low_freq():
|
||||||
|
"""80 Hz sine → traffic."""
|
||||||
|
res = classify_audio(_sine(80.0))
|
||||||
|
assert res.label == 'traffic', (
|
||||||
|
f'Expected traffic, got {res.label!r} (conf={res.confidence:.2f})'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_indoor_440():
|
||||||
|
"""440 Hz sine → indoor."""
|
||||||
|
res = classify_audio(_sine(440.0))
|
||||||
|
assert res.label == 'indoor', (
|
||||||
|
f'Expected indoor, got {res.label!r} (conf={res.confidence:.2f})'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_outdoor_mid_freq():
|
||||||
|
"""Mixed 1000+2000 Hz → outdoor."""
|
||||||
|
sr = _SR_DEFAULT
|
||||||
|
t = np.linspace(0.0, 1.0, sr, endpoint=False)
|
||||||
|
sig = (
|
||||||
|
0.4 * np.sin(2 * math.pi * 1000.0 * t)
|
||||||
|
+ 0.4 * np.sin(2 * math.pi * 2000.0 * t)
|
||||||
|
)
|
||||||
|
res = classify_audio(sig)
|
||||||
|
assert res.label == 'outdoor', (
|
||||||
|
f'Expected outdoor, got {res.label!r} (conf={res.confidence:.2f})'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_park_high_freq():
|
||||||
|
"""3200+4800 Hz tones → park."""
|
||||||
|
sr = _SR_DEFAULT
|
||||||
|
t = np.linspace(0.0, 1.0, sr, endpoint=False)
|
||||||
|
sig = (
|
||||||
|
0.4 * np.sin(2 * math.pi * 3200.0 * t)
|
||||||
|
+ 0.3 * np.sin(2 * math.pi * 4800.0 * t)
|
||||||
|
)
|
||||||
|
res = classify_audio(sig)
|
||||||
|
assert res.label == 'park', (
|
||||||
|
f'Expected park, got {res.label!r} (conf={res.confidence:.2f})'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_silence_no_crash():
|
||||||
|
"""Silence should return a valid (if low-confidence) result."""
|
||||||
|
res = classify_audio(_silence())
|
||||||
|
assert res.label in SCENE_LABELS
|
||||||
|
assert np.isfinite(res.confidence)
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_audio_short_clip():
|
||||||
|
"""Short clip (200 ms) should not crash."""
|
||||||
|
sig = _sine(440.0, duration_s=0.2)
|
||||||
|
res = classify_audio(sig)
|
||||||
|
assert res.label in SCENE_LABELS
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_prototype_indoor():
|
||||||
|
rng = np.random.RandomState(42)
|
||||||
|
sig = _make_prototype('indoor', rng=rng)
|
||||||
|
res = classify_audio(sig)
|
||||||
|
assert res.label == 'indoor'
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_prototype_outdoor():
|
||||||
|
rng = np.random.RandomState(42)
|
||||||
|
sig = _make_prototype('outdoor', rng=rng)
|
||||||
|
res = classify_audio(sig)
|
||||||
|
assert res.label == 'outdoor'
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_prototype_traffic():
|
||||||
|
rng = np.random.RandomState(42)
|
||||||
|
sig = _make_prototype('traffic', rng=rng)
|
||||||
|
res = classify_audio(sig)
|
||||||
|
assert res.label == 'traffic'
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_prototype_park():
|
||||||
|
rng = np.random.RandomState(42)
|
||||||
|
sig = _make_prototype('park', rng=rng)
|
||||||
|
res = classify_audio(sig)
|
||||||
|
assert res.label == 'park'
|
||||||
@ -30,6 +30,8 @@ rosidl_generate_interfaces(${PROJECT_NAME}
|
|||||||
# Issue #348 — depth-based obstacle size estimator
|
# Issue #348 — depth-based obstacle size estimator
|
||||||
"msg/ObstacleSize.msg"
|
"msg/ObstacleSize.msg"
|
||||||
"msg/ObstacleSizeArray.msg"
|
"msg/ObstacleSizeArray.msg"
|
||||||
|
# Issue #353 — audio scene classifier
|
||||||
|
"msg/AudioScene.msg"
|
||||||
DEPENDENCIES std_msgs geometry_msgs vision_msgs builtin_interfaces
|
DEPENDENCIES std_msgs geometry_msgs vision_msgs builtin_interfaces
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,4 @@
|
|||||||
|
std_msgs/Header header
|
||||||
|
string label # 'indoor' | 'outdoor' | 'traffic' | 'park'
|
||||||
|
float32 confidence # 0.0–1.0 (nearest-centroid inverted distance)
|
||||||
|
float32[16] features # raw feature vector: MFCC[0..12] + centroid_hz + rolloff_hz + zcr
|
||||||
Loading…
x
Reference in New Issue
Block a user