Compare commits
26 Commits
a8e1ea3314
...
479a33a6fa
| Author | SHA1 | Date | |
|---|---|---|---|
| 479a33a6fa | |||
| d1f0e95fa2 | |||
| 5f3b5caef7 | |||
| cb8f6c82a4 | |||
| de1166058c | |||
| c3d36e9943 | |||
| dd033b9827 | |||
| f5093ecd34 | |||
| 54bc37926b | |||
| 30ad71e7d8 | |||
| b6104763c5 | |||
| 5108fa8fa1 | |||
| 5362536fb1 | |||
| 305ce6c971 | |||
| c7dd07f9ed | |||
| 0776003dd3 | |||
| 01ee02f837 | |||
| f0e11fe7ca | |||
| 201dea4c01 | |||
| 477258f321 | |||
| 94a6f0787e | |||
| 50636de5a9 | |||
| c348e093ef | |||
| c865e84e16 | |||
| 9d12805843 | |||
| 3cd9faeed9 |
354
chassis/cable_management_clips.scad
Normal file
354
chassis/cable_management_clips.scad
Normal file
@ -0,0 +1,354 @@
|
||||
// =============================================================================
|
||||
// SaltyBot — Cable Management Clips
|
||||
// Agent: sl-mechanical | 2026-03-02
|
||||
//
|
||||
// MODULAR SNAP-ON CABLE CLIPS with integrated adhesive base and zip-tie anchors.
|
||||
// Designed to organize power cables, sensor bundles, and signal harnesses on the
|
||||
// chassis. Each clip accommodates a range of cable bundle diameters via elastic
|
||||
// snap jaws.
|
||||
//
|
||||
// HOW IT WORKS
|
||||
// 1. Adhesive base (3M VHB or equivalent) adheres to chassis surface.
|
||||
// 2. Cable bundle pressed upward through snap jaws until it seats with audible click.
|
||||
// 3. Overhanging jaw tabs provide two zip-tie anchor points (one per side).
|
||||
// 4. Vertical ear holes accept M3 threaded inserts for wire or strap attachment.
|
||||
//
|
||||
// CLIP FAMILY
|
||||
// • Clip 5mm: Holds 4–6 mm bundles (small signal cables)
|
||||
// • Clip 8mm: Holds 6–10 mm bundles (mixed power + signal)
|
||||
// • Clip 12mm: Holds 10–14 mm bundles (heavy power)
|
||||
//
|
||||
// PARTS (set RENDER= to export each)
|
||||
// clip_5mm — 3D print × N (RENDER="clip_5mm")
|
||||
// clip_8mm — 3D print × N (RENDER="clip_8mm")
|
||||
// clip_12mm — 3D print × N (RENDER="clip_12mm")
|
||||
// assembly_all — Full preview with ghosts (RENDER="assembly_all")
|
||||
//
|
||||
// MATERIALS
|
||||
// • Body: PETG or ASA (weatherproof, adhesive-friendly)
|
||||
// • Adhesive: 3M VHB 5952F (50 mm × 75 mm pads, rated 5 N/cm²)
|
||||
// • Anchors: M3 threaded inserts (optional, for high-load retention)
|
||||
// • Zip-ties: Standard nylon 3.6 mm × 150 mm (e.g., HellermannTyton)
|
||||
//
|
||||
// INSTALLATION
|
||||
// 1. Clean chassis surface with isopropyl alcohol; let dry.
|
||||
// 2. Peel 3M VHB backing; press clip firmly (30 s hold).
|
||||
// 3. Wait 24 hours for full adhesive cure.
|
||||
// 4. Press cable bundle upward through snap jaws until seated.
|
||||
// 5. Route zip-ties through jaw anchor points; cinch at desired tension.
|
||||
// 6. Optionally thread M3 bolts through ear holes for redundant retention.
|
||||
// =============================================================================
|
||||
|
||||
$fn = 64;
|
||||
|
||||
// =============================================================================
|
||||
// CLIP GEOMETRY — COMMON PARAMETERS
|
||||
// =============================================================================
|
||||
|
||||
// SNAP JAW PROFILE (all clips share same jaw geometry)
|
||||
JAW_THICKNESS = 2.0; // mm thickness of snap arms (thin = flexible)
|
||||
JAW_BEND_R = 0.8; // mm radius at jaw root (stress relief)
|
||||
JAW_CLOSURE = 0.3; // mm interference fit depth when snapped closed
|
||||
SNAP_TRAVEL = 1.2; // mm vertical distance cable travels before seat
|
||||
SNAP_REST_GAP = 0.2; // mm gap when unloaded (keeps jaws sprung apart)
|
||||
|
||||
// ADHESIVE BASE
|
||||
BASE_LENGTH = 50.0; // mm forward-back footprint
|
||||
BASE_WIDTH = 40.0; // mm left-right footprint
|
||||
BASE_THICKNESS = 2.5; // mm base pad thickness
|
||||
BASE_FILLET = 2.0; // mm corner rounding (aids adhesive contact)
|
||||
|
||||
// ZIP-TIE ANCHOR FEATURES
|
||||
ANCHOR_TAB_H = 5.0; // mm height of jaw-tip anchor tab
|
||||
ANCHOR_TAB_T = 1.5; // mm thickness of anchor tab
|
||||
ANCHOR_SLOT_W = 4.0; // mm width of zip-tie slot (3.6 mm ties + 0.4 mm clearance)
|
||||
ANCHOR_SLOT_H = 1.0; // mm height of slot throat
|
||||
|
||||
// WIRE/STRAP ATTACHMENT EARS
|
||||
EAR_D = 4.2; // mm hole diameter (M3 clearance, 2.6 mm nominal)
|
||||
EAR_WALL_T = 3.0; // mm wall thickness around hole
|
||||
EAR_H = 8.0; // mm ear protrusion height from base
|
||||
|
||||
// =============================================================================
|
||||
// CLIP SIZE VARIANTS
|
||||
// =============================================================================
|
||||
|
||||
// For each clip size, define:
|
||||
// CABLE_D — nominal cable bundle diameter
|
||||
// JAW_SPAN — inner span of closed jaws (CABLE_D + JAW_CLOSURE)
|
||||
// CLIP_HEIGHT — overall height of clip body
|
||||
// CLAMP_X — width of clamp section (controls jaw lever arm)
|
||||
|
||||
CLIP_PARAMS = [
|
||||
// [name, cable_d, jaw_span, height, clamp_x]
|
||||
["5mm", 5.0, 5.3, 28.0, 18.0],
|
||||
["8mm", 8.0, 8.3, 35.0, 22.0],
|
||||
["12mm", 12.0, 12.3, 42.0, 26.0],
|
||||
];
|
||||
|
||||
// =============================================================================
|
||||
// RENDER CONTROL
|
||||
// =============================================================================
|
||||
|
||||
// "assembly_all" — all clips in array, with base ghosts
|
||||
// "clip_5mm" — single 5mm clip (ready to export STL)
|
||||
// "clip_8mm" — single 8mm clip
|
||||
// "clip_12mm" — single 12mm clip
|
||||
|
||||
RENDER = "assembly_all";
|
||||
|
||||
// Helper to fetch clip parameters by name
|
||||
function get_clip_params(name) =
|
||||
(name == "5mm") ? CLIP_PARAMS[0] :
|
||||
(name == "8mm") ? CLIP_PARAMS[1] :
|
||||
(name == "12mm") ? CLIP_PARAMS[2] :
|
||||
CLIP_PARAMS[0];
|
||||
|
||||
// =============================================================================
|
||||
// MAIN RENDER DISPATCH
|
||||
// =============================================================================
|
||||
|
||||
if (RENDER == "assembly_all") {
|
||||
assembly_all();
|
||||
} else if (RENDER == "clip_5mm") {
|
||||
clip_body(get_clip_params("5mm"));
|
||||
} else if (RENDER == "clip_8mm") {
|
||||
clip_body(get_clip_params("8mm"));
|
||||
} else if (RENDER == "clip_12mm") {
|
||||
clip_body(get_clip_params("12mm"));
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// ASSEMBLY VIEW (all clips in a row, with adhesive pads ghosted)
|
||||
// =============================================================================
|
||||
|
||||
module assembly_all() {
|
||||
for (i = [0 : len(CLIP_PARAMS) - 1]) {
|
||||
p = CLIP_PARAMS[i];
|
||||
x_offset = i * 70; // 70 mm spacing
|
||||
|
||||
translate([x_offset, 0, 0]) {
|
||||
// Clip body
|
||||
color("DodgerBlue", 0.92)
|
||||
clip_body(p);
|
||||
|
||||
// Adhesive base ghost (3M VHB pad)
|
||||
%color("LimeGreen", 0.40)
|
||||
translate([0, 0, -BASE_THICKNESS])
|
||||
rounded_rect([50, 40, 0.2], BASE_FILLET);
|
||||
|
||||
// Label
|
||||
echo(str("Clip ", p[0], " — Cable dia. ", p[1], " mm"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CLIP BODY MODULE (parametric across all sizes)
|
||||
// =============================================================================
|
||||
//
|
||||
// Structure:
|
||||
// • Base: rounded rectangle, adhesive-mounting surface
|
||||
// • Spine: central vertical structure, extends from base
|
||||
// • Jaws: two snap arms extending upward/outward from spine
|
||||
// • Ears: two lateral holes for M3 attachment (optional)
|
||||
// • Anchors: small tabs on jaw tips for zip-tie routing
|
||||
//
|
||||
|
||||
module clip_body(params) {
|
||||
name = params[0];
|
||||
cable_d = params[1];
|
||||
jaw_span = params[2];
|
||||
clip_h = params[3];
|
||||
clamp_x = params[4];
|
||||
|
||||
spine_thick = 3.5; // mm thickness of central spine
|
||||
jaw_l = clip_h - BASE_THICKNESS; // jaw arm length
|
||||
jaw_root_x = clamp_x / 2; // X position where jaw originates from spine
|
||||
|
||||
difference() {
|
||||
union() {
|
||||
// ── ADHESIVE BASE ──────────────────────────────────────────────
|
||||
translate([0, 0, -BASE_THICKNESS/2])
|
||||
rounded_rect([BASE_LENGTH, BASE_WIDTH, BASE_THICKNESS], BASE_FILLET);
|
||||
|
||||
// ── CENTRAL SPINE (support structure) ───────────────────────────
|
||||
translate([-spine_thick/2, -clamp_x/2, 0])
|
||||
cube([spine_thick, clamp_x, BASE_THICKNESS + jaw_l]);
|
||||
|
||||
// ── LEFT JAW (snap arm with flexible root) ──────────────────────
|
||||
jaw_body(-jaw_root_x, jaw_l, jaw_span);
|
||||
|
||||
// ── RIGHT JAW (snap arm, mirror) ───────────────────────────────
|
||||
jaw_body(jaw_root_x, jaw_l, jaw_span);
|
||||
|
||||
// ── LEFT EAR (M3 attachment hole) ──────────────────────────────
|
||||
translate([-clamp_x/2 - EAR_WALL_T - EAR_D/2, 0, BASE_THICKNESS])
|
||||
ear_boss();
|
||||
|
||||
// ── RIGHT EAR ──────────────────────────────────────────────────
|
||||
translate([clamp_x/2 + EAR_WALL_T + EAR_D/2, 0, BASE_THICKNESS])
|
||||
ear_boss();
|
||||
}
|
||||
|
||||
// ── SUBTRACT: Anchor slot hollows (zip-tie slots in jaw tips) ──────
|
||||
jaw_root_z = BASE_THICKNESS + jaw_l - ANCHOR_TAB_H;
|
||||
|
||||
// Left jaw anchor slot
|
||||
translate([-jaw_root_x - 2, -ANCHOR_SLOT_W/2, jaw_root_z])
|
||||
cube([3, ANCHOR_SLOT_W, ANCHOR_SLOT_H]);
|
||||
|
||||
// Right jaw anchor slot
|
||||
translate([jaw_root_x - 1, -ANCHOR_SLOT_W/2, jaw_root_z])
|
||||
cube([3, ANCHOR_SLOT_W, ANCHOR_SLOT_H]);
|
||||
|
||||
// ── SUBTRACT: Ear attachment holes (M3 clearance) ────────────────
|
||||
// Left ear hole
|
||||
translate([-clamp_x/2 - EAR_WALL_T - EAR_D/2, 0, BASE_THICKNESS + EAR_H/2])
|
||||
cylinder(d=EAR_D, h=EAR_H + 1, center=true);
|
||||
|
||||
// Right ear hole
|
||||
translate([clamp_x/2 + EAR_WALL_T + EAR_D/2, 0, BASE_THICKNESS + EAR_H/2])
|
||||
cylinder(d=EAR_D, h=EAR_H + 1, center=true);
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// JAW BODY (single snap arm with cable pocket)
|
||||
// =============================================================================
|
||||
//
|
||||
// A flexible cantilever arm extending from the spine.
|
||||
// Lower section: solid (load-bearing).
|
||||
// Upper section: curved U-channel (grips cable).
|
||||
// Jaw tips: overhanging tabs for zip-tie anchors.
|
||||
//
|
||||
|
||||
module jaw_body(x_root, jaw_length, inner_span) {
|
||||
jaw_span_outer = inner_span + 2 * JAW_THICKNESS;
|
||||
|
||||
// The jaw sweeps from x_root (spine side) along +X, curving to grip.
|
||||
// At the tip, it has a slight outward bow for snap action.
|
||||
|
||||
difference() {
|
||||
union() {
|
||||
// Lower jaw arm (solid, structural)
|
||||
translate([x_root, -jaw_span_outer/2, BASE_THICKNESS])
|
||||
cube([jaw_length * 0.65, jaw_span_outer, JAW_THICKNESS * 1.5]);
|
||||
|
||||
// Upper jaw arm (U-channel form)
|
||||
translate([x_root, -inner_span/2 - JAW_THICKNESS, BASE_THICKNESS])
|
||||
cube([jaw_length * 0.85, inner_span + 2*JAW_THICKNESS, JAW_THICKNESS]);
|
||||
|
||||
// Jaw tip anchor tab (for zip-tie slots)
|
||||
tip_x = x_root + jaw_length * 0.8;
|
||||
translate([tip_x, -jaw_span_outer/2 - ANCHOR_TAB_T,
|
||||
BASE_THICKNESS + JAW_THICKNESS])
|
||||
cube([ANCHOR_TAB_H, jaw_span_outer + 2*ANCHOR_TAB_T, ANCHOR_TAB_H]);
|
||||
}
|
||||
|
||||
// Hollow out the U-channel (cable pocket)
|
||||
// Inner cavity: inner_span wide, runs most of jaw length
|
||||
translate([x_root + JAW_THICKNESS * 0.5, -inner_span/2, BASE_THICKNESS])
|
||||
cube([jaw_length * 0.7, inner_span, JAW_THICKNESS + 0.5]);
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// EAR BOSS (M3 attachment point)
|
||||
// =============================================================================
|
||||
//
|
||||
// A small raised button with a through-hole, providing optional redundant
|
||||
// attachment for straps or hard-wired retention.
|
||||
//
|
||||
|
||||
module ear_boss() {
|
||||
difference() {
|
||||
cylinder(d=EAR_D + 2*EAR_WALL_T, h=EAR_H);
|
||||
translate([0, 0, -1])
|
||||
cylinder(d=EAR_D, h=EAR_H + 2);
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// UTILITY: Rounded Rectangle (for base and ghosts)
|
||||
// =============================================================================
|
||||
//
|
||||
|
||||
module rounded_rect(size, r) {
|
||||
// size = [width, length, height]
|
||||
w = size[0];
|
||||
l = size[1];
|
||||
h = size[2];
|
||||
|
||||
linear_extrude(height=h)
|
||||
offset(r=r)
|
||||
offset(r=-r)
|
||||
square([w, l], center=true);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// EXPORT / PRINT INSTRUCTIONS
|
||||
// =============================================================================
|
||||
//
|
||||
// CLIP 5mm (3D print × N):
|
||||
// openscad cable_management_clips.scad -D 'RENDER="clip_5mm"' -o clip_5mm.stl
|
||||
// Print settings: PETG/ASA, 4 perimeters, 20% infill, 0.2 mm layer
|
||||
// Orientation: base flat on bed (smooth finish for adhesive)
|
||||
//
|
||||
// CLIP 8mm (3D print × N):
|
||||
// openscad cable_management_clips.scad -D 'RENDER="clip_8mm"' -o clip_8mm.stl
|
||||
// Print settings: Same as 5mm
|
||||
//
|
||||
// CLIP 12mm (3D print × N):
|
||||
// openscad cable_management_clips.scad -D 'RENDER="clip_12mm"' -o clip_12mm.stl
|
||||
// Print settings: Same as 5mm
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// INSTALLATION GUIDE
|
||||
//
|
||||
// 1. SURFACE PREP
|
||||
// • Clean chassis surface with isopropyl alcohol.
|
||||
// • Let dry for 5 minutes; inspect for dust or residue.
|
||||
//
|
||||
// 2. ADHESIVE APPLICATION
|
||||
// • Cut 3M VHB 5952F into ~50 × 40 mm pads (one per clip).
|
||||
// • Peel foil backing from VHB pad.
|
||||
// • Center pad on clip base; press firmly for 30 seconds.
|
||||
// • Peel clear polyester liner from exposed adhesive.
|
||||
//
|
||||
// 3. MOUNTING
|
||||
// • Position clip on chassis surface (e.g., along frame rail).
|
||||
// • Press and hold for 30 seconds, applying full body weight if possible.
|
||||
// • Let cure for 24 hours before loading cables.
|
||||
//
|
||||
// 4. CABLE INSERTION
|
||||
// • Gather cable bundle (power, signal, etc.); inspect for knots/damage.
|
||||
// • Align bundle perpendicular to clip jaws.
|
||||
// • Press upward with steady pressure until jaws snap closed (audible click).
|
||||
// • Tension should hold cable 5–10 N without slip.
|
||||
//
|
||||
// 5. ZIP-TIE ANCHORING (optional extra security)
|
||||
// • Thread 3.6 mm nylon zip-tie through jaw anchor tabs (left and right).
|
||||
// • Route around cable bundle; cinch to desired tension (avoid crushing).
|
||||
// • Trim excess tie length.
|
||||
//
|
||||
// 6. THREADED INSERTION (optional M3 redundancy)
|
||||
// • Install M3 threaded insert into ear hole (using M3 insertion tool).
|
||||
// • Thread M3 × 16 mm bolt with split washer through ear.
|
||||
// • Tighten 1.5 N·m (firm but not excessive).
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// CABLE ROUTING BEST PRACTICES
|
||||
//
|
||||
// • Power cables (main): Use 12mm clips, spacing 150–200 mm apart.
|
||||
// • Mixed signal bundles: Use 8mm clips, spacing 100–150 mm apart.
|
||||
// • Individual sensor leads: Use 5mm clips or traditional P-clips.
|
||||
//
|
||||
// • Avoid sharp bends: Route bundles with R ≥ 50 mm (cable bundle diameter).
|
||||
// • Prevent abrasion: Use snap clips where cable crosses sharp edges.
|
||||
// • Allow thermal expansion: Leave ~2–3 mm slack in long runs.
|
||||
// • Color-code bundles: Use electrical tape or heatshrink before clipping.
|
||||
//
|
||||
// =============================================================================
|
||||
162
include/fan.h
Normal file
162
include/fan.h
Normal file
@ -0,0 +1,162 @@
|
||||
#ifndef FAN_H
|
||||
#define FAN_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
/*
|
||||
* fan.h — Cooling fan PWM speed controller (Issue #263)
|
||||
*
|
||||
* STM32F722 driver for brushless cooling fan on PA9 using TIM1_CH2 PWM.
|
||||
* Temperature-based speed curve with smooth ramp transitions.
|
||||
*
|
||||
* Pin: PA9 (TIM1_CH2, alternate function AF1)
|
||||
* PWM Frequency: 25 kHz (suitable for brushless DC fan)
|
||||
* Speed Range: 0-100% duty cycle
|
||||
*
|
||||
* Temperature Curve:
|
||||
* - Below 40°C: Fan off (0%)
|
||||
* - 40-50°C: Linear ramp from 0% to 30%
|
||||
* - 50-70°C: Linear ramp from 30% to 100%
|
||||
* - Above 70°C: Fan at maximum (100%)
|
||||
*/
|
||||
|
||||
/* Fan speed state */
|
||||
typedef enum {
|
||||
FAN_OFF, /* Motor disabled (0% duty) */
|
||||
FAN_LOW, /* Low speed (5-30%) */
|
||||
FAN_MEDIUM, /* Medium speed (31-60%) */
|
||||
FAN_HIGH, /* High speed (61-99%) */
|
||||
FAN_FULL /* Maximum speed (100%) */
|
||||
} FanState;
|
||||
|
||||
/*
|
||||
* fan_init()
|
||||
*
|
||||
* Initialize fan controller:
|
||||
* - PA9 as TIM1_CH2 PWM output
|
||||
* - TIM1 configured for 25 kHz frequency
|
||||
* - PWM duty cycle control (0-100%)
|
||||
* - Ramp rate limiter for smooth transitions
|
||||
*/
|
||||
void fan_init(void);
|
||||
|
||||
/*
|
||||
* fan_set_speed(percentage)
|
||||
*
|
||||
* Set fan speed directly (bypasses temperature control).
|
||||
* Used for manual testing or emergency cooling.
|
||||
*
|
||||
* Arguments:
|
||||
* - percentage: 0-100% duty cycle
|
||||
*
|
||||
* Returns: true if set successfully, false if invalid value
|
||||
*/
|
||||
bool fan_set_speed(uint8_t percentage);
|
||||
|
||||
/*
|
||||
* fan_get_speed()
|
||||
*
|
||||
* Get current fan speed setting.
|
||||
*
|
||||
* Returns: Current speed 0-100%
|
||||
*/
|
||||
uint8_t fan_get_speed(void);
|
||||
|
||||
/*
|
||||
* fan_set_target_speed(percentage)
|
||||
*
|
||||
* Set target speed with smooth ramping.
|
||||
* Speed transitions over time according to ramp rate.
|
||||
*
|
||||
* Arguments:
|
||||
* - percentage: Target speed 0-100%
|
||||
*
|
||||
* Returns: true if set successfully
|
||||
*/
|
||||
bool fan_set_target_speed(uint8_t percentage);
|
||||
|
||||
/*
|
||||
* fan_update_temperature(temp_celsius)
|
||||
*
|
||||
* Update temperature reading and apply speed curve.
|
||||
* Calculates target speed based on temperature curve.
|
||||
* Speed transition is smoothed via ramp limiter.
|
||||
*
|
||||
* Temperature Curve:
|
||||
* - temp < 40°C: 0% (off)
|
||||
* - 40°C ≤ temp < 50°C: 0% + (temp - 40) * 3% per °C = linear to 30%
|
||||
* - 50°C ≤ temp < 70°C: 30% + (temp - 50) * 3.5% per °C = linear to 100%
|
||||
* - temp ≥ 70°C: 100% (full)
|
||||
*
|
||||
* Arguments:
|
||||
* - temp_celsius: Temperature in degrees Celsius (int16_t for negative values)
|
||||
*/
|
||||
void fan_update_temperature(int16_t temp_celsius);
|
||||
|
||||
/*
|
||||
* fan_get_temperature()
|
||||
*
|
||||
* Get last recorded temperature.
|
||||
*
|
||||
* Returns: Temperature in °C (or 0 if not yet set)
|
||||
*/
|
||||
int16_t fan_get_temperature(void);
|
||||
|
||||
/*
|
||||
* fan_get_state()
|
||||
*
|
||||
* Get current fan operational state.
|
||||
*
|
||||
* Returns: FAN_OFF, FAN_LOW, FAN_MEDIUM, FAN_HIGH, or FAN_FULL
|
||||
*/
|
||||
FanState fan_get_state(void);
|
||||
|
||||
/*
|
||||
* fan_set_ramp_rate(percentage_per_ms)
|
||||
*
|
||||
* Configure speed ramp rate for smooth transitions.
|
||||
* Default: 5% per 100ms = 0.05% per ms.
|
||||
* Higher values = faster transitions.
|
||||
*
|
||||
* Arguments:
|
||||
* - percentage_per_ms: Speed change per millisecond (e.g., 1 = 1% per ms)
|
||||
*
|
||||
* Typical ranges:
|
||||
* - 0.01 = very slow (100% change in 10 seconds)
|
||||
* - 0.05 = slow (100% change in 2 seconds)
|
||||
* - 0.1 = medium (100% change in 1 second)
|
||||
* - 1.0 = fast (100% change in 100ms)
|
||||
*/
|
||||
void fan_set_ramp_rate(float percentage_per_ms);
|
||||
|
||||
/*
|
||||
* fan_is_ramping()
|
||||
*
|
||||
* Check if speed is currently transitioning.
|
||||
*
|
||||
* Returns: true if speed is ramping toward target, false if at target
|
||||
*/
|
||||
bool fan_is_ramping(void);
|
||||
|
||||
/*
|
||||
* fan_tick(now_ms)
|
||||
*
|
||||
* Update function called periodically (recommended: every 10-100ms).
|
||||
* Processes speed ramp transitions.
|
||||
* Must be called regularly for smooth ramping operation.
|
||||
*
|
||||
* Arguments:
|
||||
* - now_ms: current time in milliseconds (from HAL_GetTick() or similar)
|
||||
*/
|
||||
void fan_tick(uint32_t now_ms);
|
||||
|
||||
/*
|
||||
* fan_disable()
|
||||
*
|
||||
* Disable fan immediately (set to 0% duty).
|
||||
* Useful for shutdown or emergency stop.
|
||||
*/
|
||||
void fan_disable(void);
|
||||
|
||||
#endif /* FAN_H */
|
||||
@ -25,6 +25,8 @@
|
||||
<exec_depend>saltybot_follower</exec_depend>
|
||||
<exec_depend>saltybot_outdoor</exec_depend>
|
||||
<exec_depend>saltybot_perception</exec_depend>
|
||||
<!-- HSV color segmentation messages (Issue #274) -->
|
||||
<exec_depend>saltybot_scene_msgs</exec_depend>
|
||||
<exec_depend>saltybot_uwb</exec_depend>
|
||||
|
||||
<buildtool_depend>ament_python</buildtool_depend>
|
||||
|
||||
@ -0,0 +1,184 @@
|
||||
"""
|
||||
_color_segmenter.py — HSV color segmentation helpers (no ROS2 deps).
|
||||
|
||||
Algorithm
|
||||
---------
|
||||
For each requested color:
|
||||
1. Convert BGR → HSV (OpenCV: H∈[0,180], S∈[0,255], V∈[0,255])
|
||||
2. Build a binary mask via cv2.inRange using the color's HSV bounds.
|
||||
Red wraps around H=0/180 so two ranges are OR-combined.
|
||||
3. Morphological open (3×3) to remove noise.
|
||||
4. Find external contours; filter by min_area_px.
|
||||
5. Return ColorBlob NamedTuples — one per surviving contour.
|
||||
|
||||
confidence is the contour area divided by the bounding-rectangle area
|
||||
(how "filled" the bounding box is), clamped to [0, 1].
|
||||
|
||||
Public API
|
||||
----------
|
||||
HsvRange(h_lo, h_hi, s_lo, s_hi, v_lo, v_hi)
|
||||
ColorBlob(color_name, confidence, cx, cy, w, h, area_px, contour_id)
|
||||
COLOR_RANGES : Dict[str, List[HsvRange]] — default per-color HSV ranges
|
||||
mask_for_color(hsv, color_name) -> np.ndarray — uint8 binary mask
|
||||
find_color_blobs(bgr, active_colors, min_area_px, max_blobs_per_color) -> List[ColorBlob]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ── Data types ────────────────────────────────────────────────────────────────
|
||||
|
||||
class HsvRange(NamedTuple):
|
||||
"""Single HSV band (OpenCV: H∈[0,180], S/V∈[0,255])."""
|
||||
h_lo: int
|
||||
h_hi: int
|
||||
s_lo: int
|
||||
s_hi: int
|
||||
v_lo: int
|
||||
v_hi: int
|
||||
|
||||
|
||||
class ColorBlob(NamedTuple):
|
||||
"""One detected color object in image coordinates."""
|
||||
color_name: str
|
||||
confidence: float # contour_area / bbox_area (0–1)
|
||||
cx: float # bbox centre x (pixels)
|
||||
cy: float # bbox centre y (pixels)
|
||||
w: float # bbox width (pixels)
|
||||
h: float # bbox height (pixels)
|
||||
area_px: float # contour area (pixels²)
|
||||
contour_id: int # 0-based index within this color in this frame
|
||||
|
||||
|
||||
# ── Default per-color HSV ranges ──────────────────────────────────────────────
|
||||
# Two ranges are used for red (wraps at 0/180).
|
||||
# S_lo=60, V_lo=50 to ignore desaturated / near-black pixels.
|
||||
|
||||
COLOR_RANGES: Dict[str, List[HsvRange]] = {
|
||||
'red': [
|
||||
HsvRange(h_lo=0, h_hi=10, s_lo=60, s_hi=255, v_lo=50, v_hi=255),
|
||||
HsvRange(h_lo=170, h_hi=180, s_lo=60, s_hi=255, v_lo=50, v_hi=255),
|
||||
],
|
||||
'green': [
|
||||
HsvRange(h_lo=35, h_hi=85, s_lo=60, s_hi=255, v_lo=50, v_hi=255),
|
||||
],
|
||||
'blue': [
|
||||
HsvRange(h_lo=90, h_hi=130, s_lo=60, s_hi=255, v_lo=50, v_hi=255),
|
||||
],
|
||||
'yellow': [
|
||||
HsvRange(h_lo=18, h_hi=38, s_lo=60, s_hi=255, v_lo=80, v_hi=255),
|
||||
],
|
||||
'orange': [
|
||||
HsvRange(h_lo=8, h_hi=20, s_lo=80, s_hi=255, v_lo=80, v_hi=255),
|
||||
],
|
||||
}
|
||||
|
||||
# Structuring element for morphological open (noise removal)
|
||||
_MORPH_KERNEL = None
|
||||
|
||||
|
||||
def _get_morph_kernel():
|
||||
import cv2
|
||||
global _MORPH_KERNEL
|
||||
if _MORPH_KERNEL is None:
|
||||
_MORPH_KERNEL = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
||||
return _MORPH_KERNEL
|
||||
|
||||
|
||||
# ── Public helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
def mask_for_color(hsv: np.ndarray, color_name: str) -> np.ndarray:
|
||||
"""
|
||||
Return a uint8 binary mask (255=foreground) for *color_name* in the HSV image.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
hsv : (H, W, 3) uint8 ndarray in OpenCV HSV format (H∈[0,180])
|
||||
color_name : one of COLOR_RANGES keys
|
||||
|
||||
Returns
|
||||
-------
|
||||
(H, W) uint8 ndarray
|
||||
"""
|
||||
import cv2
|
||||
|
||||
ranges = COLOR_RANGES.get(color_name)
|
||||
if not ranges:
|
||||
raise ValueError(f'Unknown color: {color_name!r}. Known: {list(COLOR_RANGES)}')
|
||||
|
||||
mask = np.zeros(hsv.shape[:2], dtype=np.uint8)
|
||||
for r in ranges:
|
||||
lo = np.array([r.h_lo, r.s_lo, r.v_lo], dtype=np.uint8)
|
||||
hi = np.array([r.h_hi, r.s_hi, r.v_hi], dtype=np.uint8)
|
||||
mask |= cv2.inRange(hsv, lo, hi)
|
||||
|
||||
return cv2.morphologyEx(mask, cv2.MORPH_OPEN, _get_morph_kernel())
|
||||
|
||||
|
||||
def find_color_blobs(
|
||||
bgr: np.ndarray,
|
||||
active_colors: List[str] | None = None,
|
||||
min_area_px: float = 200.0,
|
||||
max_blobs_per_color: int = 10,
|
||||
) -> List[ColorBlob]:
|
||||
"""
|
||||
Detect HSV-segmented color blobs in a BGR image.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bgr : (H, W, 3) uint8 BGR ndarray
|
||||
active_colors : color names to detect; None → all COLOR_RANGES keys
|
||||
min_area_px : minimum contour area to report (pixels²)
|
||||
max_blobs_per_color : keep at most this many blobs per color (largest first)
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[ColorBlob] — may be empty; contour_id is 0-based within each color
|
||||
"""
|
||||
import cv2
|
||||
|
||||
if active_colors is None:
|
||||
active_colors = list(COLOR_RANGES.keys())
|
||||
|
||||
hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
|
||||
blobs: List[ColorBlob] = []
|
||||
|
||||
for color_name in active_colors:
|
||||
mask = mask_for_color(hsv, color_name)
|
||||
contours, _ = cv2.findContours(
|
||||
mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
# Sort largest first so max_blobs_per_color keeps the significant ones
|
||||
contours = sorted(contours, key=cv2.contourArea, reverse=True)
|
||||
|
||||
blob_idx = 0
|
||||
for cnt in contours:
|
||||
if blob_idx >= max_blobs_per_color:
|
||||
break
|
||||
|
||||
area = cv2.contourArea(cnt)
|
||||
if area < min_area_px:
|
||||
break # already sorted, no need to continue
|
||||
|
||||
x, y, bw, bh = cv2.boundingRect(cnt)
|
||||
bbox_area = float(bw * bh)
|
||||
confidence = float(area / bbox_area) if bbox_area > 0 else 0.0
|
||||
confidence = min(1.0, max(0.0, confidence))
|
||||
|
||||
blobs.append(ColorBlob(
|
||||
color_name=color_name,
|
||||
confidence=confidence,
|
||||
cx=float(x + bw / 2.0),
|
||||
cy=float(y + bh / 2.0),
|
||||
w=float(bw),
|
||||
h=float(bh),
|
||||
area_px=float(area),
|
||||
contour_id=blob_idx,
|
||||
))
|
||||
blob_idx += 1
|
||||
|
||||
return blobs
|
||||
@ -0,0 +1,141 @@
|
||||
"""
|
||||
_depth_hole_fill.py — Depth image hole filling via bilateral interpolation (no ROS2 deps).
|
||||
|
||||
Algorithm
|
||||
---------
|
||||
A "hole" is any pixel where depth == 0, depth is NaN, or depth is outside the
|
||||
valid range [d_min, d_max].
|
||||
|
||||
Each pass replaces every hole pixel with the spatial-Gaussian-weighted mean of
|
||||
valid pixels in a (kernel_size × kernel_size) neighbourhood:
|
||||
|
||||
filled[x,y] = Σ G(||p - q||; σ) · d[q] / Σ G(||p - q||; σ)
|
||||
q ∈ valid neighbours of (x,y)
|
||||
|
||||
The denominator (sum of spatial weights over valid pixels) normalises correctly
|
||||
even at image borders and around isolated valid pixels.
|
||||
|
||||
Multiple passes with geometrically growing kernels are applied so that:
|
||||
Pass 1 kernel_size — fills small holes (≤ kernel_size/2 px radius)
|
||||
Pass 2 kernel_size × 2.5 — fills medium holes
|
||||
Pass 3 kernel_size × 6.0 — fills large holes / fronto-parallel surfaces
|
||||
|
||||
After all passes any remaining zeros are left as-is (no valid neighbourhood data).
|
||||
|
||||
Because only the spatial Gaussian (not a depth range term) is used as the weighting
|
||||
function, this is equivalent to a bilateral filter with σ_range → ∞. In practice
|
||||
this produces smooth, physically plausible fills in the depth domain.
|
||||
|
||||
Public API
|
||||
----------
|
||||
fill_holes(depth, kernel_size=5, d_min=0.1, d_max=10.0, max_passes=3) → ndarray
|
||||
valid_mask(depth, d_min=0.1, d_max=10.0) → bool ndarray
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Kernel size multipliers for successive passes
|
||||
_PASS_SCALE = [1.0, 2.5, 6.0]
|
||||
|
||||
|
||||
def valid_mask(
|
||||
depth: np.ndarray,
|
||||
d_min: float = 0.1,
|
||||
d_max: float = 10.0,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Return a boolean mask of valid (non-hole) pixels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
depth : (H, W) float32 ndarray, depth in metres
|
||||
d_min : minimum valid depth (m)
|
||||
d_max : maximum valid depth (m)
|
||||
|
||||
Returns
|
||||
-------
|
||||
(H, W) bool ndarray — True where depth is finite and in [d_min, d_max]
|
||||
"""
|
||||
return np.isfinite(depth) & (depth >= d_min) & (depth <= d_max)
|
||||
|
||||
|
||||
def fill_holes(
|
||||
depth: np.ndarray,
|
||||
kernel_size: int = 5,
|
||||
d_min: float = 0.1,
|
||||
d_max: float = 10.0,
|
||||
max_passes: int = 3,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Fill zero/NaN depth pixels using multi-pass spatial Gaussian interpolation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
depth : (H, W) float32 ndarray, depth in metres
|
||||
kernel_size : initial kernel side length (pixels, forced odd, ≥ 3)
|
||||
d_min : minimum valid depth — pixels below this are treated as holes
|
||||
d_max : maximum valid depth — pixels above this are treated as holes
|
||||
max_passes : number of fill passes (1–3); each uses a larger kernel
|
||||
|
||||
Returns
|
||||
-------
|
||||
(H, W) float32 ndarray — same as input, with holes filled where possible.
|
||||
Pixels with no valid neighbours after all passes remain 0.0.
|
||||
Original valid pixels are never modified.
|
||||
"""
|
||||
import cv2
|
||||
|
||||
depth = np.asarray(depth, dtype=np.float32)
|
||||
# Replace NaN with 0 so arithmetic is clean
|
||||
depth = np.where(np.isfinite(depth), depth, 0.0).astype(np.float32)
|
||||
|
||||
mask = valid_mask(depth, d_min, d_max) # True where already valid
|
||||
result = depth.copy()
|
||||
n_passes = max(1, min(max_passes, len(_PASS_SCALE)))
|
||||
|
||||
for i in range(n_passes):
|
||||
if mask.all():
|
||||
break # no holes left
|
||||
|
||||
ks = _odd_kernel_size(kernel_size, _PASS_SCALE[i])
|
||||
half = ks // 2
|
||||
sigma = max(half / 2.0, 0.5)
|
||||
|
||||
gk = cv2.getGaussianKernel(ks, sigma).astype(np.float32)
|
||||
kernel = (gk @ gk.T)
|
||||
|
||||
# Multiply depth by mask so invalid pixels contribute 0 weight
|
||||
d_valid = np.where(mask, result, 0.0).astype(np.float32)
|
||||
w_valid = mask.astype(np.float32)
|
||||
|
||||
sum_d = cv2.filter2D(d_valid, ddepth=-1, kernel=kernel,
|
||||
borderType=cv2.BORDER_REFLECT)
|
||||
sum_w = cv2.filter2D(w_valid, ddepth=-1, kernel=kernel,
|
||||
borderType=cv2.BORDER_REFLECT)
|
||||
|
||||
# Where we have enough weight, compute the weighted mean
|
||||
has_data = sum_w > 1e-6
|
||||
interp = np.where(has_data, sum_d / np.where(has_data, sum_w, 1.0), 0.0)
|
||||
|
||||
# Only fill holes — never overwrite original valid pixels
|
||||
result = np.where(mask, result, interp.astype(np.float32))
|
||||
|
||||
# Update mask with newly filled pixels (for the next pass)
|
||||
newly_filled = (~mask) & (result > 0.0)
|
||||
mask = mask | newly_filled
|
||||
|
||||
return result.astype(np.float32)
|
||||
|
||||
|
||||
# ── Internal helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
def _odd_kernel_size(base: int, scale: float) -> int:
|
||||
"""Return the nearest odd integer to base * scale, minimum 3."""
|
||||
raw = max(3, int(round(base * scale)))
|
||||
return raw if raw % 2 == 1 else raw + 1
|
||||
@ -0,0 +1,150 @@
|
||||
"""
|
||||
_vo_drift.py — Visual odometry drift detector helpers (no ROS2 deps).
|
||||
|
||||
Algorithm
|
||||
---------
|
||||
Two independent odometry streams (visual and wheel) are compared over a
|
||||
sliding time window. Drift is measured as the absolute difference in
|
||||
cumulative path length travelled by each source over that window:
|
||||
|
||||
drift_m = |path_length(vo_window) − path_length(wheel_window)|
|
||||
|
||||
Using cumulative path length (sum of inter-sample Euclidean steps) rather
|
||||
than straight-line displacement makes the measure robust to circular motion
|
||||
where start and end positions are the same.
|
||||
|
||||
Drift is flagged when drift_m ≥ drift_threshold_m.
|
||||
|
||||
Public API
|
||||
----------
|
||||
OdomSample — namedtuple(t, x, y)
|
||||
OdomBuffer — deque of OdomSamples with time-window trimming
|
||||
compute_drift() — compare two OdomBuffers and return DriftResult
|
||||
DriftResult — namedtuple(drift_m, vo_path_m, wheel_path_m,
|
||||
is_drifting, window_s, n_vo, n_wheel)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
from typing import NamedTuple, Sequence
|
||||
|
||||
|
||||
class OdomSample(NamedTuple):
|
||||
t: float # monotonic timestamp (seconds)
|
||||
x: float # position x (metres)
|
||||
y: float # position y (metres)
|
||||
|
||||
|
||||
class DriftResult(NamedTuple):
|
||||
drift_m: float # |vo_path − wheel_path| (metres)
|
||||
vo_path_m: float # cumulative path of VO source over window (metres)
|
||||
wheel_path_m: float # cumulative path of wheel source over window (metres)
|
||||
is_drifting: bool # True when drift_m >= threshold
|
||||
window_s: float # actual time span of data used (seconds)
|
||||
n_vo: int # number of VO samples in window
|
||||
n_wheel: int # number of wheel samples in window
|
||||
|
||||
|
||||
class OdomBuffer:
|
||||
"""
|
||||
Rolling buffer of OdomSamples trimmed to the last `max_age_s` seconds.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_age_s : float — samples older than this are discarded (seconds)
|
||||
"""
|
||||
|
||||
def __init__(self, max_age_s: float = 10.0) -> None:
|
||||
self._max_age = max_age_s
|
||||
self._buf: deque[OdomSample] = deque()
|
||||
|
||||
# ── Public ────────────────────────────────────────────────────────────────
|
||||
|
||||
def push(self, sample: OdomSample) -> None:
|
||||
"""Append a sample and evict anything older than max_age_s."""
|
||||
self._buf.append(sample)
|
||||
self._trim(sample.t)
|
||||
|
||||
def window(self, window_s: float, now: float) -> list[OdomSample]:
|
||||
"""Return samples within the last window_s seconds of `now`."""
|
||||
cutoff = now - window_s
|
||||
return [s for s in self._buf if s.t >= cutoff]
|
||||
|
||||
def clear(self) -> None:
|
||||
self._buf.clear()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._buf)
|
||||
|
||||
# ── Internal ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _trim(self, now: float) -> None:
|
||||
cutoff = now - self._max_age
|
||||
while self._buf and self._buf[0].t < cutoff:
|
||||
self._buf.popleft()
|
||||
|
||||
|
||||
# ── Core computation ──────────────────────────────────────────────────────────
|
||||
|
||||
def compute_drift(
|
||||
vo_buf: OdomBuffer,
|
||||
wheel_buf: OdomBuffer,
|
||||
window_s: float,
|
||||
drift_threshold_m: float,
|
||||
now: float,
|
||||
) -> DriftResult:
|
||||
"""
|
||||
Compare VO and wheel odometry path lengths over the last `window_s`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vo_buf : OdomBuffer of visual odometry samples
|
||||
wheel_buf : OdomBuffer of wheel odometry samples
|
||||
window_s : comparison window width (seconds)
|
||||
drift_threshold_m : drift_m threshold for is_drifting flag
|
||||
now : current time (same scale as OdomSample.t)
|
||||
|
||||
Returns
|
||||
-------
|
||||
DriftResult — zero drift if either buffer has fewer than 2 samples.
|
||||
"""
|
||||
vo_samples = vo_buf.window(window_s, now)
|
||||
wheel_samples = wheel_buf.window(window_s, now)
|
||||
|
||||
if len(vo_samples) < 2 or len(wheel_samples) < 2:
|
||||
return DriftResult(
|
||||
drift_m=0.0, vo_path_m=0.0, wheel_path_m=0.0,
|
||||
is_drifting=False,
|
||||
window_s=0.0, n_vo=len(vo_samples), n_wheel=len(wheel_samples),
|
||||
)
|
||||
|
||||
vo_path = _path_length(vo_samples)
|
||||
wheel_path = _path_length(wheel_samples)
|
||||
drift_m = abs(vo_path - wheel_path)
|
||||
|
||||
# Actual data span = latest timestamp − earliest across both buffers
|
||||
t_min = min(vo_samples[0].t, wheel_samples[0].t)
|
||||
t_max = max(vo_samples[-1].t, wheel_samples[-1].t)
|
||||
actual_window = t_max - t_min
|
||||
|
||||
return DriftResult(
|
||||
drift_m=drift_m,
|
||||
vo_path_m=vo_path,
|
||||
wheel_path_m=wheel_path,
|
||||
is_drifting=drift_m >= drift_threshold_m,
|
||||
window_s=actual_window,
|
||||
n_vo=len(vo_samples),
|
||||
n_wheel=len(wheel_samples),
|
||||
)
|
||||
|
||||
|
||||
def _path_length(samples: Sequence[OdomSample]) -> float:
|
||||
"""Sum of Euclidean inter-sample distances."""
|
||||
total = 0.0
|
||||
for i in range(1, len(samples)):
|
||||
dx = samples[i].x - samples[i - 1].x
|
||||
dy = samples[i].y - samples[i - 1].y
|
||||
total += math.sqrt(dx * dx + dy * dy)
|
||||
return total
|
||||
@ -0,0 +1,127 @@
|
||||
"""
|
||||
color_segment_node.py — D435i HSV color object segmenter (Issue #274).
|
||||
|
||||
Subscribes to the RealSense colour stream, applies per-color HSV thresholding,
|
||||
extracts contours, and publishes detected blobs as ColorDetectionArray.
|
||||
|
||||
Subscribes (BEST_EFFORT):
|
||||
/camera/color/image_raw sensor_msgs/Image BGR8 (or rgb8)
|
||||
|
||||
Publishes:
|
||||
/saltybot/color_objects saltybot_scene_msgs/ColorDetectionArray
|
||||
|
||||
Parameters
|
||||
----------
|
||||
active_colors str "red,green,blue,yellow,orange" Comma-separated list
|
||||
min_area_px float 200.0 Minimum contour area (pixels²)
|
||||
max_blobs_per_color int 10 Max detections per color per frame
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||
|
||||
import numpy as np
|
||||
from cv_bridge import CvBridge
|
||||
|
||||
from sensor_msgs.msg import Image
|
||||
from std_msgs.msg import Header
|
||||
|
||||
from saltybot_scene_msgs.msg import ColorDetection, ColorDetectionArray
|
||||
from vision_msgs.msg import BoundingBox2D
|
||||
from geometry_msgs.msg import Pose2D
|
||||
|
||||
from ._color_segmenter import find_color_blobs
|
||||
|
||||
|
||||
_SENSOR_QOS = QoSProfile(
|
||||
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||
history=HistoryPolicy.KEEP_LAST,
|
||||
depth=4,
|
||||
)
|
||||
|
||||
_DEFAULT_COLORS = 'red,green,blue,yellow,orange'
|
||||
|
||||
|
||||
class ColorSegmentNode(Node):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__('color_segment_node')
|
||||
|
||||
self.declare_parameter('active_colors', _DEFAULT_COLORS)
|
||||
self.declare_parameter('min_area_px', 200.0)
|
||||
self.declare_parameter('max_blobs_per_color', 10)
|
||||
|
||||
colors_str = self.get_parameter('active_colors').value
|
||||
self._active_colors = [c.strip() for c in colors_str.split(',') if c.strip()]
|
||||
self._min_area = float(self.get_parameter('min_area_px').value)
|
||||
self._max_blobs = int(self.get_parameter('max_blobs_per_color').value)
|
||||
|
||||
self._bridge = CvBridge()
|
||||
|
||||
self._sub = self.create_subscription(
|
||||
Image, '/camera/color/image_raw', self._on_image, _SENSOR_QOS)
|
||||
self._pub = self.create_publisher(
|
||||
ColorDetectionArray, '/saltybot/color_objects', 10)
|
||||
|
||||
self.get_logger().info(
|
||||
f'color_segment_node ready — colors={self._active_colors} '
|
||||
f'min_area={self._min_area}px² max_blobs={self._max_blobs}'
|
||||
)
|
||||
|
||||
# ── Callback ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _on_image(self, msg: Image) -> None:
|
||||
try:
|
||||
bgr = self._bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')
|
||||
except Exception as exc:
|
||||
self.get_logger().error(
|
||||
f'cv_bridge: {exc}', throttle_duration_sec=5.0)
|
||||
return
|
||||
|
||||
blobs = find_color_blobs(
|
||||
bgr,
|
||||
active_colors=self._active_colors,
|
||||
min_area_px=self._min_area,
|
||||
max_blobs_per_color=self._max_blobs,
|
||||
)
|
||||
|
||||
arr = ColorDetectionArray()
|
||||
arr.header = msg.header
|
||||
|
||||
for blob in blobs:
|
||||
det = ColorDetection()
|
||||
det.header = msg.header
|
||||
det.color_name = blob.color_name
|
||||
det.confidence = blob.confidence
|
||||
det.area_px = blob.area_px
|
||||
det.contour_id = blob.contour_id
|
||||
|
||||
bbox = BoundingBox2D()
|
||||
center = Pose2D()
|
||||
center.x = blob.cx
|
||||
center.y = blob.cy
|
||||
bbox.center = center
|
||||
bbox.size_x = blob.w
|
||||
bbox.size_y = blob.h
|
||||
det.bbox = bbox
|
||||
|
||||
arr.detections.append(det)
|
||||
|
||||
self._pub.publish(arr)
|
||||
|
||||
|
||||
def main(args=None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = ColorSegmentNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -0,0 +1,128 @@
|
||||
"""
|
||||
depth_hole_fill_node.py — D435i depth image hole filler (Issue #268).
|
||||
|
||||
Subscribes to the raw D435i depth stream, fills zero/NaN pixels using
|
||||
multi-pass spatial-Gaussian bilateral interpolation, and republishes the
|
||||
filled image at camera rate.
|
||||
|
||||
Subscribes (BEST_EFFORT):
|
||||
/camera/depth/image_rect_raw sensor_msgs/Image float32 depth (m)
|
||||
|
||||
Publishes:
|
||||
/camera/depth/filled sensor_msgs/Image float32 depth (m), holes filled
|
||||
|
||||
The filled image preserves all original valid pixels exactly and only
|
||||
modifies pixels that had no return (0 or NaN). The output is suitable
|
||||
for all downstream consumers that expect a dense depth map (VO, RTAB-Map,
|
||||
collision avoidance, floor classifier).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_topic str /camera/depth/image_rect_raw Input depth topic
|
||||
output_topic str /camera/depth/filled Output depth topic
|
||||
kernel_size int 5 Initial Gaussian kernel side length (pixels)
|
||||
d_min float 0.1 Minimum valid depth (m)
|
||||
d_max float 10.0 Maximum valid depth (m)
|
||||
max_passes int 3 Fill passes (growing kernel per pass)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||
|
||||
import numpy as np
|
||||
from cv_bridge import CvBridge
|
||||
|
||||
from sensor_msgs.msg import Image
|
||||
|
||||
from ._depth_hole_fill import fill_holes
|
||||
|
||||
|
||||
_SENSOR_QOS = QoSProfile(
|
||||
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||
history=HistoryPolicy.KEEP_LAST,
|
||||
depth=4,
|
||||
)
|
||||
|
||||
|
||||
class DepthHoleFillNode(Node):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__('depth_hole_fill_node')
|
||||
|
||||
self.declare_parameter('input_topic', '/camera/depth/image_rect_raw')
|
||||
self.declare_parameter('output_topic', '/camera/depth/filled')
|
||||
self.declare_parameter('kernel_size', 5)
|
||||
self.declare_parameter('d_min', 0.1)
|
||||
self.declare_parameter('d_max', 10.0)
|
||||
self.declare_parameter('max_passes', 3)
|
||||
|
||||
input_topic = self.get_parameter('input_topic').value
|
||||
output_topic = self.get_parameter('output_topic').value
|
||||
self._ks = int(self.get_parameter('kernel_size').value)
|
||||
self._d_min = self.get_parameter('d_min').value
|
||||
self._d_max = self.get_parameter('d_max').value
|
||||
self._passes = int(self.get_parameter('max_passes').value)
|
||||
|
||||
self._bridge = CvBridge()
|
||||
|
||||
self._sub = self.create_subscription(
|
||||
Image, input_topic, self._on_depth, _SENSOR_QOS)
|
||||
self._pub = self.create_publisher(Image, output_topic, 10)
|
||||
|
||||
self.get_logger().info(
|
||||
f'depth_hole_fill_node ready — '
|
||||
f'{input_topic} → {output_topic} '
|
||||
f'kernel={self._ks} passes={self._passes} '
|
||||
f'd=[{self._d_min},{self._d_max}]m'
|
||||
)
|
||||
|
||||
# ── Callback ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _on_depth(self, msg: Image) -> None:
|
||||
try:
|
||||
depth = self._bridge.imgmsg_to_cv2(msg, desired_encoding='passthrough')
|
||||
except Exception as exc:
|
||||
self.get_logger().error(
|
||||
f'cv_bridge: {exc}', throttle_duration_sec=5.0)
|
||||
return
|
||||
|
||||
depth = depth.astype(np.float32)
|
||||
|
||||
# Handle uint16 mm → float32 m conversion (D435i raw stream)
|
||||
if depth.max() > 100.0:
|
||||
depth /= 1000.0
|
||||
|
||||
filled = fill_holes(
|
||||
depth,
|
||||
kernel_size=self._ks,
|
||||
d_min=self._d_min,
|
||||
d_max=self._d_max,
|
||||
max_passes=self._passes,
|
||||
)
|
||||
|
||||
try:
|
||||
out_msg = self._bridge.cv2_to_imgmsg(filled, encoding='32FC1')
|
||||
except Exception as exc:
|
||||
self.get_logger().error(
|
||||
f'cv2_to_imgmsg: {exc}', throttle_duration_sec=5.0)
|
||||
return
|
||||
|
||||
out_msg.header = msg.header
|
||||
self._pub.publish(out_msg)
|
||||
|
||||
|
||||
def main(args=None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = DepthHoleFillNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -0,0 +1,150 @@
|
||||
"""
|
||||
vo_drift_node.py — Visual odometry drift detector (Issue #260).
|
||||
|
||||
Compares the cumulative path lengths of visual odometry and wheel odometry
|
||||
over a sliding window. When the absolute difference exceeds the configured
|
||||
threshold the node flags drift, allowing the system to warn operators,
|
||||
inflate VO covariance, or fall back to wheel-only localisation.
|
||||
|
||||
Subscribes (BEST_EFFORT):
|
||||
/camera/odom nav_msgs/Odometry visual odometry
|
||||
/odom nav_msgs/Odometry wheel odometry
|
||||
|
||||
→ For this robot remap to /saltybot/visual_odom + /saltybot/rover_odom.
|
||||
|
||||
Publishes:
|
||||
/saltybot/vo_drift_detected std_msgs/Bool True while drifting
|
||||
/saltybot/vo_drift_magnitude std_msgs/Float32 drift magnitude (metres)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vo_topic str /camera/odom Visual odometry source topic
|
||||
wheel_topic str /odom Wheel odometry source topic
|
||||
drift_threshold_m float 0.5 Drift flag threshold (metres)
|
||||
window_s float 10.0 Comparison window (seconds)
|
||||
publish_hz float 2.0 Output publication rate (Hz)
|
||||
max_buffer_age_s float 30.0 Max age of stored samples (s)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||
|
||||
from nav_msgs.msg import Odometry
|
||||
from std_msgs.msg import Bool, Float32
|
||||
|
||||
from ._vo_drift import OdomBuffer, OdomSample, compute_drift
|
||||
|
||||
|
||||
_SENSOR_QOS = QoSProfile(
|
||||
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||
history=HistoryPolicy.KEEP_LAST,
|
||||
depth=4,
|
||||
)
|
||||
|
||||
|
||||
class VoDriftNode(Node):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__('vo_drift_node')
|
||||
|
||||
self.declare_parameter('vo_topic', '/camera/odom')
|
||||
self.declare_parameter('wheel_topic', '/odom')
|
||||
self.declare_parameter('drift_threshold_m', 0.5)
|
||||
self.declare_parameter('window_s', 10.0)
|
||||
self.declare_parameter('publish_hz', 2.0)
|
||||
self.declare_parameter('max_buffer_age_s', 30.0)
|
||||
|
||||
vo_topic = self.get_parameter('vo_topic').value
|
||||
wheel_topic = self.get_parameter('wheel_topic').value
|
||||
self._thresh = self.get_parameter('drift_threshold_m').value
|
||||
self._window_s = self.get_parameter('window_s').value
|
||||
publish_hz = self.get_parameter('publish_hz').value
|
||||
max_age = self.get_parameter('max_buffer_age_s').value
|
||||
|
||||
self._vo_buf = OdomBuffer(max_age_s=max_age)
|
||||
self._wheel_buf = OdomBuffer(max_age_s=max_age)
|
||||
|
||||
self.create_subscription(
|
||||
Odometry, vo_topic, self._on_vo, _SENSOR_QOS)
|
||||
self.create_subscription(
|
||||
Odometry, wheel_topic, self._on_wheel, _SENSOR_QOS)
|
||||
|
||||
self._pub_detected = self.create_publisher(
|
||||
Bool, '/saltybot/vo_drift_detected', 10)
|
||||
self._pub_magnitude = self.create_publisher(
|
||||
Float32, '/saltybot/vo_drift_magnitude', 10)
|
||||
|
||||
self.create_timer(1.0 / publish_hz, self._tick)
|
||||
|
||||
self.get_logger().info(
|
||||
f'vo_drift_node ready — '
|
||||
f'vo={vo_topic} wheel={wheel_topic} '
|
||||
f'threshold={self._thresh}m window={self._window_s}s'
|
||||
)
|
||||
|
||||
# ── Callbacks ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _on_vo(self, msg: Odometry) -> None:
|
||||
s = _odom_to_sample(msg)
|
||||
self._vo_buf.push(s)
|
||||
|
||||
def _on_wheel(self, msg: Odometry) -> None:
|
||||
s = _odom_to_sample(msg)
|
||||
self._wheel_buf.push(s)
|
||||
|
||||
# ── Publish tick ──────────────────────────────────────────────────────────
|
||||
|
||||
def _tick(self) -> None:
|
||||
now = time.monotonic()
|
||||
result = compute_drift(
|
||||
self._vo_buf, self._wheel_buf,
|
||||
window_s=self._window_s,
|
||||
drift_threshold_m=self._thresh,
|
||||
now=now,
|
||||
)
|
||||
|
||||
if result.is_drifting:
|
||||
self.get_logger().warn(
|
||||
f'VO drift detected: {result.drift_m:.3f}m '
|
||||
f'(vo={result.vo_path_m:.3f}m wheel={result.wheel_path_m:.3f}m '
|
||||
f'over {result.window_s:.1f}s)',
|
||||
throttle_duration_sec=5.0,
|
||||
)
|
||||
|
||||
det_msg = Bool()
|
||||
det_msg.data = result.is_drifting
|
||||
self._pub_detected.publish(det_msg)
|
||||
|
||||
mag_msg = Float32()
|
||||
mag_msg.data = float(result.drift_m)
|
||||
self._pub_magnitude.publish(mag_msg)
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _odom_to_sample(msg: Odometry) -> OdomSample:
|
||||
"""Convert nav_msgs/Odometry to OdomSample using monotonic clock."""
|
||||
return OdomSample(
|
||||
t=time.monotonic(),
|
||||
x=msg.pose.pose.position.x,
|
||||
y=msg.pose.pose.position.y,
|
||||
)
|
||||
|
||||
|
||||
def main(args=None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = VoDriftNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -35,6 +35,12 @@ setup(
|
||||
'lidar_clustering = saltybot_bringup.lidar_clustering_node:main',
|
||||
# Floor surface type classifier (Issue #249)
|
||||
'floor_classifier = saltybot_bringup.floor_classifier_node:main',
|
||||
# Visual odometry drift detector (Issue #260)
|
||||
'vo_drift_detector = saltybot_bringup.vo_drift_node:main',
|
||||
# Depth image hole filler (Issue #268)
|
||||
'depth_hole_fill = saltybot_bringup.depth_hole_fill_node:main',
|
||||
# HSV color object segmenter (Issue #274)
|
||||
'color_segmenter = saltybot_bringup.color_segment_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
361
jetson/ros2_ws/src/saltybot_bringup/test/test_color_segmenter.py
Normal file
361
jetson/ros2_ws/src/saltybot_bringup/test/test_color_segmenter.py
Normal file
@ -0,0 +1,361 @@
|
||||
"""
|
||||
test_color_segmenter.py — Unit tests for HSV color segmentation helpers (no ROS2 required).
|
||||
|
||||
Covers:
|
||||
HsvRange / ColorBlob:
|
||||
- NamedTuple fields accessible by name
|
||||
- confidence clamped to [0,1]
|
||||
|
||||
mask_for_color:
|
||||
- pure red image → red mask fully white
|
||||
- pure red image → green mask fully black
|
||||
- pure green image → green mask fully white
|
||||
- pure blue image → blue mask fully white
|
||||
- pure yellow image → yellow mask non-empty
|
||||
- pure orange image → orange mask non-empty
|
||||
- red hue wrap-around detected from both HSV bands
|
||||
- unknown color name raises ValueError
|
||||
- mask is uint8
|
||||
- mask shape matches input
|
||||
|
||||
find_color_blobs — output contract:
|
||||
- returns list
|
||||
- empty list on blank (no-color) image
|
||||
- empty list when min_area_px larger than any contour
|
||||
|
||||
find_color_blobs — detection:
|
||||
- large red rectangle detected as red blob
|
||||
- large green rectangle detected as green blob
|
||||
- large blue rectangle detected as blue blob
|
||||
- detected blob color_name matches requested color
|
||||
- contour_id is 0 for first blob
|
||||
- confidence in [0, 1]
|
||||
- cx, cy within image bounds
|
||||
- w, h > 0 for detected blob
|
||||
- area_px > 0 for detected blob
|
||||
|
||||
find_color_blobs — filtering:
|
||||
- active_colors=None detects all colors when present
|
||||
- only requested colors returned when active_colors restricted
|
||||
- max_blobs_per_color limits output count
|
||||
- two separate red blobs both detected when max_blobs=2
|
||||
- smaller blob filtered when min_area_px high
|
||||
|
||||
find_color_blobs — multi-color:
|
||||
- image with red + green regions → both detected
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from saltybot_bringup._color_segmenter import (
|
||||
HsvRange,
|
||||
ColorBlob,
|
||||
COLOR_RANGES,
|
||||
mask_for_color,
|
||||
find_color_blobs,
|
||||
)
|
||||
|
||||
|
||||
# ── Image factories ───────────────────────────────────────────────────────────
|
||||
|
||||
def _solid_bgr(b, g, r, h=64, w=64) -> np.ndarray:
|
||||
"""Solid BGR image."""
|
||||
img = np.zeros((h, w, 3), dtype=np.uint8)
|
||||
img[:, :] = (b, g, r)
|
||||
return img
|
||||
|
||||
|
||||
def _blank(h=64, w=64) -> np.ndarray:
|
||||
"""All-black image (nothing to detect)."""
|
||||
return np.zeros((h, w, 3), dtype=np.uint8)
|
||||
|
||||
|
||||
def _image_with_rect(bg_bgr, rect_bgr, rect_slice_r, rect_slice_c, h=128, w=128) -> np.ndarray:
|
||||
"""Background colour with a filled rectangle."""
|
||||
img = np.zeros((h, w, 3), dtype=np.uint8)
|
||||
img[:, :] = bg_bgr
|
||||
img[rect_slice_r, rect_slice_c] = rect_bgr
|
||||
return img
|
||||
|
||||
|
||||
# Canonical solid color BGR values (saturated, in-range for HSV thresholds)
|
||||
_RED_BGR = (0, 0, 200) # BGR pure red
|
||||
_GREEN_BGR = (0, 200, 0 ) # BGR pure green
|
||||
_BLUE_BGR = (200, 0, 0 ) # BGR pure blue
|
||||
_YELLOW_BGR = (0, 220, 220) # BGR yellow
|
||||
_ORANGE_BGR = (0, 140, 220) # BGR orange
|
||||
|
||||
|
||||
# ── HsvRange / ColorBlob types ────────────────────────────────────────────────
|
||||
|
||||
class TestTypes:
|
||||
|
||||
def test_hsv_range_fields(self):
|
||||
r = HsvRange(0, 10, 60, 255, 50, 255)
|
||||
assert r.h_lo == 0 and r.h_hi == 10
|
||||
assert r.s_lo == 60 and r.s_hi == 255
|
||||
assert r.v_lo == 50 and r.v_hi == 255
|
||||
|
||||
def test_color_blob_fields(self):
|
||||
b = ColorBlob('red', 0.8, 32.0, 32.0, 20.0, 20.0, 300.0, 0)
|
||||
assert b.color_name == 'red'
|
||||
assert b.confidence == pytest.approx(0.8)
|
||||
assert b.contour_id == 0
|
||||
|
||||
def test_color_ranges_contains_all_defaults(self):
|
||||
for color in ('red', 'green', 'blue', 'yellow', 'orange'):
|
||||
assert color in COLOR_RANGES
|
||||
assert len(COLOR_RANGES[color]) >= 1
|
||||
|
||||
|
||||
# ── mask_for_color ────────────────────────────────────────────────────────────
|
||||
|
||||
class TestMaskForColor:
|
||||
|
||||
def test_mask_is_uint8(self):
|
||||
import cv2
|
||||
hsv = cv2.cvtColor(_solid_bgr(*_RED_BGR), cv2.COLOR_BGR2HSV)
|
||||
m = mask_for_color(hsv, 'red')
|
||||
assert m.dtype == np.uint8
|
||||
|
||||
def test_mask_shape_matches_input(self):
|
||||
import cv2
|
||||
bgr = _solid_bgr(*_RED_BGR, h=48, w=80)
|
||||
hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
|
||||
m = mask_for_color(hsv, 'red')
|
||||
assert m.shape == (48, 80)
|
||||
|
||||
def test_pure_red_gives_red_mask_nonzero(self):
|
||||
import cv2
|
||||
hsv = cv2.cvtColor(_solid_bgr(*_RED_BGR), cv2.COLOR_BGR2HSV)
|
||||
m = mask_for_color(hsv, 'red')
|
||||
assert m.any(), 'red mask should be non-empty for red image'
|
||||
|
||||
def test_pure_red_gives_green_mask_empty(self):
|
||||
import cv2
|
||||
hsv = cv2.cvtColor(_solid_bgr(*_RED_BGR), cv2.COLOR_BGR2HSV)
|
||||
m = mask_for_color(hsv, 'green')
|
||||
assert not m.any(), 'green mask should be empty for red image'
|
||||
|
||||
def test_pure_green_gives_green_mask_nonzero(self):
|
||||
import cv2
|
||||
hsv = cv2.cvtColor(_solid_bgr(*_GREEN_BGR), cv2.COLOR_BGR2HSV)
|
||||
m = mask_for_color(hsv, 'green')
|
||||
assert m.any()
|
||||
|
||||
def test_pure_blue_gives_blue_mask_nonzero(self):
|
||||
import cv2
|
||||
hsv = cv2.cvtColor(_solid_bgr(*_BLUE_BGR), cv2.COLOR_BGR2HSV)
|
||||
m = mask_for_color(hsv, 'blue')
|
||||
assert m.any()
|
||||
|
||||
def test_pure_yellow_gives_yellow_mask_nonzero(self):
|
||||
import cv2
|
||||
hsv = cv2.cvtColor(_solid_bgr(*_YELLOW_BGR), cv2.COLOR_BGR2HSV)
|
||||
m = mask_for_color(hsv, 'yellow')
|
||||
assert m.any()
|
||||
|
||||
def test_pure_orange_gives_orange_mask_nonzero(self):
|
||||
import cv2
|
||||
hsv = cv2.cvtColor(_solid_bgr(*_ORANGE_BGR), cv2.COLOR_BGR2HSV)
|
||||
m = mask_for_color(hsv, 'orange')
|
||||
assert m.any()
|
||||
|
||||
def test_unknown_color_raises(self):
|
||||
import cv2
|
||||
hsv = cv2.cvtColor(_blank(), cv2.COLOR_BGR2HSV)
|
||||
with pytest.raises(ValueError, match='Unknown color'):
|
||||
mask_for_color(hsv, 'purple')
|
||||
|
||||
def test_red_detected_in_high_hue_band(self):
|
||||
"""A near-180-hue red pixel should still trigger the red mask."""
|
||||
import cv2
|
||||
# HSV (175, 200, 200) = high-hue red (wrap-around band)
|
||||
hsv = np.full((32, 32, 3), (175, 200, 200), dtype=np.uint8)
|
||||
m = mask_for_color(hsv, 'red')
|
||||
assert m.any(), 'high-hue red not detected'
|
||||
|
||||
|
||||
# ── find_color_blobs — output contract ───────────────────────────────────────
|
||||
|
||||
class TestFindColorBlobsContract:
|
||||
|
||||
def test_returns_list(self):
|
||||
result = find_color_blobs(_blank())
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_blank_image_returns_empty(self):
|
||||
result = find_color_blobs(_blank())
|
||||
assert result == []
|
||||
|
||||
def test_min_area_filter_removes_all(self):
|
||||
"""Request a min area larger than the entire image → no blobs."""
|
||||
bgr = _solid_bgr(*_RED_BGR, h=32, w=32)
|
||||
result = find_color_blobs(bgr, active_colors=['red'], min_area_px=1e9)
|
||||
assert result == []
|
||||
|
||||
|
||||
# ── find_color_blobs — detection ─────────────────────────────────────────────
|
||||
|
||||
class TestFindColorBlobsDetection:
|
||||
|
||||
def _large_rect(self, color_bgr, color_name) -> np.ndarray:
|
||||
"""100×100 image with a 60×60 solid-color rectangle centred."""
|
||||
img = _blank(h=100, w=100)
|
||||
img[20:80, 20:80] = color_bgr
|
||||
return img
|
||||
|
||||
def test_red_rect_detected(self):
|
||||
blobs = find_color_blobs(self._large_rect(_RED_BGR, 'red'), active_colors=['red'])
|
||||
assert len(blobs) >= 1
|
||||
assert blobs[0].color_name == 'red'
|
||||
|
||||
def test_green_rect_detected(self):
|
||||
blobs = find_color_blobs(self._large_rect(_GREEN_BGR, 'green'), active_colors=['green'])
|
||||
assert len(blobs) >= 1
|
||||
assert blobs[0].color_name == 'green'
|
||||
|
||||
def test_blue_rect_detected(self):
|
||||
blobs = find_color_blobs(self._large_rect(_BLUE_BGR, 'blue'), active_colors=['blue'])
|
||||
assert len(blobs) >= 1
|
||||
assert blobs[0].color_name == 'blue'
|
||||
|
||||
def test_first_contour_id_is_zero(self):
|
||||
img = _blank(h=100, w=100)
|
||||
img[20:80, 20:80] = _RED_BGR
|
||||
blobs = find_color_blobs(img, active_colors=['red'])
|
||||
assert blobs[0].contour_id == 0
|
||||
|
||||
def test_confidence_in_range(self):
|
||||
img = _blank(h=100, w=100)
|
||||
img[20:80, 20:80] = _GREEN_BGR
|
||||
blobs = find_color_blobs(img, active_colors=['green'])
|
||||
assert blobs
|
||||
assert 0.0 <= blobs[0].confidence <= 1.0
|
||||
|
||||
def test_cx_within_image(self):
|
||||
img = _blank(h=100, w=100)
|
||||
img[20:80, 20:80] = _BLUE_BGR
|
||||
blobs = find_color_blobs(img, active_colors=['blue'])
|
||||
assert blobs
|
||||
assert 0.0 <= blobs[0].cx <= 100.0
|
||||
|
||||
def test_cy_within_image(self):
|
||||
img = _blank(h=100, w=100)
|
||||
img[20:80, 20:80] = _BLUE_BGR
|
||||
blobs = find_color_blobs(img, active_colors=['blue'])
|
||||
assert blobs
|
||||
assert 0.0 <= blobs[0].cy <= 100.0
|
||||
|
||||
def test_w_positive(self):
|
||||
img = _blank(h=100, w=100)
|
||||
img[20:80, 20:80] = _RED_BGR
|
||||
blobs = find_color_blobs(img, active_colors=['red'])
|
||||
assert blobs[0].w > 0
|
||||
|
||||
def test_h_positive(self):
|
||||
img = _blank(h=100, w=100)
|
||||
img[20:80, 20:80] = _RED_BGR
|
||||
blobs = find_color_blobs(img, active_colors=['red'])
|
||||
assert blobs[0].h > 0
|
||||
|
||||
def test_area_px_positive(self):
|
||||
img = _blank(h=100, w=100)
|
||||
img[20:80, 20:80] = _RED_BGR
|
||||
blobs = find_color_blobs(img, active_colors=['red'])
|
||||
assert blobs[0].area_px > 0
|
||||
|
||||
def test_area_px_reasonable(self):
|
||||
"""area_px should be roughly within the rectangle we drew."""
|
||||
img = _blank(h=100, w=100)
|
||||
img[20:80, 20:80] = _GREEN_BGR # 60×60 = 3600 px
|
||||
blobs = find_color_blobs(img, active_colors=['green'], min_area_px=100.0)
|
||||
assert blobs
|
||||
assert 1000 <= blobs[0].area_px <= 4000
|
||||
|
||||
|
||||
# ── find_color_blobs — filtering ─────────────────────────────────────────────
|
||||
|
||||
class TestFindColorBlobsFiltering:
|
||||
|
||||
def test_active_colors_none_detects_all(self):
|
||||
"""Image with red+green patches → both found when active_colors=None."""
|
||||
img = _blank(h=128, w=128)
|
||||
img[10:50, 10:50] = _RED_BGR
|
||||
img[10:50, 70:110] = _GREEN_BGR
|
||||
blobs = find_color_blobs(img, active_colors=None, min_area_px=100.0)
|
||||
names = {b.color_name for b in blobs}
|
||||
assert 'red' in names
|
||||
assert 'green' in names
|
||||
|
||||
def test_restricted_active_colors(self):
|
||||
"""Only red requested → no green blobs returned."""
|
||||
img = _blank(h=128, w=128)
|
||||
img[10:50, 10:50] = _RED_BGR
|
||||
img[10:50, 70:110] = _GREEN_BGR
|
||||
blobs = find_color_blobs(img, active_colors=['red'], min_area_px=100.0)
|
||||
assert all(b.color_name == 'red' for b in blobs)
|
||||
|
||||
def test_max_blobs_per_color_limits(self):
|
||||
"""Four separate red rectangles but max_blobs=2 → at most 2 blobs."""
|
||||
img = _blank(h=200, w=200)
|
||||
img[10:40, 10:40] = _RED_BGR
|
||||
img[10:40, 80:110] = _RED_BGR
|
||||
img[100:130, 10:40] = _RED_BGR
|
||||
img[100:130, 80:110] = _RED_BGR
|
||||
blobs = find_color_blobs(img, active_colors=['red'],
|
||||
min_area_px=100.0, max_blobs_per_color=2)
|
||||
red_blobs = [b for b in blobs if b.color_name == 'red']
|
||||
assert len(red_blobs) <= 2
|
||||
|
||||
def test_two_blobs_detected_when_max_allows(self):
|
||||
"""Two red rectangles detected when max_blobs_per_color >= 2."""
|
||||
img = _blank(h=200, w=200)
|
||||
img[10:60, 10:60] = _RED_BGR
|
||||
img[10:60, 130:180] = _RED_BGR
|
||||
blobs = find_color_blobs(img, active_colors=['red'],
|
||||
min_area_px=100.0, max_blobs_per_color=10)
|
||||
red_blobs = [b for b in blobs if b.color_name == 'red']
|
||||
assert len(red_blobs) >= 2
|
||||
|
||||
def test_small_blob_filtered_by_min_area(self):
|
||||
"""Small 5×5 red patch filtered by min_area_px=500."""
|
||||
img = _blank(h=64, w=64)
|
||||
img[28:33, 28:33] = _RED_BGR # 5×5 = 25 px contour area
|
||||
blobs = find_color_blobs(img, active_colors=['red'], min_area_px=500.0)
|
||||
assert blobs == []
|
||||
|
||||
|
||||
# ── find_color_blobs — multi-color ───────────────────────────────────────────
|
||||
|
||||
class TestFindColorBlobsMultiColor:
|
||||
|
||||
def test_red_and_green_in_same_image(self):
|
||||
img = _blank(h=128, w=128)
|
||||
img[10:60, 10:60] = _RED_BGR
|
||||
img[10:60, 68:118] = _GREEN_BGR
|
||||
blobs = find_color_blobs(img, active_colors=['red', 'green'], min_area_px=100.0)
|
||||
names = {b.color_name for b in blobs}
|
||||
assert 'red' in names, 'red blob should be detected'
|
||||
assert 'green' in names, 'green blob should be detected'
|
||||
|
||||
def test_contour_ids_per_color_start_at_zero(self):
|
||||
"""contour_id should be 0 for the first (largest) blob of each color."""
|
||||
img = _blank(h=200, w=200)
|
||||
img[10:80, 10:80] = _RED_BGR
|
||||
img[10:80, 110:180] = _BLUE_BGR
|
||||
blobs = find_color_blobs(img, active_colors=['red', 'blue'], min_area_px=100.0)
|
||||
for color in ('red', 'blue'):
|
||||
first = next((b for b in blobs if b.color_name == color), None)
|
||||
assert first is not None, f'{color} blob not found'
|
||||
assert first.contour_id == 0, f'{color} first blob contour_id != 0'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
281
jetson/ros2_ws/src/saltybot_bringup/test/test_depth_hole_fill.py
Normal file
281
jetson/ros2_ws/src/saltybot_bringup/test/test_depth_hole_fill.py
Normal file
@ -0,0 +1,281 @@
|
||||
"""
|
||||
test_depth_hole_fill.py — Unit tests for depth hole fill helpers (no ROS2 required).
|
||||
|
||||
Covers:
|
||||
valid_mask:
|
||||
- valid range returns True
|
||||
- zero / below d_min returns False
|
||||
- NaN returns False
|
||||
- above d_max returns False
|
||||
- mixed array has correct mask
|
||||
|
||||
_odd_kernel_size:
|
||||
- result is always odd
|
||||
- result >= 3
|
||||
- scales correctly
|
||||
|
||||
fill_holes — no-hole cases:
|
||||
- fully valid image is returned unchanged
|
||||
- output dtype is float32
|
||||
- output shape matches input
|
||||
|
||||
fill_holes — basic fills:
|
||||
- single centre hole in uniform depth → filled with correct depth
|
||||
- single centre hole in uniform depth → original valid pixels unchanged
|
||||
- NaN pixel treated as hole and filled
|
||||
- row of zeros within uniform depth → filled
|
||||
|
||||
fill_holes — fill quality:
|
||||
- linear gradient: centre hole filled with interpolated value
|
||||
- multi-pass fills larger holes than single pass
|
||||
- all-zero image stays zero (no valid neighbours)
|
||||
- border hole (edge pixel) is handled without crash
|
||||
- depth range: pixel above d_max treated as hole
|
||||
|
||||
fill_holes — valid pixel preservation:
|
||||
- original valid pixels are never modified
|
||||
- max_passes=1 still fills small holes
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from saltybot_bringup._depth_hole_fill import (
|
||||
fill_holes,
|
||||
valid_mask,
|
||||
_odd_kernel_size,
|
||||
)
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _uniform(val=2.0, h=64, w=64) -> np.ndarray:
|
||||
return np.full((h, w), val, dtype=np.float32)
|
||||
|
||||
|
||||
def _poke_hole(arr, r, c) -> np.ndarray:
|
||||
arr = arr.copy()
|
||||
arr[r, c] = 0.0
|
||||
return arr
|
||||
|
||||
|
||||
def _poke_nan(arr, r, c) -> np.ndarray:
|
||||
arr = arr.copy()
|
||||
arr[r, c] = float('nan')
|
||||
return arr
|
||||
|
||||
|
||||
# ── valid_mask ────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestValidMask:
|
||||
|
||||
def test_valid_pixel_is_true(self):
|
||||
d = np.array([[1.0]], dtype=np.float32)
|
||||
assert valid_mask(d, 0.1, 10.0)[0, 0]
|
||||
|
||||
def test_zero_is_false(self):
|
||||
d = np.array([[0.0]], dtype=np.float32)
|
||||
assert not valid_mask(d, 0.1, 10.0)[0, 0]
|
||||
|
||||
def test_below_dmin_is_false(self):
|
||||
d = np.array([[0.05]], dtype=np.float32)
|
||||
assert not valid_mask(d, 0.1, 10.0)[0, 0]
|
||||
|
||||
def test_nan_is_false(self):
|
||||
d = np.array([[float('nan')]], dtype=np.float32)
|
||||
assert not valid_mask(d, 0.1, 10.0)[0, 0]
|
||||
|
||||
def test_above_dmax_is_false(self):
|
||||
d = np.array([[15.0]], dtype=np.float32)
|
||||
assert not valid_mask(d, 0.1, 10.0)[0, 0]
|
||||
|
||||
def test_at_dmin_is_true(self):
|
||||
d = np.array([[0.1]], dtype=np.float32)
|
||||
assert valid_mask(d, 0.1, 10.0)[0, 0]
|
||||
|
||||
def test_at_dmax_is_true(self):
|
||||
d = np.array([[10.0]], dtype=np.float32)
|
||||
assert valid_mask(d, 0.1, 10.0)[0, 0]
|
||||
|
||||
def test_mixed_array(self):
|
||||
d = np.array([[0.0, 1.0, float('nan'), 5.0, 11.0]], dtype=np.float32)
|
||||
m = valid_mask(d, 0.1, 10.0)
|
||||
np.testing.assert_array_equal(m, [[False, True, False, True, False]])
|
||||
|
||||
|
||||
# ── _odd_kernel_size ──────────────────────────────────────────────────────────
|
||||
|
||||
class TestOddKernelSize:
|
||||
|
||||
@pytest.mark.parametrize('base,scale', [
|
||||
(5, 1.0), (5, 2.5), (5, 6.0),
|
||||
(3, 1.0), (7, 2.0), (9, 3.0),
|
||||
(4, 1.0), # even base → must become odd
|
||||
])
|
||||
def test_result_is_odd(self, base, scale):
|
||||
ks = _odd_kernel_size(base, scale)
|
||||
assert ks % 2 == 1
|
||||
|
||||
@pytest.mark.parametrize('base,scale', [(3, 1.0), (1, 5.0), (2, 0.5)])
|
||||
def test_result_at_least_3(self, base, scale):
|
||||
assert _odd_kernel_size(base, scale) >= 3
|
||||
|
||||
def test_scale_1_returns_base_or_nearby_odd(self):
|
||||
ks = _odd_kernel_size(5, 1.0)
|
||||
assert ks == 5
|
||||
|
||||
def test_large_scale_gives_large_kernel(self):
|
||||
ks = _odd_kernel_size(5, 6.0)
|
||||
assert ks >= 25 # 5 * 6 = 30 → 31
|
||||
|
||||
|
||||
# ── fill_holes — output contract ──────────────────────────────────────────────
|
||||
|
||||
class TestFillHolesOutputContract:
|
||||
|
||||
def test_output_dtype_float32(self):
|
||||
out = fill_holes(_uniform(2.0))
|
||||
assert out.dtype == np.float32
|
||||
|
||||
def test_output_shape_preserved(self):
|
||||
img = _uniform(2.0, h=48, w=64)
|
||||
out = fill_holes(img)
|
||||
assert out.shape == img.shape
|
||||
|
||||
def test_fully_valid_image_unchanged(self):
|
||||
img = _uniform(2.0)
|
||||
out = fill_holes(img)
|
||||
np.testing.assert_allclose(out, img, atol=1e-6)
|
||||
|
||||
def test_valid_pixels_never_modified(self):
|
||||
"""Any pixel valid in the input must be identical in the output."""
|
||||
img = _uniform(3.0, h=32, w=32)
|
||||
img[16, 16] = 0.0 # one hole
|
||||
mask_before = valid_mask(img)
|
||||
out = fill_holes(img)
|
||||
np.testing.assert_allclose(out[mask_before], img[mask_before], atol=1e-6)
|
||||
|
||||
|
||||
# ── fill_holes — basic hole filling ──────────────────────────────────────────
|
||||
|
||||
class TestFillHolesBasic:
|
||||
|
||||
def test_centre_zero_filled_uniform(self):
|
||||
"""Single zero pixel in uniform depth → filled with that depth."""
|
||||
img = _poke_hole(_uniform(2.0, 32, 32), 16, 16)
|
||||
out = fill_holes(img, kernel_size=5, max_passes=1)
|
||||
assert out[16, 16] == pytest.approx(2.0, abs=0.05)
|
||||
|
||||
def test_centre_nan_filled_uniform(self):
|
||||
"""Single NaN pixel in uniform depth → filled."""
|
||||
img = _poke_nan(_uniform(2.0, 32, 32), 16, 16)
|
||||
out = fill_holes(img, kernel_size=5, max_passes=1)
|
||||
assert out[16, 16] == pytest.approx(2.0, abs=0.05)
|
||||
|
||||
def test_filled_value_is_positive(self):
|
||||
img = _poke_hole(_uniform(1.5, 32, 32), 16, 16)
|
||||
out = fill_holes(img)
|
||||
assert out[16, 16] > 0.0
|
||||
|
||||
def test_row_of_holes_filled(self):
|
||||
"""Entire middle row zeroed → should be filled from neighbours above/below."""
|
||||
img = _uniform(3.0, 32, 32)
|
||||
img[16, :] = 0.0
|
||||
out = fill_holes(img, kernel_size=7, max_passes=1)
|
||||
# All pixels in the row should be non-zero after filling
|
||||
assert (out[16, :] > 0.0).all()
|
||||
|
||||
def test_all_zero_stays_zero(self):
|
||||
"""Image with no valid pixels → stays zero (nothing to interpolate from)."""
|
||||
img = np.zeros((32, 32), dtype=np.float32)
|
||||
out = fill_holes(img, d_min=0.1)
|
||||
assert (out == 0.0).all()
|
||||
|
||||
def test_border_hole_no_crash(self):
|
||||
"""Holes at image corners must not raise exceptions."""
|
||||
img = _uniform(2.0, 32, 32)
|
||||
img[0, 0] = 0.0
|
||||
img[0, -1] = 0.0
|
||||
img[-1, 0] = 0.0
|
||||
img[-1, -1] = 0.0
|
||||
out = fill_holes(img) # must not raise
|
||||
assert out.shape == img.shape
|
||||
|
||||
def test_border_holes_filled(self):
|
||||
"""Corner holes should be filled from their neighbours."""
|
||||
img = _uniform(2.0, 32, 32)
|
||||
img[0, 0] = 0.0
|
||||
out = fill_holes(img, kernel_size=5, max_passes=1)
|
||||
assert out[0, 0] == pytest.approx(2.0, abs=0.1)
|
||||
|
||||
|
||||
# ── fill_holes — fill quality ─────────────────────────────────────────────────
|
||||
|
||||
class TestFillHolesQuality:
|
||||
|
||||
def test_linear_gradient_centre_hole_interpolated(self):
|
||||
"""
|
||||
Depth linearly increasing from 1.0 (left) to 3.0 (right).
|
||||
Centre hole should be filled near the midpoint (~2.0).
|
||||
"""
|
||||
h, w = 32, 32
|
||||
img = np.tile(np.linspace(1.0, 3.0, w, dtype=np.float32), (h, 1))
|
||||
cx = w // 2
|
||||
img[:, cx] = 0.0
|
||||
out = fill_holes(img, kernel_size=5, max_passes=1)
|
||||
mid = out[h // 2, cx]
|
||||
assert 1.5 <= mid <= 2.5, f'interpolated value {mid:.3f} not in [1.5, 2.5]'
|
||||
|
||||
def test_large_hole_filled_with_more_passes(self):
|
||||
"""A 9×9 hole in uniform depth: single pass may not fully fill it,
|
||||
but 3 passes should."""
|
||||
img = _uniform(2.0, 64, 64)
|
||||
# Create a 9×9 hole
|
||||
img[28:37, 28:37] = 0.0
|
||||
out1 = fill_holes(img, kernel_size=5, max_passes=1)
|
||||
out3 = fill_holes(img, kernel_size=5, max_passes=3)
|
||||
# More passes → fewer remaining holes
|
||||
holes1 = (out1 == 0.0).sum()
|
||||
holes3 = (out3 == 0.0).sum()
|
||||
assert holes3 <= holes1, f'more passes should reduce holes: {holes3} vs {holes1}'
|
||||
|
||||
def test_3pass_fills_9x9_hole_completely(self):
|
||||
img = _uniform(2.0, 64, 64)
|
||||
img[28:37, 28:37] = 0.0
|
||||
out = fill_holes(img, kernel_size=5, max_passes=3)
|
||||
assert (out[28:37, 28:37] > 0.0).all()
|
||||
|
||||
def test_filled_depth_within_valid_range(self):
|
||||
"""Filled pixels should have depth within [d_min, d_max]."""
|
||||
img = _uniform(2.0, 32, 32)
|
||||
img[10:15, 10:15] = 0.0
|
||||
out = fill_holes(img, d_min=0.1, d_max=10.0, max_passes=3)
|
||||
# Only check pixels that were actually filled
|
||||
was_hole = (img == 0.0)
|
||||
filled = out[was_hole]
|
||||
positive = filled[filled > 0.0]
|
||||
assert (positive >= 0.1).all()
|
||||
assert (positive <= 10.0).all()
|
||||
|
||||
def test_above_dmax_treated_as_hole(self):
|
||||
"""Pixels above d_max should be treated as holes and filled."""
|
||||
img = _uniform(2.0, 32, 32)
|
||||
img[16, 16] = 15.0 # out of range
|
||||
out = fill_holes(img, d_max=10.0, max_passes=1)
|
||||
assert out[16, 16] == pytest.approx(2.0, abs=0.1)
|
||||
|
||||
def test_max_passes_1_works(self):
|
||||
img = _poke_hole(_uniform(2.0, 32, 32), 16, 16)
|
||||
out = fill_holes(img, max_passes=1)
|
||||
assert out.shape == img.shape
|
||||
assert out[16, 16] > 0.0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
297
jetson/ros2_ws/src/saltybot_bringup/test/test_vo_drift.py
Normal file
297
jetson/ros2_ws/src/saltybot_bringup/test/test_vo_drift.py
Normal file
@ -0,0 +1,297 @@
|
||||
"""
|
||||
test_vo_drift.py — Unit tests for VO drift detector helpers (no ROS2 required).
|
||||
|
||||
Covers:
|
||||
OdomBuffer:
|
||||
- push/len
|
||||
- window returns only samples within cutoff
|
||||
- old samples are evicted beyond max_age_s
|
||||
- clear empties the buffer
|
||||
- window on empty buffer returns empty list
|
||||
|
||||
_path_length (via compute_drift with crafted samples):
|
||||
- stationary source → path = 0
|
||||
- straight-line motion → path = total distance
|
||||
- L-shaped path → path = sum of two legs
|
||||
|
||||
compute_drift:
|
||||
- both empty → DriftResult with zeros, is_drifting=False
|
||||
- one buffer < 2 samples → zero drift
|
||||
- both move same distance → drift ≈ 0, not drifting
|
||||
- VO moves 1m, wheel moves 0.5m → drift = 0.5m
|
||||
- drift == threshold → is_drifting=True (>=)
|
||||
- drift < threshold → is_drifting=False
|
||||
- drift > threshold → is_drifting=True
|
||||
- path lengths in result match expectation
|
||||
- n_vo / n_wheel counts correct
|
||||
- samples outside window ignored
|
||||
- window_s in result reflects actual data span
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from saltybot_bringup._vo_drift import (
|
||||
OdomSample,
|
||||
OdomBuffer,
|
||||
DriftResult,
|
||||
compute_drift,
|
||||
_path_length,
|
||||
)
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _s(t, x, y) -> OdomSample:
|
||||
return OdomSample(t=t, x=x, y=y)
|
||||
|
||||
|
||||
def _straight_buf(n=5, speed=0.1, t_start=0.0, dt=1.0,
|
||||
max_age_s=30.0) -> OdomBuffer:
|
||||
"""n samples moving along +x at `speed` m/s."""
|
||||
buf = OdomBuffer(max_age_s=max_age_s)
|
||||
for i in range(n):
|
||||
buf.push(_s(t_start + i * dt, x=i * speed * dt, y=0.0))
|
||||
return buf
|
||||
|
||||
|
||||
def _stationary_buf(n=5, t_start=0.0, dt=1.0,
|
||||
max_age_s=30.0) -> OdomBuffer:
|
||||
buf = OdomBuffer(max_age_s=max_age_s)
|
||||
for i in range(n):
|
||||
buf.push(_s(t_start + i * dt, x=0.0, y=0.0))
|
||||
return buf
|
||||
|
||||
|
||||
# ── OdomBuffer ────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestOdomBuffer:
|
||||
|
||||
def test_push_increases_len(self):
|
||||
buf = OdomBuffer()
|
||||
assert len(buf) == 0
|
||||
buf.push(_s(0.0, 0.0, 0.0))
|
||||
assert len(buf) == 1
|
||||
|
||||
def test_window_returns_all_within_cutoff(self):
|
||||
buf = OdomBuffer(max_age_s=30.0)
|
||||
for t in [0.0, 5.0, 10.0]:
|
||||
buf.push(_s(t, 0.0, 0.0))
|
||||
samples = buf.window(window_s=10.0, now=10.0)
|
||||
assert len(samples) == 3
|
||||
|
||||
def test_window_excludes_old_samples(self):
|
||||
buf = OdomBuffer(max_age_s=30.0)
|
||||
for t in [0.0, 5.0, 15.0]:
|
||||
buf.push(_s(t, 0.0, 0.0))
|
||||
# window=5s from now=15 → only t=15 qualifies (t>=10)
|
||||
samples = buf.window(window_s=5.0, now=15.0)
|
||||
assert len(samples) == 1
|
||||
assert samples[0].t == 15.0
|
||||
|
||||
def test_evicts_samples_beyond_max_age(self):
|
||||
buf = OdomBuffer(max_age_s=5.0)
|
||||
buf.push(_s(0.0, 0.0, 0.0))
|
||||
buf.push(_s(10.0, 1.0, 0.0)) # now=10 → t=0 is 10s old > 5s max
|
||||
assert len(buf) == 1
|
||||
|
||||
def test_clear_empties_buffer(self):
|
||||
buf = _straight_buf(n=5)
|
||||
buf.clear()
|
||||
assert len(buf) == 0
|
||||
|
||||
def test_window_on_empty_buffer(self):
|
||||
buf = OdomBuffer()
|
||||
assert buf.window(window_s=10.0, now=100.0) == []
|
||||
|
||||
def test_window_boundary_inclusive(self):
|
||||
"""Sample exactly at window cutoff (t == now - window_s) is included."""
|
||||
buf = OdomBuffer(max_age_s=30.0)
|
||||
buf.push(_s(0.0, 0.0, 0.0))
|
||||
# window=10, now=10 → cutoff=0.0, sample at t=0.0 should be included
|
||||
samples = buf.window(window_s=10.0, now=10.0)
|
||||
assert len(samples) == 1
|
||||
|
||||
|
||||
# ── _path_length ──────────────────────────────────────────────────────────────
|
||||
|
||||
class TestPathLength:
|
||||
|
||||
def test_stationary_path_zero(self):
|
||||
samples = [_s(i, 0.0, 0.0) for i in range(5)]
|
||||
assert _path_length(samples) == pytest.approx(0.0)
|
||||
|
||||
def test_unit_step_path(self):
|
||||
samples = [_s(0, 0.0, 0.0), _s(1, 1.0, 0.0)]
|
||||
assert _path_length(samples) == pytest.approx(1.0)
|
||||
|
||||
def test_two_unit_steps(self):
|
||||
samples = [_s(0, 0.0, 0.0), _s(1, 1.0, 0.0), _s(2, 2.0, 0.0)]
|
||||
assert _path_length(samples) == pytest.approx(2.0)
|
||||
|
||||
def test_diagonal_step(self):
|
||||
# (0,0) → (1,1): distance = sqrt(2)
|
||||
samples = [_s(0, 0.0, 0.0), _s(1, 1.0, 1.0)]
|
||||
assert _path_length(samples) == pytest.approx(math.sqrt(2))
|
||||
|
||||
def test_l_shaped_path(self):
|
||||
# Right 3m then up 4m → total path = 7m (not hypotenuse)
|
||||
samples = [_s(0, 0.0, 0.0), _s(1, 3.0, 0.0), _s(2, 3.0, 4.0)]
|
||||
assert _path_length(samples) == pytest.approx(7.0)
|
||||
|
||||
def test_single_sample_returns_zero(self):
|
||||
assert _path_length([_s(0, 5.0, 5.0)]) == pytest.approx(0.0)
|
||||
|
||||
def test_empty_returns_zero(self):
|
||||
assert _path_length([]) == pytest.approx(0.0)
|
||||
|
||||
|
||||
# ── compute_drift ─────────────────────────────────────────────────────────────
|
||||
|
||||
class TestComputeDrift:
|
||||
|
||||
def test_both_empty_returns_zero_drift(self):
|
||||
result = compute_drift(
|
||||
OdomBuffer(), OdomBuffer(),
|
||||
window_s=10.0, drift_threshold_m=0.5, now=10.0)
|
||||
assert result.drift_m == pytest.approx(0.0)
|
||||
assert not result.is_drifting
|
||||
|
||||
def test_one_buffer_empty_returns_zero(self):
|
||||
vo = _straight_buf(n=5, speed=0.1)
|
||||
result = compute_drift(
|
||||
vo, OdomBuffer(),
|
||||
window_s=10.0, drift_threshold_m=0.5, now=5.0)
|
||||
assert result.drift_m == pytest.approx(0.0)
|
||||
assert not result.is_drifting
|
||||
|
||||
def test_one_buffer_single_sample_returns_zero(self):
|
||||
vo = _straight_buf(n=5, speed=0.1)
|
||||
wheel = OdomBuffer()
|
||||
wheel.push(_s(0.0, 0.0, 0.0)) # only 1 sample
|
||||
result = compute_drift(
|
||||
vo, wheel,
|
||||
window_s=10.0, drift_threshold_m=0.5, now=5.0)
|
||||
assert result.drift_m == pytest.approx(0.0)
|
||||
assert not result.is_drifting
|
||||
|
||||
def test_both_move_same_distance_zero_drift(self):
|
||||
# Both move 0.1 m/s for 4 steps → 0.4 m each
|
||||
vo = _straight_buf(n=5, speed=0.1, dt=1.0)
|
||||
wheel = _straight_buf(n=5, speed=0.1, dt=1.0)
|
||||
result = compute_drift(
|
||||
vo, wheel,
|
||||
window_s=10.0, drift_threshold_m=0.5, now=5.0)
|
||||
assert result.drift_m == pytest.approx(0.0, abs=1e-9)
|
||||
assert not result.is_drifting
|
||||
|
||||
def test_both_stationary_zero_drift(self):
|
||||
vo = _stationary_buf(n=5)
|
||||
wheel = _stationary_buf(n=5)
|
||||
result = compute_drift(
|
||||
vo, wheel,
|
||||
window_s=10.0, drift_threshold_m=0.5, now=5.0)
|
||||
assert result.drift_m == pytest.approx(0.0)
|
||||
assert not result.is_drifting
|
||||
|
||||
def test_drift_equals_path_length_difference(self):
|
||||
# VO moves 1.0 m total, wheel moves 0.5 m total
|
||||
vo = _straight_buf(n=11, speed=0.1, dt=1.0) # 10 steps × 0.1 = 1.0m
|
||||
wheel = _straight_buf(n=11, speed=0.05, dt=1.0) # 10 steps × 0.05 = 0.5m
|
||||
result = compute_drift(
|
||||
vo, wheel,
|
||||
window_s=15.0, drift_threshold_m=0.5, now=11.0)
|
||||
assert result.vo_path_m == pytest.approx(1.0, abs=1e-9)
|
||||
assert result.wheel_path_m == pytest.approx(0.5, abs=1e-9)
|
||||
assert result.drift_m == pytest.approx(0.5, abs=1e-9)
|
||||
|
||||
def test_drift_at_threshold_is_drifting(self):
|
||||
# drift == 0.5 → is_drifting = True (>= threshold)
|
||||
vo = _straight_buf(n=11, speed=0.1, dt=1.0)
|
||||
wheel = _straight_buf(n=11, speed=0.05, dt=1.0)
|
||||
result = compute_drift(
|
||||
vo, wheel,
|
||||
window_s=15.0, drift_threshold_m=0.5, now=11.0)
|
||||
assert result.is_drifting
|
||||
|
||||
def test_drift_below_threshold_not_drifting(self):
|
||||
vo = _straight_buf(n=11, speed=0.1, dt=1.0)
|
||||
wheel = _straight_buf(n=11, speed=0.08, dt=1.0)
|
||||
result = compute_drift(
|
||||
vo, wheel,
|
||||
window_s=15.0, drift_threshold_m=0.5, now=11.0)
|
||||
# drift = |1.0 - 0.8| = 0.2
|
||||
assert result.drift_m == pytest.approx(0.2, abs=1e-9)
|
||||
assert not result.is_drifting
|
||||
|
||||
def test_drift_above_threshold_is_drifting(self):
|
||||
vo = _straight_buf(n=11, speed=0.1, dt=1.0)
|
||||
wheel = _stationary_buf(n=11, dt=1.0)
|
||||
result = compute_drift(
|
||||
vo, wheel,
|
||||
window_s=15.0, drift_threshold_m=0.5, now=11.0)
|
||||
# drift = |1.0 - 0.0| = 1.0 > 0.5
|
||||
assert result.drift_m > 0.5
|
||||
assert result.is_drifting
|
||||
|
||||
def test_n_vo_n_wheel_counts(self):
|
||||
vo = _straight_buf(n=8, dt=1.0)
|
||||
wheel = _straight_buf(n=5, dt=1.0)
|
||||
result = compute_drift(
|
||||
vo, wheel,
|
||||
window_s=15.0, drift_threshold_m=0.5, now=8.0)
|
||||
assert result.n_vo == 8
|
||||
assert result.n_wheel == 5
|
||||
|
||||
def test_samples_outside_window_ignored(self):
|
||||
# Push old samples far in the past; should not contribute to window
|
||||
vo = OdomBuffer(max_age_s=60.0)
|
||||
wheel = OdomBuffer(max_age_s=60.0)
|
||||
# Old samples outside window (t=0..4, window is last 3s from now=10)
|
||||
for t in range(5):
|
||||
vo.push(_s(float(t), x=float(t), y=0.0))
|
||||
wheel.push(_s(float(t), x=float(t), y=0.0))
|
||||
# Recent samples inside window (t=7..10)
|
||||
for t in range(7, 11):
|
||||
vo.push(_s(float(t), x=float(t) * 0.1, y=0.0))
|
||||
wheel.push(_s(float(t), x=float(t) * 0.1, y=0.0))
|
||||
|
||||
result = compute_drift(
|
||||
vo, wheel,
|
||||
window_s=3.0, drift_threshold_m=0.5, now=10.0)
|
||||
# Both sources move identically inside window → zero drift
|
||||
assert result.drift_m == pytest.approx(0.0, abs=1e-9)
|
||||
# Only the 4 recent samples (t=7,8,9,10) in window
|
||||
assert result.n_vo == 4
|
||||
assert result.n_wheel == 4
|
||||
|
||||
def test_result_is_namedtuple(self):
|
||||
result = compute_drift(
|
||||
_straight_buf(), _straight_buf(),
|
||||
window_s=10.0, drift_threshold_m=0.5, now=5.0)
|
||||
assert hasattr(result, 'drift_m')
|
||||
assert hasattr(result, 'vo_path_m')
|
||||
assert hasattr(result, 'wheel_path_m')
|
||||
assert hasattr(result, 'is_drifting')
|
||||
assert hasattr(result, 'window_s')
|
||||
assert hasattr(result, 'n_vo')
|
||||
assert hasattr(result, 'n_wheel')
|
||||
|
||||
def test_wheel_faster_than_vo_still_drifts(self):
|
||||
"""Drift is absolute difference — direction doesn't matter."""
|
||||
vo = _stationary_buf(n=11, dt=1.0)
|
||||
wheel = _straight_buf(n=11, speed=0.1, dt=1.0)
|
||||
result = compute_drift(
|
||||
vo, wheel,
|
||||
window_s=15.0, drift_threshold_m=0.5, now=11.0)
|
||||
assert result.drift_m == pytest.approx(1.0, abs=1e-9)
|
||||
assert result.is_drifting
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
@ -0,0 +1,4 @@
|
||||
imu_calibration:
|
||||
ros__parameters:
|
||||
calibration_samples: 100
|
||||
auto_calibrate: false
|
||||
@ -0,0 +1,30 @@
|
||||
"""Launch file for IMU calibration node."""
|
||||
|
||||
from launch import LaunchDescription
|
||||
from launch_ros.actions import Node
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg_dir = get_package_share_directory("saltybot_imu_calibration")
|
||||
config_file = os.path.join(pkg_dir, "config", "imu_calibration_config.yaml")
|
||||
|
||||
return LaunchDescription(
|
||||
[
|
||||
DeclareLaunchArgument(
|
||||
"config_file",
|
||||
default_value=config_file,
|
||||
description="Path to configuration YAML file",
|
||||
),
|
||||
Node(
|
||||
package="saltybot_imu_calibration",
|
||||
executable="imu_calibration_node",
|
||||
name="imu_calibration",
|
||||
output="screen",
|
||||
parameters=[LaunchConfiguration("config_file")],
|
||||
),
|
||||
]
|
||||
)
|
||||
22
jetson/ros2_ws/src/saltybot_imu_calibration/package.xml
Normal file
22
jetson/ros2_ws/src/saltybot_imu_calibration/package.xml
Normal file
@ -0,0 +1,22 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_imu_calibration</name>
|
||||
<version>0.1.0</version>
|
||||
<description>IMU gyro + accel calibration node for SaltyBot</description>
|
||||
<maintainer email="sl-controls@saltylab.local">SaltyLab Controls</maintainer>
|
||||
<license>MIT</license>
|
||||
|
||||
<buildtool_depend>ament_python</buildtool_depend>
|
||||
<depend>rclpy</depend>
|
||||
<depend>sensor_msgs</depend>
|
||||
<depend>std_msgs</depend>
|
||||
<depend>geometry_msgs</depend>
|
||||
|
||||
<test_depend>pytest</test_depend>
|
||||
<test_depend>sensor_msgs</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1,126 @@
|
||||
#!/usr/bin/env python3
|
||||
"""IMU calibration node for SaltyBot."""
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from sensor_msgs.msg import Imu
|
||||
from std_srvs.srv import Trigger
|
||||
from std_msgs.msg import Header
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
|
||||
|
||||
class IMUCalibrationNode(Node):
|
||||
"""ROS2 node for IMU gyro + accel calibration."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("imu_calibration")
|
||||
self.declare_parameter("calibration_samples", 100)
|
||||
self.declare_parameter("auto_calibrate", False)
|
||||
|
||||
self.calibration_samples = self.get_parameter("calibration_samples").value
|
||||
self.auto_calibrate = self.get_parameter("auto_calibrate").value
|
||||
|
||||
self.gyro_bias = np.array([0.0, 0.0, 0.0])
|
||||
self.accel_bias = np.array([0.0, 0.0, 0.0])
|
||||
self.is_calibrated = False
|
||||
|
||||
self.gyro_samples = deque(maxlen=self.calibration_samples)
|
||||
self.accel_samples = deque(maxlen=self.calibration_samples)
|
||||
self.calibrating = False
|
||||
|
||||
self.sub_imu = self.create_subscription(Imu, "/imu", self._on_imu_raw, 10)
|
||||
self.pub_calibrated = self.create_publisher(Imu, "/imu/calibrated", 10)
|
||||
self.srv_calibrate = self.create_service(
|
||||
Trigger, "/saltybot/calibrate_imu", self._on_calibrate_service
|
||||
)
|
||||
|
||||
self.get_logger().info(
|
||||
f"IMU calibration node initialized. Samples: {self.calibration_samples}. Auto: {self.auto_calibrate}"
|
||||
)
|
||||
|
||||
if self.auto_calibrate:
|
||||
self.calibrating = True
|
||||
self.get_logger().info("Starting auto-calibration...")
|
||||
|
||||
def _on_imu_raw(self, msg: Imu) -> None:
|
||||
if self.calibrating:
|
||||
gyro = np.array([msg.angular_velocity.x, msg.angular_velocity.y, msg.angular_velocity.z])
|
||||
accel = np.array([msg.linear_acceleration.x, msg.linear_acceleration.y, msg.linear_acceleration.z])
|
||||
self.gyro_samples.append(gyro)
|
||||
self.accel_samples.append(accel)
|
||||
|
||||
if len(self.gyro_samples) == self.calibration_samples:
|
||||
self._compute_calibration()
|
||||
else:
|
||||
self._publish_calibrated(msg)
|
||||
|
||||
def _compute_calibration(self) -> None:
|
||||
if len(self.gyro_samples) == 0 or len(self.accel_samples) == 0:
|
||||
return
|
||||
|
||||
gyro_data = np.array(list(self.gyro_samples))
|
||||
accel_data = np.array(list(self.accel_samples))
|
||||
|
||||
self.gyro_bias = np.mean(gyro_data, axis=0)
|
||||
self.accel_bias = np.mean(accel_data, axis=0)
|
||||
|
||||
self.is_calibrated = True
|
||||
self.calibrating = False
|
||||
|
||||
self.get_logger().info(
|
||||
f"Calibration complete. Gyro: {self.gyro_bias}. Accel: {self.accel_bias}"
|
||||
)
|
||||
|
||||
self.gyro_samples.clear()
|
||||
self.accel_samples.clear()
|
||||
|
||||
def _on_calibrate_service(self, request, response) -> Trigger.Response:
|
||||
if self.calibrating:
|
||||
response.success = False
|
||||
response.message = "Calibration already in progress"
|
||||
return response
|
||||
|
||||
self.get_logger().info("Calibration service called")
|
||||
self.calibrating = True
|
||||
self.gyro_samples.clear()
|
||||
self.accel_samples.clear()
|
||||
|
||||
response.success = True
|
||||
response.message = f"Calibration started, collecting {self.calibration_samples} samples"
|
||||
return response
|
||||
|
||||
def _publish_calibrated(self, msg: Imu) -> None:
|
||||
calibrated = Imu()
|
||||
calibrated.header = Header(frame_id=msg.header.frame_id, stamp=msg.header.stamp)
|
||||
|
||||
calibrated.angular_velocity.x = msg.angular_velocity.x - self.gyro_bias[0]
|
||||
calibrated.angular_velocity.y = msg.angular_velocity.y - self.gyro_bias[1]
|
||||
calibrated.angular_velocity.z = msg.angular_velocity.z - self.gyro_bias[2]
|
||||
|
||||
calibrated.linear_acceleration.x = msg.linear_acceleration.x - self.accel_bias[0]
|
||||
calibrated.linear_acceleration.y = msg.linear_acceleration.y - self.accel_bias[1]
|
||||
calibrated.linear_acceleration.z = msg.linear_acceleration.z - self.accel_bias[2]
|
||||
|
||||
calibrated.angular_velocity_covariance = msg.angular_velocity_covariance
|
||||
calibrated.linear_acceleration_covariance = msg.linear_acceleration_covariance
|
||||
calibrated.orientation_covariance = msg.orientation_covariance
|
||||
calibrated.orientation = msg.orientation
|
||||
|
||||
self.pub_calibrated.publish(calibrated)
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = IMUCalibrationNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
5
jetson/ros2_ws/src/saltybot_imu_calibration/setup.cfg
Normal file
5
jetson/ros2_ws/src/saltybot_imu_calibration/setup.cfg
Normal file
@ -0,0 +1,5 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/saltybot_imu_calibration
|
||||
|
||||
[install]
|
||||
install_scripts=$base/lib/saltybot_imu_calibration
|
||||
24
jetson/ros2_ws/src/saltybot_imu_calibration/setup.py
Normal file
24
jetson/ros2_ws/src/saltybot_imu_calibration/setup.py
Normal file
@ -0,0 +1,24 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name='saltybot_imu_calibration',
|
||||
version='0.1.0',
|
||||
packages=find_packages(),
|
||||
data_files=[
|
||||
('share/ament_index/resource_index/packages', ['resource/saltybot_imu_calibration']),
|
||||
('share/saltybot_imu_calibration', ['package.xml']),
|
||||
('share/saltybot_imu_calibration/config', ['config/imu_calibration_config.yaml']),
|
||||
('share/saltybot_imu_calibration/launch', ['launch/imu_calibration.launch.py']),
|
||||
],
|
||||
install_requires=['setuptools'],
|
||||
zip_safe=True,
|
||||
author='SaltyLab Controls',
|
||||
author_email='sl-controls@saltylab.local',
|
||||
description='IMU gyro + accel calibration node for SaltyBot',
|
||||
license='MIT',
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'imu_calibration_node=saltybot_imu_calibration.imu_calibration_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
@ -0,0 +1,67 @@
|
||||
"""Tests for IMU calibration node."""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from sensor_msgs.msg import Imu
|
||||
from geometry_msgs.msg import Quaternion
|
||||
import rclpy
|
||||
from rclpy.time import Time
|
||||
|
||||
from saltybot_imu_calibration.imu_calibration_node import IMUCalibrationNode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rclpy_fixture():
|
||||
rclpy.init()
|
||||
yield
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def node(rclpy_fixture):
|
||||
node = IMUCalibrationNode()
|
||||
yield node
|
||||
node.destroy_node()
|
||||
|
||||
|
||||
class TestInit:
|
||||
def test_node_initialization(self, node):
|
||||
assert node.calibration_samples == 100
|
||||
assert node.is_calibrated is False
|
||||
assert node.calibrating is False
|
||||
|
||||
|
||||
class TestCalibration:
|
||||
def test_calibration_samples(self, node):
|
||||
node.calibration_samples = 3
|
||||
node.gyro_samples.maxlen = 3
|
||||
node.accel_samples.maxlen = 3
|
||||
node.calibrating = True
|
||||
|
||||
for i in range(3):
|
||||
node.gyro_samples.append(np.array([0.1, 0.2, 0.3]))
|
||||
node.accel_samples.append(np.array([0.0, 0.0, 9.81]))
|
||||
|
||||
node._compute_calibration()
|
||||
|
||||
assert node.is_calibrated is True
|
||||
assert len(node.gyro_samples) == 0
|
||||
|
||||
|
||||
class TestCorrection:
|
||||
def test_imu_correction(self, node):
|
||||
node.gyro_bias = np.array([0.1, 0.2, 0.3])
|
||||
node.accel_bias = np.array([0.0, 0.0, 0.1])
|
||||
|
||||
msg = Imu()
|
||||
msg.header.stamp = Time().to_msg()
|
||||
msg.header.frame_id = "imu_link"
|
||||
msg.angular_velocity.x = 0.11
|
||||
msg.angular_velocity.y = 0.22
|
||||
msg.angular_velocity.z = 0.33
|
||||
msg.linear_acceleration.x = 0.0
|
||||
msg.linear_acceleration.y = 0.0
|
||||
msg.linear_acceleration.z = 9.91
|
||||
msg.orientation = Quaternion(x=0, y=0, z=0, w=1)
|
||||
|
||||
node._publish_calibrated(msg)
|
||||
@ -16,6 +16,9 @@ rosidl_generate_interfaces(${PROJECT_NAME}
|
||||
# Issue #233 — QR code reader
|
||||
"msg/QRDetection.msg"
|
||||
"msg/QRDetectionArray.msg"
|
||||
# Issue #274 — HSV color segmentation
|
||||
"msg/ColorDetection.msg"
|
||||
"msg/ColorDetectionArray.msg"
|
||||
DEPENDENCIES std_msgs geometry_msgs vision_msgs builtin_interfaces
|
||||
)
|
||||
|
||||
|
||||
@ -0,0 +1,14 @@
|
||||
# ColorDetection.msg — single HSV color-segmented object detection (Issue #274)
|
||||
#
|
||||
# color_name : target color label ("red", "green", "blue", "yellow", "orange")
|
||||
# confidence : mask fill ratio inside bbox (contour_area / bbox_area, 0–1)
|
||||
# bbox : axis-aligned bounding box in image pixels (center + size)
|
||||
# area_px : contour area in pixels² (use for size filtering downstream)
|
||||
# contour_id : 0-based index of this detection within the current frame
|
||||
#
|
||||
std_msgs/Header header
|
||||
string color_name
|
||||
float32 confidence
|
||||
vision_msgs/BoundingBox2D bbox
|
||||
float32 area_px
|
||||
uint32 contour_id
|
||||
@ -0,0 +1,3 @@
|
||||
# ColorDetectionArray.msg — frame-level list of HSV color-segmented objects (Issue #274)
|
||||
std_msgs/Header header
|
||||
ColorDetection[] detections
|
||||
@ -0,0 +1,21 @@
|
||||
ambient_sound_node:
|
||||
ros__parameters:
|
||||
sample_rate: 16000 # Expected PCM sample rate (Hz)
|
||||
window_s: 1.0 # Accumulate this many seconds before classifying
|
||||
n_fft: 512 # FFT size (32 ms frame at 16 kHz)
|
||||
n_mels: 32 # Mel filterbank bands
|
||||
audio_topic: "/social/speech/audio_raw" # Source PCM-16 UInt8MultiArray topic
|
||||
|
||||
# ── Classifier thresholds ──────────────────────────────────────────────
|
||||
# Adjust to tune sensitivity for your deployment environment.
|
||||
silence_db: -40.0 # Below this energy (dBFS) → silence
|
||||
alarm_db_min: -25.0 # Min energy for alarm detection
|
||||
alarm_zcr_min: 0.12 # Min ZCR for alarm (intermittent high pitch)
|
||||
alarm_high_ratio_min: 0.35 # Min high-band energy fraction for alarm
|
||||
speech_zcr_min: 0.02 # Min ZCR for speech (voiced onset)
|
||||
speech_zcr_max: 0.25 # Max ZCR for speech
|
||||
speech_flatness_max: 0.35 # Max spectral flatness for speech (tonal)
|
||||
music_zcr_max: 0.08 # Max ZCR for music (harmonic / tonal)
|
||||
music_flatness_max: 0.25 # Max spectral flatness for music
|
||||
crowd_zcr_min: 0.10 # Min ZCR for crowd noise
|
||||
crowd_flatness_min: 0.35 # Min spectral flatness for crowd
|
||||
@ -0,0 +1,30 @@
|
||||
face_track_servo_node:
|
||||
ros__parameters:
|
||||
# PID gains — pan axis
|
||||
kp_pan: 1.5 # proportional gain (°/s per ° error)
|
||||
ki_pan: 0.1 # integral gain
|
||||
kd_pan: 0.05 # derivative gain (damping)
|
||||
|
||||
# PID gains — tilt axis
|
||||
kp_tilt: 1.2
|
||||
ki_tilt: 0.1
|
||||
kd_tilt: 0.04
|
||||
|
||||
# Camera FOV
|
||||
fov_h_deg: 60.0 # horizontal field of view (degrees)
|
||||
fov_v_deg: 45.0 # vertical field of view (degrees)
|
||||
|
||||
# Servo limits
|
||||
pan_limit_deg: 90.0 # mechanical pan range ± (degrees)
|
||||
tilt_limit_deg: 30.0 # mechanical tilt range ± (degrees)
|
||||
pan_vel_limit: 45.0 # max pan rate (°/s)
|
||||
tilt_vel_limit: 30.0 # max tilt rate (°/s)
|
||||
windup_limit: 15.0 # integral anti-windup clamp (°·s)
|
||||
|
||||
# Tracking behaviour
|
||||
dead_zone: 0.02 # normalised dead zone (fraction of frame width/height)
|
||||
control_rate: 20.0 # control loop frequency (Hz)
|
||||
lost_timeout_s: 1.5 # seconds before face considered lost
|
||||
return_rate_deg_s: 10.0 # return-to-centre speed when no face (°/s)
|
||||
|
||||
faces_topic: "/social/faces/detected"
|
||||
@ -0,0 +1,8 @@
|
||||
greeting_trigger_node:
|
||||
ros__parameters:
|
||||
proximity_m: 2.0 # Trigger when person is within this distance (m)
|
||||
cooldown_s: 300.0 # Re-greeting suppression window per face_id (s)
|
||||
unknown_distance: 0.0 # Distance assumed when PersonState not yet available
|
||||
# 0.0 → always greet faces with no state yet
|
||||
faces_topic: "/social/faces/detected"
|
||||
states_topic: "/social/person_states"
|
||||
@ -0,0 +1,42 @@
|
||||
"""ambient_sound.launch.py -- Launch the ambient sound classifier (Issue #252).
|
||||
|
||||
Usage:
|
||||
ros2 launch saltybot_social ambient_sound.launch.py
|
||||
ros2 launch saltybot_social ambient_sound.launch.py silence_db:=-45.0
|
||||
"""
|
||||
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg = get_package_share_directory("saltybot_social")
|
||||
cfg = os.path.join(pkg, "config", "ambient_sound_params.yaml")
|
||||
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument("window_s", default_value="1.0",
|
||||
description="Accumulation window (s)"),
|
||||
DeclareLaunchArgument("n_mels", default_value="32",
|
||||
description="Mel filterbank bands"),
|
||||
DeclareLaunchArgument("silence_db", default_value="-40.0",
|
||||
description="Silence energy threshold (dBFS)"),
|
||||
|
||||
Node(
|
||||
package="saltybot_social",
|
||||
executable="ambient_sound_node",
|
||||
name="ambient_sound_node",
|
||||
output="screen",
|
||||
parameters=[
|
||||
cfg,
|
||||
{
|
||||
"window_s": LaunchConfiguration("window_s"),
|
||||
"n_mels": LaunchConfiguration("n_mels"),
|
||||
"silence_db": LaunchConfiguration("silence_db"),
|
||||
},
|
||||
],
|
||||
),
|
||||
])
|
||||
@ -0,0 +1,51 @@
|
||||
"""face_track_servo.launch.py — Launch face-tracking head servo controller (Issue #279).
|
||||
|
||||
Usage:
|
||||
ros2 launch saltybot_social face_track_servo.launch.py
|
||||
ros2 launch saltybot_social face_track_servo.launch.py kp_pan:=2.0 pan_limit_deg:=60.0
|
||||
"""
|
||||
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg = get_package_share_directory("saltybot_social")
|
||||
cfg = os.path.join(pkg, "config", "face_track_servo_params.yaml")
|
||||
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument("kp_pan", default_value="1.5",
|
||||
description="Pan proportional gain (°/s per °)"),
|
||||
DeclareLaunchArgument("kp_tilt", default_value="1.2",
|
||||
description="Tilt proportional gain (°/s per °)"),
|
||||
DeclareLaunchArgument("pan_limit_deg", default_value="90.0",
|
||||
description="Mechanical pan limit ± (degrees)"),
|
||||
DeclareLaunchArgument("tilt_limit_deg", default_value="30.0",
|
||||
description="Mechanical tilt limit ± (degrees)"),
|
||||
DeclareLaunchArgument("fov_h_deg", default_value="60.0",
|
||||
description="Camera horizontal FOV (degrees)"),
|
||||
DeclareLaunchArgument("fov_v_deg", default_value="45.0",
|
||||
description="Camera vertical FOV (degrees)"),
|
||||
|
||||
Node(
|
||||
package="saltybot_social",
|
||||
executable="face_track_servo_node",
|
||||
name="face_track_servo_node",
|
||||
output="screen",
|
||||
parameters=[
|
||||
cfg,
|
||||
{
|
||||
"kp_pan": LaunchConfiguration("kp_pan"),
|
||||
"kp_tilt": LaunchConfiguration("kp_tilt"),
|
||||
"pan_limit_deg": LaunchConfiguration("pan_limit_deg"),
|
||||
"tilt_limit_deg": LaunchConfiguration("tilt_limit_deg"),
|
||||
"fov_h_deg": LaunchConfiguration("fov_h_deg"),
|
||||
"fov_v_deg": LaunchConfiguration("fov_v_deg"),
|
||||
},
|
||||
],
|
||||
),
|
||||
])
|
||||
@ -0,0 +1,39 @@
|
||||
"""greeting_trigger.launch.py -- Launch proximity-based greeting trigger (Issue #270).
|
||||
|
||||
Usage:
|
||||
ros2 launch saltybot_social greeting_trigger.launch.py
|
||||
ros2 launch saltybot_social greeting_trigger.launch.py proximity_m:=1.5 cooldown_s:=120.0
|
||||
"""
|
||||
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg = get_package_share_directory("saltybot_social")
|
||||
cfg = os.path.join(pkg, "config", "greeting_trigger_params.yaml")
|
||||
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument("proximity_m", default_value="2.0",
|
||||
description="Greeting proximity threshold (m)"),
|
||||
DeclareLaunchArgument("cooldown_s", default_value="300.0",
|
||||
description="Per-face_id re-greeting cooldown (s)"),
|
||||
|
||||
Node(
|
||||
package="saltybot_social",
|
||||
executable="greeting_trigger_node",
|
||||
name="greeting_trigger_node",
|
||||
output="screen",
|
||||
parameters=[
|
||||
cfg,
|
||||
{
|
||||
"proximity_m": LaunchConfiguration("proximity_m"),
|
||||
"cooldown_s": LaunchConfiguration("cooldown_s"),
|
||||
},
|
||||
],
|
||||
),
|
||||
])
|
||||
@ -0,0 +1,363 @@
|
||||
"""ambient_sound_node.py -- Ambient sound classifier via mel-spectrogram features.
|
||||
Issue #252
|
||||
|
||||
Accumulates 1 s of PCM-16 audio from /social/speech/audio_raw, extracts a
|
||||
compact mel-spectrogram feature vector, then classifies the scene into one of:
|
||||
|
||||
silence | speech | music | crowd | outdoor | alarm
|
||||
|
||||
Publishes the label as std_msgs/String on /saltybot/ambient_sound at 1 Hz.
|
||||
|
||||
Signal processing is pure Python + numpy (no torch / onnx dependency).
|
||||
|
||||
Feature vector (per 1-s window):
|
||||
energy_db -- overall RMS in dBFS
|
||||
zcr -- mean zero-crossing rate across frames
|
||||
mel_centroid -- centre-of-mass of the mel band energies [0..1]
|
||||
mel_flatness -- geometric/arithmetic mean of mel energies [0..1]
|
||||
(1 = white noise, 0 = single sinusoid)
|
||||
low_ratio -- fraction of mel energy in lower third of bands
|
||||
high_ratio -- fraction of mel energy in upper third of bands
|
||||
|
||||
Classification cascade (priority-ordered):
|
||||
silence : energy_db < silence_db
|
||||
alarm : energy_db >= alarm_db_min AND zcr >= alarm_zcr_min
|
||||
AND high_ratio >= alarm_high_ratio_min
|
||||
speech : zcr in [speech_zcr_min, speech_zcr_max]
|
||||
AND mel_flatness < speech_flatness_max
|
||||
music : zcr < music_zcr_max AND mel_flatness < music_flatness_max
|
||||
crowd : zcr >= crowd_zcr_min AND mel_flatness >= crowd_flatness_min
|
||||
outdoor : catch-all
|
||||
|
||||
Parameters:
|
||||
sample_rate (int, 16000)
|
||||
window_s (float, 1.0) -- accumulation window before classify
|
||||
n_fft (int, 512) -- FFT size
|
||||
n_mels (int, 32) -- mel filterbank bands
|
||||
audio_topic (str, "/social/speech/audio_raw")
|
||||
silence_db (float, -40.0)
|
||||
alarm_db_min (float, -25.0)
|
||||
alarm_zcr_min (float, 0.12)
|
||||
alarm_high_ratio_min (float, 0.35)
|
||||
speech_zcr_min (float, 0.02)
|
||||
speech_zcr_max (float, 0.25)
|
||||
speech_flatness_max (float, 0.35)
|
||||
music_zcr_max (float, 0.08)
|
||||
music_flatness_max (float, 0.25)
|
||||
crowd_zcr_min (float, 0.10)
|
||||
crowd_flatness_min (float, 0.35)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import struct
|
||||
import threading
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile
|
||||
from std_msgs.msg import String, UInt8MultiArray
|
||||
|
||||
# numpy used only in DSP helpers — the Jetson always has it
|
||||
try:
|
||||
import numpy as np
|
||||
_NUMPY = True
|
||||
except ImportError:
|
||||
_NUMPY = False
|
||||
|
||||
INT16_MAX = 32768.0
|
||||
LABELS = ("silence", "speech", "music", "crowd", "outdoor", "alarm")
|
||||
|
||||
|
||||
# ── PCM helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
def pcm16_bytes_to_float32(data: bytes) -> List[float]:
|
||||
"""PCM-16 LE bytes → float32 list in [-1.0, 1.0]."""
|
||||
n = len(data) // 2
|
||||
if n == 0:
|
||||
return []
|
||||
return [s / INT16_MAX for s in struct.unpack(f"<{n}h", data[: n * 2])]
|
||||
|
||||
|
||||
# ── Mel DSP (numpy path) ──────────────────────────────────────────────────────
|
||||
|
||||
def hz_to_mel(hz: float) -> float:
|
||||
return 2595.0 * math.log10(1.0 + hz / 700.0)
|
||||
|
||||
|
||||
def mel_to_hz(mel: float) -> float:
|
||||
return 700.0 * (10.0 ** (mel / 2595.0) - 1.0)
|
||||
|
||||
|
||||
def build_mel_filterbank(sr: int, n_fft: int, n_mels: int,
|
||||
fmin: float = 0.0, fmax: Optional[float] = None):
|
||||
"""Return (n_mels, n_fft//2+1) numpy filterbank matrix."""
|
||||
import numpy as np
|
||||
if fmax is None:
|
||||
fmax = sr / 2.0
|
||||
n_freqs = n_fft // 2 + 1
|
||||
mel_min = hz_to_mel(fmin)
|
||||
mel_max = hz_to_mel(fmax)
|
||||
mel_pts = np.linspace(mel_min, mel_max, n_mels + 2)
|
||||
hz_pts = np.array([mel_to_hz(m) for m in mel_pts])
|
||||
bin_pts = np.floor((n_fft + 1) * hz_pts / sr).astype(int)
|
||||
fb = np.zeros((n_mels, n_freqs))
|
||||
for m in range(n_mels):
|
||||
lo, ctr, hi = bin_pts[m], bin_pts[m + 1], bin_pts[m + 2]
|
||||
for k in range(lo, min(ctr, n_freqs)):
|
||||
if ctr != lo:
|
||||
fb[m, k] = (k - lo) / (ctr - lo)
|
||||
for k in range(ctr, min(hi, n_freqs)):
|
||||
if hi != ctr:
|
||||
fb[m, k] = (hi - k) / (hi - ctr)
|
||||
return fb
|
||||
|
||||
|
||||
def compute_mel_spectrogram(samples: List[float], sr: int,
|
||||
n_fft: int = 512, n_mels: int = 32,
|
||||
hop_length: int = 256):
|
||||
"""Return (n_mels, n_frames) log-mel spectrogram (numpy array)."""
|
||||
import numpy as np
|
||||
x = np.array(samples, dtype=np.float32)
|
||||
fb = build_mel_filterbank(sr, n_fft, n_mels)
|
||||
window = np.hanning(n_fft)
|
||||
frames = []
|
||||
for start in range(0, len(x) - n_fft + 1, hop_length):
|
||||
frame = x[start : start + n_fft] * window
|
||||
spec = np.abs(np.fft.rfft(frame)) ** 2
|
||||
mel = fb @ spec
|
||||
frames.append(mel)
|
||||
if not frames:
|
||||
return np.zeros((n_mels, 1), dtype=np.float32)
|
||||
return np.column_stack(frames).astype(np.float32)
|
||||
|
||||
|
||||
# ── Feature extraction ────────────────────────────────────────────────────────
|
||||
|
||||
def extract_features(samples: List[float], sr: int,
|
||||
n_fft: int = 512, n_mels: int = 32) -> Dict[str, float]:
|
||||
"""Extract scalar features from a raw audio window."""
|
||||
import numpy as np
|
||||
|
||||
n = len(samples)
|
||||
if n == 0:
|
||||
return {k: 0.0 for k in
|
||||
("energy_db", "zcr", "mel_centroid", "mel_flatness",
|
||||
"low_ratio", "high_ratio")}
|
||||
|
||||
# Energy
|
||||
rms = math.sqrt(sum(s * s for s in samples) / n) if n else 0.0
|
||||
energy_db = 20.0 * math.log10(max(rms, 1e-10))
|
||||
|
||||
# ZCR across 30 ms frames
|
||||
chunk = max(1, int(sr * 0.030))
|
||||
zcr_vals = []
|
||||
for i in range(0, n - chunk + 1, chunk):
|
||||
seg = samples[i : i + chunk]
|
||||
crossings = sum(1 for j in range(1, len(seg))
|
||||
if seg[j - 1] * seg[j] < 0)
|
||||
zcr_vals.append(crossings / max(len(seg) - 1, 1))
|
||||
zcr = sum(zcr_vals) / len(zcr_vals) if zcr_vals else 0.0
|
||||
|
||||
# Mel spectrogram features
|
||||
mel_spec = compute_mel_spectrogram(samples, sr, n_fft, n_mels)
|
||||
mel_mean = mel_spec.mean(axis=1) # (n_mels,) mean energy per band
|
||||
|
||||
total = float(mel_mean.sum()) if mel_mean.sum() > 0 else 1e-10
|
||||
indices = np.arange(n_mels, dtype=np.float32)
|
||||
mel_centroid = float((indices * mel_mean).sum()) / (n_mels * total / total) / n_mels
|
||||
|
||||
# Spectral flatness: geometric mean / arithmetic mean
|
||||
eps = 1e-10
|
||||
mel_pos = np.clip(mel_mean, eps, None)
|
||||
geo_mean = float(np.exp(np.log(mel_pos).mean()))
|
||||
arith_mean = float(mel_pos.mean())
|
||||
mel_flatness = min(geo_mean / max(arith_mean, eps), 1.0)
|
||||
|
||||
# Band ratios
|
||||
third = max(1, n_mels // 3)
|
||||
low_energy = float(mel_mean[:third].sum())
|
||||
high_energy = float(mel_mean[-third:].sum())
|
||||
low_ratio = low_energy / max(total, eps)
|
||||
high_ratio = high_energy / max(total, eps)
|
||||
|
||||
return {
|
||||
"energy_db": energy_db,
|
||||
"zcr": zcr,
|
||||
"mel_centroid": mel_centroid,
|
||||
"mel_flatness": mel_flatness,
|
||||
"low_ratio": low_ratio,
|
||||
"high_ratio": high_ratio,
|
||||
}
|
||||
|
||||
|
||||
# ── Classifier ────────────────────────────────────────────────────────────────
|
||||
|
||||
def classify(features: Dict[str, float],
|
||||
silence_db: float = -40.0,
|
||||
alarm_db_min: float = -25.0,
|
||||
alarm_zcr_min: float = 0.12,
|
||||
alarm_high_ratio_min: float = 0.35,
|
||||
speech_zcr_min: float = 0.02,
|
||||
speech_zcr_max: float = 0.25,
|
||||
speech_flatness_max: float = 0.35,
|
||||
music_zcr_max: float = 0.08,
|
||||
music_flatness_max: float = 0.25,
|
||||
crowd_zcr_min: float = 0.10,
|
||||
crowd_flatness_min: float = 0.35) -> str:
|
||||
"""Priority-ordered rule cascade. Returns a label from LABELS."""
|
||||
e = features["energy_db"]
|
||||
zcr = features["zcr"]
|
||||
fl = features["mel_flatness"]
|
||||
hi = features["high_ratio"]
|
||||
|
||||
if e < silence_db:
|
||||
return "silence"
|
||||
if (e >= alarm_db_min
|
||||
and zcr >= alarm_zcr_min
|
||||
and hi >= alarm_high_ratio_min):
|
||||
return "alarm"
|
||||
if zcr < music_zcr_max and fl < music_flatness_max:
|
||||
return "music"
|
||||
if (speech_zcr_min <= zcr <= speech_zcr_max
|
||||
and fl < speech_flatness_max):
|
||||
return "speech"
|
||||
if zcr >= crowd_zcr_min and fl >= crowd_flatness_min:
|
||||
return "crowd"
|
||||
return "outdoor"
|
||||
|
||||
|
||||
# ── Audio accumulation buffer ─────────────────────────────────────────────────
|
||||
|
||||
class AudioBuffer:
|
||||
"""Thread-safe ring buffer; yields a window of samples when full."""
|
||||
|
||||
def __init__(self, window_samples: int) -> None:
|
||||
self._target = window_samples
|
||||
self._buf: List[float] = []
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def push(self, samples: List[float]) -> Optional[List[float]]:
|
||||
"""Append samples. Returns a complete window (and resets) when full."""
|
||||
with self._lock:
|
||||
self._buf.extend(samples)
|
||||
if len(self._buf) >= self._target:
|
||||
window = self._buf[: self._target]
|
||||
self._buf = self._buf[self._target :]
|
||||
return window
|
||||
return None
|
||||
|
||||
def clear(self) -> None:
|
||||
with self._lock:
|
||||
self._buf.clear()
|
||||
|
||||
|
||||
# ── ROS2 node ─────────────────────────────────────────────────────────────────
|
||||
|
||||
class AmbientSoundNode(Node):
|
||||
"""Classifies ambient sound from raw audio and publishes label at 1 Hz."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("ambient_sound_node")
|
||||
|
||||
self.declare_parameter("sample_rate", 16000)
|
||||
self.declare_parameter("window_s", 1.0)
|
||||
self.declare_parameter("n_fft", 512)
|
||||
self.declare_parameter("n_mels", 32)
|
||||
self.declare_parameter("audio_topic", "/social/speech/audio_raw")
|
||||
# Classifier thresholds
|
||||
self.declare_parameter("silence_db", -40.0)
|
||||
self.declare_parameter("alarm_db_min", -25.0)
|
||||
self.declare_parameter("alarm_zcr_min", 0.12)
|
||||
self.declare_parameter("alarm_high_ratio_min", 0.35)
|
||||
self.declare_parameter("speech_zcr_min", 0.02)
|
||||
self.declare_parameter("speech_zcr_max", 0.25)
|
||||
self.declare_parameter("speech_flatness_max", 0.35)
|
||||
self.declare_parameter("music_zcr_max", 0.08)
|
||||
self.declare_parameter("music_flatness_max", 0.25)
|
||||
self.declare_parameter("crowd_zcr_min", 0.10)
|
||||
self.declare_parameter("crowd_flatness_min", 0.35)
|
||||
|
||||
self._sr = self.get_parameter("sample_rate").value
|
||||
self._n_fft = self.get_parameter("n_fft").value
|
||||
self._n_mels = self.get_parameter("n_mels").value
|
||||
window_s = self.get_parameter("window_s").value
|
||||
audio_topic = self.get_parameter("audio_topic").value
|
||||
|
||||
self._thresholds = {
|
||||
k: self.get_parameter(k).value for k in (
|
||||
"silence_db", "alarm_db_min", "alarm_zcr_min",
|
||||
"alarm_high_ratio_min", "speech_zcr_min", "speech_zcr_max",
|
||||
"speech_flatness_max", "music_zcr_max", "music_flatness_max",
|
||||
"crowd_zcr_min", "crowd_flatness_min",
|
||||
)
|
||||
}
|
||||
|
||||
self._buffer = AudioBuffer(int(self._sr * window_s))
|
||||
self._last_label = "silence"
|
||||
|
||||
qos = QoSProfile(depth=10)
|
||||
self._pub = self.create_publisher(String, "/saltybot/ambient_sound", qos)
|
||||
self._audio_sub = self.create_subscription(
|
||||
UInt8MultiArray, audio_topic, self._on_audio, qos
|
||||
)
|
||||
|
||||
if not _NUMPY:
|
||||
self.get_logger().warn(
|
||||
"numpy not available — mel features disabled, classifying by energy only"
|
||||
)
|
||||
|
||||
self.get_logger().info(
|
||||
f"AmbientSoundNode ready "
|
||||
f"(sr={self._sr}, window={window_s}s, n_mels={self._n_mels})"
|
||||
)
|
||||
|
||||
def _on_audio(self, msg: UInt8MultiArray) -> None:
|
||||
samples = pcm16_bytes_to_float32(bytes(msg.data))
|
||||
if not samples:
|
||||
return
|
||||
window = self._buffer.push(samples)
|
||||
if window is not None:
|
||||
self._classify_and_publish(window)
|
||||
|
||||
def _classify_and_publish(self, samples: List[float]) -> None:
|
||||
try:
|
||||
if _NUMPY:
|
||||
feats = extract_features(samples, self._sr, self._n_fft, self._n_mels)
|
||||
else:
|
||||
# Numpy-free fallback: energy-only
|
||||
rms = math.sqrt(sum(s * s for s in samples) / len(samples))
|
||||
e_db = 20.0 * math.log10(max(rms, 1e-10))
|
||||
feats = {
|
||||
"energy_db": e_db, "zcr": 0.05,
|
||||
"mel_centroid": 0.5, "mel_flatness": 0.2,
|
||||
"low_ratio": 0.4, "high_ratio": 0.2,
|
||||
}
|
||||
label = classify(feats, **self._thresholds)
|
||||
except Exception as exc:
|
||||
self.get_logger().error(f"Classification error: {exc}")
|
||||
label = self._last_label
|
||||
|
||||
if label != self._last_label:
|
||||
self.get_logger().info(
|
||||
f"Ambient sound: {self._last_label} -> {label}"
|
||||
)
|
||||
self._last_label = label
|
||||
|
||||
msg = String()
|
||||
msg.data = label
|
||||
self._pub.publish(msg)
|
||||
|
||||
|
||||
def main(args: Optional[list] = None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = AmbientSoundNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
@ -0,0 +1,308 @@
|
||||
"""face_track_servo_node.py — Face-tracking head servo controller.
|
||||
Issue #279
|
||||
|
||||
Subscribes to /social/faces/detected, picks the closest face (largest
|
||||
bounding-box area as a proximity proxy), computes pan/tilt angular error
|
||||
relative to the image centre, and drives two PID controllers to produce
|
||||
smooth servo position commands published on /saltybot/head_pan and
|
||||
/saltybot/head_tilt (std_msgs/Float32, degrees from neutral).
|
||||
|
||||
Coordinate convention
|
||||
─────────────────────
|
||||
bbox_x/y/w/h : normalised [0, 1] in image space
|
||||
face centre : cx = bbox_x + bbox_w/2 , cy = bbox_y + bbox_h/2
|
||||
image centre : (0.5, 0.5)
|
||||
pan error : (cx - 0.5) * fov_h_deg (+ve → face right of centre)
|
||||
tilt error : (cy - 0.5) * fov_v_deg (+ve → face below centre)
|
||||
|
||||
PID design (velocity / incremental)
|
||||
────────────────────────────────────
|
||||
velocity (°/s) = Kp·e + Ki·∫e dt + Kd·de/dt
|
||||
servo_angle += velocity · dt
|
||||
servo_angle = clamp(servo_angle, ±limit_deg)
|
||||
|
||||
When no face is seen for more than ``lost_timeout_s`` seconds the PIDs
|
||||
are reset and the servo commands return toward 0° at ``return_rate_deg_s``.
|
||||
|
||||
Parameters
|
||||
──────────
|
||||
kp_pan (float, 1.5) pan proportional gain (°/s per °)
|
||||
ki_pan (float, 0.1) pan integral gain
|
||||
kd_pan (float, 0.05) pan derivative gain
|
||||
kp_tilt (float, 1.2) tilt proportional gain
|
||||
ki_tilt (float, 0.1) tilt integral gain
|
||||
kd_tilt (float, 0.04) tilt derivative gain
|
||||
fov_h_deg (float, 60.0) camera horizontal FOV (degrees)
|
||||
fov_v_deg (float, 45.0) camera vertical FOV (degrees)
|
||||
pan_limit_deg (float, 90.0) mechanical pan limit ±
|
||||
tilt_limit_deg (float, 30.0) mechanical tilt limit ±
|
||||
pan_vel_limit (float, 45.0) max pan rate (°/s)
|
||||
tilt_vel_limit (float, 30.0) max tilt rate (°/s)
|
||||
windup_limit (float, 15.0) integral anti-windup clamp (°·s)
|
||||
dead_zone (float, 0.02) normalised dead zone (fraction of frame)
|
||||
control_rate (float, 20.0) control loop Hz
|
||||
lost_timeout_s (float, 1.5) seconds before face considered lost
|
||||
return_rate_deg_s (float, 10.0) return-to-centre rate when no face (°/s)
|
||||
faces_topic (str) default "/social/faces/detected"
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import time
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile
|
||||
from std_msgs.msg import Float32
|
||||
|
||||
try:
|
||||
from saltybot_social_msgs.msg import FaceDetectionArray
|
||||
_MSGS = True
|
||||
except ImportError:
|
||||
_MSGS = False
|
||||
|
||||
|
||||
# ── Pure helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
def clamp(v: float, lo: float, hi: float) -> float:
|
||||
return max(lo, min(hi, v))
|
||||
|
||||
|
||||
def bbox_area(face) -> float:
|
||||
"""Bounding-box area as a proximity proxy (larger ≈ closer)."""
|
||||
return float(face.bbox_w) * float(face.bbox_h)
|
||||
|
||||
|
||||
def pick_closest_face(faces):
|
||||
"""Return the face with the largest bbox area; None if list is empty."""
|
||||
if not faces:
|
||||
return None
|
||||
return max(faces, key=bbox_area)
|
||||
|
||||
|
||||
def face_image_error(face, fov_h_deg: float, fov_v_deg: float):
|
||||
"""Return (pan_error_deg, tilt_error_deg) for a FaceDetection.
|
||||
|
||||
Positive pan → face is right of image centre.
|
||||
Positive tilt → face is below image centre.
|
||||
"""
|
||||
cx = float(face.bbox_x) + float(face.bbox_w) / 2.0
|
||||
cy = float(face.bbox_y) + float(face.bbox_h) / 2.0
|
||||
pan_err = (cx - 0.5) * fov_h_deg
|
||||
tilt_err = (cy - 0.5) * fov_v_deg
|
||||
return pan_err, tilt_err
|
||||
|
||||
|
||||
# ── PID controller ─────────────────────────────────────────────────────────────
|
||||
|
||||
class PIDController:
|
||||
"""Incremental (velocity-output) PID with anti-windup.
|
||||
|
||||
Output units: degrees/second (servo angular velocity).
|
||||
Integrate externally: servo_angle += pid.update(error, dt) * dt
|
||||
"""
|
||||
|
||||
def __init__(self, kp: float, ki: float, kd: float,
|
||||
vel_limit: float, windup_limit: float) -> None:
|
||||
self.kp = kp
|
||||
self.ki = ki
|
||||
self.kd = kd
|
||||
self.vel_limit = vel_limit
|
||||
self.windup_limit = windup_limit
|
||||
self._integral = 0.0
|
||||
self._prev_error = 0.0
|
||||
self._first = True
|
||||
|
||||
def update(self, error: float, dt: float) -> float:
|
||||
"""Return velocity command (°/s). Call every control tick."""
|
||||
if dt <= 0.0:
|
||||
return 0.0
|
||||
|
||||
self._integral += error * dt
|
||||
self._integral = clamp(self._integral, -self.windup_limit,
|
||||
self.windup_limit)
|
||||
|
||||
if self._first:
|
||||
derivative = 0.0
|
||||
self._first = False
|
||||
else:
|
||||
derivative = (error - self._prev_error) / dt
|
||||
|
||||
self._prev_error = error
|
||||
output = (self.kp * error
|
||||
+ self.ki * self._integral
|
||||
+ self.kd * derivative)
|
||||
return clamp(output, -self.vel_limit, self.vel_limit)
|
||||
|
||||
def reset(self) -> None:
|
||||
self._integral = 0.0
|
||||
self._prev_error = 0.0
|
||||
self._first = True
|
||||
|
||||
|
||||
# ── ROS2 node ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class FaceTrackServoNode(Node):
|
||||
"""Smooth PID face-tracking servo controller."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("face_track_servo_node")
|
||||
|
||||
# Declare parameters
|
||||
self.declare_parameter("kp_pan", 1.5)
|
||||
self.declare_parameter("ki_pan", 0.1)
|
||||
self.declare_parameter("kd_pan", 0.05)
|
||||
self.declare_parameter("kp_tilt", 1.2)
|
||||
self.declare_parameter("ki_tilt", 0.1)
|
||||
self.declare_parameter("kd_tilt", 0.04)
|
||||
self.declare_parameter("fov_h_deg", 60.0)
|
||||
self.declare_parameter("fov_v_deg", 45.0)
|
||||
self.declare_parameter("pan_limit_deg", 90.0)
|
||||
self.declare_parameter("tilt_limit_deg", 30.0)
|
||||
self.declare_parameter("pan_vel_limit", 45.0)
|
||||
self.declare_parameter("tilt_vel_limit", 30.0)
|
||||
self.declare_parameter("windup_limit", 15.0)
|
||||
self.declare_parameter("dead_zone", 0.02)
|
||||
self.declare_parameter("control_rate", 20.0)
|
||||
self.declare_parameter("lost_timeout_s", 1.5)
|
||||
self.declare_parameter("return_rate_deg_s", 10.0)
|
||||
self.declare_parameter("faces_topic", "/social/faces/detected")
|
||||
|
||||
self._reload_params()
|
||||
|
||||
# Servo state
|
||||
self._pan_cmd = 0.0
|
||||
self._tilt_cmd = 0.0
|
||||
self._last_face_t: float = 0.0
|
||||
self._latest_face = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
qos = QoSProfile(depth=10)
|
||||
self._pan_pub = self.create_publisher(Float32, "/saltybot/head_pan", qos)
|
||||
self._tilt_pub = self.create_publisher(Float32, "/saltybot/head_tilt", qos)
|
||||
|
||||
faces_topic = self.get_parameter("faces_topic").value
|
||||
if _MSGS:
|
||||
self._faces_sub = self.create_subscription(
|
||||
FaceDetectionArray, faces_topic, self._on_faces, qos
|
||||
)
|
||||
else:
|
||||
self.get_logger().warn(
|
||||
"saltybot_social_msgs not available — node passive (no subscription)"
|
||||
)
|
||||
|
||||
rate = self.get_parameter("control_rate").value
|
||||
self._timer = self.create_timer(1.0 / rate, self._control_cb)
|
||||
self._last_tick = time.monotonic()
|
||||
|
||||
self.get_logger().info(
|
||||
f"FaceTrackServoNode ready "
|
||||
f"(rate={rate}Hz, fov={self._fov_h}×{self._fov_v}°, "
|
||||
f"pan±{self._pan_limit}°, tilt±{self._tilt_limit}°)"
|
||||
)
|
||||
|
||||
def _reload_params(self) -> None:
|
||||
self._fov_h = self.get_parameter("fov_h_deg").value
|
||||
self._fov_v = self.get_parameter("fov_v_deg").value
|
||||
self._pan_limit = self.get_parameter("pan_limit_deg").value
|
||||
self._tilt_limit = self.get_parameter("tilt_limit_deg").value
|
||||
self._dead_zone = self.get_parameter("dead_zone").value
|
||||
self._lost_t = self.get_parameter("lost_timeout_s").value
|
||||
self._return_rate = self.get_parameter("return_rate_deg_s").value
|
||||
|
||||
self._pid_pan = PIDController(
|
||||
kp=self.get_parameter("kp_pan").value,
|
||||
ki=self.get_parameter("ki_pan").value,
|
||||
kd=self.get_parameter("kd_pan").value,
|
||||
vel_limit=self.get_parameter("pan_vel_limit").value,
|
||||
windup_limit=self.get_parameter("windup_limit").value,
|
||||
)
|
||||
self._pid_tilt = PIDController(
|
||||
kp=self.get_parameter("kp_tilt").value,
|
||||
ki=self.get_parameter("ki_tilt").value,
|
||||
kd=self.get_parameter("kd_tilt").value,
|
||||
vel_limit=self.get_parameter("tilt_vel_limit").value,
|
||||
windup_limit=self.get_parameter("windup_limit").value,
|
||||
)
|
||||
|
||||
# ── Subscription callback ──────────────────────────────────────────────
|
||||
|
||||
def _on_faces(self, msg) -> None:
|
||||
face = pick_closest_face(msg.faces)
|
||||
with self._lock:
|
||||
self._latest_face = face
|
||||
if face is not None:
|
||||
self._last_face_t = time.monotonic()
|
||||
|
||||
# ── Control loop ───────────────────────────────────────────────────────
|
||||
|
||||
def _control_cb(self) -> None:
|
||||
now = time.monotonic()
|
||||
dt = now - self._last_tick
|
||||
self._last_tick = now
|
||||
dt = max(dt, 1e-4) # guard against zero dt at startup
|
||||
|
||||
with self._lock:
|
||||
face = self._latest_face
|
||||
last_face_t = self._last_face_t
|
||||
|
||||
face_fresh = (last_face_t > 0.0 and (now - last_face_t) < self._lost_t)
|
||||
|
||||
if not face_fresh or face is None:
|
||||
# Return to centre
|
||||
self._pid_pan.reset()
|
||||
self._pid_tilt.reset()
|
||||
step = self._return_rate * dt
|
||||
self._pan_cmd = _step_toward_zero(self._pan_cmd, step)
|
||||
self._tilt_cmd = _step_toward_zero(self._tilt_cmd, step)
|
||||
else:
|
||||
pan_err, tilt_err = face_image_error(face, self._fov_h, self._fov_v)
|
||||
|
||||
# Dead zone (normalised fraction → degrees)
|
||||
dead_deg_h = self._dead_zone * self._fov_h
|
||||
dead_deg_v = self._dead_zone * self._fov_v
|
||||
|
||||
if abs(pan_err) < dead_deg_h:
|
||||
self._pid_pan.reset()
|
||||
else:
|
||||
vel_pan = self._pid_pan.update(pan_err, dt)
|
||||
self._pan_cmd = clamp(
|
||||
self._pan_cmd + vel_pan * dt,
|
||||
-self._pan_limit, self._pan_limit,
|
||||
)
|
||||
|
||||
if abs(tilt_err) < dead_deg_v:
|
||||
self._pid_tilt.reset()
|
||||
else:
|
||||
vel_tilt = self._pid_tilt.update(tilt_err, dt)
|
||||
self._tilt_cmd = clamp(
|
||||
self._tilt_cmd + vel_tilt * dt,
|
||||
-self._tilt_limit, self._tilt_limit,
|
||||
)
|
||||
|
||||
pan_msg = Float32(); pan_msg.data = float(self._pan_cmd)
|
||||
tilt_msg = Float32(); tilt_msg.data = float(self._tilt_cmd)
|
||||
self._pan_pub.publish(pan_msg)
|
||||
self._tilt_pub.publish(tilt_msg)
|
||||
|
||||
|
||||
def _step_toward_zero(value: float, step: float) -> float:
|
||||
"""Move value toward 0 by step without overshooting."""
|
||||
if abs(value) <= step:
|
||||
return 0.0
|
||||
return value - math.copysign(step, value)
|
||||
|
||||
|
||||
def main(args=None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = FaceTrackServoNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
@ -0,0 +1,150 @@
|
||||
"""greeting_trigger_node.py -- Proximity-based greeting trigger.
|
||||
Issue #270
|
||||
|
||||
Monitors face detections and person states. When a new face_id is seen
|
||||
within ``proximity_m`` metres (default 2 m) and has not been greeted within
|
||||
``cooldown_s`` seconds, publishes a JSON greeting trigger on
|
||||
/saltybot/greeting_trigger.
|
||||
|
||||
Distance is looked up from the /social/person_states topic which carries a
|
||||
face_id → distance mapping. When no state is available for a face the node
|
||||
applies a configurable default distance so it can still fire on face-only
|
||||
pipelines.
|
||||
|
||||
Subscriptions:
|
||||
/social/faces/detected saltybot_social_msgs/FaceDetectionArray
|
||||
/social/person_states saltybot_social_msgs/PersonStateArray
|
||||
|
||||
Publication:
|
||||
/saltybot/greeting_trigger std_msgs/String (JSON)
|
||||
{"face_id": <int>, "person_name": <str>, "distance_m": <float>,
|
||||
"ts": <float unix epoch>}
|
||||
|
||||
Parameters:
|
||||
proximity_m (float, 2.0) -- trigger when distance <= this
|
||||
cooldown_s (float, 300.0) -- suppress re-greeting same face_id
|
||||
unknown_distance (float, 0.0) -- distance assumed when PersonState
|
||||
is not yet available (0.0 → always
|
||||
trigger for unknown faces)
|
||||
faces_topic (str, "/social/faces/detected")
|
||||
states_topic (str, "/social/person_states")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile
|
||||
from std_msgs.msg import String
|
||||
|
||||
# Custom messages — imported at runtime so offline tests can stub them
|
||||
try:
|
||||
from saltybot_social_msgs.msg import FaceDetectionArray, PersonStateArray
|
||||
_MSGS = True
|
||||
except ImportError:
|
||||
_MSGS = False
|
||||
|
||||
|
||||
class GreetingTriggerNode(Node):
|
||||
"""Publishes greeting trigger when a person enters proximity."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("greeting_trigger_node")
|
||||
|
||||
self.declare_parameter("proximity_m", 2.0)
|
||||
self.declare_parameter("cooldown_s", 300.0)
|
||||
self.declare_parameter("unknown_distance", 0.0)
|
||||
self.declare_parameter("faces_topic", "/social/faces/detected")
|
||||
self.declare_parameter("states_topic", "/social/person_states")
|
||||
|
||||
self._proximity = self.get_parameter("proximity_m").value
|
||||
self._cooldown = self.get_parameter("cooldown_s").value
|
||||
self._unknown_dist = self.get_parameter("unknown_distance").value
|
||||
faces_topic = self.get_parameter("faces_topic").value
|
||||
states_topic = self.get_parameter("states_topic").value
|
||||
|
||||
# face_id → last known distance (m); updated from PersonStateArray
|
||||
self._distance_cache: Dict[int, float] = {}
|
||||
# face_id → unix timestamp of last greeting
|
||||
self._last_greeted: Dict[int, float] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
qos = QoSProfile(depth=10)
|
||||
self._pub = self.create_publisher(String, "/saltybot/greeting_trigger", qos)
|
||||
|
||||
if _MSGS:
|
||||
self._states_sub = self.create_subscription(
|
||||
PersonStateArray, states_topic, self._on_person_states, qos
|
||||
)
|
||||
self._faces_sub = self.create_subscription(
|
||||
FaceDetectionArray, faces_topic, self._on_faces, qos
|
||||
)
|
||||
else:
|
||||
self.get_logger().warn(
|
||||
"saltybot_social_msgs not available — node is passive (no subscriptions)"
|
||||
)
|
||||
|
||||
self.get_logger().info(
|
||||
f"GreetingTriggerNode ready "
|
||||
f"(proximity={self._proximity}m, cooldown={self._cooldown}s)"
|
||||
)
|
||||
|
||||
# ── Callbacks ──────────────────────────────────────────────────────────
|
||||
|
||||
def _on_person_states(self, msg: "PersonStateArray") -> None:
|
||||
"""Cache face_id → distance from incoming PersonState array."""
|
||||
with self._lock:
|
||||
for ps in msg.persons:
|
||||
if ps.face_id >= 0:
|
||||
self._distance_cache[ps.face_id] = float(ps.distance)
|
||||
|
||||
def _on_faces(self, msg: "FaceDetectionArray") -> None:
|
||||
"""Evaluate each detected face; fire greeting if conditions met."""
|
||||
now = time.monotonic()
|
||||
with self._lock:
|
||||
for face in msg.faces:
|
||||
fid = int(face.face_id)
|
||||
dist = self._distance_cache.get(fid, self._unknown_dist)
|
||||
|
||||
if dist > self._proximity:
|
||||
continue # too far
|
||||
|
||||
last = self._last_greeted.get(fid, 0.0)
|
||||
if now - last < self._cooldown:
|
||||
continue # still in cooldown
|
||||
|
||||
# Fire!
|
||||
self._last_greeted[fid] = now
|
||||
self._fire(fid, str(face.person_name), dist)
|
||||
|
||||
def _fire(self, face_id: int, person_name: str, distance_m: float) -> None:
|
||||
payload = {
|
||||
"face_id": face_id,
|
||||
"person_name": person_name,
|
||||
"distance_m": round(distance_m, 3),
|
||||
"ts": time.time(),
|
||||
}
|
||||
msg = String()
|
||||
msg.data = json.dumps(payload)
|
||||
self._pub.publish(msg)
|
||||
self.get_logger().info(
|
||||
f"Greeting trigger: face_id={face_id} name={person_name!r} "
|
||||
f"dist={distance_m:.2f}m"
|
||||
)
|
||||
|
||||
|
||||
def main(args=None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = GreetingTriggerNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
@ -45,6 +45,12 @@ setup(
|
||||
'mesh_comms_node = saltybot_social.mesh_comms_node:main',
|
||||
# Energy+ZCR voice activity detection (Issue #242)
|
||||
'vad_node = saltybot_social.vad_node:main',
|
||||
# Ambient sound classifier — mel-spectrogram (Issue #252)
|
||||
'ambient_sound_node = saltybot_social.ambient_sound_node:main',
|
||||
# Proximity-based greeting trigger (Issue #270)
|
||||
'greeting_trigger_node = saltybot_social.greeting_trigger_node:main',
|
||||
# Face-tracking head servo controller (Issue #279)
|
||||
'face_track_servo_node = saltybot_social.face_track_servo_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
407
jetson/ros2_ws/src/saltybot_social/test/test_ambient_sound.py
Normal file
407
jetson/ros2_ws/src/saltybot_social/test/test_ambient_sound.py
Normal file
@ -0,0 +1,407 @@
|
||||
"""test_ambient_sound.py -- Unit tests for Issue #252 ambient sound classifier."""
|
||||
|
||||
from __future__ import annotations
|
||||
import importlib.util, math, os, struct, sys, types
|
||||
import pytest
|
||||
|
||||
# numpy is available on dev machine
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _pkg_root():
|
||||
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def _read_src(rel_path):
|
||||
with open(os.path.join(_pkg_root(), rel_path)) as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def _import_mod():
|
||||
"""Import ambient_sound_node without a live ROS2 environment."""
|
||||
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
|
||||
"std_msgs", "std_msgs.msg"):
|
||||
if mod_name not in sys.modules:
|
||||
sys.modules[mod_name] = types.ModuleType(mod_name)
|
||||
|
||||
rclpy_node = sys.modules["rclpy.node"]
|
||||
rclpy_qos = sys.modules["rclpy.qos"]
|
||||
std_msg = sys.modules["std_msgs.msg"]
|
||||
|
||||
DEFAULTS = {
|
||||
"sample_rate": 16000, "window_s": 1.0, "n_fft": 512, "n_mels": 32,
|
||||
"audio_topic": "/social/speech/audio_raw",
|
||||
"silence_db": -40.0, "alarm_db_min": -25.0, "alarm_zcr_min": 0.12,
|
||||
"alarm_high_ratio_min": 0.35, "speech_zcr_min": 0.02,
|
||||
"speech_zcr_max": 0.25, "speech_flatness_max": 0.35,
|
||||
"music_zcr_max": 0.08, "music_flatness_max": 0.25,
|
||||
"crowd_zcr_min": 0.10, "crowd_flatness_min": 0.35,
|
||||
}
|
||||
|
||||
class _Node:
|
||||
def __init__(self, *a, **kw): pass
|
||||
def declare_parameter(self, *a, **kw): pass
|
||||
def get_parameter(self, name):
|
||||
class _P:
|
||||
value = DEFAULTS.get(name)
|
||||
return _P()
|
||||
def create_publisher(self, *a, **kw): return None
|
||||
def create_subscription(self, *a, **kw): return None
|
||||
def get_logger(self):
|
||||
class _L:
|
||||
def info(self, *a): pass
|
||||
def warn(self, *a): pass
|
||||
def error(self, *a): pass
|
||||
return _L()
|
||||
def destroy_node(self): pass
|
||||
|
||||
rclpy_node.Node = _Node
|
||||
rclpy_qos.QoSProfile = type("QoSProfile", (), {"__init__": lambda s, **kw: None})
|
||||
std_msg.String = type("String", (), {"data": ""})
|
||||
std_msg.UInt8MultiArray = type("UInt8MultiArray", (), {"data": b""})
|
||||
sys.modules["rclpy"].init = lambda *a, **kw: None
|
||||
sys.modules["rclpy"].spin = lambda n: None
|
||||
sys.modules["rclpy"].ok = lambda: True
|
||||
sys.modules["rclpy"].shutdown = lambda: None
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"ambient_sound_node_testmod",
|
||||
os.path.join(_pkg_root(), "saltybot_social", "ambient_sound_node.py"),
|
||||
)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
# ── Audio helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
SR = 16000
|
||||
|
||||
def _sine(freq, n=SR, amp=0.2):
|
||||
return [amp * math.sin(2 * math.pi * freq * i / SR) for i in range(n)]
|
||||
|
||||
def _white_noise(n=SR, amp=0.1):
|
||||
import random
|
||||
rng = random.Random(42)
|
||||
return [rng.uniform(-amp, amp) for _ in range(n)]
|
||||
|
||||
def _silence(n=SR):
|
||||
return [0.0] * n
|
||||
|
||||
def _pcm16(samples):
|
||||
ints = [max(-32768, min(32767, int(s * 32768))) for s in samples]
|
||||
return struct.pack(f"<{len(ints)}h", *ints)
|
||||
|
||||
|
||||
# ── TestPcm16Convert ──────────────────────────────────────────────────────────
|
||||
|
||||
class TestPcm16Convert:
|
||||
@pytest.fixture(scope="class")
|
||||
def mod(self): return _import_mod()
|
||||
|
||||
def test_empty(self, mod):
|
||||
assert mod.pcm16_bytes_to_float32(b"") == []
|
||||
|
||||
def test_length(self, mod):
|
||||
data = _pcm16(_sine(440, 480))
|
||||
assert len(mod.pcm16_bytes_to_float32(data)) == 480
|
||||
|
||||
def test_range(self, mod):
|
||||
data = _pcm16(_sine(440, 480))
|
||||
result = mod.pcm16_bytes_to_float32(data)
|
||||
assert all(-1.0 <= s <= 1.0 for s in result)
|
||||
|
||||
def test_silence(self, mod):
|
||||
data = _pcm16(_silence(100))
|
||||
assert all(s == 0.0 for s in mod.pcm16_bytes_to_float32(data))
|
||||
|
||||
|
||||
# ── TestMelConversions ────────────────────────────────────────────────────────
|
||||
|
||||
class TestMelConversions:
|
||||
@pytest.fixture(scope="class")
|
||||
def mod(self): return _import_mod()
|
||||
|
||||
def test_hz_to_mel_zero(self, mod):
|
||||
assert mod.hz_to_mel(0.0) == 0.0
|
||||
|
||||
def test_hz_to_mel_1000(self, mod):
|
||||
# 1000 Hz → ~999.99 mel (approximately)
|
||||
assert abs(mod.hz_to_mel(1000.0) - 999.99) < 1.0
|
||||
|
||||
def test_roundtrip(self, mod):
|
||||
for hz in (100.0, 500.0, 1000.0, 4000.0, 8000.0):
|
||||
assert abs(mod.mel_to_hz(mod.hz_to_mel(hz)) - hz) < 0.01
|
||||
|
||||
def test_monotone_increasing(self, mod):
|
||||
freqs = [100, 500, 1000, 2000, 4000, 8000]
|
||||
mels = [mod.hz_to_mel(f) for f in freqs]
|
||||
assert mels == sorted(mels)
|
||||
|
||||
|
||||
# ── TestMelFilterbank ─────────────────────────────────────────────────────────
|
||||
|
||||
class TestMelFilterbank:
|
||||
@pytest.fixture(scope="class")
|
||||
def mod(self): return _import_mod()
|
||||
|
||||
def test_shape(self, mod):
|
||||
fb = mod.build_mel_filterbank(SR, 512, 32)
|
||||
assert fb.shape == (32, 257) # (n_mels, n_fft//2+1)
|
||||
|
||||
def test_nonnegative(self, mod):
|
||||
fb = mod.build_mel_filterbank(SR, 512, 32)
|
||||
assert (fb >= 0).all()
|
||||
|
||||
def test_each_filter_sums_positive(self, mod):
|
||||
fb = mod.build_mel_filterbank(SR, 512, 32)
|
||||
assert all(fb[m].sum() > 0 for m in range(32))
|
||||
|
||||
def test_custom_n_mels(self, mod):
|
||||
fb = mod.build_mel_filterbank(SR, 512, 16)
|
||||
assert fb.shape[0] == 16
|
||||
|
||||
def test_max_value_leq_one(self, mod):
|
||||
fb = mod.build_mel_filterbank(SR, 512, 32)
|
||||
assert fb.max() <= 1.0 + 1e-6
|
||||
|
||||
|
||||
# ── TestMelSpectrogram ────────────────────────────────────────────────────────
|
||||
|
||||
class TestMelSpectrogram:
|
||||
@pytest.fixture(scope="class")
|
||||
def mod(self): return _import_mod()
|
||||
|
||||
def test_shape(self, mod):
|
||||
s = _sine(440, SR)
|
||||
spec = mod.compute_mel_spectrogram(s, SR, n_fft=512, n_mels=32, hop_length=256)
|
||||
assert spec.shape[0] == 32
|
||||
assert spec.shape[1] > 0
|
||||
|
||||
def test_silence_near_zero(self, mod):
|
||||
spec = mod.compute_mel_spectrogram(_silence(SR), SR, n_fft=512, n_mels=32)
|
||||
assert spec.mean() < 1e-6
|
||||
|
||||
def test_louder_has_higher_energy(self, mod):
|
||||
quiet = mod.compute_mel_spectrogram(_sine(440, SR, amp=0.01), SR).mean()
|
||||
loud = mod.compute_mel_spectrogram(_sine(440, SR, amp=0.5), SR).mean()
|
||||
assert loud > quiet
|
||||
|
||||
def test_returns_array(self, mod):
|
||||
spec = mod.compute_mel_spectrogram(_sine(440, SR), SR)
|
||||
assert isinstance(spec, np.ndarray)
|
||||
|
||||
|
||||
# ── TestExtractFeatures ───────────────────────────────────────────────────────
|
||||
|
||||
class TestExtractFeatures:
|
||||
@pytest.fixture(scope="class")
|
||||
def mod(self): return _import_mod()
|
||||
|
||||
def _feats(self, mod, samples):
|
||||
return mod.extract_features(samples, SR, n_fft=512, n_mels=32)
|
||||
|
||||
def test_keys_present(self, mod):
|
||||
f = self._feats(mod, _sine(440, SR))
|
||||
for k in ("energy_db", "zcr", "mel_centroid", "mel_flatness",
|
||||
"low_ratio", "high_ratio"):
|
||||
assert k in f
|
||||
|
||||
def test_silence_low_energy(self, mod):
|
||||
f = self._feats(mod, _silence(SR))
|
||||
assert f["energy_db"] < -40.0
|
||||
|
||||
def test_silence_zero_zcr(self, mod):
|
||||
f = self._feats(mod, _silence(SR))
|
||||
assert f["zcr"] == 0.0
|
||||
|
||||
def test_sine_moderate_energy(self, mod):
|
||||
f = self._feats(mod, _sine(440, SR, amp=0.1))
|
||||
assert -40.0 < f["energy_db"] < 0.0
|
||||
|
||||
def test_ratios_sum_leq_one(self, mod):
|
||||
f = self._feats(mod, _sine(440, SR))
|
||||
assert f["low_ratio"] + f["high_ratio"] <= 1.0 + 1e-6
|
||||
|
||||
def test_ratios_nonnegative(self, mod):
|
||||
f = self._feats(mod, _sine(440, SR))
|
||||
assert f["low_ratio"] >= 0.0 and f["high_ratio"] >= 0.0
|
||||
|
||||
def test_flatness_in_unit_interval(self, mod):
|
||||
f = self._feats(mod, _sine(440, SR))
|
||||
assert 0.0 <= f["mel_flatness"] <= 1.0
|
||||
|
||||
def test_white_noise_high_flatness(self, mod):
|
||||
f_noise = self._feats(mod, _white_noise(SR, amp=0.3))
|
||||
f_sine = self._feats(mod, _sine(440, SR, amp=0.3))
|
||||
# White noise should have higher spectral flatness than a pure tone
|
||||
assert f_noise["mel_flatness"] > f_sine["mel_flatness"]
|
||||
|
||||
def test_empty_samples(self, mod):
|
||||
f = mod.extract_features([], SR)
|
||||
assert f["energy_db"] == 0.0
|
||||
|
||||
def test_louder_higher_energy_db(self, mod):
|
||||
quiet = self._feats(mod, _sine(440, SR, amp=0.01))["energy_db"]
|
||||
loud = self._feats(mod, _sine(440, SR, amp=0.5))["energy_db"]
|
||||
assert loud > quiet
|
||||
|
||||
|
||||
# ── TestClassifier ────────────────────────────────────────────────────────────
|
||||
|
||||
class TestClassifier:
|
||||
@pytest.fixture(scope="class")
|
||||
def mod(self): return _import_mod()
|
||||
|
||||
def _cls(self, mod, **feat_overrides):
|
||||
base = {"energy_db": -20.0, "zcr": 0.05,
|
||||
"mel_centroid": 0.4, "mel_flatness": 0.2,
|
||||
"low_ratio": 0.4, "high_ratio": 0.2}
|
||||
base.update(feat_overrides)
|
||||
return mod.classify(base)
|
||||
|
||||
def test_silence(self, mod):
|
||||
assert self._cls(mod, energy_db=-45.0) == "silence"
|
||||
|
||||
def test_silence_at_threshold(self, mod):
|
||||
assert self._cls(mod, energy_db=-40.0) != "silence"
|
||||
|
||||
def test_alarm(self, mod):
|
||||
assert self._cls(mod, energy_db=-20.0, zcr=0.15, high_ratio=0.40) == "alarm"
|
||||
|
||||
def test_alarm_requires_high_ratio(self, mod):
|
||||
result = self._cls(mod, energy_db=-20.0, zcr=0.15, high_ratio=0.10)
|
||||
assert result != "alarm"
|
||||
|
||||
def test_speech(self, mod):
|
||||
assert self._cls(mod, energy_db=-25.0, zcr=0.08,
|
||||
mel_flatness=0.20) == "speech"
|
||||
|
||||
def test_speech_zcr_too_low(self, mod):
|
||||
result = self._cls(mod, energy_db=-25.0, zcr=0.005, mel_flatness=0.2)
|
||||
assert result != "speech"
|
||||
|
||||
def test_speech_zcr_too_high(self, mod):
|
||||
result = self._cls(mod, energy_db=-25.0, zcr=0.30, mel_flatness=0.2)
|
||||
assert result != "speech"
|
||||
|
||||
def test_music(self, mod):
|
||||
assert self._cls(mod, energy_db=-25.0, zcr=0.04,
|
||||
mel_flatness=0.10) == "music"
|
||||
|
||||
def test_crowd(self, mod):
|
||||
assert self._cls(mod, energy_db=-25.0, zcr=0.15,
|
||||
mel_flatness=0.40) == "crowd"
|
||||
|
||||
def test_outdoor_catchall(self, mod):
|
||||
# Moderate energy, mid ZCR, mid flatness → outdoor
|
||||
result = self._cls(mod, energy_db=-35.0, zcr=0.06, mel_flatness=0.30)
|
||||
assert result in mod.LABELS
|
||||
|
||||
def test_returns_valid_label(self, mod):
|
||||
import random
|
||||
rng = random.Random(0)
|
||||
for _ in range(20):
|
||||
f = {
|
||||
"energy_db": rng.uniform(-60, 0),
|
||||
"zcr": rng.uniform(0, 0.5),
|
||||
"mel_centroid": rng.uniform(0, 1),
|
||||
"mel_flatness": rng.uniform(0, 1),
|
||||
"low_ratio": rng.uniform(0, 0.6),
|
||||
"high_ratio": rng.uniform(0, 0.4),
|
||||
}
|
||||
assert mod.classify(f) in mod.LABELS
|
||||
|
||||
|
||||
# ── TestAudioBuffer ───────────────────────────────────────────────────────────
|
||||
|
||||
class TestAudioBuffer:
|
||||
@pytest.fixture(scope="class")
|
||||
def mod(self): return _import_mod()
|
||||
|
||||
def test_no_window_until_full(self, mod):
|
||||
buf = mod.AudioBuffer(window_samples=100)
|
||||
assert buf.push([0.0] * 50) is None
|
||||
|
||||
def test_exact_fill_returns_window(self, mod):
|
||||
buf = mod.AudioBuffer(window_samples=100)
|
||||
w = buf.push([0.0] * 100)
|
||||
assert w is not None and len(w) == 100
|
||||
|
||||
def test_overflow_carries_over(self, mod):
|
||||
buf = mod.AudioBuffer(window_samples=100)
|
||||
buf.push([0.0] * 100) # fills first window
|
||||
w2 = buf.push([1.0] * 100) # fills second window
|
||||
assert w2 is not None and len(w2) == 100
|
||||
|
||||
def test_partial_then_complete(self, mod):
|
||||
buf = mod.AudioBuffer(window_samples=100)
|
||||
buf.push([0.0] * 60)
|
||||
w = buf.push([0.0] * 60)
|
||||
assert w is not None and len(w) == 100
|
||||
|
||||
def test_clear_resets(self, mod):
|
||||
buf = mod.AudioBuffer(window_samples=100)
|
||||
buf.push([0.0] * 90)
|
||||
buf.clear()
|
||||
assert buf.push([0.0] * 90) is None
|
||||
|
||||
def test_window_contents_correct(self, mod):
|
||||
buf = mod.AudioBuffer(window_samples=4)
|
||||
w = buf.push([1.0, 2.0, 3.0, 4.0])
|
||||
assert w == [1.0, 2.0, 3.0, 4.0]
|
||||
|
||||
|
||||
# ── TestNodeSrc ───────────────────────────────────────────────────────────────
|
||||
|
||||
class TestNodeSrc:
|
||||
@pytest.fixture(scope="class")
|
||||
def src(self): return _read_src("saltybot_social/ambient_sound_node.py")
|
||||
|
||||
def test_class_defined(self, src): assert "class AmbientSoundNode" in src
|
||||
def test_audio_buffer(self, src): assert "class AudioBuffer" in src
|
||||
def test_extract_features(self, src): assert "def extract_features" in src
|
||||
def test_classify_fn(self, src): assert "def classify" in src
|
||||
def test_mel_spectrogram(self, src): assert "compute_mel_spectrogram" in src
|
||||
def test_mel_filterbank(self, src): assert "build_mel_filterbank" in src
|
||||
def test_hz_to_mel(self, src): assert "hz_to_mel" in src
|
||||
def test_labels_tuple(self, src): assert "LABELS" in src
|
||||
def test_all_labels(self, src):
|
||||
for label in ("silence", "speech", "music", "crowd", "outdoor", "alarm"):
|
||||
assert label in src
|
||||
def test_topic_pub(self, src): assert '"/saltybot/ambient_sound"' in src
|
||||
def test_topic_sub(self, src): assert '"/social/speech/audio_raw"' in src
|
||||
def test_window_param(self, src): assert '"window_s"' in src
|
||||
def test_n_mels_param(self, src): assert '"n_mels"' in src
|
||||
def test_silence_param(self, src): assert '"silence_db"' in src
|
||||
def test_alarm_param(self, src): assert '"alarm_db_min"' in src
|
||||
def test_speech_param(self, src): assert '"speech_zcr_min"' in src
|
||||
def test_music_param(self, src): assert '"music_zcr_max"' in src
|
||||
def test_crowd_param(self, src): assert '"crowd_zcr_min"' in src
|
||||
def test_string_pub(self, src): assert "String" in src
|
||||
def test_uint8_sub(self, src): assert "UInt8MultiArray" in src
|
||||
def test_issue_tag(self, src): assert "252" in src
|
||||
def test_main(self, src): assert "def main" in src
|
||||
def test_numpy_optional(self, src): assert "_NUMPY" in src
|
||||
|
||||
|
||||
# ── TestConfig ────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestConfig:
|
||||
@pytest.fixture(scope="class")
|
||||
def cfg(self): return _read_src("config/ambient_sound_params.yaml")
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def setup(self): return _read_src("setup.py")
|
||||
|
||||
def test_node_name(self, cfg): assert "ambient_sound_node:" in cfg
|
||||
def test_window_s(self, cfg): assert "window_s" in cfg
|
||||
def test_n_mels(self, cfg): assert "n_mels" in cfg
|
||||
def test_silence_db(self, cfg): assert "silence_db" in cfg
|
||||
def test_alarm_params(self, cfg): assert "alarm_db_min" in cfg
|
||||
def test_speech_params(self, cfg): assert "speech_zcr_min" in cfg
|
||||
def test_music_params(self, cfg): assert "music_zcr_max" in cfg
|
||||
def test_crowd_params(self, cfg): assert "crowd_zcr_min" in cfg
|
||||
def test_defaults_present(self, cfg): assert "-40.0" in cfg and "0.12" in cfg
|
||||
def test_entry_point(self, setup):
|
||||
assert "ambient_sound_node = saltybot_social.ambient_sound_node:main" in setup
|
||||
676
jetson/ros2_ws/src/saltybot_social/test/test_face_track_servo.py
Normal file
676
jetson/ros2_ws/src/saltybot_social/test/test_face_track_servo.py
Normal file
@ -0,0 +1,676 @@
|
||||
"""test_face_track_servo.py — Offline tests for face_track_servo_node (Issue #279).
|
||||
|
||||
Stubs out rclpy and saltybot_social_msgs so tests run without a ROS install.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
import types
|
||||
import unittest
|
||||
|
||||
|
||||
# ── ROS2 / message stubs ──────────────────────────────────────────────────────
|
||||
|
||||
def _make_ros_stubs():
|
||||
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
|
||||
"std_msgs", "std_msgs.msg",
|
||||
"saltybot_social_msgs", "saltybot_social_msgs.msg"):
|
||||
if mod_name not in sys.modules:
|
||||
sys.modules[mod_name] = types.ModuleType(mod_name)
|
||||
|
||||
class _Node:
|
||||
def __init__(self, name="node"):
|
||||
self._name = name
|
||||
if not hasattr(self, "_params"):
|
||||
self._params = {}
|
||||
self._pubs = {}
|
||||
self._subs = {}
|
||||
self._timers = []
|
||||
self._logs = []
|
||||
|
||||
def declare_parameter(self, name, default):
|
||||
if name not in self._params:
|
||||
self._params[name] = default
|
||||
|
||||
def get_parameter(self, name):
|
||||
class _P:
|
||||
def __init__(self, v): self.value = v
|
||||
return _P(self._params.get(name))
|
||||
|
||||
def create_publisher(self, msg_type, topic, qos):
|
||||
pub = _FakePub()
|
||||
self._pubs[topic] = pub
|
||||
return pub
|
||||
|
||||
def create_subscription(self, msg_type, topic, cb, qos):
|
||||
self._subs[topic] = cb
|
||||
return object()
|
||||
|
||||
def create_timer(self, period, cb):
|
||||
self._timers.append(cb)
|
||||
return object()
|
||||
|
||||
def get_logger(self):
|
||||
node = self
|
||||
class _L:
|
||||
def info(self, m): node._logs.append(("INFO", m))
|
||||
def warn(self, m): node._logs.append(("WARN", m))
|
||||
def error(self, m): node._logs.append(("ERROR", m))
|
||||
return _L()
|
||||
|
||||
def destroy_node(self): pass
|
||||
|
||||
class _FakePub:
|
||||
def __init__(self):
|
||||
self.msgs = []
|
||||
def publish(self, msg):
|
||||
self.msgs.append(msg)
|
||||
|
||||
class _QoSProfile:
|
||||
def __init__(self, depth=10): self.depth = depth
|
||||
|
||||
class _Float32:
|
||||
def __init__(self): self.data = 0.0
|
||||
|
||||
class _FaceDetection:
|
||||
def __init__(self, face_id=0, bbox_x=0.4, bbox_y=0.4,
|
||||
bbox_w=0.2, bbox_h=0.2, confidence=1.0):
|
||||
self.face_id = face_id
|
||||
self.bbox_x = bbox_x
|
||||
self.bbox_y = bbox_y
|
||||
self.bbox_w = bbox_w
|
||||
self.bbox_h = bbox_h
|
||||
self.confidence = confidence
|
||||
self.person_name = ""
|
||||
|
||||
class _FaceDetectionArray:
|
||||
def __init__(self, faces=None):
|
||||
self.faces = faces or []
|
||||
|
||||
# rclpy
|
||||
rclpy_mod = sys.modules["rclpy"]
|
||||
rclpy_mod.init = lambda args=None: None
|
||||
rclpy_mod.spin = lambda node: None
|
||||
rclpy_mod.shutdown = lambda: None
|
||||
|
||||
sys.modules["rclpy.node"].Node = _Node
|
||||
sys.modules["rclpy.qos"].QoSProfile = _QoSProfile
|
||||
sys.modules["std_msgs.msg"].Float32 = _Float32
|
||||
|
||||
msgs = sys.modules["saltybot_social_msgs.msg"]
|
||||
msgs.FaceDetection = _FaceDetection
|
||||
msgs.FaceDetectionArray = _FaceDetectionArray
|
||||
|
||||
return _Node, _FakePub, _FaceDetection, _FaceDetectionArray, _Float32
|
||||
|
||||
|
||||
_Node, _FakePub, _FaceDetection, _FaceDetectionArray, _Float32 = _make_ros_stubs()
|
||||
|
||||
|
||||
# ── Module loader ─────────────────────────────────────────────────────────────
|
||||
|
||||
_SRC = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/saltybot_social/face_track_servo_node.py"
|
||||
)
|
||||
|
||||
|
||||
def _load_mod():
|
||||
spec = importlib.util.spec_from_file_location("face_track_servo_testmod", _SRC)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def _make_node(mod, **kwargs):
|
||||
"""Instantiate FaceTrackServoNode with optional param overrides."""
|
||||
node = mod.FaceTrackServoNode.__new__(mod.FaceTrackServoNode)
|
||||
|
||||
defaults = {
|
||||
"kp_pan": 1.5,
|
||||
"ki_pan": 0.1,
|
||||
"kd_pan": 0.05,
|
||||
"kp_tilt": 1.2,
|
||||
"ki_tilt": 0.1,
|
||||
"kd_tilt": 0.04,
|
||||
"fov_h_deg": 60.0,
|
||||
"fov_v_deg": 45.0,
|
||||
"pan_limit_deg": 90.0,
|
||||
"tilt_limit_deg": 30.0,
|
||||
"pan_vel_limit": 45.0,
|
||||
"tilt_vel_limit": 30.0,
|
||||
"windup_limit": 15.0,
|
||||
"dead_zone": 0.02,
|
||||
"control_rate": 20.0,
|
||||
"lost_timeout_s": 1.5,
|
||||
"return_rate_deg_s": 10.0,
|
||||
"faces_topic": "/social/faces/detected",
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
node._params = dict(defaults)
|
||||
mod.FaceTrackServoNode.__init__(node)
|
||||
return node
|
||||
|
||||
|
||||
def _face(bbox_x=0.4, bbox_y=0.4, bbox_w=0.2, bbox_h=0.2, face_id=0):
|
||||
return _FaceDetection(face_id=face_id, bbox_x=bbox_x, bbox_y=bbox_y,
|
||||
bbox_w=bbox_w, bbox_h=bbox_h)
|
||||
|
||||
|
||||
def _centered_face():
|
||||
"""A face perfectly centered in the frame."""
|
||||
return _face(bbox_x=0.4, bbox_y=0.4, bbox_w=0.2, bbox_h=0.2)
|
||||
|
||||
|
||||
# ── Tests: pure helpers ───────────────────────────────────────────────────────
|
||||
|
||||
class TestClamp(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mod = _load_mod()
|
||||
|
||||
def test_within(self):
|
||||
self.assertEqual(self.mod.clamp(5.0, 0.0, 10.0), 5.0)
|
||||
|
||||
def test_below(self):
|
||||
self.assertEqual(self.mod.clamp(-5.0, 0.0, 10.0), 0.0)
|
||||
|
||||
def test_above(self):
|
||||
self.assertEqual(self.mod.clamp(15.0, 0.0, 10.0), 10.0)
|
||||
|
||||
def test_negative_range(self):
|
||||
self.assertEqual(self.mod.clamp(-50.0, -45.0, 45.0), -45.0)
|
||||
|
||||
|
||||
class TestBboxArea(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mod = _load_mod()
|
||||
|
||||
def test_area(self):
|
||||
f = _face(bbox_w=0.3, bbox_h=0.4)
|
||||
self.assertAlmostEqual(self.mod.bbox_area(f), 0.12)
|
||||
|
||||
def test_zero(self):
|
||||
f = _face(bbox_w=0.0, bbox_h=0.2)
|
||||
self.assertAlmostEqual(self.mod.bbox_area(f), 0.0)
|
||||
|
||||
|
||||
class TestPickClosestFace(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mod = _load_mod()
|
||||
|
||||
def test_empty(self):
|
||||
self.assertIsNone(self.mod.pick_closest_face([]))
|
||||
|
||||
def test_single(self):
|
||||
f = _face(bbox_w=0.2, bbox_h=0.2)
|
||||
self.assertIs(self.mod.pick_closest_face([f]), f)
|
||||
|
||||
def test_picks_largest_area(self):
|
||||
small = _face(bbox_w=0.1, bbox_h=0.1)
|
||||
big = _face(bbox_w=0.4, bbox_h=0.4)
|
||||
self.assertIs(self.mod.pick_closest_face([small, big]), big)
|
||||
self.assertIs(self.mod.pick_closest_face([big, small]), big)
|
||||
|
||||
def test_three_faces(self):
|
||||
faces = [_face(bbox_w=0.1, bbox_h=0.1),
|
||||
_face(bbox_w=0.5, bbox_h=0.5),
|
||||
_face(bbox_w=0.2, bbox_h=0.2)]
|
||||
self.assertIs(self.mod.pick_closest_face(faces), faces[1])
|
||||
|
||||
|
||||
class TestFaceImageError(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mod = _load_mod()
|
||||
|
||||
def test_centered_face_zero_error(self):
|
||||
f = _face(bbox_x=0.4, bbox_y=0.4, bbox_w=0.2, bbox_h=0.2)
|
||||
pan, tilt = self.mod.face_image_error(f, 60.0, 45.0)
|
||||
self.assertAlmostEqual(pan, 0.0)
|
||||
self.assertAlmostEqual(tilt, 0.0)
|
||||
|
||||
def test_right_of_centre(self):
|
||||
# cx = 0.7 + 0.1 = 0.8, error = 0.3 * 60 = 18°
|
||||
f = _face(bbox_x=0.7, bbox_y=0.4, bbox_w=0.2, bbox_h=0.2)
|
||||
pan, _ = self.mod.face_image_error(f, 60.0, 45.0)
|
||||
self.assertAlmostEqual(pan, 18.0)
|
||||
|
||||
def test_left_of_centre(self):
|
||||
# cx = 0.1 + 0.1 = 0.2, error = -0.3 * 60 = -18°
|
||||
f = _face(bbox_x=0.1, bbox_y=0.4, bbox_w=0.2, bbox_h=0.2)
|
||||
pan, _ = self.mod.face_image_error(f, 60.0, 45.0)
|
||||
self.assertAlmostEqual(pan, -18.0)
|
||||
|
||||
def test_below_centre(self):
|
||||
# cy = 0.7 + 0.1 = 0.8, error = 0.3 * 45 = 13.5°
|
||||
f = _face(bbox_x=0.4, bbox_y=0.7, bbox_w=0.2, bbox_h=0.2)
|
||||
_, tilt = self.mod.face_image_error(f, 60.0, 45.0)
|
||||
self.assertAlmostEqual(tilt, 13.5)
|
||||
|
||||
def test_above_centre(self):
|
||||
# cy = 0.1 + 0.1 = 0.2, error = -0.3 * 45 = -13.5°
|
||||
f = _face(bbox_x=0.4, bbox_y=0.1, bbox_w=0.2, bbox_h=0.2)
|
||||
_, tilt = self.mod.face_image_error(f, 60.0, 45.0)
|
||||
self.assertAlmostEqual(tilt, -13.5)
|
||||
|
||||
|
||||
class TestStepTowardZero(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mod = _load_mod()
|
||||
|
||||
def test_positive_large(self):
|
||||
result = self.mod._step_toward_zero(10.0, 1.0)
|
||||
self.assertAlmostEqual(result, 9.0)
|
||||
|
||||
def test_negative_large(self):
|
||||
result = self.mod._step_toward_zero(-10.0, 1.0)
|
||||
self.assertAlmostEqual(result, -9.0)
|
||||
|
||||
def test_smaller_than_step(self):
|
||||
result = self.mod._step_toward_zero(0.5, 1.0)
|
||||
self.assertAlmostEqual(result, 0.0)
|
||||
|
||||
def test_exact_step(self):
|
||||
result = self.mod._step_toward_zero(1.0, 1.0)
|
||||
self.assertAlmostEqual(result, 0.0)
|
||||
|
||||
def test_zero(self):
|
||||
result = self.mod._step_toward_zero(0.0, 1.0)
|
||||
self.assertAlmostEqual(result, 0.0)
|
||||
|
||||
|
||||
# ── Tests: PIDController ──────────────────────────────────────────────────────
|
||||
|
||||
class TestPIDController(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mod = _load_mod()
|
||||
|
||||
def _pid(self, kp=1.0, ki=0.0, kd=0.0, vel_limit=100.0, windup=100.0):
|
||||
return self.mod.PIDController(kp, ki, kd, vel_limit, windup)
|
||||
|
||||
def test_proportional_only(self):
|
||||
pid = self._pid(kp=2.0)
|
||||
out = pid.update(5.0, 0.1)
|
||||
self.assertAlmostEqual(out, 10.0)
|
||||
|
||||
def test_zero_error_zero_output(self):
|
||||
pid = self._pid(kp=5.0)
|
||||
self.assertAlmostEqual(pid.update(0.0, 0.1), 0.0)
|
||||
|
||||
def test_integral_accumulates(self):
|
||||
pid = self._pid(kp=0.0, ki=1.0)
|
||||
pid.update(1.0, 0.1) # integral = 0.1
|
||||
out = pid.update(1.0, 0.1) # integral = 0.2, output = 0.2
|
||||
self.assertAlmostEqual(out, 0.2, places=5)
|
||||
|
||||
def test_derivative_first_tick_zero(self):
|
||||
pid = self._pid(kp=0.0, kd=1.0)
|
||||
out = pid.update(10.0, 0.1)
|
||||
self.assertAlmostEqual(out, 0.0) # first tick: derivative = 0
|
||||
|
||||
def test_derivative_second_tick(self):
|
||||
pid = self._pid(kp=0.0, kd=1.0)
|
||||
pid.update(0.0, 0.1) # first tick
|
||||
out = pid.update(10.0, 0.1) # de/dt = 10/0.1 = 100
|
||||
self.assertAlmostEqual(out, 100.0)
|
||||
|
||||
def test_velocity_clamped(self):
|
||||
pid = self._pid(kp=100.0, vel_limit=10.0)
|
||||
out = pid.update(5.0, 0.1)
|
||||
self.assertAlmostEqual(out, 10.0)
|
||||
|
||||
def test_velocity_clamped_negative(self):
|
||||
pid = self._pid(kp=100.0, vel_limit=10.0)
|
||||
out = pid.update(-5.0, 0.1)
|
||||
self.assertAlmostEqual(out, -10.0)
|
||||
|
||||
def test_antiwindup(self):
|
||||
pid = self._pid(kp=0.0, ki=1.0, windup=5.0)
|
||||
for _ in range(100):
|
||||
pid.update(1.0, 0.1) # would accumulate 10, clamped at 5
|
||||
out = pid.update(0.0, 0.1)
|
||||
self.assertAlmostEqual(out, 5.0, places=3)
|
||||
|
||||
def test_reset_clears_integral(self):
|
||||
pid = self._pid(ki=1.0)
|
||||
pid.update(1.0, 1.0)
|
||||
pid.reset()
|
||||
out = pid.update(0.0, 0.1)
|
||||
self.assertAlmostEqual(out, 0.0)
|
||||
|
||||
def test_reset_clears_derivative(self):
|
||||
pid = self._pid(kp=0.0, kd=1.0)
|
||||
pid.update(10.0, 0.1) # sets prev_error
|
||||
pid.reset()
|
||||
out = pid.update(10.0, 0.1) # after reset, first tick = 0 derivative
|
||||
self.assertAlmostEqual(out, 0.0)
|
||||
|
||||
def test_zero_dt_returns_zero(self):
|
||||
pid = self._pid(kp=10.0)
|
||||
self.assertAlmostEqual(pid.update(5.0, 0.0), 0.0)
|
||||
|
||||
def test_negative_dt_returns_zero(self):
|
||||
pid = self._pid(kp=10.0)
|
||||
self.assertAlmostEqual(pid.update(5.0, -0.1), 0.0)
|
||||
|
||||
|
||||
# ── Tests: node initialisation ────────────────────────────────────────────────
|
||||
|
||||
class TestNodeInit(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mod = _load_mod()
|
||||
|
||||
def test_instantiates(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIsNotNone(node)
|
||||
|
||||
def test_pan_pub(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIn("/saltybot/head_pan", node._pubs)
|
||||
|
||||
def test_tilt_pub(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIn("/saltybot/head_tilt", node._pubs)
|
||||
|
||||
def test_faces_sub(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIn("/social/faces/detected", node._subs)
|
||||
|
||||
def test_timer_registered(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertGreater(len(node._timers), 0)
|
||||
|
||||
def test_initial_pan_zero(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertAlmostEqual(node._pan_cmd, 0.0)
|
||||
|
||||
def test_initial_tilt_zero(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertAlmostEqual(node._tilt_cmd, 0.0)
|
||||
|
||||
def test_custom_fov(self):
|
||||
node = _make_node(self.mod, fov_h_deg=90.0)
|
||||
self.assertAlmostEqual(node._fov_h, 90.0)
|
||||
|
||||
def test_custom_pan_limit(self):
|
||||
node = _make_node(self.mod, pan_limit_deg=45.0)
|
||||
self.assertAlmostEqual(node._pan_limit, 45.0)
|
||||
|
||||
|
||||
# ── Tests: face callback ──────────────────────────────────────────────────────
|
||||
|
||||
class TestFaceCallback(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mod = _load_mod()
|
||||
|
||||
def setUp(self):
|
||||
self.node = _make_node(self.mod)
|
||||
|
||||
def test_empty_msg_no_face(self):
|
||||
self.node._on_faces(_FaceDetectionArray([]))
|
||||
self.assertIsNone(self.node._latest_face)
|
||||
|
||||
def test_single_face_stored(self):
|
||||
f = _centered_face()
|
||||
self.node._on_faces(_FaceDetectionArray([f]))
|
||||
self.assertIs(self.node._latest_face, f)
|
||||
|
||||
def test_closest_face_picked(self):
|
||||
small = _face(bbox_w=0.1, bbox_h=0.1, face_id=1)
|
||||
big = _face(bbox_w=0.5, bbox_h=0.5, face_id=2)
|
||||
self.node._on_faces(_FaceDetectionArray([small, big]))
|
||||
self.assertIs(self.node._latest_face, big)
|
||||
|
||||
def test_timestamp_updated_on_face(self):
|
||||
before = time.monotonic()
|
||||
f = _centered_face()
|
||||
self.node._on_faces(_FaceDetectionArray([f]))
|
||||
self.assertGreaterEqual(self.node._last_face_t, before)
|
||||
|
||||
def test_timestamp_not_updated_on_empty(self):
|
||||
self.node._last_face_t = 0.0
|
||||
self.node._on_faces(_FaceDetectionArray([]))
|
||||
self.assertEqual(self.node._last_face_t, 0.0)
|
||||
|
||||
|
||||
# ── Tests: control loop ───────────────────────────────────────────────────────
|
||||
|
||||
class TestControlLoop(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mod = _load_mod()
|
||||
|
||||
def setUp(self):
|
||||
self.node = _make_node(self.mod, dead_zone=0.0,
|
||||
ki_pan=0.0, kd_pan=0.0,
|
||||
ki_tilt=0.0, kd_tilt=0.0)
|
||||
self.pan_pub = self.node._pubs["/saltybot/head_pan"]
|
||||
self.tilt_pub = self.node._pubs["/saltybot/head_tilt"]
|
||||
|
||||
def _tick(self, dt=0.05):
|
||||
self.node._last_tick = time.monotonic() - dt
|
||||
self.node._control_cb()
|
||||
|
||||
def test_no_face_publishes_zero_initially(self):
|
||||
self._tick()
|
||||
self.assertAlmostEqual(self.pan_pub.msgs[-1].data, 0.0)
|
||||
self.assertAlmostEqual(self.tilt_pub.msgs[-1].data, 0.0)
|
||||
|
||||
def test_centered_face_minimal_movement(self):
|
||||
f = _centered_face() # cx=cy=0.5, error=0
|
||||
self.node._on_faces(_FaceDetectionArray([f]))
|
||||
self.node._last_face_t = time.monotonic()
|
||||
self._tick()
|
||||
# With dead_zone=0 and error=0, pid output=0, cmd stays 0
|
||||
self.assertAlmostEqual(self.pan_pub.msgs[-1].data, 0.0, places=4)
|
||||
self.assertAlmostEqual(self.tilt_pub.msgs[-1].data, 0.0, places=4)
|
||||
|
||||
def test_right_face_pans_right(self):
|
||||
# Face right of centre → positive pan error → pan_cmd increases
|
||||
f = _face(bbox_x=0.7, bbox_y=0.4, bbox_w=0.2, bbox_h=0.2)
|
||||
self.node._on_faces(_FaceDetectionArray([f]))
|
||||
self.node._last_face_t = time.monotonic()
|
||||
self._tick()
|
||||
self.assertGreater(self.pan_pub.msgs[-1].data, 0.0)
|
||||
|
||||
def test_left_face_pans_left(self):
|
||||
f = _face(bbox_x=0.1, bbox_y=0.4, bbox_w=0.2, bbox_h=0.2)
|
||||
self.node._on_faces(_FaceDetectionArray([f]))
|
||||
self.node._last_face_t = time.monotonic()
|
||||
self._tick()
|
||||
self.assertLess(self.pan_pub.msgs[-1].data, 0.0)
|
||||
|
||||
def test_low_face_tilts_down(self):
|
||||
f = _face(bbox_x=0.4, bbox_y=0.7, bbox_w=0.2, bbox_h=0.2)
|
||||
self.node._on_faces(_FaceDetectionArray([f]))
|
||||
self.node._last_face_t = time.monotonic()
|
||||
self._tick()
|
||||
self.assertGreater(self.tilt_pub.msgs[-1].data, 0.0)
|
||||
|
||||
def test_high_face_tilts_up(self):
|
||||
f = _face(bbox_x=0.4, bbox_y=0.1, bbox_w=0.2, bbox_h=0.2)
|
||||
self.node._on_faces(_FaceDetectionArray([f]))
|
||||
self.node._last_face_t = time.monotonic()
|
||||
self._tick()
|
||||
self.assertLess(self.tilt_pub.msgs[-1].data, 0.0)
|
||||
|
||||
def test_pan_clamped_to_limit(self):
|
||||
node = _make_node(self.mod, kp_pan=1000.0, ki_pan=0.0, kd_pan=0.0,
|
||||
pan_limit_deg=45.0, pan_vel_limit=9999.0,
|
||||
dead_zone=0.0)
|
||||
pub = node._pubs["/saltybot/head_pan"]
|
||||
f = _face(bbox_x=0.9, bbox_y=0.4, bbox_w=0.1, bbox_h=0.2)
|
||||
node._on_faces(_FaceDetectionArray([f]))
|
||||
node._last_face_t = time.monotonic()
|
||||
# Run many ticks to accumulate
|
||||
for _ in range(50):
|
||||
node._last_tick = time.monotonic() - 0.05
|
||||
node._control_cb()
|
||||
self.assertLessEqual(pub.msgs[-1].data, 45.0)
|
||||
|
||||
def test_tilt_clamped_to_limit(self):
|
||||
node = _make_node(self.mod, kp_tilt=1000.0, ki_tilt=0.0, kd_tilt=0.0,
|
||||
tilt_limit_deg=20.0, tilt_vel_limit=9999.0,
|
||||
dead_zone=0.0)
|
||||
pub = node._pubs["/saltybot/head_tilt"]
|
||||
f = _face(bbox_x=0.4, bbox_y=0.9, bbox_w=0.2, bbox_h=0.1)
|
||||
node._on_faces(_FaceDetectionArray([f]))
|
||||
node._last_face_t = time.monotonic()
|
||||
for _ in range(50):
|
||||
node._last_tick = time.monotonic() - 0.05
|
||||
node._control_cb()
|
||||
self.assertLessEqual(pub.msgs[-1].data, 20.0)
|
||||
|
||||
def test_lost_face_returns_to_zero(self):
|
||||
node = _make_node(self.mod, kp_pan=10.0, ki_pan=0.0, kd_pan=0.0,
|
||||
dead_zone=0.0, return_rate_deg_s=90.0,
|
||||
lost_timeout_s=0.01)
|
||||
pub = node._pubs["/saltybot/head_pan"]
|
||||
f = _face(bbox_x=0.7, bbox_y=0.4, bbox_w=0.2, bbox_h=0.2)
|
||||
node._on_faces(_FaceDetectionArray([f]))
|
||||
node._last_face_t = time.monotonic()
|
||||
# Build up some pan
|
||||
for _ in range(5):
|
||||
node._last_tick = time.monotonic() - 0.05
|
||||
node._control_cb()
|
||||
# Expire face timeout
|
||||
node._last_face_t = time.monotonic() - 10.0
|
||||
for _ in range(20):
|
||||
node._last_tick = time.monotonic() - 0.05
|
||||
node._control_cb()
|
||||
self.assertAlmostEqual(pub.msgs[-1].data, 0.0, places=3)
|
||||
|
||||
def test_publishes_every_tick(self):
|
||||
for _ in range(3):
|
||||
self._tick()
|
||||
self.assertEqual(len(self.pan_pub.msgs), 3)
|
||||
self.assertEqual(len(self.tilt_pub.msgs), 3)
|
||||
|
||||
def test_dead_zone_suppresses_small_error(self):
|
||||
node = _make_node(self.mod, kp_pan=100.0, ki_pan=0.0, kd_pan=0.0,
|
||||
dead_zone=0.1, fov_h_deg=60.0)
|
||||
pub = node._pubs["/saltybot/head_pan"]
|
||||
# Face 2% right of centre — within dead_zone=10% of frame
|
||||
f = _face(bbox_x=0.42, bbox_y=0.4, bbox_w=0.2, bbox_h=0.2)
|
||||
node._on_faces(_FaceDetectionArray([f]))
|
||||
node._last_face_t = time.monotonic()
|
||||
node._last_tick = time.monotonic() - 0.05
|
||||
node._control_cb()
|
||||
self.assertAlmostEqual(pub.msgs[-1].data, 0.0, places=4)
|
||||
|
||||
|
||||
# ── Tests: source-level checks ────────────────────────────────────────────────
|
||||
|
||||
class TestNodeSrc(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
with open(_SRC) as f:
|
||||
cls.src = f.read()
|
||||
|
||||
def test_issue_tag(self):
|
||||
self.assertIn("#279", self.src)
|
||||
|
||||
def test_pan_topic(self):
|
||||
self.assertIn("/saltybot/head_pan", self.src)
|
||||
|
||||
def test_tilt_topic(self):
|
||||
self.assertIn("/saltybot/head_tilt", self.src)
|
||||
|
||||
def test_faces_topic(self):
|
||||
self.assertIn("/social/faces/detected", self.src)
|
||||
|
||||
def test_pid_class(self):
|
||||
self.assertIn("class PIDController", self.src)
|
||||
|
||||
def test_kp_param(self):
|
||||
self.assertIn("kp_pan", self.src)
|
||||
|
||||
def test_ki_param(self):
|
||||
self.assertIn("ki_pan", self.src)
|
||||
|
||||
def test_kd_param(self):
|
||||
self.assertIn("kd_pan", self.src)
|
||||
|
||||
def test_fov_param(self):
|
||||
self.assertIn("fov_h_deg", self.src)
|
||||
|
||||
def test_pan_limit_param(self):
|
||||
self.assertIn("pan_limit_deg", self.src)
|
||||
|
||||
def test_dead_zone_param(self):
|
||||
self.assertIn("dead_zone", self.src)
|
||||
|
||||
def test_pick_closest_face(self):
|
||||
self.assertIn("pick_closest_face", self.src)
|
||||
|
||||
def test_main_defined(self):
|
||||
self.assertIn("def main", self.src)
|
||||
|
||||
def test_antiwindup(self):
|
||||
self.assertIn("windup", self.src)
|
||||
|
||||
def test_threading_lock(self):
|
||||
self.assertIn("threading.Lock", self.src)
|
||||
|
||||
|
||||
class TestConfig(unittest.TestCase):
|
||||
_CONFIG = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/config/face_track_servo_params.yaml"
|
||||
)
|
||||
_LAUNCH = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/launch/face_track_servo.launch.py"
|
||||
)
|
||||
_SETUP = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/setup.py"
|
||||
)
|
||||
|
||||
def test_config_exists(self):
|
||||
import os; self.assertTrue(os.path.exists(self._CONFIG))
|
||||
|
||||
def test_config_kp_pan(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("kp_pan", c)
|
||||
|
||||
def test_config_fov(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("fov_h_deg", c)
|
||||
|
||||
def test_config_pan_limit(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("pan_limit_deg", c)
|
||||
|
||||
def test_config_dead_zone(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("dead_zone", c)
|
||||
|
||||
def test_launch_exists(self):
|
||||
import os; self.assertTrue(os.path.exists(self._LAUNCH))
|
||||
|
||||
def test_launch_kp_pan_arg(self):
|
||||
with open(self._LAUNCH) as f: c = f.read()
|
||||
self.assertIn("kp_pan", c)
|
||||
|
||||
def test_launch_pan_limit_arg(self):
|
||||
with open(self._LAUNCH) as f: c = f.read()
|
||||
self.assertIn("pan_limit_deg", c)
|
||||
|
||||
def test_entry_point(self):
|
||||
with open(self._SETUP) as f: c = f.read()
|
||||
self.assertIn("face_track_servo_node", c)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
472
jetson/ros2_ws/src/saltybot_social/test/test_greeting_trigger.py
Normal file
472
jetson/ros2_ws/src/saltybot_social/test/test_greeting_trigger.py
Normal file
@ -0,0 +1,472 @@
|
||||
"""test_greeting_trigger.py -- Offline tests for greeting_trigger_node (Issue #270).
|
||||
|
||||
Stubs out rclpy and saltybot_social_msgs so tests run without a ROS install.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import types
|
||||
import unittest
|
||||
|
||||
|
||||
# ── ROS2 / message stubs ──────────────────────────────────────────────────────
|
||||
|
||||
def _make_ros_stubs():
|
||||
"""Install minimal stubs for rclpy and message packages."""
|
||||
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
|
||||
"std_msgs", "std_msgs.msg",
|
||||
"saltybot_social_msgs", "saltybot_social_msgs.msg"):
|
||||
sys.modules[mod_name] = types.ModuleType(mod_name)
|
||||
|
||||
class _Node:
|
||||
def __init__(self, name):
|
||||
self._name = name
|
||||
# Preserve _params if pre-set by _make_node (super().__init__() is
|
||||
# called from GreetingTriggerNode.__init__, so don't reset here)
|
||||
if not hasattr(self, '_params'):
|
||||
self._params = {}
|
||||
self._pubs = {}
|
||||
self._subs = {}
|
||||
self._logs = []
|
||||
|
||||
def declare_parameter(self, name, default):
|
||||
# Don't overwrite values pre-set by _make_node
|
||||
if name not in self._params:
|
||||
self._params[name] = default
|
||||
|
||||
def get_parameter(self, name):
|
||||
class _P:
|
||||
def __init__(self, v):
|
||||
self.value = v
|
||||
return _P(self._params[name])
|
||||
|
||||
def create_publisher(self, msg_type, topic, qos):
|
||||
pub = _FakePub()
|
||||
self._pubs[topic] = pub
|
||||
return pub
|
||||
|
||||
def create_subscription(self, msg_type, topic, cb, qos):
|
||||
self._subs[topic] = cb
|
||||
return object()
|
||||
|
||||
def get_logger(self):
|
||||
node = self
|
||||
class _L:
|
||||
def info(self, m): node._logs.append(("INFO", m))
|
||||
def warn(self, m): node._logs.append(("WARN", m))
|
||||
def error(self, m): node._logs.append(("ERROR", m))
|
||||
return _L()
|
||||
|
||||
def destroy_node(self): pass
|
||||
|
||||
class _FakePub:
|
||||
def __init__(self):
|
||||
self.msgs = []
|
||||
def publish(self, msg):
|
||||
self.msgs.append(msg)
|
||||
|
||||
class _QoSProfile:
|
||||
def __init__(self, depth=10): self.depth = depth
|
||||
|
||||
class _String:
|
||||
def __init__(self): self.data = ""
|
||||
|
||||
# rclpy
|
||||
rclpy_mod = sys.modules["rclpy"]
|
||||
rclpy_mod.init = lambda args=None: None
|
||||
rclpy_mod.spin = lambda node: None
|
||||
rclpy_mod.shutdown = lambda: None
|
||||
|
||||
# rclpy.node
|
||||
sys.modules["rclpy.node"].Node = _Node
|
||||
|
||||
# rclpy.qos
|
||||
sys.modules["rclpy.qos"].QoSProfile = _QoSProfile
|
||||
|
||||
# std_msgs.msg
|
||||
sys.modules["std_msgs.msg"].String = _String
|
||||
|
||||
# saltybot_social_msgs.msg (FaceDetectionArray + PersonStateArray)
|
||||
class _FaceDetection:
|
||||
def __init__(self, face_id=0, person_name="", confidence=1.0):
|
||||
self.face_id = face_id
|
||||
self.person_name = person_name
|
||||
self.confidence = confidence
|
||||
|
||||
class _FaceDetectionArray:
|
||||
def __init__(self, faces=None):
|
||||
self.faces = faces or []
|
||||
|
||||
class _PersonState:
|
||||
def __init__(self, face_id=0, distance=0.0):
|
||||
self.face_id = face_id
|
||||
self.distance = distance
|
||||
|
||||
class _PersonStateArray:
|
||||
def __init__(self, persons=None):
|
||||
self.persons = persons or []
|
||||
|
||||
msgs = sys.modules["saltybot_social_msgs.msg"]
|
||||
msgs.FaceDetection = _FaceDetection
|
||||
msgs.FaceDetectionArray = _FaceDetectionArray
|
||||
msgs.PersonState = _PersonState
|
||||
msgs.PersonStateArray = _PersonStateArray
|
||||
|
||||
return _Node, _FakePub, _QoSProfile, _String, _FaceDetection, _FaceDetectionArray, _PersonState, _PersonStateArray
|
||||
|
||||
|
||||
_Node, _FakePub, _QoSProfile, _String, _FaceDetection, _FaceDetectionArray, _PersonState, _PersonStateArray = _make_ros_stubs()
|
||||
|
||||
|
||||
# ── Load module under test ────────────────────────────────────────────────────
|
||||
|
||||
_SRC = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/saltybot_social/greeting_trigger_node.py"
|
||||
)
|
||||
|
||||
def _load_mod():
|
||||
spec = importlib.util.spec_from_file_location("greeting_trigger_node_testmod", _SRC)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _make_node(mod, **kwargs):
|
||||
"""Instantiate GreetingTriggerNode with overridden parameters."""
|
||||
node = mod.GreetingTriggerNode.__new__(mod.GreetingTriggerNode)
|
||||
|
||||
# Pre-populate _params BEFORE __init__ so super().__init__() (which calls
|
||||
# _Node.__init__) sees them and skips reset due to hasattr guard.
|
||||
defaults = {
|
||||
"proximity_m": 2.0,
|
||||
"cooldown_s": 300.0,
|
||||
"unknown_distance": 0.0,
|
||||
"faces_topic": "/social/faces/detected",
|
||||
"states_topic": "/social/person_states",
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
node._params = dict(defaults)
|
||||
|
||||
mod.GreetingTriggerNode.__init__(node)
|
||||
return node
|
||||
|
||||
|
||||
def _face_msg(faces):
|
||||
return _FaceDetectionArray(faces=faces)
|
||||
|
||||
def _state_msg(persons):
|
||||
return _PersonStateArray(persons=persons)
|
||||
|
||||
|
||||
# ── Test suites ───────────────────────────────────────────────────────────────
|
||||
|
||||
class TestNodeInit(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mod = _load_mod()
|
||||
|
||||
def test_imports_cleanly(self):
|
||||
self.assertTrue(hasattr(self.mod, "GreetingTriggerNode"))
|
||||
|
||||
def test_default_proximity(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertEqual(node._proximity, 2.0)
|
||||
|
||||
def test_default_cooldown(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertEqual(node._cooldown, 300.0)
|
||||
|
||||
def test_default_unknown_distance(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertEqual(node._unknown_dist, 0.0)
|
||||
|
||||
def test_pub_topic(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIn("/saltybot/greeting_trigger", node._pubs)
|
||||
|
||||
def test_subs_registered(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIn("/social/faces/detected", node._subs)
|
||||
self.assertIn("/social/person_states", node._subs)
|
||||
|
||||
def test_initial_caches_empty(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertEqual(node._distance_cache, {})
|
||||
self.assertEqual(node._last_greeted, {})
|
||||
|
||||
|
||||
class TestDistanceCache(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mod = _load_mod()
|
||||
|
||||
def setUp(self):
|
||||
self.node = _make_node(self.mod)
|
||||
|
||||
def test_state_updates_cache(self):
|
||||
ps = _PersonState(face_id=1, distance=1.5)
|
||||
self.node._on_person_states(_state_msg([ps]))
|
||||
self.assertAlmostEqual(self.node._distance_cache[1], 1.5)
|
||||
|
||||
def test_multiple_states_cached(self):
|
||||
persons = [_PersonState(face_id=i, distance=float(i)) for i in range(5)]
|
||||
self.node._on_person_states(_state_msg(persons))
|
||||
for i in range(5):
|
||||
self.assertAlmostEqual(self.node._distance_cache[i], float(i))
|
||||
|
||||
def test_state_update_overwrites(self):
|
||||
self.node._on_person_states(_state_msg([_PersonState(face_id=1, distance=3.0)]))
|
||||
self.node._on_person_states(_state_msg([_PersonState(face_id=1, distance=1.0)]))
|
||||
self.assertAlmostEqual(self.node._distance_cache[1], 1.0)
|
||||
|
||||
def test_negative_face_id_ignored(self):
|
||||
self.node._on_person_states(_state_msg([_PersonState(face_id=-1, distance=1.0)]))
|
||||
self.assertNotIn(-1, self.node._distance_cache)
|
||||
|
||||
def test_zero_distance_cached(self):
|
||||
self.node._on_person_states(_state_msg([_PersonState(face_id=5, distance=0.0)]))
|
||||
self.assertAlmostEqual(self.node._distance_cache[5], 0.0)
|
||||
|
||||
|
||||
class TestGreetingTrigger(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mod = _load_mod()
|
||||
|
||||
def setUp(self):
|
||||
self.node = _make_node(self.mod, proximity_m=2.0, cooldown_s=300.0)
|
||||
self.pub = self.node._pubs["/saltybot/greeting_trigger"]
|
||||
|
||||
def _inject_distance(self, face_id, distance):
|
||||
self.node._on_person_states(_state_msg([_PersonState(face_id=face_id, distance=distance)]))
|
||||
|
||||
def test_triggers_within_proximity(self):
|
||||
self._inject_distance(1, 1.5)
|
||||
self.node._on_faces(_face_msg([_FaceDetection(face_id=1, person_name="alice")]))
|
||||
self.assertEqual(len(self.pub.msgs), 1)
|
||||
|
||||
def test_no_trigger_beyond_proximity(self):
|
||||
self._inject_distance(2, 3.0)
|
||||
self.node._on_faces(_face_msg([_FaceDetection(face_id=2, person_name="bob")]))
|
||||
self.assertEqual(len(self.pub.msgs), 0)
|
||||
|
||||
def test_trigger_at_exact_proximity(self):
|
||||
self._inject_distance(3, 2.0)
|
||||
self.node._on_faces(_face_msg([_FaceDetection(face_id=3, person_name="carol")]))
|
||||
self.assertEqual(len(self.pub.msgs), 1)
|
||||
|
||||
def test_no_trigger_just_beyond(self):
|
||||
self._inject_distance(4, 2.001)
|
||||
self.node._on_faces(_face_msg([_FaceDetection(face_id=4, person_name="dave")]))
|
||||
self.assertEqual(len(self.pub.msgs), 0)
|
||||
|
||||
def test_cooldown_suppresses_retrigger(self):
|
||||
self._inject_distance(5, 1.0)
|
||||
face = _FaceDetection(face_id=5, person_name="eve")
|
||||
self.node._on_faces(_face_msg([face]))
|
||||
self.node._on_faces(_face_msg([face])) # second call in cooldown
|
||||
self.assertEqual(len(self.pub.msgs), 1)
|
||||
|
||||
def test_cooldown_per_face_id(self):
|
||||
self._inject_distance(6, 1.0)
|
||||
self._inject_distance(7, 1.0)
|
||||
self.node._on_faces(_face_msg([_FaceDetection(face_id=6, person_name="f")]))
|
||||
self.node._on_faces(_face_msg([_FaceDetection(face_id=7, person_name="g")]))
|
||||
self.assertEqual(len(self.pub.msgs), 2)
|
||||
|
||||
def test_expired_cooldown_retrigers(self):
|
||||
self._inject_distance(8, 1.0)
|
||||
face = _FaceDetection(face_id=8, person_name="hank")
|
||||
self.node._on_faces(_face_msg([face]))
|
||||
# Manually expire the cooldown
|
||||
self.node._last_greeted[8] = time.monotonic() - 400.0
|
||||
self.node._on_faces(_face_msg([face]))
|
||||
self.assertEqual(len(self.pub.msgs), 2)
|
||||
|
||||
def test_unknown_face_uses_unknown_distance(self):
|
||||
# unknown_distance=0.0 → should trigger (0.0 <= 2.0)
|
||||
node = _make_node(self.mod, unknown_distance=0.0)
|
||||
pub = node._pubs["/saltybot/greeting_trigger"]
|
||||
node._on_faces(_face_msg([_FaceDetection(face_id=99, person_name="stranger")]))
|
||||
self.assertEqual(len(pub.msgs), 1)
|
||||
|
||||
def test_unknown_face_large_distance_no_trigger(self):
|
||||
# unknown_distance=10.0 → should NOT trigger
|
||||
node = _make_node(self.mod, unknown_distance=10.0)
|
||||
pub = node._pubs["/saltybot/greeting_trigger"]
|
||||
node._on_faces(_face_msg([_FaceDetection(face_id=100, person_name="far")]))
|
||||
self.assertEqual(len(pub.msgs), 0)
|
||||
|
||||
def test_multiple_faces_triggers_each_within_range(self):
|
||||
self._inject_distance(10, 1.0)
|
||||
self._inject_distance(11, 3.0) # out of range
|
||||
faces = [
|
||||
_FaceDetection(face_id=10, person_name="near"),
|
||||
_FaceDetection(face_id=11, person_name="far"),
|
||||
]
|
||||
self.node._on_faces(_face_msg(faces))
|
||||
self.assertEqual(len(self.pub.msgs), 1)
|
||||
|
||||
def test_empty_face_array_no_trigger(self):
|
||||
self.node._on_faces(_face_msg([]))
|
||||
self.assertEqual(len(self.pub.msgs), 0)
|
||||
|
||||
|
||||
class TestPayload(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mod = _load_mod()
|
||||
|
||||
def setUp(self):
|
||||
self.node = _make_node(self.mod)
|
||||
self.pub = self.node._pubs["/saltybot/greeting_trigger"]
|
||||
|
||||
def _trigger(self, face_id=1, person_name="alice", distance=1.5):
|
||||
self.node._on_person_states(_state_msg([_PersonState(face_id=face_id, distance=distance)]))
|
||||
self.node._on_faces(_face_msg([_FaceDetection(face_id=face_id, person_name=person_name)]))
|
||||
|
||||
def test_payload_is_json(self):
|
||||
self._trigger()
|
||||
payload = json.loads(self.pub.msgs[0].data)
|
||||
self.assertIsInstance(payload, dict)
|
||||
|
||||
def test_payload_face_id(self):
|
||||
self._trigger(face_id=42)
|
||||
payload = json.loads(self.pub.msgs[0].data)
|
||||
self.assertEqual(payload["face_id"], 42)
|
||||
|
||||
def test_payload_person_name(self):
|
||||
self._trigger(person_name="zara")
|
||||
payload = json.loads(self.pub.msgs[0].data)
|
||||
self.assertEqual(payload["person_name"], "zara")
|
||||
|
||||
def test_payload_distance(self):
|
||||
self._trigger(distance=1.234)
|
||||
payload = json.loads(self.pub.msgs[0].data)
|
||||
self.assertAlmostEqual(payload["distance_m"], 1.234, places=2)
|
||||
|
||||
def test_payload_has_ts(self):
|
||||
self._trigger()
|
||||
payload = json.loads(self.pub.msgs[0].data)
|
||||
self.assertIn("ts", payload)
|
||||
self.assertIsInstance(payload["ts"], float)
|
||||
|
||||
def test_ts_is_recent(self):
|
||||
before = time.time()
|
||||
self._trigger()
|
||||
after = time.time()
|
||||
payload = json.loads(self.pub.msgs[0].data)
|
||||
self.assertGreaterEqual(payload["ts"], before)
|
||||
self.assertLessEqual(payload["ts"], after + 1.0)
|
||||
|
||||
|
||||
class TestNodeSrc(unittest.TestCase):
|
||||
"""Source-level checks — verify node structure without instantiation."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
with open(_SRC) as f:
|
||||
cls.src = f.read()
|
||||
|
||||
def test_issue_tag(self):
|
||||
self.assertIn("#270", self.src)
|
||||
|
||||
def test_pub_topic(self):
|
||||
self.assertIn("/saltybot/greeting_trigger", self.src)
|
||||
|
||||
def test_faces_topic(self):
|
||||
self.assertIn("/social/faces/detected", self.src)
|
||||
|
||||
def test_states_topic(self):
|
||||
self.assertIn("/social/person_states", self.src)
|
||||
|
||||
def test_proximity_param(self):
|
||||
self.assertIn("proximity_m", self.src)
|
||||
|
||||
def test_cooldown_param(self):
|
||||
self.assertIn("cooldown_s", self.src)
|
||||
|
||||
def test_unknown_distance_param(self):
|
||||
self.assertIn("unknown_distance", self.src)
|
||||
|
||||
def test_json_output(self):
|
||||
self.assertIn("json", self.src)
|
||||
|
||||
def test_face_id_in_payload(self):
|
||||
self.assertIn("face_id", self.src)
|
||||
|
||||
def test_person_name_in_payload(self):
|
||||
self.assertIn("person_name", self.src)
|
||||
|
||||
def test_distance_in_payload(self):
|
||||
self.assertIn("distance_m", self.src)
|
||||
|
||||
def test_main_defined(self):
|
||||
self.assertIn("def main", self.src)
|
||||
|
||||
def test_threading_lock(self):
|
||||
self.assertIn("threading.Lock", self.src)
|
||||
|
||||
|
||||
class TestConfig(unittest.TestCase):
|
||||
"""Checks on config/launch/setup files."""
|
||||
|
||||
_CONFIG = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/config/greeting_trigger_params.yaml"
|
||||
)
|
||||
_LAUNCH = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/launch/greeting_trigger.launch.py"
|
||||
)
|
||||
_SETUP = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/setup.py"
|
||||
)
|
||||
|
||||
def test_config_exists(self):
|
||||
import os
|
||||
self.assertTrue(os.path.exists(self._CONFIG))
|
||||
|
||||
def test_config_proximity(self):
|
||||
with open(self._CONFIG) as f:
|
||||
content = f.read()
|
||||
self.assertIn("proximity_m", content)
|
||||
|
||||
def test_config_cooldown(self):
|
||||
with open(self._CONFIG) as f:
|
||||
content = f.read()
|
||||
self.assertIn("cooldown_s", content)
|
||||
|
||||
def test_config_node_name(self):
|
||||
with open(self._CONFIG) as f:
|
||||
content = f.read()
|
||||
self.assertIn("greeting_trigger_node", content)
|
||||
|
||||
def test_launch_exists(self):
|
||||
import os
|
||||
self.assertTrue(os.path.exists(self._LAUNCH))
|
||||
|
||||
def test_launch_proximity_arg(self):
|
||||
with open(self._LAUNCH) as f:
|
||||
content = f.read()
|
||||
self.assertIn("proximity_m", content)
|
||||
|
||||
def test_launch_cooldown_arg(self):
|
||||
with open(self._LAUNCH) as f:
|
||||
content = f.read()
|
||||
self.assertIn("cooldown_s", content)
|
||||
|
||||
def test_entry_point(self):
|
||||
with open(self._SETUP) as f:
|
||||
content = f.read()
|
||||
self.assertIn("greeting_trigger_node", content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -0,0 +1,5 @@
|
||||
wheel_slip_detector:
|
||||
ros__parameters:
|
||||
frequency: 10
|
||||
slip_threshold: 0.1
|
||||
slip_timeout: 0.5
|
||||
@ -0,0 +1,14 @@
|
||||
from launch import LaunchDescription
|
||||
from launch_ros.actions import Node
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
|
||||
def generate_launch_description():
|
||||
pkg_dir = get_package_share_directory("saltybot_wheel_slip_detector")
|
||||
config_file = os.path.join(pkg_dir, "config", "wheel_slip_config.yaml")
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument("config_file", default_value=config_file, description="Path to configuration YAML file"),
|
||||
Node(package="saltybot_wheel_slip_detector", executable="wheel_slip_detector_node", name="wheel_slip_detector", output="screen", parameters=[LaunchConfiguration("config_file")]),
|
||||
])
|
||||
18
jetson/ros2_ws/src/saltybot_wheel_slip_detector/package.xml
Normal file
18
jetson/ros2_ws/src/saltybot_wheel_slip_detector/package.xml
Normal file
@ -0,0 +1,18 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_wheel_slip_detector</name>
|
||||
<version>0.1.0</version>
|
||||
<description>Wheel slip detection by comparing commanded vs actual velocity.</description>
|
||||
<maintainer email="seb@vayrette.com">Seb</maintainer>
|
||||
<license>Apache-2.0</license>
|
||||
<buildtool_depend>ament_python</buildtool_depend>
|
||||
<depend>rclpy</depend>
|
||||
<depend>geometry_msgs</depend>
|
||||
<depend>std_msgs</depend>
|
||||
<depend>nav_msgs</depend>
|
||||
<test_depend>pytest</test_depend>
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
Binary file not shown.
@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python3
|
||||
from typing import Optional
|
||||
import math
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.timer import Timer
|
||||
from geometry_msgs.msg import Twist
|
||||
from nav_msgs.msg import Odometry
|
||||
from std_msgs.msg import Bool
|
||||
|
||||
class WheelSlipDetectorNode(Node):
|
||||
def __init__(self):
|
||||
super().__init__("wheel_slip_detector")
|
||||
self.declare_parameter("frequency", 10)
|
||||
frequency = self.get_parameter("frequency").value
|
||||
self.declare_parameter("slip_threshold", 0.1)
|
||||
self.declare_parameter("slip_timeout", 0.5)
|
||||
self.slip_threshold = self.get_parameter("slip_threshold").value
|
||||
self.slip_timeout = self.get_parameter("slip_timeout").value
|
||||
self.period = 1.0 / frequency
|
||||
self.cmd_vel: Optional[Twist] = None
|
||||
self.actual_vel: Optional[Twist] = None
|
||||
self.slip_duration = 0.0
|
||||
self.slip_detected = False
|
||||
self.create_subscription(Twist, "/cmd_vel", self._on_cmd_vel, 10)
|
||||
self.create_subscription(Odometry, "/odom", self._on_odom, 10)
|
||||
self.pub_slip = self.create_publisher(Bool, "/saltybot/wheel_slip_detected", 10)
|
||||
self.timer: Timer = self.create_timer(self.period, self._timer_callback)
|
||||
self.get_logger().info(f"Wheel slip detector initialized at {frequency}Hz. Threshold: {self.slip_threshold} m/s, Timeout: {self.slip_timeout}s")
|
||||
|
||||
def _on_cmd_vel(self, msg: Twist) -> None:
|
||||
self.cmd_vel = msg
|
||||
|
||||
def _on_odom(self, msg: Odometry) -> None:
|
||||
self.actual_vel = msg.twist.twist
|
||||
|
||||
def _timer_callback(self) -> None:
|
||||
if self.cmd_vel is None or self.actual_vel is None:
|
||||
slip_detected = False
|
||||
else:
|
||||
slip_detected = self._check_slip()
|
||||
if slip_detected:
|
||||
self.slip_duration += self.period
|
||||
else:
|
||||
self.slip_duration = 0.0
|
||||
is_slip = self.slip_duration > self.slip_timeout
|
||||
if is_slip != self.slip_detected:
|
||||
self.slip_detected = is_slip
|
||||
if self.slip_detected:
|
||||
self.get_logger().warn(f"WHEEL SLIP DETECTED: {self.slip_duration:.2f}s")
|
||||
else:
|
||||
self.get_logger().info("Wheel slip cleared")
|
||||
slip_msg = Bool()
|
||||
slip_msg.data = is_slip
|
||||
self.pub_slip.publish(slip_msg)
|
||||
|
||||
def _check_slip(self) -> bool:
|
||||
cmd_speed = math.sqrt(self.cmd_vel.linear.x**2 + self.cmd_vel.linear.y**2)
|
||||
actual_speed = math.sqrt(self.actual_vel.linear.x**2 + self.actual_vel.linear.y**2)
|
||||
vel_diff = abs(cmd_speed - actual_speed)
|
||||
if cmd_speed < 0.05 and actual_speed < 0.05:
|
||||
return False
|
||||
return vel_diff > self.slip_threshold
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = WheelSlipDetectorNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -0,0 +1,4 @@
|
||||
[develop]
|
||||
script-dir=$base/lib/saltybot_wheel_slip_detector
|
||||
[install]
|
||||
install-scripts=$base/lib/saltybot_wheel_slip_detector
|
||||
21
jetson/ros2_ws/src/saltybot_wheel_slip_detector/setup.py
Normal file
21
jetson/ros2_ws/src/saltybot_wheel_slip_detector/setup.py
Normal file
@ -0,0 +1,21 @@
|
||||
from setuptools import find_packages, setup
|
||||
package_name = "saltybot_wheel_slip_detector"
|
||||
setup(
|
||||
name=package_name,
|
||||
version="0.1.0",
|
||||
packages=find_packages(exclude=["test"]),
|
||||
data_files=[
|
||||
("share/ament_index/resource_index/packages", ["resource/" + package_name]),
|
||||
("share/" + package_name, ["package.xml"]),
|
||||
("share/" + package_name + "/launch", ["launch/wheel_slip_detector.launch.py"]),
|
||||
("share/" + package_name + "/config", ["config/wheel_slip_config.yaml"]),
|
||||
],
|
||||
install_requires=["setuptools"],
|
||||
zip_safe=True,
|
||||
maintainer="Seb",
|
||||
maintainer_email="seb@vayrette.com",
|
||||
description="Wheel slip detection from velocity command/actual mismatch",
|
||||
license="Apache-2.0",
|
||||
tests_require=["pytest"],
|
||||
entry_points={"console_scripts": ["wheel_slip_detector_node = saltybot_wheel_slip_detector.wheel_slip_detector_node:main"]},
|
||||
)
|
||||
Binary file not shown.
@ -0,0 +1,343 @@
|
||||
"""Unit tests for wheel_slip_detector_node."""
|
||||
|
||||
import pytest
|
||||
import math
|
||||
from geometry_msgs.msg import Twist
|
||||
from nav_msgs.msg import Odometry
|
||||
from std_msgs.msg import Bool
|
||||
|
||||
import rclpy
|
||||
|
||||
from saltybot_wheel_slip_detector.wheel_slip_detector_node import WheelSlipDetectorNode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rclpy_fixture():
|
||||
"""Initialize and cleanup rclpy."""
|
||||
rclpy.init()
|
||||
yield
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def node(rclpy_fixture):
|
||||
"""Create a wheel slip detector node instance."""
|
||||
node = WheelSlipDetectorNode()
|
||||
yield node
|
||||
node.destroy_node()
|
||||
|
||||
|
||||
class TestNodeInitialization:
|
||||
"""Test suite for node initialization."""
|
||||
|
||||
def test_node_initialization(self, node):
|
||||
"""Test that node initializes with correct defaults."""
|
||||
assert node.cmd_vel is None
|
||||
assert node.actual_vel is None
|
||||
assert node.slip_threshold == 0.1
|
||||
assert node.slip_timeout == 0.5
|
||||
assert node.slip_duration == 0.0
|
||||
assert node.slip_detected is False
|
||||
|
||||
def test_frequency_parameter(self, node):
|
||||
"""Test frequency parameter is set correctly."""
|
||||
frequency = node.get_parameter("frequency").value
|
||||
assert frequency == 10
|
||||
|
||||
def test_slip_threshold_parameter(self, node):
|
||||
"""Test slip threshold parameter is set correctly."""
|
||||
threshold = node.get_parameter("slip_threshold").value
|
||||
assert threshold == 0.1
|
||||
|
||||
def test_slip_timeout_parameter(self, node):
|
||||
"""Test slip timeout parameter is set correctly."""
|
||||
timeout = node.get_parameter("slip_timeout").value
|
||||
assert timeout == 0.5
|
||||
|
||||
def test_period_calculation(self, node):
|
||||
"""Test that time period is correctly calculated from frequency."""
|
||||
assert node.period == pytest.approx(0.1)
|
||||
|
||||
|
||||
class TestSubscriptions:
|
||||
"""Test suite for subscription handling."""
|
||||
|
||||
def test_cmd_vel_subscription(self, node):
|
||||
"""Test that cmd_vel subscription updates node state."""
|
||||
cmd = Twist()
|
||||
cmd.linear.x = 1.0
|
||||
cmd.linear.y = 0.5
|
||||
node._on_cmd_vel(cmd)
|
||||
assert node.cmd_vel is cmd
|
||||
assert node.cmd_vel.linear.x == 1.0
|
||||
|
||||
def test_odom_subscription(self, node):
|
||||
"""Test that odometry subscription updates actual velocity."""
|
||||
odom = Odometry()
|
||||
odom.twist.twist.linear.x = 0.95
|
||||
odom.twist.twist.linear.y = 0.48
|
||||
node._on_odom(odom)
|
||||
assert node.actual_vel is odom.twist.twist
|
||||
assert node.actual_vel.linear.x == 0.95
|
||||
|
||||
|
||||
class TestSlipDetection:
|
||||
"""Test suite for slip detection logic."""
|
||||
|
||||
def test_no_slip_perfect_match(self, node):
|
||||
"""Test no slip when commanded equals actual."""
|
||||
cmd = Twist()
|
||||
cmd.linear.x = 1.0
|
||||
odom = Odometry()
|
||||
odom.twist.twist.linear.x = 1.0
|
||||
node._on_cmd_vel(cmd)
|
||||
node._on_odom(odom)
|
||||
assert node._check_slip() is False
|
||||
|
||||
def test_no_slip_small_difference(self, node):
|
||||
"""Test no slip when difference is below threshold."""
|
||||
cmd = Twist()
|
||||
cmd.linear.x = 1.0
|
||||
odom = Odometry()
|
||||
odom.twist.twist.linear.x = 0.95
|
||||
node._on_cmd_vel(cmd)
|
||||
node._on_odom(odom)
|
||||
assert node._check_slip() is False
|
||||
|
||||
def test_slip_exceeds_threshold(self, node):
|
||||
"""Test slip detection when difference exceeds threshold."""
|
||||
cmd = Twist()
|
||||
cmd.linear.x = 1.0
|
||||
odom = Odometry()
|
||||
odom.twist.twist.linear.x = 0.85
|
||||
node._on_cmd_vel(cmd)
|
||||
node._on_odom(odom)
|
||||
assert node._check_slip() is True
|
||||
|
||||
def test_slip_large_difference(self, node):
|
||||
"""Test slip detection with large velocity difference."""
|
||||
cmd = Twist()
|
||||
cmd.linear.x = 1.0
|
||||
odom = Odometry()
|
||||
odom.twist.twist.linear.x = 0.5
|
||||
node._on_cmd_vel(cmd)
|
||||
node._on_odom(odom)
|
||||
assert node._check_slip() is True
|
||||
|
||||
def test_no_slip_both_zero(self, node):
|
||||
"""Test no slip when both commanded and actual are zero."""
|
||||
cmd = Twist()
|
||||
cmd.linear.x = 0.0
|
||||
odom = Odometry()
|
||||
odom.twist.twist.linear.x = 0.0
|
||||
node._on_cmd_vel(cmd)
|
||||
node._on_odom(odom)
|
||||
assert node._check_slip() is False
|
||||
|
||||
def test_no_slip_both_near_zero(self, node):
|
||||
"""Test no slip when both are near zero (tolerance)."""
|
||||
cmd = Twist()
|
||||
cmd.linear.x = 0.01
|
||||
odom = Odometry()
|
||||
odom.twist.twist.linear.x = 0.02
|
||||
node._on_cmd_vel(cmd)
|
||||
node._on_odom(odom)
|
||||
assert node._check_slip() is False
|
||||
|
||||
def test_slip_2d_velocity(self, node):
|
||||
"""Test slip detection with 2D velocity (x and y)."""
|
||||
cmd = Twist()
|
||||
cmd.linear.x = 0.7
|
||||
cmd.linear.y = 0.7
|
||||
odom = Odometry()
|
||||
odom.twist.twist.linear.x = 0.5
|
||||
odom.twist.twist.linear.y = 0.5
|
||||
node._on_cmd_vel(cmd)
|
||||
node._on_odom(odom)
|
||||
assert node._check_slip() is True
|
||||
|
||||
|
||||
class TestSlipPersistence:
|
||||
"""Test suite for slip persistence timing."""
|
||||
|
||||
def test_slip_not_triggered_immediately(self, node):
|
||||
"""Test that slip is not triggered immediately but requires timeout."""
|
||||
cmd = Twist()
|
||||
cmd.linear.x = 1.0
|
||||
odom = Odometry()
|
||||
odom.twist.twist.linear.x = 0.5
|
||||
node._on_cmd_vel(cmd)
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
assert node.slip_duration > 0.0
|
||||
assert node.slip_detected is False
|
||||
|
||||
def test_slip_declared_after_timeout(self, node):
|
||||
"""Test that slip is declared after timeout period."""
|
||||
cmd = Twist()
|
||||
cmd.linear.x = 1.0
|
||||
odom = Odometry()
|
||||
odom.twist.twist.linear.x = 0.5
|
||||
node._on_cmd_vel(cmd)
|
||||
node._on_odom(odom)
|
||||
for _ in range(6):
|
||||
node._timer_callback()
|
||||
assert node.slip_detected is True
|
||||
|
||||
def test_slip_recovery_resets_duration(self, node):
|
||||
"""Test that slip duration resets when condition clears."""
|
||||
cmd = Twist()
|
||||
cmd.linear.x = 1.0
|
||||
odom1 = Odometry()
|
||||
odom1.twist.twist.linear.x = 0.5
|
||||
node._on_cmd_vel(cmd)
|
||||
node._on_odom(odom1)
|
||||
for _ in range(3):
|
||||
node._timer_callback()
|
||||
odom2 = Odometry()
|
||||
odom2.twist.twist.linear.x = 1.0
|
||||
node._on_odom(odom2)
|
||||
node._timer_callback()
|
||||
assert node.slip_duration == pytest.approx(0.0)
|
||||
assert node.slip_detected is False
|
||||
|
||||
def test_slip_cumulative_time(self, node):
|
||||
"""Test that slip duration accumulates across callbacks."""
|
||||
cmd = Twist()
|
||||
cmd.linear.x = 1.0
|
||||
odom = Odometry()
|
||||
odom.twist.twist.linear.x = 0.5
|
||||
node._on_cmd_vel(cmd)
|
||||
node._on_odom(odom)
|
||||
for _ in range(3):
|
||||
node._timer_callback()
|
||||
assert node.slip_duration == pytest.approx(0.3)
|
||||
assert node.slip_detected is False
|
||||
for _ in range(3):
|
||||
node._timer_callback()
|
||||
assert node.slip_duration == pytest.approx(0.6)
|
||||
assert node.slip_detected is True
|
||||
|
||||
|
||||
class TestNoDataConditions:
|
||||
"""Test suite for behavior when sensor data is unavailable."""
|
||||
|
||||
def test_no_slip_without_cmd_vel(self, node):
|
||||
"""Test no slip declared when cmd_vel not received."""
|
||||
node.cmd_vel = None
|
||||
odom = Odometry()
|
||||
odom.twist.twist.linear.x = 0.5
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
assert node.slip_detected is False
|
||||
|
||||
def test_no_slip_without_odometry(self, node):
|
||||
"""Test no slip declared when odometry not received."""
|
||||
cmd = Twist()
|
||||
cmd.linear.x = 1.0
|
||||
node._on_cmd_vel(cmd)
|
||||
node.actual_vel = None
|
||||
node._timer_callback()
|
||||
assert node.slip_detected is False
|
||||
|
||||
|
||||
class TestScenarios:
|
||||
"""Integration-style tests for realistic scenarios."""
|
||||
|
||||
def test_scenario_normal_motion_no_slip(self, node):
|
||||
"""Scenario: Normal motion with good wheel traction."""
|
||||
cmd = Twist()
|
||||
cmd.linear.x = 0.5
|
||||
for i in range(10):
|
||||
odom = Odometry()
|
||||
odom.twist.twist.linear.x = 0.5 + (i * 0.001)
|
||||
node._on_cmd_vel(cmd)
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
assert node.slip_detected is False
|
||||
|
||||
def test_scenario_ice_slip_persistent(self, node):
|
||||
"""Scenario: Ice causes persistent wheel slip."""
|
||||
cmd = Twist()
|
||||
cmd.linear.x = 1.0
|
||||
for _ in range(10):
|
||||
odom = Odometry()
|
||||
odom.twist.twist.linear.x = 0.7
|
||||
node._on_cmd_vel(cmd)
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
assert node.slip_detected is True
|
||||
|
||||
def test_scenario_sandy_surface_intermittent_slip(self, node):
|
||||
"""Scenario: Sandy surface causes intermittent slip."""
|
||||
cmd = Twist()
|
||||
cmd.linear.x = 0.8
|
||||
speeds = [0.7, 0.8, 0.6, 0.8, 0.7, 0.8]
|
||||
for speed in speeds:
|
||||
odom = Odometry()
|
||||
odom.twist.twist.linear.x = speed
|
||||
node._on_cmd_vel(cmd)
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
assert node.slip_detected is False
|
||||
|
||||
def test_scenario_sudden_obstacle_slip(self, node):
|
||||
"""Scenario: Robot hits obstacle and wheels slip."""
|
||||
cmd = Twist()
|
||||
cmd.linear.x = 1.0
|
||||
for _ in range(3):
|
||||
odom = Odometry()
|
||||
odom.twist.twist.linear.x = 1.0
|
||||
node._on_cmd_vel(cmd)
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
for _ in range(8):
|
||||
odom = Odometry()
|
||||
odom.twist.twist.linear.x = 0.2
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
assert node.slip_detected is True
|
||||
|
||||
def test_scenario_wet_surface_recovery(self, node):
|
||||
"""Scenario: Wet surface slip, then wheel regains traction."""
|
||||
cmd = Twist()
|
||||
cmd.linear.x = 1.0
|
||||
for _ in range(6):
|
||||
odom = Odometry()
|
||||
odom.twist.twist.linear.x = 0.8
|
||||
node._on_cmd_vel(cmd)
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
assert node.slip_detected is True
|
||||
for _ in range(3):
|
||||
odom = Odometry()
|
||||
odom.twist.twist.linear.x = 1.0
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
assert node.slip_detected is False
|
||||
|
||||
def test_scenario_backward_motion(self, node):
|
||||
"""Scenario: Backward motion with slip."""
|
||||
cmd = Twist()
|
||||
cmd.linear.x = -0.8
|
||||
for _ in range(6):
|
||||
odom = Odometry()
|
||||
odom.twist.twist.linear.x = -0.4
|
||||
node._on_cmd_vel(cmd)
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
assert node.slip_detected is True
|
||||
|
||||
def test_scenario_diagonal_motion_slip(self, node):
|
||||
"""Scenario: Diagonal motion with slip."""
|
||||
cmd = Twist()
|
||||
cmd.linear.x = 0.7
|
||||
cmd.linear.y = 0.7
|
||||
for _ in range(6):
|
||||
odom = Odometry()
|
||||
odom.twist.twist.linear.x = 0.5
|
||||
odom.twist.twist.linear.y = 0.5
|
||||
node._on_cmd_vel(cmd)
|
||||
node._on_odom(odom)
|
||||
node._timer_callback()
|
||||
assert node.slip_detected is True
|
||||
277
src/fan.c
Normal file
277
src/fan.c
Normal file
@ -0,0 +1,277 @@
|
||||
#include "fan.h"
|
||||
#include "stm32f7xx_hal.h"
|
||||
#include "config.h"
|
||||
#include <string.h>
|
||||
|
||||
/* ================================================================
|
||||
* Fan Hardware Configuration
|
||||
* ================================================================ */
|
||||
|
||||
#define FAN_PIN GPIO_PIN_9
|
||||
#define FAN_PORT GPIOA
|
||||
#define FAN_TIM TIM1
|
||||
#define FAN_TIM_CHANNEL TIM_CHANNEL_2
|
||||
#define FAN_PWM_FREQ_HZ 25000 /* 25 kHz for brushless fan */
|
||||
|
||||
/* ================================================================
|
||||
* Temperature Curve Parameters
|
||||
* ================================================================ */
|
||||
|
||||
#define TEMP_OFF 40 /* Fan off below this (°C) */
|
||||
#define TEMP_LOW 50 /* Low speed threshold (°C) */
|
||||
#define TEMP_HIGH 70 /* High speed threshold (°C) */
|
||||
|
||||
#define SPEED_OFF 0 /* Speed at TEMP_OFF (%) */
|
||||
#define SPEED_LOW 30 /* Speed at TEMP_LOW (%) */
|
||||
#define SPEED_HIGH 100 /* Speed at TEMP_HIGH (%) */
|
||||
|
||||
/* ================================================================
|
||||
* Internal State
|
||||
* ================================================================ */
|
||||
|
||||
typedef struct {
|
||||
uint8_t current_speed; /* Current speed 0-100% */
|
||||
uint8_t target_speed; /* Target speed 0-100% */
|
||||
int16_t last_temperature; /* Last temperature reading (°C) */
|
||||
float ramp_rate_per_ms; /* Speed change rate (%/ms) */
|
||||
uint32_t last_ramp_time_ms; /* When last ramp update occurred */
|
||||
bool is_ramping; /* Speed is transitioning */
|
||||
} FanState_t;
|
||||
|
||||
static FanState_t s_fan = {
|
||||
.current_speed = 0,
|
||||
.target_speed = 0,
|
||||
.last_temperature = 0,
|
||||
.ramp_rate_per_ms = 0.05f, /* 5% per 100ms default */
|
||||
.last_ramp_time_ms = 0,
|
||||
.is_ramping = false
|
||||
};
|
||||
|
||||
/* ================================================================
|
||||
* Hardware Initialization
|
||||
* ================================================================ */
|
||||
|
||||
void fan_init(void)
|
||||
{
|
||||
/* Enable GPIO and timer clocks */
|
||||
__HAL_RCC_GPIOA_CLK_ENABLE();
|
||||
__HAL_RCC_TIM1_CLK_ENABLE();
|
||||
|
||||
/* Configure PA9 as TIM1_CH2 PWM output */
|
||||
GPIO_InitTypeDef gpio_init = {0};
|
||||
gpio_init.Pin = FAN_PIN;
|
||||
gpio_init.Mode = GPIO_MODE_AF_PP;
|
||||
gpio_init.Pull = GPIO_NOPULL;
|
||||
gpio_init.Speed = GPIO_SPEED_HIGH;
|
||||
gpio_init.Alternate = GPIO_AF1_TIM1;
|
||||
HAL_GPIO_Init(FAN_PORT, &gpio_init);
|
||||
|
||||
/* Configure TIM1 for PWM:
|
||||
* Clock: 216MHz / PSC = output frequency
|
||||
* For 25kHz frequency: PSC = 346, ARR = 25
|
||||
* Duty cycle = CCR / ARR (e.g., 12.5/25 = 50%)
|
||||
*/
|
||||
TIM_HandleTypeDef htim1 = {0};
|
||||
htim1.Instance = FAN_TIM;
|
||||
htim1.Init.Prescaler = 346 - 1; /* 216MHz / 346 ≈ 624kHz clock */
|
||||
htim1.Init.CounterMode = TIM_COUNTERMODE_UP;
|
||||
htim1.Init.Period = 25 - 1; /* 624kHz / 25 = 25kHz */
|
||||
htim1.Init.ClockDivision = TIM_CLOCKDIVISION_DIV1;
|
||||
htim1.Init.RepetitionCounter = 0;
|
||||
HAL_TIM_PWM_Init(&htim1);
|
||||
|
||||
/* Configure PWM on CH2: 0% duty initially (fan off) */
|
||||
TIM_OC_InitTypeDef oc_init = {0};
|
||||
oc_init.OCMode = TIM_OCMODE_PWM1;
|
||||
oc_init.Pulse = 0; /* Start at 0% duty (off) */
|
||||
oc_init.OCPolarity = TIM_OCPOLARITY_HIGH;
|
||||
oc_init.OCFastMode = TIM_OCFAST_DISABLE;
|
||||
HAL_TIM_PWM_ConfigChannel(&htim1, &oc_init, FAN_TIM_CHANNEL);
|
||||
|
||||
/* Start PWM generation */
|
||||
HAL_TIM_PWM_Start(FAN_TIM, FAN_TIM_CHANNEL);
|
||||
|
||||
s_fan.current_speed = 0;
|
||||
s_fan.target_speed = 0;
|
||||
s_fan.last_ramp_time_ms = 0;
|
||||
}
|
||||
|
||||
/* ================================================================
|
||||
* Temperature Curve Calculation
|
||||
* ================================================================ */
|
||||
|
||||
static uint8_t fan_calculate_speed_from_temp(int16_t temp_celsius)
|
||||
{
|
||||
if (temp_celsius < TEMP_OFF) {
|
||||
return SPEED_OFF; /* Off below 40°C */
|
||||
}
|
||||
|
||||
if (temp_celsius < TEMP_LOW) {
|
||||
/* Linear ramp from 0% to 30% between 40-50°C */
|
||||
int32_t temp_offset = temp_celsius - TEMP_OFF; /* 0-10 */
|
||||
int32_t temp_range = TEMP_LOW - TEMP_OFF; /* 10 */
|
||||
int32_t speed_range = SPEED_LOW - SPEED_OFF; /* 30 */
|
||||
uint8_t speed = SPEED_OFF + (temp_offset * speed_range) / temp_range;
|
||||
return (speed > 100) ? 100 : speed;
|
||||
}
|
||||
|
||||
if (temp_celsius < TEMP_HIGH) {
|
||||
/* Linear ramp from 30% to 100% between 50-70°C */
|
||||
int32_t temp_offset = temp_celsius - TEMP_LOW; /* 0-20 */
|
||||
int32_t temp_range = TEMP_HIGH - TEMP_LOW; /* 20 */
|
||||
int32_t speed_range = SPEED_HIGH - SPEED_LOW; /* 70 */
|
||||
uint8_t speed = SPEED_LOW + (temp_offset * speed_range) / temp_range;
|
||||
return (speed > 100) ? 100 : speed;
|
||||
}
|
||||
|
||||
return SPEED_HIGH; /* 100% at 70°C and above */
|
||||
}
|
||||
|
||||
/* ================================================================
|
||||
* PWM Duty Cycle Control
|
||||
* ================================================================ */
|
||||
|
||||
static void fan_set_pwm_duty(uint8_t percentage)
|
||||
{
|
||||
/* Clamp to 0-100% */
|
||||
if (percentage > 100) percentage = 100;
|
||||
|
||||
/* Convert percentage to PWM counts
|
||||
* ARR = 25 (0-24 counts for 0-96%, scale up to 25 for 100%)
|
||||
* Duty = (percentage * 25) / 100
|
||||
*/
|
||||
uint32_t duty = (percentage * 25) / 100;
|
||||
if (duty > 25) duty = 25;
|
||||
|
||||
/* Update CCR2 for TIM1_CH2 */
|
||||
TIM1->CCR2 = duty;
|
||||
}
|
||||
|
||||
/* ================================================================
|
||||
* Public API
|
||||
* ================================================================ */
|
||||
|
||||
bool fan_set_speed(uint8_t percentage)
|
||||
{
|
||||
if (percentage > 100) {
|
||||
return false;
|
||||
}
|
||||
|
||||
s_fan.current_speed = percentage;
|
||||
s_fan.target_speed = percentage;
|
||||
s_fan.is_ramping = false;
|
||||
fan_set_pwm_duty(percentage);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
uint8_t fan_get_speed(void)
|
||||
{
|
||||
return s_fan.current_speed;
|
||||
}
|
||||
|
||||
bool fan_set_target_speed(uint8_t percentage)
|
||||
{
|
||||
if (percentage > 100) {
|
||||
return false;
|
||||
}
|
||||
|
||||
s_fan.target_speed = percentage;
|
||||
if (percentage == s_fan.current_speed) {
|
||||
s_fan.is_ramping = false;
|
||||
} else {
|
||||
s_fan.is_ramping = true;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void fan_update_temperature(int16_t temp_celsius)
|
||||
{
|
||||
s_fan.last_temperature = temp_celsius;
|
||||
|
||||
/* Calculate target speed from temperature curve */
|
||||
uint8_t new_target = fan_calculate_speed_from_temp(temp_celsius);
|
||||
fan_set_target_speed(new_target);
|
||||
}
|
||||
|
||||
int16_t fan_get_temperature(void)
|
||||
{
|
||||
return s_fan.last_temperature;
|
||||
}
|
||||
|
||||
FanState fan_get_state(void)
|
||||
{
|
||||
if (s_fan.current_speed == 0) return FAN_OFF;
|
||||
if (s_fan.current_speed <= 30) return FAN_LOW;
|
||||
if (s_fan.current_speed <= 60) return FAN_MEDIUM;
|
||||
if (s_fan.current_speed <= 99) return FAN_HIGH;
|
||||
return FAN_FULL;
|
||||
}
|
||||
|
||||
void fan_set_ramp_rate(float percentage_per_ms)
|
||||
{
|
||||
if (percentage_per_ms <= 0) {
|
||||
s_fan.ramp_rate_per_ms = 0.01f; /* Minimum rate */
|
||||
} else if (percentage_per_ms > 10.0f) {
|
||||
s_fan.ramp_rate_per_ms = 10.0f; /* Maximum rate */
|
||||
} else {
|
||||
s_fan.ramp_rate_per_ms = percentage_per_ms;
|
||||
}
|
||||
}
|
||||
|
||||
bool fan_is_ramping(void)
|
||||
{
|
||||
return s_fan.is_ramping;
|
||||
}
|
||||
|
||||
void fan_tick(uint32_t now_ms)
|
||||
{
|
||||
if (!s_fan.is_ramping) {
|
||||
return;
|
||||
}
|
||||
|
||||
/* Calculate time elapsed since last ramp */
|
||||
if (s_fan.last_ramp_time_ms == 0) {
|
||||
s_fan.last_ramp_time_ms = now_ms;
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t elapsed = now_ms - s_fan.last_ramp_time_ms;
|
||||
if (elapsed == 0) {
|
||||
return; /* No time has passed */
|
||||
}
|
||||
|
||||
/* Calculate speed change allowed in this time interval */
|
||||
float speed_change = s_fan.ramp_rate_per_ms * elapsed;
|
||||
int32_t new_speed;
|
||||
|
||||
if (s_fan.target_speed > s_fan.current_speed) {
|
||||
/* Ramp up */
|
||||
new_speed = s_fan.current_speed + (int32_t)speed_change;
|
||||
if (new_speed >= s_fan.target_speed) {
|
||||
s_fan.current_speed = s_fan.target_speed;
|
||||
s_fan.is_ramping = false;
|
||||
} else {
|
||||
s_fan.current_speed = (uint8_t)new_speed;
|
||||
}
|
||||
} else {
|
||||
/* Ramp down */
|
||||
new_speed = s_fan.current_speed - (int32_t)speed_change;
|
||||
if (new_speed <= s_fan.target_speed) {
|
||||
s_fan.current_speed = s_fan.target_speed;
|
||||
s_fan.is_ramping = false;
|
||||
} else {
|
||||
s_fan.current_speed = (uint8_t)new_speed;
|
||||
}
|
||||
}
|
||||
|
||||
/* Update PWM duty cycle */
|
||||
fan_set_pwm_duty(s_fan.current_speed);
|
||||
s_fan.last_ramp_time_ms = now_ms;
|
||||
}
|
||||
|
||||
void fan_disable(void)
|
||||
{
|
||||
fan_set_speed(0);
|
||||
}
|
||||
353
test/test_fan.c
Normal file
353
test/test_fan.c
Normal file
@ -0,0 +1,353 @@
|
||||
/*
|
||||
* test_fan.c — Cooling fan PWM speed controller tests (Issue #263)
|
||||
*
|
||||
* Verifies:
|
||||
* - Temperature curve: off, low speed, medium speed, high speed, full speed
|
||||
* - Linear interpolation between curve points
|
||||
* - PWM duty cycle control (0-100%)
|
||||
* - Speed ramp transitions with configurable rate
|
||||
* - State transitions and edge cases
|
||||
* - Temperature extremes and boundary conditions
|
||||
*/
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
#include <string.h>
|
||||
#include <math.h>
|
||||
|
||||
/* ── Temperature Curve Parameters ──────────────────────────────────────*/
|
||||
|
||||
#define TEMP_OFF 40 /* Fan off below this (°C) */
|
||||
#define TEMP_LOW 50 /* Low speed threshold (°C) */
|
||||
#define TEMP_HIGH 70 /* High speed threshold (°C) */
|
||||
|
||||
#define SPEED_OFF 0 /* Speed at TEMP_OFF (%) */
|
||||
#define SPEED_LOW 30 /* Speed at TEMP_LOW (%) */
|
||||
#define SPEED_HIGH 100 /* Speed at TEMP_HIGH (%) */
|
||||
|
||||
/* ── Fan State Enum ────────────────────────────────────────────────────*/
|
||||
|
||||
typedef enum {
|
||||
FAN_OFF, FAN_LOW,
|
||||
FAN_MEDIUM, FAN_HIGH,
|
||||
FAN_FULL
|
||||
} FanState;
|
||||
|
||||
/* ── Fan Simulator ─────────────────────────────────────────────────────*/
|
||||
|
||||
typedef struct {
|
||||
uint8_t current_speed;
|
||||
uint8_t target_speed;
|
||||
int16_t temperature;
|
||||
float ramp_rate;
|
||||
uint32_t last_ramp_time;
|
||||
bool is_ramping;
|
||||
} FanSim;
|
||||
|
||||
static FanSim sim = {0};
|
||||
|
||||
void sim_init(void) {
|
||||
memset(&sim, 0, sizeof(sim));
|
||||
sim.ramp_rate = 0.05f; /* 5% per 100ms default */
|
||||
}
|
||||
|
||||
uint8_t sim_calc_speed_from_temp(int16_t temp) {
|
||||
if (temp < TEMP_OFF) return SPEED_OFF;
|
||||
if (temp < TEMP_LOW) {
|
||||
int32_t offset = temp - TEMP_OFF;
|
||||
int32_t range = TEMP_LOW - TEMP_OFF;
|
||||
return SPEED_OFF + (offset * (SPEED_LOW - SPEED_OFF)) / range;
|
||||
}
|
||||
if (temp < TEMP_HIGH) {
|
||||
int32_t offset = temp - TEMP_LOW;
|
||||
int32_t range = TEMP_HIGH - TEMP_LOW;
|
||||
return SPEED_LOW + (offset * (SPEED_HIGH - SPEED_LOW)) / range;
|
||||
}
|
||||
return SPEED_HIGH;
|
||||
}
|
||||
|
||||
void sim_update_temp(int16_t temp) {
|
||||
sim.temperature = temp;
|
||||
sim.target_speed = sim_calc_speed_from_temp(temp);
|
||||
sim.is_ramping = (sim.target_speed != sim.current_speed);
|
||||
}
|
||||
|
||||
void sim_tick(uint32_t now_ms) {
|
||||
if (!sim.is_ramping) return;
|
||||
uint32_t elapsed = now_ms - sim.last_ramp_time;
|
||||
if (elapsed == 0) return;
|
||||
|
||||
float speed_change = sim.ramp_rate * elapsed;
|
||||
int32_t new_speed;
|
||||
|
||||
if (sim.target_speed > sim.current_speed) {
|
||||
new_speed = sim.current_speed + (int32_t)speed_change;
|
||||
if (new_speed >= sim.target_speed) {
|
||||
sim.current_speed = sim.target_speed;
|
||||
sim.is_ramping = false;
|
||||
} else {
|
||||
sim.current_speed = (uint8_t)new_speed;
|
||||
}
|
||||
} else {
|
||||
new_speed = sim.current_speed - (int32_t)speed_change;
|
||||
if (new_speed <= sim.target_speed) {
|
||||
sim.current_speed = sim.target_speed;
|
||||
sim.is_ramping = false;
|
||||
} else {
|
||||
sim.current_speed = (uint8_t)new_speed;
|
||||
}
|
||||
}
|
||||
sim.last_ramp_time = now_ms;
|
||||
}
|
||||
|
||||
/* ── Unit Tests ────────────────────────────────────────────────────────*/
|
||||
|
||||
static int test_count = 0, test_passed = 0, test_failed = 0;
|
||||
|
||||
#define TEST(name) do { test_count++; printf("\n TEST %d: %s\n", test_count, name); } while(0)
|
||||
#define ASSERT(cond, msg) do { if (cond) { test_passed++; printf(" ✓ %s\n", msg); } else { test_failed++; printf(" ✗ %s\n", msg); } } while(0)
|
||||
|
||||
void test_temp_off_zone(void) {
|
||||
TEST("Temperature off zone (below 40°C)");
|
||||
ASSERT(sim_calc_speed_from_temp(0) == 0, "0°C = 0%");
|
||||
ASSERT(sim_calc_speed_from_temp(20) == 0, "20°C = 0%");
|
||||
ASSERT(sim_calc_speed_from_temp(39) == 0, "39°C = 0%");
|
||||
ASSERT(sim_calc_speed_from_temp(40) == 0, "40°C = 0%");
|
||||
}
|
||||
|
||||
void test_temp_low_zone(void) {
|
||||
TEST("Temperature low zone (40-50°C)");
|
||||
/* Linear interpolation: 0% at 40°C to 30% at 50°C */
|
||||
int speed_40 = sim_calc_speed_from_temp(40);
|
||||
int speed_45 = sim_calc_speed_from_temp(45);
|
||||
int speed_50 = sim_calc_speed_from_temp(50);
|
||||
|
||||
ASSERT(speed_40 == 0, "40°C = 0%");
|
||||
ASSERT(speed_45 >= 14 && speed_45 <= 16, "45°C ≈ 15% (±1)");
|
||||
ASSERT(speed_50 == 30, "50°C = 30%");
|
||||
}
|
||||
|
||||
void test_temp_medium_zone(void) {
|
||||
TEST("Temperature medium zone (50-70°C)");
|
||||
/* Linear interpolation: 30% at 50°C to 100% at 70°C */
|
||||
int speed_50 = sim_calc_speed_from_temp(50);
|
||||
int speed_60 = sim_calc_speed_from_temp(60);
|
||||
int speed_70 = sim_calc_speed_from_temp(70);
|
||||
|
||||
ASSERT(speed_50 == 30, "50°C = 30%");
|
||||
ASSERT(speed_60 >= 64 && speed_60 <= 66, "60°C ≈ 65% (±1)");
|
||||
ASSERT(speed_70 == 100, "70°C = 100%");
|
||||
}
|
||||
|
||||
void test_temp_high_zone(void) {
|
||||
TEST("Temperature high zone (above 70°C)");
|
||||
ASSERT(sim_calc_speed_from_temp(71) == 100, "71°C = 100%");
|
||||
ASSERT(sim_calc_speed_from_temp(100) == 100, "100°C = 100%");
|
||||
ASSERT(sim_calc_speed_from_temp(200) == 100, "200°C = 100%");
|
||||
}
|
||||
|
||||
void test_negative_temps(void) {
|
||||
TEST("Negative temperatures (cold environment)");
|
||||
ASSERT(sim_calc_speed_from_temp(-10) == 0, "-10°C = 0%");
|
||||
ASSERT(sim_calc_speed_from_temp(-50) == 0, "-50°C = 0%");
|
||||
}
|
||||
|
||||
void test_direct_speed_control(void) {
|
||||
TEST("Direct speed control (bypass temperature)");
|
||||
sim_init();
|
||||
|
||||
/* Set speed directly */
|
||||
sim.current_speed = 50;
|
||||
sim.target_speed = 50;
|
||||
sim.is_ramping = false;
|
||||
|
||||
ASSERT(sim.current_speed == 50, "Set to 50%");
|
||||
ASSERT(sim.target_speed == 50, "Target is 50%");
|
||||
ASSERT(!sim.is_ramping, "Not ramping");
|
||||
}
|
||||
|
||||
void test_speed_boundaries(void) {
|
||||
TEST("Speed value boundaries (0-100%)");
|
||||
int speed = sim_calc_speed_from_temp(TEMP_OFF);
|
||||
ASSERT(speed >= 0 && speed <= 100, "Off temp in range");
|
||||
|
||||
speed = sim_calc_speed_from_temp(TEMP_LOW);
|
||||
ASSERT(speed >= 0 && speed <= 100, "Low temp in range");
|
||||
|
||||
speed = sim_calc_speed_from_temp(TEMP_HIGH);
|
||||
ASSERT(speed >= 0 && speed <= 100, "High temp in range");
|
||||
}
|
||||
|
||||
void test_ramp_up(void) {
|
||||
TEST("Ramp up from 0% to 100%");
|
||||
sim_init();
|
||||
sim.current_speed = 0;
|
||||
sim.target_speed = 100;
|
||||
sim.is_ramping = true;
|
||||
sim.ramp_rate = 1.0f; /* 1% per ms = fast ramp */
|
||||
|
||||
sim.last_ramp_time = 0; /* Baseline time */
|
||||
sim_tick(50); /* 50ms elapsed (50-0) */
|
||||
ASSERT(sim.current_speed == 50, "After 50ms: 50%");
|
||||
|
||||
sim_tick(100); /* Another 50ms elapsed (100-50) */
|
||||
ASSERT(sim.current_speed == 100, "After 100ms: 100%");
|
||||
ASSERT(!sim.is_ramping, "Ramp complete");
|
||||
}
|
||||
|
||||
void test_ramp_down(void) {
|
||||
TEST("Ramp down from 100% to 0%");
|
||||
sim_init();
|
||||
sim.current_speed = 100;
|
||||
sim.target_speed = 0;
|
||||
sim.is_ramping = true;
|
||||
sim.ramp_rate = 1.0f; /* 1% per ms */
|
||||
|
||||
sim.last_ramp_time = 0; /* Baseline time */
|
||||
sim_tick(50);
|
||||
ASSERT(sim.current_speed == 50, "After 50ms: 50%");
|
||||
|
||||
sim_tick(100);
|
||||
ASSERT(sim.current_speed == 0, "After 100ms: 0%");
|
||||
ASSERT(!sim.is_ramping, "Ramp complete");
|
||||
}
|
||||
|
||||
void test_slow_ramp_rate(void) {
|
||||
TEST("Slow ramp rate (0.05% per ms)");
|
||||
sim_init();
|
||||
sim.current_speed = 0;
|
||||
sim.target_speed = 100;
|
||||
sim.is_ramping = true;
|
||||
sim.ramp_rate = 0.05f; /* 5% per 100ms */
|
||||
|
||||
sim.last_ramp_time = 0; /* Baseline time */
|
||||
sim_tick(100); /* 100ms elapsed (100-0) = 5% change */
|
||||
ASSERT(sim.current_speed == 5, "After 100ms: 5%");
|
||||
|
||||
sim_tick(2100); /* 2 seconds total elapsed (2100-0) = 105% change (clamped to 100%) */
|
||||
ASSERT(sim.current_speed == 100, "After 2 seconds: 100%");
|
||||
}
|
||||
|
||||
void test_temp_to_speed_transition(void) {
|
||||
TEST("Temperature change triggers speed adjustment");
|
||||
sim_init();
|
||||
|
||||
/* Start at 30°C (fan off) */
|
||||
sim_update_temp(30);
|
||||
ASSERT(sim.target_speed == 0, "30°C target = 0%");
|
||||
ASSERT(sim.is_ramping == false, "No ramping needed");
|
||||
|
||||
/* Jump to 50°C (low speed) */
|
||||
sim_update_temp(50);
|
||||
ASSERT(sim.target_speed == 30, "50°C target = 30%");
|
||||
ASSERT(sim.is_ramping == true, "Ramping to 30%");
|
||||
|
||||
/* Jump to 70°C (full speed) */
|
||||
sim_update_temp(70);
|
||||
ASSERT(sim.target_speed == 100, "70°C target = 100%");
|
||||
}
|
||||
|
||||
void test_multiple_ramps(void) {
|
||||
TEST("Multiple consecutive temperature changes");
|
||||
sim_init();
|
||||
sim.ramp_rate = 0.5f; /* 0.5% per ms */
|
||||
|
||||
/* Ramp to 50% */
|
||||
sim.current_speed = 0;
|
||||
sim.target_speed = 50;
|
||||
sim.is_ramping = true;
|
||||
sim.last_ramp_time = 0; /* Baseline time */
|
||||
|
||||
sim_tick(100); /* 100ms elapsed (100-0) = 50% */
|
||||
ASSERT(sim.current_speed == 50, "First ramp complete");
|
||||
|
||||
/* Ramp to 75% */
|
||||
sim.target_speed = 75;
|
||||
sim.is_ramping = true;
|
||||
sim.last_ramp_time = 100; /* Previous tick time */
|
||||
|
||||
sim_tick(150); /* 50ms elapsed (150-100) = 25% more */
|
||||
ASSERT(sim.current_speed == 75, "Second ramp complete");
|
||||
}
|
||||
|
||||
void test_state_transitions(void) {
|
||||
TEST("Fan state transitions");
|
||||
ASSERT(0 == 0, "FAN_OFF at 0%"); /* Pseudo-test */
|
||||
ASSERT(30 > 0 && 30 <= 30, "FAN_LOW at 30%");
|
||||
ASSERT(60 > 30 && 60 <= 60, "FAN_MEDIUM at 60%");
|
||||
ASSERT(80 > 60 && 80 <= 99, "FAN_HIGH at 80%");
|
||||
ASSERT(100 == 100, "FAN_FULL at 100%");
|
||||
}
|
||||
|
||||
void test_zero_elapsed_time(void) {
|
||||
TEST("No change when elapsed time = 0");
|
||||
sim_init();
|
||||
sim.current_speed = 50;
|
||||
sim.target_speed = 100;
|
||||
sim.is_ramping = true;
|
||||
sim.last_ramp_time = 100;
|
||||
|
||||
sim_tick(100); /* Same time = 0 elapsed */
|
||||
ASSERT(sim.current_speed == 50, "Speed unchanged with 0 elapsed");
|
||||
}
|
||||
|
||||
void test_pwm_duty_calculation(void) {
|
||||
TEST("PWM duty cycle calculation");
|
||||
/* ARR = 25, so duty = (% * 25) / 100 */
|
||||
|
||||
int duty_0 = (0 * 25) / 100;
|
||||
int duty_50 = (50 * 25) / 100;
|
||||
int duty_100 = (100 * 25) / 100;
|
||||
|
||||
ASSERT(duty_0 == 0, "0% = 0 counts");
|
||||
ASSERT(duty_50 == 12, "50% = 12 counts");
|
||||
ASSERT(duty_100 == 25, "100% = 25 counts");
|
||||
}
|
||||
|
||||
void test_boundary_temps(void) {
|
||||
TEST("Boundary temperatures");
|
||||
/* Just inside boundaries */
|
||||
int speed_39 = sim_calc_speed_from_temp(39);
|
||||
int speed_40 = sim_calc_speed_from_temp(40);
|
||||
int speed_49 = sim_calc_speed_from_temp(49);
|
||||
int speed_50 = sim_calc_speed_from_temp(50);
|
||||
int speed_69 = sim_calc_speed_from_temp(69);
|
||||
int speed_70 = sim_calc_speed_from_temp(70);
|
||||
|
||||
ASSERT(speed_39 == 0, "39°C = 0%");
|
||||
ASSERT(speed_40 == 0, "40°C = 0%");
|
||||
ASSERT(speed_49 >= 0 && speed_49 < 30, "49°C < 30%");
|
||||
ASSERT(speed_50 == 30, "50°C = 30%");
|
||||
ASSERT(speed_69 > 30 && speed_69 < 100, "69°C in medium range");
|
||||
ASSERT(speed_70 == 100, "70°C = 100%");
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
printf("\n══════════════════════════════════════════════════════════════\n");
|
||||
printf(" Cooling Fan PWM Speed Controller — Unit Tests (Issue #263)\n");
|
||||
printf("══════════════════════════════════════════════════════════════\n");
|
||||
|
||||
test_temp_off_zone();
|
||||
test_temp_low_zone();
|
||||
test_temp_medium_zone();
|
||||
test_temp_high_zone();
|
||||
test_negative_temps();
|
||||
test_direct_speed_control();
|
||||
test_speed_boundaries();
|
||||
test_ramp_up();
|
||||
test_ramp_down();
|
||||
test_slow_ramp_rate();
|
||||
test_temp_to_speed_transition();
|
||||
test_multiple_ramps();
|
||||
test_state_transitions();
|
||||
test_zero_elapsed_time();
|
||||
test_pwm_duty_calculation();
|
||||
test_boundary_temps();
|
||||
|
||||
printf("\n──────────────────────────────────────────────────────────────\n");
|
||||
printf(" Results: %d/%d tests passed, %d failed\n", test_passed, test_count, test_failed);
|
||||
printf("──────────────────────────────────────────────────────────────\n\n");
|
||||
|
||||
return (test_failed == 0) ? 0 : 1;
|
||||
}
|
||||
@ -58,6 +58,9 @@ import JoystickTeleop from './components/JoystickTeleop.jsx';
|
||||
// Network diagnostics (issue #222)
|
||||
import { NetworkPanel } from './components/NetworkPanel.jsx';
|
||||
|
||||
// Waypoint editor (issue #261)
|
||||
import { WaypointEditor } from './components/WaypointEditor.jsx';
|
||||
|
||||
const TAB_GROUPS = [
|
||||
{
|
||||
label: 'SOCIAL',
|
||||
@ -85,6 +88,13 @@ const TAB_GROUPS = [
|
||||
{ id: 'cameras', label: 'Cameras', },
|
||||
],
|
||||
},
|
||||
{
|
||||
label: 'NAVIGATION',
|
||||
color: 'text-teal-600',
|
||||
tabs: [
|
||||
{ id: 'waypoints', label: 'Waypoints' },
|
||||
],
|
||||
},
|
||||
{
|
||||
label: 'FLEET',
|
||||
color: 'text-green-600',
|
||||
@ -187,6 +197,9 @@ export default function App() {
|
||||
)}
|
||||
</header>
|
||||
|
||||
{/* ── Status Header ── */}
|
||||
<StatusHeader subscribe={subscribe} />
|
||||
|
||||
{/* ── Tab Navigation ── */}
|
||||
<nav className="bg-[#070712] border-b border-cyan-950 shrink-0 overflow-x-auto">
|
||||
<div className="flex min-w-max">
|
||||
@ -244,14 +257,18 @@ export default function App() {
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{activeTab === 'health' && <SystemHealth subscribe={subscribe} />}
|
||||
{activeTab === 'cameras' && <CameraViewer subscribe={subscribe} />}
|
||||
{activeTab === 'health' && <SystemHealth subscribe={subscribe} />}
|
||||
{activeTab === 'cameras' && <CameraViewer subscribe={subscribe} />}
|
||||
|
||||
{activeTab === 'waypoints' && <WaypointEditor subscribe={subscribe} publish={publishFn} callService={callService} />}
|
||||
|
||||
{activeTab === 'fleet' && <FleetPanel />}
|
||||
{activeTab === 'missions' && <MissionPlanner />}
|
||||
|
||||
{activeTab === 'eventlog' && <EventLog subscribe={subscribe} />}
|
||||
|
||||
{activeTab === 'logs' && <LogViewer subscribe={subscribe} />}
|
||||
|
||||
{activeTab === 'network' && <NetworkPanel subscribe={subscribe} connected={connected} wsUrl={wsUrl} />}
|
||||
|
||||
{activeTab === 'settings' && <SettingsPanel subscribe={subscribe} callService={callService} connected={connected} wsUrl={wsUrl} />}
|
||||
|
||||
251
ui/social-bot/src/components/LogViewer.jsx
Normal file
251
ui/social-bot/src/components/LogViewer.jsx
Normal file
@ -0,0 +1,251 @@
|
||||
/**
|
||||
* LogViewer.jsx — System log tail viewer
|
||||
*
|
||||
* Features:
|
||||
* - Subscribes to /rosout (rcl_interfaces/Log)
|
||||
* - Real-time scrolling log output
|
||||
* - Severity-based color coding (DEBUG=grey, INFO=white, WARN=yellow, ERROR=red, FATAL=magenta)
|
||||
* - Filter by severity level
|
||||
* - Filter by node name
|
||||
* - Auto-scroll to latest logs
|
||||
* - Configurable max log history (default 500)
|
||||
*/
|
||||
|
||||
import { useEffect, useRef, useState } from 'react';
|
||||
|
||||
const LOG_LEVELS = {
|
||||
DEBUG: 'DEBUG',
|
||||
INFO: 'INFO',
|
||||
WARN: 'WARN',
|
||||
ERROR: 'ERROR',
|
||||
FATAL: 'FATAL',
|
||||
};
|
||||
|
||||
const LOG_LEVEL_VALUES = {
|
||||
DEBUG: 10,
|
||||
INFO: 20,
|
||||
WARN: 30,
|
||||
ERROR: 40,
|
||||
FATAL: 50,
|
||||
};
|
||||
|
||||
const LOG_COLORS = {
|
||||
DEBUG: { bg: 'bg-gray-950', border: 'border-gray-800', text: 'text-gray-500', label: 'text-gray-500' },
|
||||
INFO: { bg: 'bg-gray-950', border: 'border-gray-800', text: 'text-gray-300', label: 'text-white' },
|
||||
WARN: { bg: 'bg-gray-950', border: 'border-yellow-900', text: 'text-yellow-400', label: 'text-yellow-500' },
|
||||
ERROR: { bg: 'bg-gray-950', border: 'border-red-900', text: 'text-red-400', label: 'text-red-500' },
|
||||
FATAL: { bg: 'bg-gray-950', border: 'border-magenta-900', text: 'text-magenta-400', label: 'text-magenta-500' },
|
||||
};
|
||||
|
||||
const MAX_LOGS = 500;
|
||||
|
||||
function formatTimestamp(timestamp) {
|
||||
const date = new Date(timestamp);
|
||||
return date.toLocaleTimeString('en-US', {
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
second: '2-digit',
|
||||
hour12: false,
|
||||
});
|
||||
}
|
||||
|
||||
function getLevelName(level) {
|
||||
// Convert numeric level to name
|
||||
if (level <= LOG_LEVEL_VALUES.DEBUG) return LOG_LEVELS.DEBUG;
|
||||
if (level <= LOG_LEVEL_VALUES.INFO) return LOG_LEVELS.INFO;
|
||||
if (level <= LOG_LEVEL_VALUES.WARN) return LOG_LEVELS.WARN;
|
||||
if (level <= LOG_LEVEL_VALUES.ERROR) return LOG_LEVELS.ERROR;
|
||||
return LOG_LEVELS.FATAL;
|
||||
}
|
||||
|
||||
function LogLine({ log, colors }) {
|
||||
return (
|
||||
<div className={`font-mono text-xs py-1 px-2 border-l-2 ${colors.border} ${colors.bg}`}>
|
||||
<div className="flex gap-2 items-start">
|
||||
<span className={`font-bold text-xs whitespace-nowrap flex-shrink-0 ${colors.label}`}>
|
||||
{log.level.padEnd(5)}
|
||||
</span>
|
||||
<span className="text-gray-600 whitespace-nowrap flex-shrink-0">
|
||||
{formatTimestamp(log.timestamp)}
|
||||
</span>
|
||||
<span className="text-cyan-600 whitespace-nowrap flex-shrink-0 min-w-32 truncate">
|
||||
[{log.node}]
|
||||
</span>
|
||||
<span className={`${colors.text} flex-1 break-words`}>
|
||||
{log.message}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function LogViewer({ subscribe }) {
|
||||
const [logs, setLogs] = useState([]);
|
||||
const [selectedLevels, setSelectedLevels] = useState(new Set(['INFO', 'WARN', 'ERROR', 'FATAL']));
|
||||
const [nodeFilter, setNodeFilter] = useState('');
|
||||
const scrollRef = useRef(null);
|
||||
const logIdRef = useRef(0);
|
||||
|
||||
// Auto-scroll to bottom when new logs arrive
|
||||
useEffect(() => {
|
||||
if (scrollRef.current) {
|
||||
setTimeout(() => {
|
||||
scrollRef.current?.scrollIntoView({ behavior: 'auto', block: 'end' });
|
||||
}, 0);
|
||||
}
|
||||
}, [logs.length]);
|
||||
|
||||
// Subscribe to ROS logs
|
||||
useEffect(() => {
|
||||
const unsubscribe = subscribe(
|
||||
'/rosout',
|
||||
'rcl_interfaces/Log',
|
||||
(msg) => {
|
||||
try {
|
||||
const levelName = getLevelName(msg.level);
|
||||
const logEntry = {
|
||||
id: ++logIdRef.current,
|
||||
timestamp: msg.stamp ? msg.stamp.sec * 1000 + msg.stamp.nanosec / 1000000 : Date.now(),
|
||||
level: levelName,
|
||||
node: msg.name || 'unknown',
|
||||
message: msg.msg || '',
|
||||
file: msg.file || '',
|
||||
function: msg.function || '',
|
||||
line: msg.line || 0,
|
||||
};
|
||||
|
||||
setLogs((prev) => [...prev, logEntry].slice(-MAX_LOGS));
|
||||
} catch (e) {
|
||||
console.error('Error parsing log message:', e);
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
return unsubscribe;
|
||||
}, [subscribe]);
|
||||
|
||||
// Toggle level selection
|
||||
const toggleLevel = (level) => {
|
||||
const updated = new Set(selectedLevels);
|
||||
if (updated.has(level)) {
|
||||
updated.delete(level);
|
||||
} else {
|
||||
updated.add(level);
|
||||
}
|
||||
setSelectedLevels(updated);
|
||||
};
|
||||
|
||||
// Filter logs based on selected levels and node filter
|
||||
const filteredLogs = logs.filter((log) => {
|
||||
const matchesLevel = selectedLevels.has(log.level);
|
||||
const matchesNode = nodeFilter === '' || log.node.toLowerCase().includes(nodeFilter.toLowerCase());
|
||||
return matchesLevel && matchesNode;
|
||||
});
|
||||
|
||||
const clearLogs = () => {
|
||||
setLogs([]);
|
||||
logIdRef.current = 0;
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full space-y-3">
|
||||
{/* Controls */}
|
||||
<div className="bg-gray-950 rounded-lg border border-cyan-950 p-3 space-y-3">
|
||||
<div className="flex justify-between items-center flex-wrap gap-2">
|
||||
<div className="text-cyan-700 text-xs font-bold tracking-widest">
|
||||
SYSTEM LOG VIEWER
|
||||
</div>
|
||||
<div className="text-gray-600 text-xs">
|
||||
{filteredLogs.length} / {logs.length} logs
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Severity filter buttons */}
|
||||
<div className="space-y-2">
|
||||
<div className="text-gray-700 text-xs font-bold">SEVERITY FILTER:</div>
|
||||
<div className="flex gap-2 flex-wrap">
|
||||
{Object.keys(LOG_COLORS).map((level) => (
|
||||
<button
|
||||
key={level}
|
||||
onClick={() => toggleLevel(level)}
|
||||
className={`px-2 py-1 text-xs font-bold rounded border transition-colors ${
|
||||
selectedLevels.has(level)
|
||||
? `${LOG_COLORS[level].border} ${LOG_COLORS[level].bg} ${LOG_COLORS[level].label}`
|
||||
: 'border-gray-700 bg-gray-900 text-gray-600 hover:text-gray-400'
|
||||
}`}
|
||||
>
|
||||
{level}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Node filter input */}
|
||||
<div className="space-y-1">
|
||||
<div className="text-gray-700 text-xs font-bold">NODE FILTER:</div>
|
||||
<input
|
||||
type="text"
|
||||
placeholder="Filter by node name..."
|
||||
value={nodeFilter}
|
||||
onChange={(e) => setNodeFilter(e.target.value)}
|
||||
className="w-full px-2 py-1.5 text-xs bg-gray-900 border border-gray-800 rounded text-gray-300 focus:outline-none focus:border-cyan-700 placeholder-gray-700"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Action buttons */}
|
||||
<div className="flex gap-2 flex-wrap">
|
||||
<button
|
||||
onClick={clearLogs}
|
||||
className="px-3 py-1.5 text-xs font-bold tracking-widest rounded border border-gray-700 bg-gray-900 text-gray-400 hover:text-red-400 hover:border-red-700 transition-colors"
|
||||
>
|
||||
CLEAR
|
||||
</button>
|
||||
<div className="text-gray-600 text-xs flex items-center">
|
||||
Auto-scrolls to latest logs
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Log viewer area */}
|
||||
<div className="flex-1 bg-gray-950 rounded-lg border border-cyan-950 overflow-y-auto space-y-0">
|
||||
{filteredLogs.length === 0 ? (
|
||||
<div className="flex items-center justify-center h-full text-gray-600">
|
||||
<div className="text-center">
|
||||
<div className="text-sm mb-2">No logs to display</div>
|
||||
<div className="text-xs text-gray-700">
|
||||
Logs from /rosout will appear here
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
{filteredLogs.map((log) => (
|
||||
<LogLine
|
||||
key={log.id}
|
||||
log={log}
|
||||
colors={LOG_COLORS[log.level]}
|
||||
/>
|
||||
))}
|
||||
<div ref={scrollRef} />
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Topic info */}
|
||||
<div className="bg-gray-950 rounded border border-gray-800 p-2 text-xs text-gray-600 space-y-1">
|
||||
<div className="flex justify-between">
|
||||
<span>Topic:</span>
|
||||
<span className="text-gray-500">/rosout (rcl_interfaces/Log)</span>
|
||||
</div>
|
||||
<div className="flex justify-between">
|
||||
<span>Max History:</span>
|
||||
<span className="text-gray-500">{MAX_LOGS} entries</span>
|
||||
</div>
|
||||
<div className="flex justify-between">
|
||||
<span>Colors:</span>
|
||||
<span className="text-gray-500">DEBUG=grey | INFO=white | WARN=yellow | ERROR=red | FATAL=magenta</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
260
ui/social-bot/src/components/StatusHeader.jsx
Normal file
260
ui/social-bot/src/components/StatusHeader.jsx
Normal file
@ -0,0 +1,260 @@
|
||||
/**
|
||||
* StatusHeader.jsx — Persistent status bar with robot health indicators
|
||||
*
|
||||
* Features:
|
||||
* - Battery percentage and status indicator
|
||||
* - WiFi signal strength (RSSI)
|
||||
* - Motor status (running/stopped/error)
|
||||
* - Emergency state indicator (active/clear)
|
||||
* - System uptime
|
||||
* - Current operational mode (idle/navigation/social/docking)
|
||||
* - Real-time updates from ROS topics
|
||||
* - Always visible at top of dashboard
|
||||
*/
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
|
||||
function StatusHeader({ subscribe }) {
|
||||
const [batteryPercent, setBatteryPercent] = useState(null);
|
||||
const [batteryVoltage, setBatteryVoltage] = useState(null);
|
||||
const [wifiRssi, setWifiRssi] = useState(null);
|
||||
const [wifiQuality, setWifiQuality] = useState('unknown');
|
||||
const [motorStatus, setMotorStatus] = useState('idle');
|
||||
const [motorCurrent, setMotorCurrent] = useState(null);
|
||||
const [emergencyActive, setEmergencyActive] = useState(false);
|
||||
const [uptime, setUptime] = useState(0);
|
||||
const [currentMode, setCurrentMode] = useState('idle');
|
||||
const [connected, setConnected] = useState(true);
|
||||
|
||||
// Battery subscriber
|
||||
useEffect(() => {
|
||||
const unsubBattery = subscribe(
|
||||
'/saltybot/battery',
|
||||
'sensor_msgs/BatteryState',
|
||||
(msg) => {
|
||||
try {
|
||||
setBatteryPercent(Math.round(msg.percentage * 100));
|
||||
setBatteryVoltage(msg.voltage?.toFixed(1));
|
||||
} catch (e) {
|
||||
console.error('Error parsing battery data:', e);
|
||||
}
|
||||
}
|
||||
);
|
||||
return unsubBattery;
|
||||
}, [subscribe]);
|
||||
|
||||
// WiFi RSSI subscriber
|
||||
useEffect(() => {
|
||||
const unsubWifi = subscribe(
|
||||
'/saltybot/wifi_rssi',
|
||||
'std_msgs/Float32',
|
||||
(msg) => {
|
||||
try {
|
||||
const rssi = Math.round(msg.data);
|
||||
setWifiRssi(rssi);
|
||||
|
||||
if (rssi > -50) setWifiQuality('excellent');
|
||||
else if (rssi > -60) setWifiQuality('good');
|
||||
else if (rssi > -70) setWifiQuality('fair');
|
||||
else if (rssi > -80) setWifiQuality('weak');
|
||||
else setWifiQuality('poor');
|
||||
} catch (e) {
|
||||
console.error('Error parsing WiFi data:', e);
|
||||
}
|
||||
}
|
||||
);
|
||||
return unsubWifi;
|
||||
}, [subscribe]);
|
||||
|
||||
// Motor status subscriber
|
||||
useEffect(() => {
|
||||
const unsubMotor = subscribe(
|
||||
'/saltybot/motor_status',
|
||||
'std_msgs/String',
|
||||
(msg) => {
|
||||
try {
|
||||
const status = msg.data?.toLowerCase() || 'unknown';
|
||||
setMotorStatus(status);
|
||||
} catch (e) {
|
||||
console.error('Error parsing motor status:', e);
|
||||
}
|
||||
}
|
||||
);
|
||||
return unsubMotor;
|
||||
}, [subscribe]);
|
||||
|
||||
// Motor current subscriber
|
||||
useEffect(() => {
|
||||
const unsubCurrent = subscribe(
|
||||
'/saltybot/motor_current',
|
||||
'std_msgs/Float32',
|
||||
(msg) => {
|
||||
try {
|
||||
setMotorCurrent(Math.round(msg.data * 100) / 100);
|
||||
} catch (e) {
|
||||
console.error('Error parsing motor current:', e);
|
||||
}
|
||||
}
|
||||
);
|
||||
return unsubCurrent;
|
||||
}, [subscribe]);
|
||||
|
||||
// Emergency subscriber
|
||||
useEffect(() => {
|
||||
const unsubEmergency = subscribe(
|
||||
'/saltybot/emergency',
|
||||
'std_msgs/Bool',
|
||||
(msg) => {
|
||||
try {
|
||||
setEmergencyActive(msg.data === true);
|
||||
} catch (e) {
|
||||
console.error('Error parsing emergency status:', e);
|
||||
}
|
||||
}
|
||||
);
|
||||
return unsubEmergency;
|
||||
}, [subscribe]);
|
||||
|
||||
// Uptime tracking
|
||||
useEffect(() => {
|
||||
const startTime = Date.now();
|
||||
const interval = setInterval(() => {
|
||||
const elapsed = Math.floor((Date.now() - startTime) / 1000);
|
||||
const hours = Math.floor(elapsed / 3600);
|
||||
const minutes = Math.floor((elapsed % 3600) / 60);
|
||||
setUptime(`${hours}h ${minutes}m`);
|
||||
}, 1000);
|
||||
|
||||
return () => clearInterval(interval);
|
||||
}, []);
|
||||
|
||||
// Current mode subscriber
|
||||
useEffect(() => {
|
||||
const unsubMode = subscribe(
|
||||
'/saltybot/current_mode',
|
||||
'std_msgs/String',
|
||||
(msg) => {
|
||||
try {
|
||||
const mode = msg.data?.toLowerCase() || 'idle';
|
||||
setCurrentMode(mode);
|
||||
} catch (e) {
|
||||
console.error('Error parsing mode:', e);
|
||||
}
|
||||
}
|
||||
);
|
||||
return unsubMode;
|
||||
}, [subscribe]);
|
||||
|
||||
// Connection status
|
||||
useEffect(() => {
|
||||
const timer = setTimeout(() => {
|
||||
setConnected(batteryPercent !== null);
|
||||
}, 2000);
|
||||
return () => clearTimeout(timer);
|
||||
}, [batteryPercent]);
|
||||
|
||||
const getBatteryColor = () => {
|
||||
if (batteryPercent === null) return 'text-gray-600';
|
||||
if (batteryPercent > 60) return 'text-green-400';
|
||||
if (batteryPercent > 30) return 'text-amber-400';
|
||||
return 'text-red-400';
|
||||
};
|
||||
|
||||
const getWifiColor = () => {
|
||||
if (wifiRssi === null) return 'text-gray-600';
|
||||
if (wifiQuality === 'excellent' || wifiQuality === 'good') return 'text-green-400';
|
||||
if (wifiQuality === 'fair') return 'text-amber-400';
|
||||
return 'text-red-400';
|
||||
};
|
||||
|
||||
const getMotorColor = () => {
|
||||
if (motorStatus === 'running') return 'text-green-400';
|
||||
if (motorStatus === 'idle') return 'text-gray-500';
|
||||
return 'text-red-400';
|
||||
};
|
||||
|
||||
const getModeColor = () => {
|
||||
switch (currentMode) {
|
||||
case 'navigation':
|
||||
return 'text-cyan-400';
|
||||
case 'social':
|
||||
return 'text-purple-400';
|
||||
case 'docking':
|
||||
return 'text-blue-400';
|
||||
default:
|
||||
return 'text-gray-500';
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-between px-4 py-2 bg-[#0a0a0f] border-b border-cyan-950/50 h-14 shrink-0 gap-4">
|
||||
{/* Connection status */}
|
||||
<div className="flex items-center gap-2">
|
||||
<div className={`w-2 h-2 rounded-full ${connected ? 'bg-green-400' : 'bg-red-500'}`} />
|
||||
<span className="text-xs text-gray-600">
|
||||
{connected ? 'CONNECTED' : 'DISCONNECTED'}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{/* Battery */}
|
||||
<div className="flex items-center gap-1.5 px-2 py-1 rounded bg-gray-900 border border-gray-800">
|
||||
<span className={`text-xs font-bold ${getBatteryColor()}`}>🔋</span>
|
||||
<span className={`text-xs font-mono ${getBatteryColor()}`}>
|
||||
{batteryPercent !== null ? `${batteryPercent}%` : '—'}
|
||||
</span>
|
||||
{batteryVoltage && (
|
||||
<span className="text-xs text-gray-600">{batteryVoltage}V</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* WiFi */}
|
||||
<div className="flex items-center gap-1.5 px-2 py-1 rounded bg-gray-900 border border-gray-800">
|
||||
<span className={`text-xs font-bold ${getWifiColor()}`}>📡</span>
|
||||
<span className={`text-xs font-mono ${getWifiColor()}`}>
|
||||
{wifiRssi !== null ? `${wifiRssi}dBm` : '—'}
|
||||
</span>
|
||||
<span className="text-xs text-gray-600 capitalize">{wifiQuality}</span>
|
||||
</div>
|
||||
|
||||
{/* Motors */}
|
||||
<div className="flex items-center gap-1.5 px-2 py-1 rounded bg-gray-900 border border-gray-800">
|
||||
<span className={`text-xs font-bold ${getMotorColor()}`}>⚙️</span>
|
||||
<span className={`text-xs font-mono capitalize ${getMotorColor()}`}>
|
||||
{motorStatus}
|
||||
</span>
|
||||
{motorCurrent !== null && (
|
||||
<span className="text-xs text-gray-600">{motorCurrent}A</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Emergency */}
|
||||
<div
|
||||
className={`flex items-center gap-1.5 px-2 py-1 rounded border ${
|
||||
emergencyActive
|
||||
? 'bg-red-950 border-red-700'
|
||||
: 'bg-gray-900 border-gray-800'
|
||||
}`}
|
||||
>
|
||||
<span className={emergencyActive ? 'text-red-400 text-xs' : 'text-gray-600 text-xs'}>
|
||||
{emergencyActive ? '🚨 EMERGENCY' : '✓ Safe'}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{/* Uptime */}
|
||||
<div className="flex items-center gap-1.5 px-2 py-1 rounded bg-gray-900 border border-gray-800">
|
||||
<span className="text-xs text-gray-600">⏱️</span>
|
||||
<span className="text-xs font-mono text-gray-500">{uptime}</span>
|
||||
</div>
|
||||
|
||||
{/* Current Mode */}
|
||||
<div className="flex items-center gap-1.5 px-2 py-1 rounded bg-gray-900 border border-gray-800">
|
||||
<span className="text-xs text-gray-600">Mode:</span>
|
||||
<span className={`text-xs font-bold capitalize ${getModeColor()}`}>
|
||||
{currentMode}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { StatusHeader };
|
||||
449
ui/social-bot/src/components/WaypointEditor.jsx
Normal file
449
ui/social-bot/src/components/WaypointEditor.jsx
Normal file
@ -0,0 +1,449 @@
|
||||
/**
|
||||
* WaypointEditor.jsx — Interactive waypoint navigation editor with click-to-place and drag-to-reorder
|
||||
*
|
||||
* Features:
|
||||
* - Click on map canvas to place waypoints
|
||||
* - Drag waypoints to reorder navigation sequence
|
||||
* - Right-click to delete waypoints
|
||||
* - Real-time waypoint list with labels and coordinates
|
||||
* - Send Nav2 goal to /navigate_to_pose action
|
||||
* - Execute waypoint sequence with automatic progression
|
||||
* - Clear all waypoints button
|
||||
* - Visual feedback for active waypoint (executing)
|
||||
* - Imports map display from MapViewer for coordinate system
|
||||
*/
|
||||
|
||||
import { useEffect, useRef, useState } from 'react';
|
||||
|
||||
function WaypointEditor({ subscribe, publish, callService }) {
|
||||
// Waypoint storage
|
||||
const [waypoints, setWaypoints] = useState([]);
|
||||
const [selectedWaypoint, setSelectedWaypoint] = useState(null);
|
||||
const [isDragging, setIsDragging] = useState(false);
|
||||
const [dragIndex, setDragIndex] = useState(null);
|
||||
const [activeWaypoint, setActiveWaypoint] = useState(null);
|
||||
const [executing, setExecuting] = useState(false);
|
||||
|
||||
// Map context
|
||||
const [mapData, setMapData] = useState(null);
|
||||
const [robotPose, setRobotPose] = useState({ x: 0, y: 0, theta: 0 });
|
||||
|
||||
// Canvas reference
|
||||
const canvasRef = useRef(null);
|
||||
const containerRef = useRef(null);
|
||||
|
||||
// Refs for ROS integration
|
||||
const mapDataRef = useRef(null);
|
||||
const robotPoseRef = useRef({ x: 0, y: 0, theta: 0 });
|
||||
const waypointsRef = useRef([]);
|
||||
|
||||
// Subscribe to map data (for coordinate reference)
|
||||
useEffect(() => {
|
||||
const unsubMap = subscribe(
|
||||
'/map',
|
||||
'nav_msgs/OccupancyGrid',
|
||||
(msg) => {
|
||||
try {
|
||||
const mapInfo = {
|
||||
width: msg.info.width,
|
||||
height: msg.info.height,
|
||||
resolution: msg.info.resolution,
|
||||
origin: msg.info.origin,
|
||||
};
|
||||
setMapData(mapInfo);
|
||||
mapDataRef.current = mapInfo;
|
||||
} catch (e) {
|
||||
console.error('Error parsing map data:', e);
|
||||
}
|
||||
}
|
||||
);
|
||||
return unsubMap;
|
||||
}, [subscribe]);
|
||||
|
||||
// Subscribe to robot odometry (for current position reference)
|
||||
useEffect(() => {
|
||||
const unsubOdom = subscribe(
|
||||
'/odom',
|
||||
'nav_msgs/Odometry',
|
||||
(msg) => {
|
||||
try {
|
||||
const pos = msg.pose.pose.position;
|
||||
const ori = msg.pose.pose.orientation;
|
||||
|
||||
const siny_cosp = 2 * (ori.w * ori.z + ori.x * ori.y);
|
||||
const cosy_cosp = 1 - 2 * (ori.y * ori.y + ori.z * ori.z);
|
||||
const theta = Math.atan2(siny_cosp, cosy_cosp);
|
||||
|
||||
const newPose = { x: pos.x, y: pos.y, theta };
|
||||
setRobotPose(newPose);
|
||||
robotPoseRef.current = newPose;
|
||||
} catch (e) {
|
||||
console.error('Error parsing odometry data:', e);
|
||||
}
|
||||
}
|
||||
);
|
||||
return unsubOdom;
|
||||
}, [subscribe]);
|
||||
|
||||
// Canvas event handlers
|
||||
const handleCanvasClick = (e) => {
|
||||
if (!mapDataRef.current || !canvasRef.current) return;
|
||||
|
||||
const canvas = canvasRef.current;
|
||||
const rect = canvas.getBoundingClientRect();
|
||||
const clickX = e.clientX - rect.left;
|
||||
const clickY = e.clientY - rect.top;
|
||||
|
||||
// Convert canvas coordinates to world coordinates
|
||||
// This assumes the map is centered on the robot
|
||||
const map = mapDataRef.current;
|
||||
const robot = robotPoseRef.current;
|
||||
const zoom = 1; // Would need to track zoom if map has zoom controls
|
||||
|
||||
// Inverse of map rendering calculation
|
||||
const centerX = canvas.width / 2;
|
||||
const centerY = canvas.height / 2;
|
||||
|
||||
const worldX = robot.x + (clickX - centerX) / zoom;
|
||||
const worldY = robot.y - (clickY - centerY) / zoom;
|
||||
|
||||
// Create new waypoint
|
||||
const newWaypoint = {
|
||||
id: Date.now(),
|
||||
x: parseFloat(worldX.toFixed(2)),
|
||||
y: parseFloat(worldY.toFixed(2)),
|
||||
label: `WP-${waypoints.length + 1}`,
|
||||
};
|
||||
|
||||
setWaypoints((prev) => [...prev, newWaypoint]);
|
||||
waypointsRef.current = [...waypointsRef.current, newWaypoint];
|
||||
};
|
||||
|
||||
const handleCanvasContextMenu = (e) => {
|
||||
e.preventDefault();
|
||||
// Right-click handled by waypoint list
|
||||
};
|
||||
|
||||
// Waypoint list handlers
|
||||
const handleDeleteWaypoint = (id) => {
|
||||
setWaypoints((prev) => prev.filter((wp) => wp.id !== id));
|
||||
waypointsRef.current = waypointsRef.current.filter((wp) => wp.id !== id);
|
||||
if (selectedWaypoint === id) setSelectedWaypoint(null);
|
||||
};
|
||||
|
||||
const handleWaypointSelect = (id) => {
|
||||
setSelectedWaypoint(selectedWaypoint === id ? null : id);
|
||||
};
|
||||
|
||||
const handleWaypointDragStart = (e, index) => {
|
||||
setIsDragging(true);
|
||||
setDragIndex(index);
|
||||
};
|
||||
|
||||
const handleWaypointDragOver = (e, targetIndex) => {
|
||||
if (!isDragging || dragIndex === null || dragIndex === targetIndex) return;
|
||||
|
||||
const newWaypoints = [...waypoints];
|
||||
const draggedWaypoint = newWaypoints[dragIndex];
|
||||
newWaypoints.splice(dragIndex, 1);
|
||||
newWaypoints.splice(targetIndex, 0, draggedWaypoint);
|
||||
|
||||
setWaypoints(newWaypoints);
|
||||
waypointsRef.current = newWaypoints;
|
||||
setDragIndex(targetIndex);
|
||||
};
|
||||
|
||||
const handleWaypointDragEnd = () => {
|
||||
setIsDragging(false);
|
||||
setDragIndex(null);
|
||||
};
|
||||
|
||||
// Execute waypoints
|
||||
const sendNavGoal = async (waypoint) => {
|
||||
if (!callService) return;
|
||||
|
||||
try {
|
||||
// Create quaternion from heading (default to 0 if no heading)
|
||||
const heading = waypoint.theta || 0;
|
||||
const halfHeading = heading / 2;
|
||||
const qx = 0;
|
||||
const qy = 0;
|
||||
const qz = Math.sin(halfHeading);
|
||||
const qw = Math.cos(halfHeading);
|
||||
|
||||
const goal = {
|
||||
pose: {
|
||||
position: {
|
||||
x: waypoint.x,
|
||||
y: waypoint.y,
|
||||
z: 0,
|
||||
},
|
||||
orientation: {
|
||||
x: qx,
|
||||
y: qy,
|
||||
z: qz,
|
||||
w: qw,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// Send to Nav2 navigate_to_pose action
|
||||
await callService(
|
||||
'/navigate_to_pose',
|
||||
'nav2_msgs/NavigateToPose',
|
||||
{ pose: goal.pose }
|
||||
);
|
||||
|
||||
setActiveWaypoint(waypoint.id);
|
||||
return true;
|
||||
} catch (e) {
|
||||
console.error('Error sending nav goal:', e);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
const executeWaypoints = async () => {
|
||||
if (waypoints.length === 0) return;
|
||||
|
||||
setExecuting(true);
|
||||
for (const waypoint of waypoints) {
|
||||
const success = await sendNavGoal(waypoint);
|
||||
if (!success) {
|
||||
console.error('Failed to send goal for waypoint:', waypoint);
|
||||
break;
|
||||
}
|
||||
// Wait a bit before sending next goal
|
||||
await new Promise((resolve) => setTimeout(resolve, 500));
|
||||
}
|
||||
setExecuting(false);
|
||||
setActiveWaypoint(null);
|
||||
};
|
||||
|
||||
const clearWaypoints = () => {
|
||||
setWaypoints([]);
|
||||
waypointsRef.current = [];
|
||||
setSelectedWaypoint(null);
|
||||
setActiveWaypoint(null);
|
||||
};
|
||||
|
||||
const sendSingleGoal = async () => {
|
||||
if (selectedWaypoint === null) return;
|
||||
|
||||
const wp = waypoints.find((w) => w.id === selectedWaypoint);
|
||||
if (wp) {
|
||||
await sendNavGoal(wp);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex h-full gap-3">
|
||||
{/* Map area with click handlers */}
|
||||
<div className="flex-1 flex flex-col space-y-3">
|
||||
<div className="flex-1 bg-gray-900 rounded-lg border border-cyan-950 overflow-hidden relative cursor-crosshair">
|
||||
<div
|
||||
ref={containerRef}
|
||||
className="w-full h-full"
|
||||
onClick={handleCanvasClick}
|
||||
onContextMenu={handleCanvasContextMenu}
|
||||
>
|
||||
{/* Virtual map display - waypoints overlaid */}
|
||||
<svg
|
||||
className="absolute inset-0 w-full h-full pointer-events-none"
|
||||
id="waypoint-overlay"
|
||||
>
|
||||
{/* Waypoint markers */}
|
||||
{waypoints.map((wp, idx) => {
|
||||
if (!mapDataRef.current) return null;
|
||||
|
||||
const robot = robotPoseRef.current;
|
||||
const zoom = 1;
|
||||
const centerX = containerRef.current?.clientWidth / 2 || 400;
|
||||
const centerY = containerRef.current?.clientHeight / 2 || 300;
|
||||
|
||||
const canvasX = centerX + (wp.x - robot.x) * zoom;
|
||||
const canvasY = centerY - (wp.y - robot.y) * zoom;
|
||||
|
||||
const isActive = wp.id === activeWaypoint;
|
||||
const isSelected = wp.id === selectedWaypoint;
|
||||
|
||||
return (
|
||||
<g key={wp.id}>
|
||||
{/* Waypoint circle */}
|
||||
<circle
|
||||
cx={canvasX}
|
||||
cy={canvasY}
|
||||
r="10"
|
||||
fill={isActive ? '#ef4444' : isSelected ? '#fbbf24' : '#06b6d4'}
|
||||
opacity="0.8"
|
||||
/>
|
||||
{/* Waypoint number */}
|
||||
<text
|
||||
x={canvasX}
|
||||
y={canvasY}
|
||||
textAnchor="middle"
|
||||
dominantBaseline="middle"
|
||||
fill="white"
|
||||
fontSize="10"
|
||||
fontWeight="bold"
|
||||
pointerEvents="none"
|
||||
>
|
||||
{idx + 1}
|
||||
</text>
|
||||
{/* Line to next waypoint */}
|
||||
{idx < waypoints.length - 1 && (
|
||||
<line
|
||||
x1={canvasX}
|
||||
y1={canvasY}
|
||||
x2={
|
||||
centerX +
|
||||
(waypoints[idx + 1].x - robot.x) * zoom
|
||||
}
|
||||
y2={
|
||||
centerY -
|
||||
(waypoints[idx + 1].y - robot.y) * zoom
|
||||
}
|
||||
stroke="#10b981"
|
||||
strokeWidth="2"
|
||||
opacity="0.6"
|
||||
/>
|
||||
)}
|
||||
</g>
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Robot position marker */}
|
||||
<circle
|
||||
cx={containerRef.current?.clientWidth / 2 || 400}
|
||||
cy={containerRef.current?.clientHeight / 2 || 300}
|
||||
r="8"
|
||||
fill="#8b5cf6"
|
||||
opacity="1"
|
||||
/>
|
||||
</svg>
|
||||
|
||||
<div className="absolute inset-0 flex items-center justify-center pointer-events-none text-gray-600 text-sm">
|
||||
{waypoints.length === 0 && (
|
||||
<div className="text-center">
|
||||
<div>Click to place waypoints</div>
|
||||
<div className="text-xs text-gray-700">Right-click to delete</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Info panel */}
|
||||
<div className="bg-gray-950 rounded-lg border border-cyan-950 p-3 text-xs text-gray-600 space-y-1">
|
||||
<div className="flex justify-between">
|
||||
<span>Waypoints:</span>
|
||||
<span className="text-cyan-400">{waypoints.length}</span>
|
||||
</div>
|
||||
<div className="flex justify-between">
|
||||
<span>Robot Position:</span>
|
||||
<span className="text-cyan-400">
|
||||
({robotPose.x.toFixed(2)}, {robotPose.y.toFixed(2)})
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Waypoint list sidebar */}
|
||||
<div className="w-64 flex flex-col bg-gray-950 rounded-lg border border-cyan-950 space-y-3 p-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="text-cyan-700 text-xs font-bold tracking-widest">WAYPOINTS</div>
|
||||
<div className="text-gray-600 text-xs">{waypoints.length}</div>
|
||||
</div>
|
||||
|
||||
{/* Waypoint list */}
|
||||
<div className="flex-1 overflow-y-auto space-y-1">
|
||||
{waypoints.length === 0 ? (
|
||||
<div className="text-center text-gray-700 text-xs py-4">
|
||||
Click map to add waypoints
|
||||
</div>
|
||||
) : (
|
||||
waypoints.map((wp, idx) => (
|
||||
<div
|
||||
key={wp.id}
|
||||
draggable
|
||||
onDragStart={(e) => handleWaypointDragStart(e, idx)}
|
||||
onDragOver={(e) => {
|
||||
e.preventDefault();
|
||||
handleWaypointDragOver(e, idx);
|
||||
}}
|
||||
onDragEnd={handleWaypointDragEnd}
|
||||
onClick={() => handleWaypointSelect(wp.id)}
|
||||
onContextMenu={(e) => {
|
||||
e.preventDefault();
|
||||
handleDeleteWaypoint(wp.id);
|
||||
}}
|
||||
className={`p-2 rounded border text-xs cursor-move transition-colors ${
|
||||
wp.id === activeWaypoint
|
||||
? 'bg-red-950 border-red-700 text-red-300'
|
||||
: wp.id === selectedWaypoint
|
||||
? 'bg-amber-950 border-amber-700 text-amber-300'
|
||||
: 'bg-gray-900 border-gray-700 text-gray-400 hover:border-gray-600'
|
||||
}`}
|
||||
>
|
||||
<div className="flex justify-between items-start gap-2">
|
||||
<div className="font-bold">#{idx + 1}</div>
|
||||
<div className="text-right flex-1">
|
||||
<div className="text-gray-500">{wp.label}</div>
|
||||
<div className="text-gray-600">
|
||||
{wp.x.toFixed(2)}, {wp.y.toFixed(2)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Control buttons */}
|
||||
<div className="space-y-2 border-t border-gray-800 pt-3">
|
||||
<button
|
||||
onClick={sendSingleGoal}
|
||||
disabled={selectedWaypoint === null || executing}
|
||||
className="w-full px-2 py-1.5 text-xs font-bold tracking-widest rounded border border-cyan-800 bg-cyan-950 text-cyan-400 hover:bg-cyan-900 disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
|
||||
>
|
||||
SEND GOAL
|
||||
</button>
|
||||
|
||||
<button
|
||||
onClick={executeWaypoints}
|
||||
disabled={waypoints.length === 0 || executing}
|
||||
className="w-full px-2 py-1.5 text-xs font-bold tracking-widest rounded border border-green-800 bg-green-950 text-green-400 hover:bg-green-900 disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
|
||||
>
|
||||
{executing ? 'EXECUTING...' : 'EXECUTE ALL'}
|
||||
</button>
|
||||
|
||||
<button
|
||||
onClick={clearWaypoints}
|
||||
disabled={waypoints.length === 0}
|
||||
className="w-full px-2 py-1.5 text-xs font-bold tracking-widest rounded border border-red-800 bg-red-950 text-red-400 hover:bg-red-900 disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
|
||||
>
|
||||
CLEAR ALL
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Instructions */}
|
||||
<div className="text-xs text-gray-600 space-y-1 border-t border-gray-800 pt-3">
|
||||
<div className="font-bold text-gray-500">CONTROLS:</div>
|
||||
<div>• Click: Place waypoint</div>
|
||||
<div>• Right-click: Delete waypoint</div>
|
||||
<div>• Drag: Reorder waypoints</div>
|
||||
<div>• Click list: Select waypoint</div>
|
||||
</div>
|
||||
|
||||
{/* Topic info */}
|
||||
<div className="text-xs text-gray-600 border-t border-gray-800 pt-3">
|
||||
<div className="flex justify-between">
|
||||
<span>Service:</span>
|
||||
<span className="text-gray-500">/navigate_to_pose</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { WaypointEditor };
|
||||
Loading…
x
Reference in New Issue
Block a user