Compare commits
7 Commits
0821845210
...
84c8b6a0ae
| Author | SHA1 | Date | |
|---|---|---|---|
| 84c8b6a0ae | |||
| dc746ccedc | |||
| d6a6965af6 | |||
| 35b940e1f5 | |||
| 5143e5bfac | |||
| 5c4f18e46c | |||
| f61a03b3c5 |
206
esp32/social_expression/social_expression.ino
Normal file
206
esp32/social_expression/social_expression.ino
Normal file
@ -0,0 +1,206 @@
|
||||
/*
|
||||
* social_expression.ino — LED mood display for saltybot social layer.
|
||||
*
|
||||
* Hardware
|
||||
* ────────
|
||||
* MCU : ESP32-C3 (e.g. Seeed XIAO ESP32-C3 or equivalent)
|
||||
* LEDs : NeoPixel ring or strip, DATA_PIN → GPIO5, power via 5V rail
|
||||
*
|
||||
* Serial protocol (from Orin Nano, 115200 8N1)
|
||||
* ─────────────────────────────────────────────
|
||||
* One JSON line per command, newline-terminated:
|
||||
* {"mood":"happy","intensity":0.8}\n
|
||||
*
|
||||
* mood values : "happy" | "curious" | "annoyed" | "playful" | "idle"
|
||||
* intensity : 0.0 (off) .. 1.0 (full brightness)
|
||||
*
|
||||
* If no command arrives for IDLE_TIMEOUT_MS, the node enters the
|
||||
* breathing idle animation automatically.
|
||||
*
|
||||
* Mood → colour mapping
|
||||
* ──────────────────────
|
||||
* happy → green #00FF00
|
||||
* curious → blue #0040FF
|
||||
* annoyed → red #FF1000
|
||||
* playful → rainbow (cycling hue across all pixels)
|
||||
* idle → soft white breathing (sine modulated)
|
||||
*
|
||||
* Dependencies (Arduino Library Manager)
|
||||
* ───────────────────────────────────────
|
||||
* Adafruit NeoPixel ≥ 1.11.0
|
||||
* ArduinoJson ≥ 7.0.0
|
||||
*
|
||||
* Build: Arduino IDE 2 / arduino-cli with board "ESP32C3 Dev Module"
|
||||
* Board manager URL: https://raw.githubusercontent.com/espressif/arduino-esp32/gh-pages/package_esp32_index.json
|
||||
*/
|
||||
|
||||
#include <Adafruit_NeoPixel.h>
|
||||
#include <ArduinoJson.h>
|
||||
|
||||
// ── Hardware config ──────────────────────────────────────────────────────────
|
||||
#define LED_PIN 5 // GPIO connected to NeoPixel data line
|
||||
#define LED_COUNT 16 // pixels in ring / strip — adjust as needed
|
||||
#define LED_BRIGHTNESS 120 // global cap 0-255 (≈47%), protects PSU
|
||||
|
||||
// ── Timing constants ─────────────────────────────────────────────────────────
|
||||
#define IDLE_TIMEOUT_MS 3000 // fall back to breathing after 3 s silence
|
||||
#define LOOP_INTERVAL_MS 20 // animation tick ≈ 50 Hz
|
||||
|
||||
// ── Colours (R, G, B) ────────────────────────────────────────────────────────
|
||||
static const uint8_t COL_HAPPY[3] = { 0, 220, 0 }; // green
|
||||
static const uint8_t COL_CURIOUS[3] = { 0, 64, 255 }; // blue
|
||||
static const uint8_t COL_ANNOYED[3] = {255, 16, 0 }; // red
|
||||
|
||||
// ── State ────────────────────────────────────────────────────────────────────
|
||||
enum Mood { MOOD_IDLE, MOOD_HAPPY, MOOD_CURIOUS, MOOD_ANNOYED, MOOD_PLAYFUL };
|
||||
|
||||
static Mood g_mood = MOOD_IDLE;
|
||||
static float g_intensity = 1.0f;
|
||||
static uint32_t g_last_cmd_ms = 0; // millis() of last received command
|
||||
|
||||
// Animation counters
|
||||
static uint16_t g_rainbow_hue = 0; // 0..65535, cycles for playful
|
||||
static uint32_t g_last_tick = 0;
|
||||
|
||||
// Serial receive buffer
|
||||
static char g_serial_buf[128];
|
||||
static uint8_t g_buf_pos = 0;
|
||||
|
||||
// ── NeoPixel object ──────────────────────────────────────────────────────────
|
||||
Adafruit_NeoPixel strip(LED_COUNT, LED_PIN, NEO_GRB + NEO_KHZ800);
|
||||
|
||||
// ── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
static uint8_t scale(uint8_t v, float intensity) {
|
||||
return (uint8_t)(v * intensity);
|
||||
}
|
||||
|
||||
/*
|
||||
* Breathing: sine envelope over time.
|
||||
* Returns brightness factor 0.0..1.0, period ~4 s.
|
||||
*/
|
||||
static float breath_factor(uint32_t now_ms) {
|
||||
float phase = (float)(now_ms % 4000) / 4000.0f; // 0..1 per period
|
||||
return 0.08f + 0.35f * (0.5f + 0.5f * sinf(2.0f * M_PI * phase));
|
||||
}
|
||||
|
||||
static void set_all(uint8_t r, uint8_t g, uint8_t b) {
|
||||
for (int i = 0; i < LED_COUNT; i++) {
|
||||
strip.setPixelColor(i, strip.Color(r, g, b));
|
||||
}
|
||||
}
|
||||
|
||||
// ── Animation drivers ────────────────────────────────────────────────────────
|
||||
|
||||
static void animate_solid(const uint8_t col[3], float intensity) {
|
||||
set_all(scale(col[0], intensity),
|
||||
scale(col[1], intensity),
|
||||
scale(col[2], intensity));
|
||||
}
|
||||
|
||||
static void animate_breathing(uint32_t now_ms, float intensity) {
|
||||
float bf = breath_factor(now_ms) * intensity;
|
||||
// Warm white: R=255 G=200 B=120
|
||||
set_all(scale(255, bf), scale(200, bf), scale(120, bf));
|
||||
}
|
||||
|
||||
static void animate_rainbow(float intensity) {
|
||||
// Spread full wheel across all pixels
|
||||
for (int i = 0; i < LED_COUNT; i++) {
|
||||
uint16_t hue = g_rainbow_hue + (uint16_t)((float)i / LED_COUNT * 65536.0f);
|
||||
uint32_t rgb = strip.ColorHSV(hue, 255,
|
||||
(uint8_t)(255.0f * intensity));
|
||||
strip.setPixelColor(i, rgb);
|
||||
}
|
||||
// Advance hue each tick (full cycle in ~6 s at 50 Hz)
|
||||
g_rainbow_hue += 218;
|
||||
}
|
||||
|
||||
// ── Serial parser ────────────────────────────────────────────────────────────
|
||||
|
||||
static void parse_command(const char *line) {
|
||||
StaticJsonDocument<128> doc;
|
||||
DeserializationError err = deserializeJson(doc, line);
|
||||
if (err) return;
|
||||
|
||||
const char *mood_str = doc["mood"] | "";
|
||||
float intensity = doc["intensity"] | 1.0f;
|
||||
if (intensity < 0.0f) intensity = 0.0f;
|
||||
if (intensity > 1.0f) intensity = 1.0f;
|
||||
|
||||
if (strcmp(mood_str, "happy") == 0) g_mood = MOOD_HAPPY;
|
||||
else if (strcmp(mood_str, "curious") == 0) g_mood = MOOD_CURIOUS;
|
||||
else if (strcmp(mood_str, "annoyed") == 0) g_mood = MOOD_ANNOYED;
|
||||
else if (strcmp(mood_str, "playful") == 0) g_mood = MOOD_PLAYFUL;
|
||||
else g_mood = MOOD_IDLE;
|
||||
|
||||
g_intensity = intensity;
|
||||
g_last_cmd_ms = millis();
|
||||
}
|
||||
|
||||
static void read_serial(void) {
|
||||
while (Serial.available()) {
|
||||
char c = (char)Serial.read();
|
||||
if (c == '\n' || c == '\r') {
|
||||
if (g_buf_pos > 0) {
|
||||
g_serial_buf[g_buf_pos] = '\0';
|
||||
parse_command(g_serial_buf);
|
||||
g_buf_pos = 0;
|
||||
}
|
||||
} else if (g_buf_pos < (sizeof(g_serial_buf) - 1)) {
|
||||
g_serial_buf[g_buf_pos++] = c;
|
||||
} else {
|
||||
// Buffer overflow — discard line
|
||||
g_buf_pos = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Arduino entry points ─────────────────────────────────────────────────────
|
||||
|
||||
void setup(void) {
|
||||
Serial.begin(115200);
|
||||
|
||||
strip.begin();
|
||||
strip.setBrightness(LED_BRIGHTNESS);
|
||||
strip.clear();
|
||||
strip.show();
|
||||
|
||||
g_last_tick = millis();
|
||||
}
|
||||
|
||||
void loop(void) {
|
||||
uint32_t now = millis();
|
||||
|
||||
read_serial();
|
||||
|
||||
// Fall back to idle if no command for IDLE_TIMEOUT_MS
|
||||
if ((now - g_last_cmd_ms) > IDLE_TIMEOUT_MS) {
|
||||
g_mood = MOOD_IDLE;
|
||||
}
|
||||
|
||||
// Throttle animation ticks
|
||||
if ((now - g_last_tick) < LOOP_INTERVAL_MS) return;
|
||||
g_last_tick = now;
|
||||
|
||||
switch (g_mood) {
|
||||
case MOOD_HAPPY:
|
||||
animate_solid(COL_HAPPY, g_intensity);
|
||||
break;
|
||||
case MOOD_CURIOUS:
|
||||
animate_solid(COL_CURIOUS, g_intensity);
|
||||
break;
|
||||
case MOOD_ANNOYED:
|
||||
animate_solid(COL_ANNOYED, g_intensity);
|
||||
break;
|
||||
case MOOD_PLAYFUL:
|
||||
animate_rainbow(g_intensity);
|
||||
break;
|
||||
case MOOD_IDLE:
|
||||
default:
|
||||
animate_breathing(now, g_intensity > 0.0f ? g_intensity : 1.0f);
|
||||
break;
|
||||
}
|
||||
|
||||
strip.show();
|
||||
}
|
||||
@ -0,0 +1,30 @@
|
||||
# attention_params.yaml — Social attention controller configuration.
|
||||
#
|
||||
# ── Proportional gain ─────────────────────────────────────────────────────────
|
||||
# angular_z = clamp(kp_angular * bearing_rad, ±max_angular_vel)
|
||||
#
|
||||
# kp_angular : 1.0 → 1.0 rad/s per 1 rad error (57° → full speed).
|
||||
# Increase for snappier tracking; decrease if robot oscillates.
|
||||
kp_angular: 1.0 # rad/s per rad
|
||||
|
||||
# ── Velocity limits ──────────────────────────────────────────────────────────
|
||||
# Hard cap on rotation output. In social mode the balance loop must not be
|
||||
# disturbed by large sudden angular commands; keep this ≤ 1.0 rad/s.
|
||||
max_angular_vel: 0.8 # rad/s
|
||||
|
||||
# ── Dead zone ────────────────────────────────────────────────────────────────
|
||||
# If |bearing| ≤ dead_zone_rad, rotation is suppressed.
|
||||
# 0.15 rad ≈ 8.6° — prevents chattering when person is roughly centred.
|
||||
dead_zone_rad: 0.15 # radians
|
||||
|
||||
# ── Control loop ─────────────────────────────────────────────────────────────
|
||||
control_rate: 20.0 # Hz
|
||||
|
||||
# ── Lost-target timeout ───────────────────────────────────────────────────────
|
||||
# If no /social/persons message arrives for this many seconds, publish zero
|
||||
# and enter IDLE state. The cmd_vel bridge deadman will also kick in.
|
||||
lost_timeout_s: 2.0 # seconds
|
||||
|
||||
# ── Master enable ─────────────────────────────────────────────────────────────
|
||||
# Toggle at runtime: ros2 param set /attention_node attention_enabled false
|
||||
attention_enabled: true
|
||||
@ -0,0 +1,22 @@
|
||||
# expression_params.yaml — LED mood display bridge configuration.
|
||||
#
|
||||
# serial_port : udev symlink or device path for the ESP32-C3 USB-CDC port.
|
||||
# Recommended: create a udev rule:
|
||||
# SUBSYSTEM=="tty", ATTRS{idVendor}=="303a", ATTRS{idProduct}=="1001",
|
||||
# SYMLINK+="esp32-social"
|
||||
#
|
||||
# baud_rate : must match social_expression.ino (default 115200)
|
||||
#
|
||||
# idle_timeout_s : if no /social/mood message arrives for this many seconds,
|
||||
# the node sends {"mood":"idle","intensity":1.0} so the ESP32 breathing
|
||||
# animation is synchronised with node awareness.
|
||||
#
|
||||
# control_rate : how often to check the idle timeout (Hz).
|
||||
# Does NOT gate the mood messages — those are forwarded immediately.
|
||||
|
||||
expression_node:
|
||||
ros__parameters:
|
||||
serial_port: /dev/esp32-social
|
||||
baud_rate: 115200
|
||||
idle_timeout_s: 3.0
|
||||
control_rate: 10.0
|
||||
@ -1,3 +1,18 @@
|
||||
"""
|
||||
social.launch.py — Launch the full saltybot social stack.
|
||||
|
||||
Includes:
|
||||
person_state_tracker — multi-modal person identity fusion (Issue #82)
|
||||
expression_node — /social/mood → ESP32-C3 NeoPixel serial (Issue #86)
|
||||
attention_node — /social/persons → /cmd_vel rotation (Issue #86)
|
||||
|
||||
Usage:
|
||||
ros2 launch saltybot_social social.launch.py
|
||||
ros2 launch saltybot_social social.launch.py serial_port:=/dev/ttyUSB1
|
||||
"""
|
||||
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
@ -5,7 +20,12 @@ from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg = get_package_share_directory("saltybot_social")
|
||||
exp_cfg = os.path.join(pkg, "config", "expression_params.yaml")
|
||||
att_cfg = os.path.join(pkg, "config", "attention_params.yaml")
|
||||
|
||||
return LaunchDescription([
|
||||
# person_state_tracker args (Issue #82)
|
||||
DeclareLaunchArgument(
|
||||
'engagement_distance',
|
||||
default_value='2.0',
|
||||
@ -21,6 +41,19 @@ def generate_launch_description():
|
||||
default_value='false',
|
||||
description='Whether UWB anchor data is available'
|
||||
),
|
||||
|
||||
# expression_node args (Issue #86)
|
||||
DeclareLaunchArgument("serial_port", default_value="/dev/esp32-social"),
|
||||
DeclareLaunchArgument("baud_rate", default_value="115200"),
|
||||
DeclareLaunchArgument("idle_timeout_s", default_value="3.0"),
|
||||
|
||||
# attention_node args (Issue #86)
|
||||
DeclareLaunchArgument("kp_angular", default_value="1.0"),
|
||||
DeclareLaunchArgument("max_angular_vel", default_value="0.8"),
|
||||
DeclareLaunchArgument("dead_zone_rad", default_value="0.15"),
|
||||
DeclareLaunchArgument("lost_timeout_s", default_value="2.0"),
|
||||
DeclareLaunchArgument("attention_enabled", default_value="true"),
|
||||
|
||||
Node(
|
||||
package='saltybot_social',
|
||||
executable='person_state_tracker',
|
||||
@ -32,4 +65,36 @@ def generate_launch_description():
|
||||
'uwb_enabled': LaunchConfiguration('uwb_enabled'),
|
||||
}],
|
||||
),
|
||||
|
||||
Node(
|
||||
package="saltybot_social",
|
||||
executable="expression_node",
|
||||
name="expression_node",
|
||||
output="screen",
|
||||
parameters=[
|
||||
exp_cfg,
|
||||
{
|
||||
"serial_port": LaunchConfiguration("serial_port"),
|
||||
"baud_rate": LaunchConfiguration("baud_rate"),
|
||||
"idle_timeout_s": LaunchConfiguration("idle_timeout_s"),
|
||||
},
|
||||
],
|
||||
),
|
||||
|
||||
Node(
|
||||
package="saltybot_social",
|
||||
executable="attention_node",
|
||||
name="attention_node",
|
||||
output="screen",
|
||||
parameters=[
|
||||
att_cfg,
|
||||
{
|
||||
"kp_angular": LaunchConfiguration("kp_angular"),
|
||||
"max_angular_vel": LaunchConfiguration("max_angular_vel"),
|
||||
"dead_zone_rad": LaunchConfiguration("dead_zone_rad"),
|
||||
"lost_timeout_s": LaunchConfiguration("lost_timeout_s"),
|
||||
"attention_enabled": LaunchConfiguration("attention_enabled"),
|
||||
},
|
||||
],
|
||||
),
|
||||
])
|
||||
|
||||
@ -3,7 +3,12 @@
|
||||
<package format="3">
|
||||
<name>saltybot_social</name>
|
||||
<version>0.1.0</version>
|
||||
<description>Multi-modal person identity fusion and state tracking for saltybot</description>
|
||||
<description>
|
||||
Social interaction layer for saltybot.
|
||||
person_state_tracker: multi-modal person identity fusion (Issue #82).
|
||||
expression_node: bridges /social/mood to ESP32-C3 NeoPixel ring over serial (Issue #86).
|
||||
attention_node: rotates robot toward active speaker via /social/persons bearing (Issue #86).
|
||||
</description>
|
||||
<maintainer email="seb@vayrette.com">seb</maintainer>
|
||||
<license>MIT</license>
|
||||
<depend>rclpy</depend>
|
||||
|
||||
@ -0,0 +1,205 @@
|
||||
"""
|
||||
attention_node.py — Social attention controller for saltybot.
|
||||
|
||||
Rotates the robot to face the active speaker or most confident detected person.
|
||||
Publishes rotation-only cmd_vel (linear.x = 0) so balance is not disturbed.
|
||||
|
||||
Subscribes
|
||||
──────────
|
||||
/social/persons (saltybot_social_msgs/PersonArray)
|
||||
Contains all tracked persons; active_id marks the current speaker.
|
||||
bearing_rad: signed bearing in base_link (+ve = left/CCW).
|
||||
|
||||
Publishes
|
||||
─────────
|
||||
/cmd_vel (geometry_msgs/Twist)
|
||||
angular.z = clamp(kp_angular * bearing, ±max_angular_vel)
|
||||
linear.x = 0 always (social mode, not follow mode)
|
||||
|
||||
Control law
|
||||
───────────
|
||||
1. Pick target: person with active_id; fall back to highest-confidence person.
|
||||
2. Dead zone: if |bearing| ≤ dead_zone_rad → angular_z = 0.
|
||||
3. Proportional: angular_z = kp_angular * bearing.
|
||||
4. Clamp to ±max_angular_vel.
|
||||
5. If no fresh person data for lost_timeout_s → publish zero and stay idle.
|
||||
|
||||
State machine
|
||||
─────────────
|
||||
ATTENDING — fresh target; publishing proportional rotation
|
||||
IDLE — no target; publishing zero
|
||||
|
||||
Parameters
|
||||
──────────
|
||||
kp_angular (float) 1.0 proportional gain (rad/s per rad error)
|
||||
max_angular_vel (float) 0.8 hard cap (rad/s)
|
||||
dead_zone_rad (float) 0.15 ~8.6° dead zone around robot heading
|
||||
control_rate (float) 20.0 Hz
|
||||
lost_timeout_s (float) 2.0 seconds before going idle
|
||||
attention_enabled (bool) True runtime kill switch
|
||||
"""
|
||||
|
||||
import math
|
||||
import time
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from geometry_msgs.msg import Twist
|
||||
|
||||
from saltybot_social_msgs.msg import PersonArray
|
||||
|
||||
|
||||
# ── Pure helpers (used by tests) ──────────────────────────────────────────────
|
||||
|
||||
def clamp(v: float, lo: float, hi: float) -> float:
|
||||
return max(lo, min(hi, v))
|
||||
|
||||
|
||||
def select_target(persons, active_id: int):
|
||||
"""
|
||||
Pick the attention target from a PersonArray's persons list.
|
||||
|
||||
Priority:
|
||||
1. Person whose track_id == active_id (speaking / explicitly active).
|
||||
2. Person with highest confidence (fallback when no speaker declared).
|
||||
|
||||
Returns the selected Person message, or None if persons is empty.
|
||||
"""
|
||||
if not persons:
|
||||
return None
|
||||
|
||||
# Prefer the declared active speaker
|
||||
for p in persons:
|
||||
if p.track_id == active_id:
|
||||
return p
|
||||
|
||||
# Fall back to most confident detection
|
||||
return max(persons, key=lambda p: p.confidence)
|
||||
|
||||
|
||||
def compute_attention_cmd(
|
||||
bearing_rad: float,
|
||||
dead_zone_rad: float,
|
||||
kp_angular: float,
|
||||
max_angular_vel: float,
|
||||
) -> float:
|
||||
"""
|
||||
Proportional rotation command toward bearing.
|
||||
|
||||
Returns angular_z (rad/s).
|
||||
Zero inside dead zone; proportional + clamped outside.
|
||||
"""
|
||||
if abs(bearing_rad) <= dead_zone_rad:
|
||||
return 0.0
|
||||
return clamp(kp_angular * bearing_rad, -max_angular_vel, max_angular_vel)
|
||||
|
||||
|
||||
# ── ROS2 node ─────────────────────────────────────────────────────────────────
|
||||
|
||||
class AttentionNode(Node):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("attention_node")
|
||||
|
||||
self.declare_parameter("kp_angular", 1.0)
|
||||
self.declare_parameter("max_angular_vel", 0.8)
|
||||
self.declare_parameter("dead_zone_rad", 0.15)
|
||||
self.declare_parameter("control_rate", 20.0)
|
||||
self.declare_parameter("lost_timeout_s", 2.0)
|
||||
self.declare_parameter("attention_enabled", True)
|
||||
|
||||
self._reload_params()
|
||||
|
||||
# State
|
||||
self._last_persons_t = 0.0 # monotonic; 0 = never received
|
||||
self._persons = []
|
||||
self._active_id = -1
|
||||
self._state = "idle"
|
||||
|
||||
self.create_subscription(
|
||||
PersonArray, "/social/persons", self._persons_cb, 10
|
||||
)
|
||||
|
||||
self._cmd_pub = self.create_publisher(Twist, "/cmd_vel", 10)
|
||||
|
||||
rate = self._p["control_rate"]
|
||||
self._timer = self.create_timer(1.0 / rate, self._control_cb)
|
||||
|
||||
self.get_logger().info(
|
||||
f"AttentionNode ready kp={self._p['kp_angular']} "
|
||||
f"dead_zone={math.degrees(self._p['dead_zone_rad']):.1f}° "
|
||||
f"max_ω={self._p['max_angular_vel']} rad/s"
|
||||
)
|
||||
|
||||
def _reload_params(self):
|
||||
self._p = {
|
||||
"kp_angular": self.get_parameter("kp_angular").value,
|
||||
"max_angular_vel": self.get_parameter("max_angular_vel").value,
|
||||
"dead_zone_rad": self.get_parameter("dead_zone_rad").value,
|
||||
"control_rate": self.get_parameter("control_rate").value,
|
||||
"lost_timeout_s": self.get_parameter("lost_timeout_s").value,
|
||||
"attention_enabled": self.get_parameter("attention_enabled").value,
|
||||
}
|
||||
|
||||
# ── Callbacks ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _persons_cb(self, msg: PersonArray):
|
||||
self._persons = msg.persons
|
||||
self._active_id = msg.active_id
|
||||
self._last_persons_t = time.monotonic()
|
||||
if self._state == "idle" and msg.persons:
|
||||
self._state = "attending"
|
||||
self.get_logger().info("Attention: person detected — attending")
|
||||
|
||||
def _control_cb(self):
|
||||
self._reload_params()
|
||||
twist = Twist()
|
||||
|
||||
if not self._p["attention_enabled"]:
|
||||
self._cmd_pub.publish(twist)
|
||||
return
|
||||
|
||||
now = time.monotonic()
|
||||
fresh = (
|
||||
self._last_persons_t > 0.0
|
||||
and (now - self._last_persons_t) < self._p["lost_timeout_s"]
|
||||
)
|
||||
|
||||
if not fresh:
|
||||
if self._state == "attending":
|
||||
self._state = "idle"
|
||||
self.get_logger().info("Attention: no person — idle")
|
||||
self._cmd_pub.publish(twist)
|
||||
return
|
||||
|
||||
target = select_target(self._persons, self._active_id)
|
||||
if target is None:
|
||||
self._cmd_pub.publish(twist)
|
||||
return
|
||||
|
||||
self._state = "attending"
|
||||
twist.angular.z = compute_attention_cmd(
|
||||
bearing_rad=target.bearing_rad,
|
||||
dead_zone_rad=self._p["dead_zone_rad"],
|
||||
kp_angular=self._p["kp_angular"],
|
||||
max_angular_vel=self._p["max_angular_vel"],
|
||||
)
|
||||
self._cmd_pub.publish(twist)
|
||||
|
||||
|
||||
# ── Entry point ───────────────────────────────────────────────────────────────
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = AttentionNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.try_shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -0,0 +1,157 @@
|
||||
"""
|
||||
expression_node.py — LED mood display bridge for saltybot social layer.
|
||||
|
||||
Subscribes
|
||||
──────────
|
||||
/social/mood (saltybot_social_msgs/Mood)
|
||||
mood : "happy" | "curious" | "annoyed" | "playful" | "idle"
|
||||
intensity : 0.0 .. 1.0
|
||||
|
||||
Behaviour
|
||||
─────────
|
||||
* On every /social/mood message, serialises the command as a single-line
|
||||
JSON frame and writes it over the ESP32-C3 serial link:
|
||||
|
||||
{"mood":"happy","intensity":0.80}\n
|
||||
|
||||
* If no mood message arrives for idle_timeout_s, sends an explicit
|
||||
idle command so the ESP32 breathing animation is synchronised with the
|
||||
node's own awareness of whether the robot is active.
|
||||
|
||||
* Re-opens the serial port automatically if the device disconnects.
|
||||
|
||||
Parameters
|
||||
──────────
|
||||
serial_port (str) /dev/ttyUSB0 — ESP32-C3 USB-CDC device
|
||||
baud_rate (int) 115200
|
||||
idle_timeout_s (float) 3.0 — seconds before sending idle
|
||||
control_rate (float) 10.0 — Hz; how often to check idle
|
||||
"""
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
|
||||
from saltybot_social_msgs.msg import Mood
|
||||
|
||||
try:
|
||||
import serial
|
||||
_SERIAL_AVAILABLE = True
|
||||
except ImportError:
|
||||
_SERIAL_AVAILABLE = False
|
||||
|
||||
|
||||
class ExpressionNode(Node):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("expression_node")
|
||||
|
||||
self.declare_parameter("serial_port", "/dev/ttyUSB0")
|
||||
self.declare_parameter("baud_rate", 115200)
|
||||
self.declare_parameter("idle_timeout_s", 3.0)
|
||||
self.declare_parameter("control_rate", 10.0)
|
||||
|
||||
self._port = self.get_parameter("serial_port").value
|
||||
self._baud = self.get_parameter("baud_rate").value
|
||||
self._idle_to = self.get_parameter("idle_timeout_s").value
|
||||
rate = self.get_parameter("control_rate").value
|
||||
|
||||
self._last_cmd_t = 0.0 # monotonic, 0 = never received
|
||||
self._lock = threading.Lock()
|
||||
self._ser = None
|
||||
|
||||
self.create_subscription(Mood, "/social/mood", self._mood_cb, 10)
|
||||
|
||||
self._timer = self.create_timer(1.0 / rate, self._tick_cb)
|
||||
|
||||
if _SERIAL_AVAILABLE:
|
||||
threading.Thread(target=self._open_serial, daemon=True).start()
|
||||
else:
|
||||
self.get_logger().warn(
|
||||
"pyserial not available — serial output disabled (dry-run mode)"
|
||||
)
|
||||
|
||||
self.get_logger().info(
|
||||
f"ExpressionNode ready port={self._port} baud={self._baud} "
|
||||
f"idle_timeout={self._idle_to}s"
|
||||
)
|
||||
|
||||
# ── Serial management ─────────────────────────────────────────────────────
|
||||
|
||||
def _open_serial(self):
|
||||
"""Background thread: keep serial port open, reconnect on error."""
|
||||
while rclpy.ok():
|
||||
try:
|
||||
with self._lock:
|
||||
self._ser = serial.Serial(
|
||||
self._port, self._baud, timeout=1.0
|
||||
)
|
||||
self.get_logger().info(f"Opened serial port {self._port}")
|
||||
# Hold serial open until error
|
||||
while rclpy.ok():
|
||||
with self._lock:
|
||||
if self._ser is None or not self._ser.is_open:
|
||||
break
|
||||
time.sleep(0.5)
|
||||
except Exception as exc:
|
||||
self.get_logger().warn(
|
||||
f"Serial {self._port} error: {exc} — retry in 3 s"
|
||||
)
|
||||
with self._lock:
|
||||
if self._ser and self._ser.is_open:
|
||||
self._ser.close()
|
||||
self._ser = None
|
||||
time.sleep(3.0)
|
||||
|
||||
def _send(self, mood: str, intensity: float):
|
||||
"""Serialise and write one JSON line to ESP32-C3."""
|
||||
line = json.dumps({"mood": mood, "intensity": round(intensity, 3)}) + "\n"
|
||||
data = line.encode("ascii")
|
||||
with self._lock:
|
||||
if self._ser and self._ser.is_open:
|
||||
try:
|
||||
self._ser.write(data)
|
||||
except Exception as exc:
|
||||
self.get_logger().warn(f"Serial write error: {exc}")
|
||||
self._ser.close()
|
||||
self._ser = None
|
||||
else:
|
||||
# Dry-run: log the frame
|
||||
self.get_logger().debug(f"[dry-run] → {line.strip()}")
|
||||
|
||||
# ── Callbacks ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _mood_cb(self, msg: Mood):
|
||||
intensity = max(0.0, min(1.0, float(msg.intensity)))
|
||||
self._send(msg.mood, intensity)
|
||||
self._last_cmd_t = time.monotonic()
|
||||
|
||||
def _tick_cb(self):
|
||||
"""Periodic check — send idle if we've been quiet too long."""
|
||||
if self._last_cmd_t == 0.0:
|
||||
return
|
||||
if (time.monotonic() - self._last_cmd_t) >= self._idle_to:
|
||||
self._send("idle", 1.0)
|
||||
# Reset so we don't flood the ESP32 with idle frames
|
||||
self._last_cmd_t = time.monotonic()
|
||||
|
||||
|
||||
# ── Entry point ───────────────────────────────────────────────────────────────
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = ExpressionNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.try_shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -21,12 +21,14 @@ setup(
|
||||
zip_safe=True,
|
||||
maintainer='seb',
|
||||
maintainer_email='seb@vayrette.com',
|
||||
description='Multi-modal person identity fusion and state tracking for saltybot',
|
||||
description='Social interaction layer — person state tracking, LED expression + attention',
|
||||
license='MIT',
|
||||
tests_require=['pytest'],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'person_state_tracker = saltybot_social.person_state_tracker_node:main',
|
||||
'expression_node = saltybot_social.expression_node:main',
|
||||
'attention_node = saltybot_social.attention_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
140
jetson/ros2_ws/src/saltybot_social/test/test_attention.py
Normal file
140
jetson/ros2_ws/src/saltybot_social/test/test_attention.py
Normal file
@ -0,0 +1,140 @@
|
||||
"""
|
||||
test_attention.py — Unit tests for attention_node helpers.
|
||||
|
||||
No ROS2 / serial / GPU dependencies — runs with plain pytest.
|
||||
"""
|
||||
|
||||
import math
|
||||
import pytest
|
||||
|
||||
from saltybot_social.attention_node import (
|
||||
clamp,
|
||||
compute_attention_cmd,
|
||||
select_target,
|
||||
)
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
class FakePerson:
|
||||
def __init__(self, track_id, bearing_rad, confidence=0.9, is_speaking=False):
|
||||
self.track_id = track_id
|
||||
self.bearing_rad = bearing_rad
|
||||
self.confidence = confidence
|
||||
self.is_speaking = is_speaking
|
||||
|
||||
|
||||
# ── clamp ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestClamp:
|
||||
def test_within_range(self):
|
||||
assert clamp(0.5, 0.0, 1.0) == 0.5
|
||||
|
||||
def test_below_lo(self):
|
||||
assert clamp(-0.5, 0.0, 1.0) == 0.0
|
||||
|
||||
def test_above_hi(self):
|
||||
assert clamp(1.5, 0.0, 1.0) == 1.0
|
||||
|
||||
def test_at_boundary(self):
|
||||
assert clamp(0.0, 0.0, 1.0) == 0.0
|
||||
assert clamp(1.0, 0.0, 1.0) == 1.0
|
||||
|
||||
|
||||
# ── select_target ─────────────────────────────────────────────────────────────
|
||||
|
||||
class TestSelectTarget:
|
||||
|
||||
def test_empty_returns_none(self):
|
||||
assert select_target([], active_id=-1) is None
|
||||
|
||||
def test_picks_active_id(self):
|
||||
persons = [FakePerson(1, 0.1, confidence=0.6),
|
||||
FakePerson(2, 0.5, confidence=0.9)]
|
||||
result = select_target(persons, active_id=1)
|
||||
assert result.track_id == 1
|
||||
|
||||
def test_falls_back_to_highest_confidence(self):
|
||||
persons = [FakePerson(1, 0.1, confidence=0.6),
|
||||
FakePerson(2, 0.5, confidence=0.9),
|
||||
FakePerson(3, 0.2, confidence=0.4)]
|
||||
result = select_target(persons, active_id=-1)
|
||||
assert result.track_id == 2
|
||||
|
||||
def test_single_person_no_active(self):
|
||||
persons = [FakePerson(5, 0.3)]
|
||||
result = select_target(persons, active_id=-1)
|
||||
assert result.track_id == 5
|
||||
|
||||
def test_active_id_not_in_list_falls_back(self):
|
||||
persons = [FakePerson(1, 0.0, confidence=0.5),
|
||||
FakePerson(2, 0.2, confidence=0.95)]
|
||||
result = select_target(persons, active_id=99)
|
||||
assert result.track_id == 2
|
||||
|
||||
|
||||
# ── compute_attention_cmd ─────────────────────────────────────────────────────
|
||||
|
||||
class TestComputeAttentionCmd:
|
||||
|
||||
def test_inside_dead_zone_zero(self):
|
||||
cmd = compute_attention_cmd(
|
||||
bearing_rad=0.05,
|
||||
dead_zone_rad=0.15,
|
||||
kp_angular=1.0,
|
||||
max_angular_vel=0.8,
|
||||
)
|
||||
assert cmd == 0.0
|
||||
|
||||
def test_exactly_at_dead_zone_boundary_zero(self):
|
||||
cmd = compute_attention_cmd(
|
||||
bearing_rad=0.15,
|
||||
dead_zone_rad=0.15,
|
||||
kp_angular=1.0,
|
||||
max_angular_vel=0.8,
|
||||
)
|
||||
assert cmd == 0.0
|
||||
|
||||
def test_positive_bearing_gives_positive_angular_z(self):
|
||||
cmd = compute_attention_cmd(
|
||||
bearing_rad=0.4,
|
||||
dead_zone_rad=0.15,
|
||||
kp_angular=1.0,
|
||||
max_angular_vel=0.8,
|
||||
)
|
||||
assert cmd > 0.0
|
||||
|
||||
def test_negative_bearing_gives_negative_angular_z(self):
|
||||
cmd = compute_attention_cmd(
|
||||
bearing_rad=-0.4,
|
||||
dead_zone_rad=0.15,
|
||||
kp_angular=1.0,
|
||||
max_angular_vel=0.8,
|
||||
)
|
||||
assert cmd < 0.0
|
||||
|
||||
def test_clamps_to_max_angular_vel(self):
|
||||
cmd = compute_attention_cmd(
|
||||
bearing_rad=math.pi, # 180° off
|
||||
dead_zone_rad=0.15,
|
||||
kp_angular=2.0,
|
||||
max_angular_vel=0.8,
|
||||
)
|
||||
assert abs(cmd) <= 0.8
|
||||
|
||||
def test_proportional_scaling(self):
|
||||
"""Double bearing → double cmd (within linear region)."""
|
||||
kp = 1.0
|
||||
dz = 0.05
|
||||
cmd1 = compute_attention_cmd(0.3, dz, kp, 10.0)
|
||||
cmd2 = compute_attention_cmd(0.6, dz, kp, 10.0)
|
||||
assert abs(cmd2 - 2.0 * cmd1) < 0.01
|
||||
|
||||
def test_zero_bearing_zero_cmd(self):
|
||||
cmd = compute_attention_cmd(
|
||||
bearing_rad=0.0,
|
||||
dead_zone_rad=0.1,
|
||||
kp_angular=1.0,
|
||||
max_angular_vel=0.8,
|
||||
)
|
||||
assert cmd == 0.0
|
||||
@ -0,0 +1,6 @@
|
||||
enrollment_node:
|
||||
ros__parameters:
|
||||
db_path: '/mnt/nvme/saltybot/gallery/persons.db'
|
||||
voice_samples_dir: '/mnt/nvme/saltybot/gallery/voice'
|
||||
auto_enroll_phrase: 'remember me my name is'
|
||||
n_samples_default: 10
|
||||
@ -0,0 +1,22 @@
|
||||
"""Launch file for saltybot social enrollment node."""
|
||||
|
||||
import os
|
||||
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
from launch import LaunchDescription
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg_share = get_package_share_directory('saltybot_social_enrollment')
|
||||
config_file = os.path.join(pkg_share, 'config', 'enrollment_params.yaml')
|
||||
|
||||
return LaunchDescription([
|
||||
Node(
|
||||
package='saltybot_social_enrollment',
|
||||
executable='enrollment_node',
|
||||
name='enrollment_node',
|
||||
parameters=[config_file],
|
||||
output='screen',
|
||||
),
|
||||
])
|
||||
23
jetson/ros2_ws/src/saltybot_social_enrollment/package.xml
Normal file
23
jetson/ros2_ws/src/saltybot_social_enrollment/package.xml
Normal file
@ -0,0 +1,23 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_social_enrollment</name>
|
||||
<version>0.1.0</version>
|
||||
<description>Person enrollment system for saltybot social interaction</description>
|
||||
<maintainer email="seb@vayrette.com">seb</maintainer>
|
||||
<license>MIT</license>
|
||||
|
||||
<buildtool_depend>ament_python</buildtool_depend>
|
||||
|
||||
<depend>rclpy</depend>
|
||||
<depend>sensor_msgs</depend>
|
||||
<depend>cv_bridge</depend>
|
||||
<depend>std_msgs</depend>
|
||||
<depend>saltybot_social_msgs</depend>
|
||||
|
||||
<exec_depend>python3-numpy</exec_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1,184 @@
|
||||
#!/usr/bin/env python3
|
||||
"""enrollment_cli.py -- Gallery management CLI for saltybot social.
|
||||
|
||||
Usage:
|
||||
ros2 run saltybot_social_enrollment enrollment_cli enroll --name "Alice" [--samples 15] [--mode face]
|
||||
ros2 run saltybot_social_enrollment enrollment_cli list
|
||||
ros2 run saltybot_social_enrollment enrollment_cli delete --id 3
|
||||
ros2 run saltybot_social_enrollment enrollment_cli rename --id 2 --name "Bob"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
|
||||
from saltybot_social_msgs.srv import (
|
||||
EnrollPerson,
|
||||
ListPersons,
|
||||
DeletePerson,
|
||||
UpdatePerson,
|
||||
)
|
||||
|
||||
|
||||
class EnrollmentCLI(Node):
|
||||
def __init__(self):
|
||||
super().__init__('enrollment_cli')
|
||||
self._enroll_client = self.create_client(
|
||||
EnrollPerson, '/social/enroll'
|
||||
)
|
||||
self._list_client = self.create_client(
|
||||
ListPersons, '/social/persons/list'
|
||||
)
|
||||
self._delete_client = self.create_client(
|
||||
DeletePerson, '/social/persons/delete'
|
||||
)
|
||||
self._update_client = self.create_client(
|
||||
UpdatePerson, '/social/persons/update'
|
||||
)
|
||||
|
||||
def enroll(self, name: str, n_samples: int = 10, mode: str = 'face'):
|
||||
if not self._enroll_client.wait_for_service(timeout_sec=5.0):
|
||||
print('ERROR: /social/enroll service not available')
|
||||
return False
|
||||
|
||||
req = EnrollPerson.Request()
|
||||
req.name = name
|
||||
req.mode = mode
|
||||
req.n_samples = n_samples
|
||||
|
||||
print(f'Enrolling "{name}" ({mode}, {n_samples} samples)...')
|
||||
future = self._enroll_client.call_async(req)
|
||||
rclpy.spin_until_future_complete(self, future, timeout_sec=60.0)
|
||||
|
||||
if future.result() is None:
|
||||
print('ERROR: Enrollment timed out')
|
||||
return False
|
||||
|
||||
resp = future.result()
|
||||
if resp.success:
|
||||
print(f'OK: Enrolled "{name}" as person_id={resp.person_id}')
|
||||
else:
|
||||
print(f'FAILED: {resp.message}')
|
||||
return resp.success
|
||||
|
||||
def list_persons(self):
|
||||
if not self._list_client.wait_for_service(timeout_sec=5.0):
|
||||
print('ERROR: /social/persons/list service not available')
|
||||
return
|
||||
|
||||
req = ListPersons.Request()
|
||||
future = self._list_client.call_async(req)
|
||||
rclpy.spin_until_future_complete(self, future, timeout_sec=10.0)
|
||||
|
||||
if future.result() is None:
|
||||
print('ERROR: List request timed out')
|
||||
return
|
||||
|
||||
resp = future.result()
|
||||
if not resp.persons:
|
||||
print('Gallery is empty.')
|
||||
return
|
||||
|
||||
print(f'{"ID":>4} {"Name":<20} {"Samples":>7} {"Embedding Dim":>13}')
|
||||
print('-' * 50)
|
||||
for p in resp.persons:
|
||||
dim = len(p.embedding) if p.embedding else 0
|
||||
print(f'{p.person_id:>4} {p.person_name:<20} {p.sample_count:>7} {dim:>13}')
|
||||
|
||||
def delete(self, person_id: int):
|
||||
if not self._delete_client.wait_for_service(timeout_sec=5.0):
|
||||
print('ERROR: /social/persons/delete service not available')
|
||||
return False
|
||||
|
||||
req = DeletePerson.Request()
|
||||
req.person_id = person_id
|
||||
|
||||
future = self._delete_client.call_async(req)
|
||||
rclpy.spin_until_future_complete(self, future, timeout_sec=10.0)
|
||||
|
||||
if future.result() is None:
|
||||
print('ERROR: Delete request timed out')
|
||||
return False
|
||||
|
||||
resp = future.result()
|
||||
if resp.success:
|
||||
print(f'OK: {resp.message}')
|
||||
else:
|
||||
print(f'FAILED: {resp.message}')
|
||||
return resp.success
|
||||
|
||||
def rename(self, person_id: int, new_name: str):
|
||||
if not self._update_client.wait_for_service(timeout_sec=5.0):
|
||||
print('ERROR: /social/persons/update service not available')
|
||||
return False
|
||||
|
||||
req = UpdatePerson.Request()
|
||||
req.person_id = person_id
|
||||
req.new_name = new_name
|
||||
|
||||
future = self._update_client.call_async(req)
|
||||
rclpy.spin_until_future_complete(self, future, timeout_sec=10.0)
|
||||
|
||||
if future.result() is None:
|
||||
print('ERROR: Rename request timed out')
|
||||
return False
|
||||
|
||||
resp = future.result()
|
||||
if resp.success:
|
||||
print(f'OK: {resp.message}')
|
||||
else:
|
||||
print(f'FAILED: {resp.message}')
|
||||
return resp.success
|
||||
|
||||
|
||||
def main(args=None):
|
||||
parser = argparse.ArgumentParser(
|
||||
description='saltybot person enrollment CLI'
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest='command', required=True)
|
||||
|
||||
# enroll
|
||||
enroll_p = subparsers.add_parser('enroll', help='Enroll a new person')
|
||||
enroll_p.add_argument('--name', required=True, help='Person name')
|
||||
enroll_p.add_argument('--samples', type=int, default=10,
|
||||
help='Number of face samples (default: 10)')
|
||||
enroll_p.add_argument('--mode', default='face',
|
||||
help='Enrollment mode (default: face)')
|
||||
|
||||
# list
|
||||
subparsers.add_parser('list', help='List enrolled persons')
|
||||
|
||||
# delete
|
||||
delete_p = subparsers.add_parser('delete', help='Delete a person')
|
||||
delete_p.add_argument('--id', type=int, required=True,
|
||||
help='Person ID to delete')
|
||||
|
||||
# rename
|
||||
rename_p = subparsers.add_parser('rename', help='Rename a person')
|
||||
rename_p.add_argument('--id', type=int, required=True,
|
||||
help='Person ID to rename')
|
||||
rename_p.add_argument('--name', required=True, help='New name')
|
||||
|
||||
parsed = parser.parse_args(sys.argv[1:])
|
||||
|
||||
rclpy.init()
|
||||
cli = EnrollmentCLI()
|
||||
|
||||
try:
|
||||
if parsed.command == 'enroll':
|
||||
cli.enroll(parsed.name, parsed.samples, parsed.mode)
|
||||
elif parsed.command == 'list':
|
||||
cli.list_persons()
|
||||
elif parsed.command == 'delete':
|
||||
cli.delete(parsed.id)
|
||||
elif parsed.command == 'rename':
|
||||
cli.rename(parsed.id, parsed.name)
|
||||
finally:
|
||||
cli.destroy_node()
|
||||
rclpy.try_shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -0,0 +1,302 @@
|
||||
"""enrollment_node.py -- ROS2 person enrollment node for saltybot social.
|
||||
|
||||
Coordinates person enrollment:
|
||||
- Forwards /social/enroll to face_recognizer's service
|
||||
- Owns persistent SQLite gallery (PersonDB)
|
||||
- Voice-triggered enrollment via "remember me my name is X"
|
||||
- Gallery management services (list/delete/update)
|
||||
- Syncs DB from /social/faces/embeddings topic
|
||||
"""
|
||||
|
||||
import threading
|
||||
import numpy as np
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile, ReliabilityPolicy, DurabilityPolicy
|
||||
from std_msgs.msg import String
|
||||
|
||||
from saltybot_social_msgs.msg import (
|
||||
FaceDetectionArray,
|
||||
FaceEmbedding,
|
||||
FaceEmbeddingArray,
|
||||
)
|
||||
from saltybot_social_msgs.srv import (
|
||||
EnrollPerson,
|
||||
ListPersons,
|
||||
DeletePerson,
|
||||
UpdatePerson,
|
||||
)
|
||||
|
||||
from saltybot_social_enrollment.person_db import PersonDB
|
||||
|
||||
|
||||
class EnrollmentNode(Node):
|
||||
def __init__(self):
|
||||
super().__init__('enrollment_node')
|
||||
|
||||
# Parameters
|
||||
self.declare_parameter('db_path', '/mnt/nvme/saltybot/gallery/persons.db')
|
||||
self.declare_parameter('voice_samples_dir', '/mnt/nvme/saltybot/gallery/voice')
|
||||
self.declare_parameter('auto_enroll_phrase', 'remember me my name is')
|
||||
self.declare_parameter('n_samples_default', 10)
|
||||
|
||||
db_path = self.get_parameter('db_path').value
|
||||
self._voice_dir = self.get_parameter('voice_samples_dir').value
|
||||
self._phrase = self.get_parameter('auto_enroll_phrase').value
|
||||
self._n_samples = self.get_parameter('n_samples_default').value
|
||||
|
||||
self._db = PersonDB(db_path)
|
||||
self.get_logger().info(f'PersonDB initialized at {db_path}')
|
||||
|
||||
# Client to face_recognizer's enroll service
|
||||
self._enroll_client = self.create_client(
|
||||
EnrollPerson, '/social/face_recognizer/enroll'
|
||||
)
|
||||
|
||||
# QoS profiles
|
||||
best_effort_qos = QoSProfile(
|
||||
depth=10,
|
||||
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||
durability=DurabilityPolicy.VOLATILE,
|
||||
)
|
||||
reliable_qos = QoSProfile(
|
||||
depth=1,
|
||||
reliability=ReliabilityPolicy.RELIABLE,
|
||||
durability=DurabilityPolicy.VOLATILE,
|
||||
)
|
||||
status_qos = QoSProfile(
|
||||
depth=1,
|
||||
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||
durability=DurabilityPolicy.VOLATILE,
|
||||
)
|
||||
|
||||
# Subscriptions
|
||||
self.create_subscription(
|
||||
FaceDetectionArray, '/social/faces/detections',
|
||||
self._on_detections, best_effort_qos
|
||||
)
|
||||
self.create_subscription(
|
||||
FaceEmbeddingArray, '/social/faces/embeddings',
|
||||
self._on_embeddings, reliable_qos
|
||||
)
|
||||
self.create_subscription(
|
||||
String, '/social/speech/transcript',
|
||||
self._on_transcript, best_effort_qos
|
||||
)
|
||||
self.create_subscription(
|
||||
String, '/social/speech/command',
|
||||
self._on_command, best_effort_qos
|
||||
)
|
||||
|
||||
# Services
|
||||
self.create_service(EnrollPerson, '/social/enroll', self._handle_enroll)
|
||||
self.create_service(ListPersons, '/social/persons/list', self._handle_list)
|
||||
self.create_service(DeletePerson, '/social/persons/delete', self._handle_delete)
|
||||
self.create_service(UpdatePerson, '/social/persons/update', self._handle_update)
|
||||
|
||||
# Publishers
|
||||
self._pub_embeddings = self.create_publisher(
|
||||
FaceEmbeddingArray, '/social/faces/embeddings', reliable_qos
|
||||
)
|
||||
self._pub_status = self.create_publisher(
|
||||
String, '/social/enrollment/status', status_qos
|
||||
)
|
||||
|
||||
self.get_logger().info('EnrollmentNode ready')
|
||||
|
||||
# ---- Voice-triggered enrollment ----
|
||||
|
||||
def _on_transcript(self, msg: String):
|
||||
text = msg.data.lower()
|
||||
phrase = self._phrase.lower()
|
||||
if phrase in text:
|
||||
idx = text.index(phrase) + len(phrase)
|
||||
name = text[idx:].strip()
|
||||
# Clean up: take first 3 words max as the name
|
||||
words = name.split()
|
||||
if words:
|
||||
name = ' '.join(words[:3]).title()
|
||||
self.get_logger().info(f'Voice enrollment triggered: "{name}"')
|
||||
self._trigger_voice_enroll(name)
|
||||
|
||||
def _on_command(self, msg: String):
|
||||
# Reserved for explicit voice commands (e.g., "enroll Alice")
|
||||
pass
|
||||
|
||||
def _trigger_voice_enroll(self, name: str):
|
||||
if not self._enroll_client.wait_for_service(timeout_sec=1.0):
|
||||
self.get_logger().warn(
|
||||
'face_recognizer enroll service not available'
|
||||
)
|
||||
self._publish_status(f'Enrollment failed: face_recognizer unavailable')
|
||||
return
|
||||
|
||||
req = EnrollPerson.Request()
|
||||
req.name = name
|
||||
req.mode = 'face'
|
||||
req.n_samples = self._n_samples
|
||||
future = self._enroll_client.call_async(req)
|
||||
future.add_done_callback(
|
||||
lambda f: self._on_enroll_done(f, name)
|
||||
)
|
||||
self._publish_status(f'Enrolling "{name}"... look at the camera')
|
||||
|
||||
def _on_enroll_done(self, future, name: str):
|
||||
try:
|
||||
resp = future.result()
|
||||
if resp.success:
|
||||
status = f'Enrolled "{name}" (id={resp.person_id})'
|
||||
self.get_logger().info(status)
|
||||
else:
|
||||
status = f'Enrollment failed for "{name}": {resp.message}'
|
||||
self.get_logger().warn(status)
|
||||
self._publish_status(status)
|
||||
except Exception as e:
|
||||
self.get_logger().error(f'Enroll call failed: {e}')
|
||||
self._publish_status(f'Enrollment error: {e}')
|
||||
|
||||
# ---- Face detection callback (during enrollment) ----
|
||||
|
||||
def _on_detections(self, msg: FaceDetectionArray):
|
||||
# Reserved for future direct enrollment (without face_recognizer)
|
||||
pass
|
||||
|
||||
# ---- Embeddings sync from face_recognizer ----
|
||||
|
||||
def _on_embeddings(self, msg: FaceEmbeddingArray):
|
||||
for emb in msg.embeddings:
|
||||
existing = self._db.get_person(emb.person_id)
|
||||
if existing is None:
|
||||
arr = np.array(emb.embedding, dtype=np.float32)
|
||||
self._db.add_person(
|
||||
emb.person_name, arr, emb.sample_count
|
||||
)
|
||||
self.get_logger().info(
|
||||
f'Synced new person from face_recognizer: '
|
||||
f'{emb.person_name} (id={emb.person_id})'
|
||||
)
|
||||
|
||||
# ---- Service handlers ----
|
||||
|
||||
def _handle_enroll(self, request, response):
|
||||
"""Forward enrollment to face_recognizer service."""
|
||||
if not self._enroll_client.wait_for_service(timeout_sec=2.0):
|
||||
response.success = False
|
||||
response.message = 'face_recognizer service unavailable'
|
||||
return response
|
||||
|
||||
# Use threading.Event to bridge async call in service callback
|
||||
event = threading.Event()
|
||||
result_holder = {}
|
||||
|
||||
req = EnrollPerson.Request()
|
||||
req.name = request.name
|
||||
req.mode = request.mode
|
||||
req.n_samples = request.n_samples
|
||||
|
||||
future = self._enroll_client.call_async(req)
|
||||
|
||||
def _done(f):
|
||||
try:
|
||||
result_holder['resp'] = f.result()
|
||||
except Exception as e:
|
||||
result_holder['err'] = str(e)
|
||||
event.set()
|
||||
|
||||
future.add_done_callback(_done)
|
||||
success = event.wait(timeout=35.0)
|
||||
|
||||
if not success:
|
||||
response.success = False
|
||||
response.message = 'Enrollment timed out'
|
||||
elif 'resp' in result_holder:
|
||||
resp = result_holder['resp']
|
||||
response.success = resp.success
|
||||
response.message = resp.message
|
||||
response.person_id = resp.person_id
|
||||
else:
|
||||
response.success = False
|
||||
response.message = result_holder.get('err', 'Unknown error')
|
||||
|
||||
return response
|
||||
|
||||
def _handle_list(self, request, response):
|
||||
persons = self._db.get_all()
|
||||
response.persons = []
|
||||
for p in persons:
|
||||
fe = FaceEmbedding()
|
||||
fe.person_id = p['id']
|
||||
fe.person_name = p['name']
|
||||
fe.sample_count = p['sample_count']
|
||||
fe.enrolled_at.sec = int(p['enrolled_at'])
|
||||
fe.enrolled_at.nanosec = int(
|
||||
(p['enrolled_at'] - int(p['enrolled_at'])) * 1e9
|
||||
)
|
||||
if p['embedding'] is not None:
|
||||
fe.embedding = p['embedding'].tolist()
|
||||
response.persons.append(fe)
|
||||
return response
|
||||
|
||||
def _handle_delete(self, request, response):
|
||||
success = self._db.delete_person(request.person_id)
|
||||
response.success = success
|
||||
if success:
|
||||
response.message = f'Deleted person {request.person_id}'
|
||||
self.get_logger().info(response.message)
|
||||
self._publish_embeddings_from_db()
|
||||
else:
|
||||
response.message = f'Person {request.person_id} not found'
|
||||
return response
|
||||
|
||||
def _handle_update(self, request, response):
|
||||
success = self._db.update_name(request.person_id, request.new_name)
|
||||
response.success = success
|
||||
if success:
|
||||
response.message = f'Updated person {request.person_id} name to "{request.new_name}"'
|
||||
self.get_logger().info(response.message)
|
||||
self._publish_embeddings_from_db()
|
||||
else:
|
||||
response.message = f'Person {request.person_id} not found'
|
||||
return response
|
||||
|
||||
# ---- Helpers ----
|
||||
|
||||
def _publish_status(self, text: str):
|
||||
msg = String()
|
||||
msg.data = text
|
||||
self._pub_status.publish(msg)
|
||||
|
||||
def _publish_embeddings_from_db(self):
|
||||
persons = self._db.get_all()
|
||||
arr = FaceEmbeddingArray()
|
||||
arr.header.stamp = self.get_clock().now().to_msg()
|
||||
for p in persons:
|
||||
if p['embedding'] is not None:
|
||||
fe = FaceEmbedding()
|
||||
fe.person_id = p['id']
|
||||
fe.person_name = p['name']
|
||||
fe.sample_count = p['sample_count']
|
||||
fe.enrolled_at.sec = int(p['enrolled_at'])
|
||||
fe.enrolled_at.nanosec = int(
|
||||
(p['enrolled_at'] - int(p['enrolled_at'])) * 1e9
|
||||
)
|
||||
fe.embedding = p['embedding'].tolist()
|
||||
arr.embeddings.append(fe)
|
||||
self._pub_embeddings.publish(arr)
|
||||
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = EnrollmentNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.try_shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -0,0 +1,138 @@
|
||||
"""person_db.py -- Persistent SQLite person gallery for saltybot enrollment."""
|
||||
|
||||
import sqlite3
|
||||
import json
|
||||
import time
|
||||
import numpy as np
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class PersonDB:
|
||||
"""Thread-safe SQLite-backed person gallery.
|
||||
|
||||
Schema:
|
||||
persons(id INTEGER PRIMARY KEY, name TEXT, enrolled_at REAL,
|
||||
sample_count INTEGER, embedding BLOB, metadata TEXT)
|
||||
voice_samples(id INTEGER PRIMARY KEY, person_id INTEGER REFERENCES persons,
|
||||
recorded_at REAL, sample_path TEXT)
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str):
|
||||
self._db_path = db_path
|
||||
self._lock = threading.Lock()
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
self._init_db()
|
||||
|
||||
def _init_db(self):
|
||||
with self._connect() as conn:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS persons (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
enrolled_at REAL NOT NULL,
|
||||
sample_count INTEGER DEFAULT 1,
|
||||
embedding BLOB,
|
||||
metadata TEXT DEFAULT '{}'
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS voice_samples (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
person_id INTEGER REFERENCES persons(id),
|
||||
recorded_at REAL NOT NULL,
|
||||
sample_path TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
def _connect(self):
|
||||
return sqlite3.connect(self._db_path)
|
||||
|
||||
def add_person(self, name: str, embedding: np.ndarray, sample_count: int = 1,
|
||||
metadata: dict = None) -> int:
|
||||
"""Add a new person. Returns new person_id."""
|
||||
emb_blob = embedding.astype(np.float32).tobytes() if embedding is not None else None
|
||||
now = time.time()
|
||||
with self._lock:
|
||||
with self._connect() as conn:
|
||||
cur = conn.execute(
|
||||
"INSERT INTO persons (name, enrolled_at, sample_count, embedding, metadata) "
|
||||
"VALUES (?, ?, ?, ?, ?)",
|
||||
(name, now, sample_count, emb_blob, json.dumps(metadata or {}))
|
||||
)
|
||||
return cur.lastrowid
|
||||
|
||||
def update_embedding(self, person_id: int, embedding: np.ndarray,
|
||||
sample_count: int) -> bool:
|
||||
emb_blob = embedding.astype(np.float32).tobytes()
|
||||
with self._lock:
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"UPDATE persons SET embedding=?, sample_count=? WHERE id=?",
|
||||
(emb_blob, sample_count, person_id)
|
||||
)
|
||||
return conn.total_changes > 0
|
||||
|
||||
def update_name(self, person_id: int, new_name: str) -> bool:
|
||||
with self._lock:
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"UPDATE persons SET name=? WHERE id=?",
|
||||
(new_name, person_id)
|
||||
)
|
||||
return conn.total_changes > 0
|
||||
|
||||
def delete_person(self, person_id: int) -> bool:
|
||||
with self._lock:
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"DELETE FROM voice_samples WHERE person_id=?",
|
||||
(person_id,)
|
||||
)
|
||||
conn.execute(
|
||||
"DELETE FROM persons WHERE id=?",
|
||||
(person_id,)
|
||||
)
|
||||
return conn.total_changes > 0
|
||||
|
||||
def get_all(self) -> list:
|
||||
"""Returns list of dicts with id, name, enrolled_at, sample_count, embedding."""
|
||||
with self._lock:
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT id, name, enrolled_at, sample_count, embedding FROM persons"
|
||||
).fetchall()
|
||||
result = []
|
||||
for row in rows:
|
||||
emb = np.frombuffer(row[4], dtype=np.float32).copy() if row[4] else None
|
||||
result.append({
|
||||
'id': row[0], 'name': row[1], 'enrolled_at': row[2],
|
||||
'sample_count': row[3], 'embedding': emb
|
||||
})
|
||||
return result
|
||||
|
||||
def get_person(self, person_id: int) -> dict | None:
|
||||
with self._lock:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT id, name, enrolled_at, sample_count, embedding "
|
||||
"FROM persons WHERE id=?",
|
||||
(person_id,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
emb = np.frombuffer(row[4], dtype=np.float32).copy() if row[4] else None
|
||||
return {
|
||||
'id': row[0], 'name': row[1], 'enrolled_at': row[2],
|
||||
'sample_count': row[3], 'embedding': emb
|
||||
}
|
||||
|
||||
def add_voice_sample(self, person_id: int, sample_path: str) -> int:
|
||||
with self._lock:
|
||||
with self._connect() as conn:
|
||||
cur = conn.execute(
|
||||
"INSERT INTO voice_samples (person_id, recorded_at, sample_path) "
|
||||
"VALUES (?, ?, ?)",
|
||||
(person_id, time.time(), sample_path)
|
||||
)
|
||||
return cur.lastrowid
|
||||
4
jetson/ros2_ws/src/saltybot_social_enrollment/setup.cfg
Normal file
4
jetson/ros2_ws/src/saltybot_social_enrollment/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/saltybot_social_enrollment
|
||||
[install]
|
||||
install_scripts=$base/lib/saltybot_social_enrollment
|
||||
29
jetson/ros2_ws/src/saltybot_social_enrollment/setup.py
Normal file
29
jetson/ros2_ws/src/saltybot_social_enrollment/setup.py
Normal file
@ -0,0 +1,29 @@
|
||||
from setuptools import setup
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
package_name = 'saltybot_social_enrollment'
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version='0.1.0',
|
||||
packages=[package_name],
|
||||
data_files=[
|
||||
('share/ament_index/resource_index/packages', ['resource/' + package_name]),
|
||||
('share/' + package_name, ['package.xml']),
|
||||
(os.path.join('share', package_name, 'launch'), glob('launch/*.py')),
|
||||
(os.path.join('share', package_name, 'config'), glob('config/*.yaml')),
|
||||
],
|
||||
install_requires=['setuptools'],
|
||||
zip_safe=True,
|
||||
maintainer='seb',
|
||||
maintainer_email='seb@vayrette.com',
|
||||
description='Person enrollment system for saltybot social interaction',
|
||||
license='MIT',
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'enrollment_node = saltybot_social_enrollment.enrollment_node:main',
|
||||
'enrollment_cli = saltybot_social_enrollment.enrollment_cli:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
@ -0,0 +1,11 @@
|
||||
face_recognizer:
|
||||
ros__parameters:
|
||||
scrfd_engine_path: '/mnt/nvme/saltybot/models/scrfd_2.5g.engine'
|
||||
scrfd_onnx_path: '/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx'
|
||||
arcface_engine_path: '/mnt/nvme/saltybot/models/arcface_r50.engine'
|
||||
arcface_onnx_path: '/mnt/nvme/saltybot/models/arcface_r50.onnx'
|
||||
gallery_dir: '/mnt/nvme/saltybot/gallery'
|
||||
recognition_threshold: 0.35
|
||||
publish_debug_image: false
|
||||
max_faces: 10
|
||||
scrfd_conf_thresh: 0.5
|
||||
@ -0,0 +1,80 @@
|
||||
"""
|
||||
face_recognition.launch.py -- Launch file for the SCRFD + ArcFace face recognition node.
|
||||
|
||||
Launches the face_recognizer node with configurable model paths and parameters.
|
||||
The RealSense camera must be running separately (e.g., via realsense.launch.py).
|
||||
"""
|
||||
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
"""Generate launch description for face recognition pipeline."""
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument(
|
||||
'scrfd_engine_path',
|
||||
default_value='/mnt/nvme/saltybot/models/scrfd_2.5g.engine',
|
||||
description='Path to SCRFD TensorRT engine file',
|
||||
),
|
||||
DeclareLaunchArgument(
|
||||
'scrfd_onnx_path',
|
||||
default_value='/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx',
|
||||
description='Path to SCRFD ONNX model file (fallback)',
|
||||
),
|
||||
DeclareLaunchArgument(
|
||||
'arcface_engine_path',
|
||||
default_value='/mnt/nvme/saltybot/models/arcface_r50.engine',
|
||||
description='Path to ArcFace TensorRT engine file',
|
||||
),
|
||||
DeclareLaunchArgument(
|
||||
'arcface_onnx_path',
|
||||
default_value='/mnt/nvme/saltybot/models/arcface_r50.onnx',
|
||||
description='Path to ArcFace ONNX model file (fallback)',
|
||||
),
|
||||
DeclareLaunchArgument(
|
||||
'gallery_dir',
|
||||
default_value='/mnt/nvme/saltybot/gallery',
|
||||
description='Directory for persistent face gallery storage',
|
||||
),
|
||||
DeclareLaunchArgument(
|
||||
'recognition_threshold',
|
||||
default_value='0.35',
|
||||
description='Cosine similarity threshold for face recognition',
|
||||
),
|
||||
DeclareLaunchArgument(
|
||||
'publish_debug_image',
|
||||
default_value='false',
|
||||
description='Publish annotated debug image to /social/faces/debug_image',
|
||||
),
|
||||
DeclareLaunchArgument(
|
||||
'max_faces',
|
||||
default_value='10',
|
||||
description='Maximum faces to process per frame',
|
||||
),
|
||||
DeclareLaunchArgument(
|
||||
'scrfd_conf_thresh',
|
||||
default_value='0.5',
|
||||
description='SCRFD detection confidence threshold',
|
||||
),
|
||||
|
||||
Node(
|
||||
package='saltybot_social_face',
|
||||
executable='face_recognition',
|
||||
name='face_recognizer',
|
||||
output='screen',
|
||||
parameters=[{
|
||||
'scrfd_engine_path': LaunchConfiguration('scrfd_engine_path'),
|
||||
'scrfd_onnx_path': LaunchConfiguration('scrfd_onnx_path'),
|
||||
'arcface_engine_path': LaunchConfiguration('arcface_engine_path'),
|
||||
'arcface_onnx_path': LaunchConfiguration('arcface_onnx_path'),
|
||||
'gallery_dir': LaunchConfiguration('gallery_dir'),
|
||||
'recognition_threshold': LaunchConfiguration('recognition_threshold'),
|
||||
'publish_debug_image': LaunchConfiguration('publish_debug_image'),
|
||||
'max_faces': LaunchConfiguration('max_faces'),
|
||||
'scrfd_conf_thresh': LaunchConfiguration('scrfd_conf_thresh'),
|
||||
}],
|
||||
),
|
||||
])
|
||||
27
jetson/ros2_ws/src/saltybot_social_face/package.xml
Normal file
27
jetson/ros2_ws/src/saltybot_social_face/package.xml
Normal file
@ -0,0 +1,27 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_social_face</name>
|
||||
<version>0.1.0</version>
|
||||
<description>SCRFD face detection and ArcFace recognition for SaltyBot social interactions</description>
|
||||
<maintainer email="seb@vayrette.com">seb</maintainer>
|
||||
<license>MIT</license>
|
||||
|
||||
<depend>rclpy</depend>
|
||||
<depend>sensor_msgs</depend>
|
||||
<depend>cv_bridge</depend>
|
||||
<depend>image_transport</depend>
|
||||
<depend>saltybot_social_msgs</depend>
|
||||
|
||||
<exec_depend>python3-numpy</exec_depend>
|
||||
<exec_depend>python3-opencv</exec_depend>
|
||||
|
||||
<test_depend>ament_copyright</test_depend>
|
||||
<test_depend>ament_flake8</test_depend>
|
||||
<test_depend>ament_pep257</test_depend>
|
||||
<test_depend>python3-pytest</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1 @@
|
||||
"""SaltyBot social face detection and recognition package."""
|
||||
@ -0,0 +1,316 @@
|
||||
"""
|
||||
arcface_recognizer.py -- ArcFace face embedding extraction and gallery matching.
|
||||
|
||||
Performs face alignment using 5-point landmarks (insightface standard reference),
|
||||
extracts 512-dimensional embeddings via ArcFace (TRT FP16 or ONNX fallback),
|
||||
and matches against a persistent gallery using cosine similarity.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# InsightFace standard reference landmarks for 112x112 alignment
|
||||
ARCFACE_SRC = np.array([
|
||||
[38.2946, 51.6963], # left eye
|
||||
[73.5318, 51.5014], # right eye
|
||||
[56.0252, 71.7366], # nose
|
||||
[41.5493, 92.3655], # left mouth
|
||||
[70.7299, 92.2041], # right mouth
|
||||
], dtype=np.float32)
|
||||
|
||||
|
||||
# -- Inference backends --------------------------------------------------------
|
||||
|
||||
class _TRTBackend:
|
||||
"""TensorRT inference engine for ArcFace."""
|
||||
|
||||
def __init__(self, engine_path: str):
|
||||
import tensorrt as trt
|
||||
import pycuda.driver as cuda
|
||||
import pycuda.autoinit # noqa: F401
|
||||
|
||||
self._cuda = cuda
|
||||
trt_logger = trt.Logger(trt.Logger.WARNING)
|
||||
with open(engine_path, 'rb') as f, trt.Runtime(trt_logger) as runtime:
|
||||
self._engine = runtime.deserialize_cuda_engine(f.read())
|
||||
self._context = self._engine.create_execution_context()
|
||||
|
||||
self._inputs = []
|
||||
self._outputs = []
|
||||
self._bindings = []
|
||||
for i in range(self._engine.num_io_tensors):
|
||||
name = self._engine.get_tensor_name(i)
|
||||
dtype = trt.nptype(self._engine.get_tensor_dtype(name))
|
||||
shape = tuple(self._engine.get_tensor_shape(name))
|
||||
nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize
|
||||
host_mem = cuda.pagelocked_empty(shape, dtype)
|
||||
device_mem = cuda.mem_alloc(nbytes)
|
||||
self._bindings.append(int(device_mem))
|
||||
if self._engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
|
||||
self._inputs.append({'host': host_mem, 'device': device_mem})
|
||||
else:
|
||||
self._outputs.append({'host': host_mem, 'device': device_mem,
|
||||
'shape': shape})
|
||||
self._stream = cuda.Stream()
|
||||
|
||||
def infer(self, input_data: np.ndarray) -> np.ndarray:
|
||||
"""Run inference and return the embedding vector."""
|
||||
np.copyto(self._inputs[0]['host'], input_data.ravel())
|
||||
self._cuda.memcpy_htod_async(
|
||||
self._inputs[0]['device'], self._inputs[0]['host'], self._stream)
|
||||
self._context.execute_async_v2(self._bindings, self._stream.handle)
|
||||
for out in self._outputs:
|
||||
self._cuda.memcpy_dtoh_async(out['host'], out['device'], self._stream)
|
||||
self._stream.synchronize()
|
||||
return self._outputs[0]['host'].reshape(self._outputs[0]['shape']).copy()
|
||||
|
||||
|
||||
class _ONNXBackend:
|
||||
"""ONNX Runtime inference (CUDA EP with CPU fallback)."""
|
||||
|
||||
def __init__(self, onnx_path: str):
|
||||
import onnxruntime as ort
|
||||
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
self._session = ort.InferenceSession(onnx_path, providers=providers)
|
||||
self._input_name = self._session.get_inputs()[0].name
|
||||
|
||||
def infer(self, input_data: np.ndarray) -> np.ndarray:
|
||||
"""Run inference and return the embedding vector."""
|
||||
results = self._session.run(None, {self._input_name: input_data})
|
||||
return results[0]
|
||||
|
||||
|
||||
# -- Face alignment ------------------------------------------------------------
|
||||
|
||||
def align_face(bgr: np.ndarray, landmarks_10: list[float]) -> np.ndarray:
|
||||
"""Align a face to 112x112 using 5-point landmarks.
|
||||
|
||||
Args:
|
||||
bgr: Source BGR image.
|
||||
landmarks_10: Flat list of 10 floats [x0,y0, x1,y1, ..., x4,y4].
|
||||
|
||||
Returns:
|
||||
Aligned BGR face crop of shape (112, 112, 3).
|
||||
"""
|
||||
src_pts = np.array(landmarks_10, dtype=np.float32).reshape(5, 2)
|
||||
M, _ = cv2.estimateAffinePartial2D(src_pts, ARCFACE_SRC)
|
||||
if M is None:
|
||||
# Fallback: simple crop and resize from bbox-like region
|
||||
cx = np.mean(src_pts[:, 0])
|
||||
cy = np.mean(src_pts[:, 1])
|
||||
spread = max(np.ptp(src_pts[:, 0]), np.ptp(src_pts[:, 1])) * 1.5
|
||||
half = spread / 2
|
||||
x1 = max(0, int(cx - half))
|
||||
y1 = max(0, int(cy - half))
|
||||
x2 = min(bgr.shape[1], int(cx + half))
|
||||
y2 = min(bgr.shape[0], int(cy + half))
|
||||
crop = bgr[y1:y2, x1:x2]
|
||||
return cv2.resize(crop, (112, 112), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
aligned = cv2.warpAffine(bgr, M, (112, 112), borderMode=cv2.BORDER_REPLICATE)
|
||||
return aligned
|
||||
|
||||
|
||||
# -- Main recognizer class -----------------------------------------------------
|
||||
|
||||
class ArcFaceRecognizer:
|
||||
"""ArcFace face embedding extractor and gallery matcher.
|
||||
|
||||
Args:
|
||||
engine_path: Path to TensorRT engine file.
|
||||
onnx_path: Path to ONNX model file (used if engine not available).
|
||||
"""
|
||||
|
||||
def __init__(self, engine_path: str = '', onnx_path: str = ''):
|
||||
self._backend: Optional[_TRTBackend | _ONNXBackend] = None
|
||||
self.gallery: dict[int, dict] = {}
|
||||
|
||||
# Try TRT first, then ONNX
|
||||
if engine_path and os.path.isfile(engine_path):
|
||||
try:
|
||||
self._backend = _TRTBackend(engine_path)
|
||||
logger.info('ArcFace TensorRT backend loaded: %s', engine_path)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning('ArcFace TRT load failed (%s), trying ONNX', e)
|
||||
|
||||
if onnx_path and os.path.isfile(onnx_path):
|
||||
try:
|
||||
self._backend = _ONNXBackend(onnx_path)
|
||||
logger.info('ArcFace ONNX backend loaded: %s', onnx_path)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error('ArcFace ONNX load failed: %s', e)
|
||||
|
||||
logger.error('No ArcFace model loaded. Recognition will be unavailable.')
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Return True if a backend is loaded and ready."""
|
||||
return self._backend is not None
|
||||
|
||||
def embed(self, bgr_face_112x112: np.ndarray) -> np.ndarray:
|
||||
"""Extract 512-dim L2-normalized embedding from a 112x112 aligned face.
|
||||
|
||||
Args:
|
||||
bgr_face_112x112: Aligned face crop, BGR, shape (112, 112, 3).
|
||||
|
||||
Returns:
|
||||
L2-normalized embedding of shape (512,).
|
||||
"""
|
||||
if self._backend is None:
|
||||
return np.zeros(512, dtype=np.float32)
|
||||
|
||||
# Preprocess: BGR->RGB, /255, subtract 0.5, divide 0.5 -> [1,3,112,112]
|
||||
rgb = cv2.cvtColor(bgr_face_112x112, cv2.COLOR_BGR2RGB).astype(np.float32)
|
||||
rgb = rgb / 255.0
|
||||
rgb = (rgb - 0.5) / 0.5
|
||||
blob = rgb.transpose(2, 0, 1)[np.newaxis] # [1, 3, 112, 112]
|
||||
blob = np.ascontiguousarray(blob)
|
||||
|
||||
output = self._backend.infer(blob)
|
||||
embedding = output.flatten()[:512].astype(np.float32)
|
||||
|
||||
# L2 normalize
|
||||
norm = np.linalg.norm(embedding)
|
||||
if norm > 0:
|
||||
embedding = embedding / norm
|
||||
return embedding
|
||||
|
||||
def align_and_embed(self, bgr_image: np.ndarray, landmarks_10: list[float]) -> np.ndarray:
|
||||
"""Align face using landmarks and extract embedding.
|
||||
|
||||
Args:
|
||||
bgr_image: Full BGR image.
|
||||
landmarks_10: Flat list of 10 floats from SCRFD detection.
|
||||
|
||||
Returns:
|
||||
L2-normalized embedding of shape (512,).
|
||||
"""
|
||||
aligned = align_face(bgr_image, landmarks_10)
|
||||
return self.embed(aligned)
|
||||
|
||||
def load_gallery(self, gallery_path: str) -> None:
|
||||
"""Load gallery from .npz file with JSON metadata sidecar.
|
||||
|
||||
Args:
|
||||
gallery_path: Path to the .npz gallery file.
|
||||
"""
|
||||
import json
|
||||
|
||||
if not os.path.isfile(gallery_path):
|
||||
logger.info('No gallery file at %s, starting empty.', gallery_path)
|
||||
self.gallery = {}
|
||||
return
|
||||
|
||||
data = np.load(gallery_path, allow_pickle=False)
|
||||
meta_path = gallery_path.replace('.npz', '_meta.json')
|
||||
|
||||
if os.path.isfile(meta_path):
|
||||
with open(meta_path, 'r') as f:
|
||||
meta = json.load(f)
|
||||
else:
|
||||
meta = {}
|
||||
|
||||
self.gallery = {}
|
||||
for key in data.files:
|
||||
pid = int(key)
|
||||
embedding = data[key].astype(np.float32)
|
||||
norm = np.linalg.norm(embedding)
|
||||
if norm > 0:
|
||||
embedding = embedding / norm
|
||||
info = meta.get(str(pid), {})
|
||||
self.gallery[pid] = {
|
||||
'name': info.get('name', f'person_{pid}'),
|
||||
'embedding': embedding,
|
||||
'samples': info.get('samples', 1),
|
||||
'enrolled_at': info.get('enrolled_at', 0.0),
|
||||
}
|
||||
logger.info('Gallery loaded: %d persons from %s', len(self.gallery), gallery_path)
|
||||
|
||||
def save_gallery(self, gallery_path: str) -> None:
|
||||
"""Save gallery to .npz file with JSON metadata sidecar.
|
||||
|
||||
Args:
|
||||
gallery_path: Path to the .npz gallery file.
|
||||
"""
|
||||
import json
|
||||
|
||||
arrays = {}
|
||||
meta = {}
|
||||
for pid, info in self.gallery.items():
|
||||
arrays[str(pid)] = info['embedding']
|
||||
meta[str(pid)] = {
|
||||
'name': info['name'],
|
||||
'samples': info['samples'],
|
||||
'enrolled_at': info['enrolled_at'],
|
||||
}
|
||||
|
||||
os.makedirs(os.path.dirname(gallery_path) or '.', exist_ok=True)
|
||||
np.savez(gallery_path, **arrays)
|
||||
|
||||
meta_path = gallery_path.replace('.npz', '_meta.json')
|
||||
with open(meta_path, 'w') as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
logger.info('Gallery saved: %d persons to %s', len(self.gallery), gallery_path)
|
||||
|
||||
def match(self, embedding: np.ndarray, threshold: float = 0.35) -> tuple[int, str, float]:
|
||||
"""Match an embedding against the gallery.
|
||||
|
||||
Args:
|
||||
embedding: L2-normalized query embedding of shape (512,).
|
||||
threshold: Minimum cosine similarity for a match.
|
||||
|
||||
Returns:
|
||||
(person_id, person_name, score) or (-1, '', 0.0) if no match.
|
||||
"""
|
||||
if not self.gallery:
|
||||
return (-1, '', 0.0)
|
||||
|
||||
best_pid = -1
|
||||
best_name = ''
|
||||
best_score = 0.0
|
||||
|
||||
for pid, info in self.gallery.items():
|
||||
score = float(np.dot(embedding, info['embedding']))
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_pid = pid
|
||||
best_name = info['name']
|
||||
|
||||
if best_score >= threshold:
|
||||
return (best_pid, best_name, best_score)
|
||||
return (-1, '', 0.0)
|
||||
|
||||
def enroll(self, person_id: int, person_name: str, embeddings_list: list[np.ndarray]) -> None:
|
||||
"""Enroll a person by averaging multiple embeddings.
|
||||
|
||||
Args:
|
||||
person_id: Unique integer ID for this person.
|
||||
person_name: Human-readable name.
|
||||
embeddings_list: List of L2-normalized embeddings to average.
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
if not embeddings_list:
|
||||
return
|
||||
|
||||
mean_emb = np.mean(embeddings_list, axis=0).astype(np.float32)
|
||||
norm = np.linalg.norm(mean_emb)
|
||||
if norm > 0:
|
||||
mean_emb = mean_emb / norm
|
||||
|
||||
self.gallery[person_id] = {
|
||||
'name': person_name,
|
||||
'embedding': mean_emb,
|
||||
'samples': len(embeddings_list),
|
||||
'enrolled_at': _time.time(),
|
||||
}
|
||||
logger.info('Enrolled person %d (%s) with %d samples.',
|
||||
person_id, person_name, len(embeddings_list))
|
||||
@ -0,0 +1,78 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
enrollment_cli.py -- CLI tool for enrolling persons via the /social/enroll service.
|
||||
|
||||
Usage:
|
||||
ros2 run saltybot_social_face enrollment_cli -- --name Alice --mode face --samples 10
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
|
||||
from saltybot_social_msgs.srv import EnrollPerson
|
||||
|
||||
|
||||
class EnrollmentCLI(Node):
|
||||
"""Simple CLI node that calls the EnrollPerson service."""
|
||||
|
||||
def __init__(self, name: str, mode: str, n_samples: int):
|
||||
super().__init__('enrollment_cli')
|
||||
self._client = self.create_client(EnrollPerson, '/social/enroll')
|
||||
|
||||
self.get_logger().info('Waiting for /social/enroll service...')
|
||||
if not self._client.wait_for_service(timeout_sec=10.0):
|
||||
self.get_logger().error('Service /social/enroll not available.')
|
||||
return
|
||||
|
||||
request = EnrollPerson.Request()
|
||||
request.name = name
|
||||
request.mode = mode
|
||||
request.n_samples = n_samples
|
||||
|
||||
self.get_logger().info(
|
||||
'Enrolling "%s" (mode=%s, samples=%d)...', name, mode, n_samples)
|
||||
|
||||
future = self._client.call_async(request)
|
||||
rclpy.spin_until_future_complete(self, future, timeout_sec=120.0)
|
||||
|
||||
if future.result() is not None:
|
||||
result = future.result()
|
||||
if result.success:
|
||||
self.get_logger().info(
|
||||
'Enrollment successful: person_id=%d, %s',
|
||||
result.person_id, result.message)
|
||||
else:
|
||||
self.get_logger().error(
|
||||
'Enrollment failed: %s', result.message)
|
||||
else:
|
||||
self.get_logger().error('Enrollment service call timed out or failed.')
|
||||
|
||||
|
||||
def main(args=None):
|
||||
"""Entry point for enrollment CLI."""
|
||||
parser = argparse.ArgumentParser(description='Enroll a person for face recognition.')
|
||||
parser.add_argument('--name', type=str, required=True,
|
||||
help='Name of the person to enroll.')
|
||||
parser.add_argument('--mode', type=str, default='face',
|
||||
choices=['face', 'voice', 'both'],
|
||||
help='Enrollment mode (default: face).')
|
||||
parser.add_argument('--samples', type=int, default=10,
|
||||
help='Number of face samples to collect (default: 10).')
|
||||
|
||||
# Parse only known args so ROS2 remapping args pass through
|
||||
parsed, remaining = parser.parse_known_args(args=sys.argv[1:])
|
||||
|
||||
rclpy.init(args=remaining)
|
||||
node = EnrollmentCLI(parsed.name, parsed.mode, parsed.samples)
|
||||
try:
|
||||
pass # Node does all work in __init__
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -0,0 +1,206 @@
|
||||
"""
|
||||
face_gallery.py -- Persistent face embedding gallery backed by numpy .npz + JSON.
|
||||
|
||||
Thread-safe gallery storage for face recognition. Embeddings are stored in a
|
||||
.npz file, with a sidecar metadata.json containing names, sample counts, and
|
||||
enrollment timestamps. Auto-increment IDs start at 1.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FaceGallery:
|
||||
"""Persistent, thread-safe face embedding gallery.
|
||||
|
||||
Args:
|
||||
gallery_dir: Directory for gallery.npz and metadata.json files.
|
||||
"""
|
||||
|
||||
def __init__(self, gallery_dir: str):
|
||||
self._gallery_dir = gallery_dir
|
||||
self._npz_path = os.path.join(gallery_dir, 'gallery.npz')
|
||||
self._meta_path = os.path.join(gallery_dir, 'metadata.json')
|
||||
self._gallery: dict[int, dict] = {}
|
||||
self._next_id = 1
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def load(self) -> None:
|
||||
"""Load gallery from disk. Populates internal gallery dict."""
|
||||
with self._lock:
|
||||
self._gallery = {}
|
||||
self._next_id = 1
|
||||
|
||||
if not os.path.isfile(self._npz_path):
|
||||
logger.info('No gallery file at %s, starting empty.', self._npz_path)
|
||||
return
|
||||
|
||||
data = np.load(self._npz_path, allow_pickle=False)
|
||||
meta: dict = {}
|
||||
if os.path.isfile(self._meta_path):
|
||||
with open(self._meta_path, 'r') as f:
|
||||
meta = json.load(f)
|
||||
|
||||
for key in data.files:
|
||||
pid = int(key)
|
||||
embedding = data[key].astype(np.float32)
|
||||
norm = np.linalg.norm(embedding)
|
||||
if norm > 0:
|
||||
embedding = embedding / norm
|
||||
info = meta.get(str(pid), {})
|
||||
self._gallery[pid] = {
|
||||
'name': info.get('name', f'person_{pid}'),
|
||||
'embedding': embedding,
|
||||
'samples': info.get('samples', 1),
|
||||
'enrolled_at': info.get('enrolled_at', 0.0),
|
||||
}
|
||||
if pid >= self._next_id:
|
||||
self._next_id = pid + 1
|
||||
|
||||
logger.info('Gallery loaded: %d persons from %s',
|
||||
len(self._gallery), self._npz_path)
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save gallery to disk (npz + JSON sidecar)."""
|
||||
with self._lock:
|
||||
os.makedirs(self._gallery_dir, exist_ok=True)
|
||||
|
||||
arrays = {}
|
||||
meta = {}
|
||||
for pid, info in self._gallery.items():
|
||||
arrays[str(pid)] = info['embedding']
|
||||
meta[str(pid)] = {
|
||||
'name': info['name'],
|
||||
'samples': info['samples'],
|
||||
'enrolled_at': info['enrolled_at'],
|
||||
}
|
||||
|
||||
np.savez(self._npz_path, **arrays)
|
||||
with open(self._meta_path, 'w') as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
logger.info('Gallery saved: %d persons to %s',
|
||||
len(self._gallery), self._npz_path)
|
||||
|
||||
def add_person(self, name: str, embedding: np.ndarray, samples: int = 1) -> int:
|
||||
"""Add a new person to the gallery.
|
||||
|
||||
Args:
|
||||
name: Person's name.
|
||||
embedding: L2-normalized 512-dim embedding.
|
||||
samples: Number of samples used to compute the embedding.
|
||||
|
||||
Returns:
|
||||
Assigned person_id (auto-increment integer).
|
||||
"""
|
||||
with self._lock:
|
||||
pid = self._next_id
|
||||
self._next_id += 1
|
||||
emb = embedding.astype(np.float32)
|
||||
norm = np.linalg.norm(emb)
|
||||
if norm > 0:
|
||||
emb = emb / norm
|
||||
self._gallery[pid] = {
|
||||
'name': name,
|
||||
'embedding': emb,
|
||||
'samples': samples,
|
||||
'enrolled_at': time.time(),
|
||||
}
|
||||
logger.info('Added person %d (%s), %d samples.', pid, name, samples)
|
||||
return pid
|
||||
|
||||
def update_name(self, person_id: int, new_name: str) -> bool:
|
||||
"""Update a person's name.
|
||||
|
||||
Args:
|
||||
person_id: The ID of the person to update.
|
||||
new_name: New name string.
|
||||
|
||||
Returns:
|
||||
True if the person was found and updated.
|
||||
"""
|
||||
with self._lock:
|
||||
if person_id not in self._gallery:
|
||||
return False
|
||||
self._gallery[person_id]['name'] = new_name
|
||||
return True
|
||||
|
||||
def delete_person(self, person_id: int) -> bool:
|
||||
"""Remove a person from the gallery.
|
||||
|
||||
Args:
|
||||
person_id: The ID of the person to delete.
|
||||
|
||||
Returns:
|
||||
True if the person was found and removed.
|
||||
"""
|
||||
with self._lock:
|
||||
if person_id not in self._gallery:
|
||||
return False
|
||||
del self._gallery[person_id]
|
||||
logger.info('Deleted person %d.', person_id)
|
||||
return True
|
||||
|
||||
def get_all(self) -> list[dict]:
|
||||
"""Get all gallery entries.
|
||||
|
||||
Returns:
|
||||
List of dicts with keys: person_id, name, embedding, samples, enrolled_at.
|
||||
"""
|
||||
with self._lock:
|
||||
result = []
|
||||
for pid, info in self._gallery.items():
|
||||
result.append({
|
||||
'person_id': pid,
|
||||
'name': info['name'],
|
||||
'embedding': info['embedding'].copy(),
|
||||
'samples': info['samples'],
|
||||
'enrolled_at': info['enrolled_at'],
|
||||
})
|
||||
return result
|
||||
|
||||
def match(self, query_embedding: np.ndarray, threshold: float = 0.35) -> tuple[int, str, float]:
|
||||
"""Match a query embedding against the gallery using cosine similarity.
|
||||
|
||||
Args:
|
||||
query_embedding: L2-normalized 512-dim embedding.
|
||||
threshold: Minimum cosine similarity for a match.
|
||||
|
||||
Returns:
|
||||
(person_id, name, score) or (-1, '', 0.0) if no match.
|
||||
"""
|
||||
with self._lock:
|
||||
if not self._gallery:
|
||||
return (-1, '', 0.0)
|
||||
|
||||
best_pid = -1
|
||||
best_name = ''
|
||||
best_score = 0.0
|
||||
|
||||
query = query_embedding.astype(np.float32)
|
||||
norm = np.linalg.norm(query)
|
||||
if norm > 0:
|
||||
query = query / norm
|
||||
|
||||
for pid, info in self._gallery.items():
|
||||
score = float(np.dot(query, info['embedding']))
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_pid = pid
|
||||
best_name = info['name']
|
||||
|
||||
if best_score >= threshold:
|
||||
return (best_pid, best_name, best_score)
|
||||
return (-1, '', 0.0)
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self._lock:
|
||||
return len(self._gallery)
|
||||
@ -0,0 +1,431 @@
|
||||
"""
|
||||
face_recognition_node.py -- ROS2 node for SCRFD face detection + ArcFace recognition.
|
||||
|
||||
Pipeline:
|
||||
1. Subscribe to /camera/color/image_raw (RealSense D435i color stream).
|
||||
2. Run SCRFD face detection (TensorRT FP16 or ONNX fallback).
|
||||
3. For each detected face, align and extract ArcFace embedding.
|
||||
4. Match embedding against persistent gallery.
|
||||
5. Publish FaceDetectionArray with identified faces.
|
||||
|
||||
Services:
|
||||
/social/enroll -- Enroll a new person (collects N face samples).
|
||||
/social/persons/list -- List all enrolled persons.
|
||||
/social/persons/delete -- Delete a person from the gallery.
|
||||
/social/persons/update -- Update a person's name.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy, DurabilityPolicy
|
||||
|
||||
import cv2
|
||||
from cv_bridge import CvBridge
|
||||
|
||||
from sensor_msgs.msg import Image
|
||||
from builtin_interfaces.msg import Time
|
||||
|
||||
from saltybot_social_msgs.msg import (
|
||||
FaceDetection,
|
||||
FaceDetectionArray,
|
||||
FaceEmbedding,
|
||||
FaceEmbeddingArray,
|
||||
)
|
||||
from saltybot_social_msgs.srv import (
|
||||
EnrollPerson,
|
||||
ListPersons,
|
||||
DeletePerson,
|
||||
UpdatePerson,
|
||||
)
|
||||
|
||||
from .scrfd_detector import SCRFDDetector
|
||||
from .arcface_recognizer import ArcFaceRecognizer
|
||||
from .face_gallery import FaceGallery
|
||||
|
||||
|
||||
class FaceRecognitionNode(Node):
|
||||
"""ROS2 node: SCRFD face detection + ArcFace gallery matching."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__('face_recognizer')
|
||||
self._bridge = CvBridge()
|
||||
self._frame_count = 0
|
||||
self._fps_t0 = time.monotonic()
|
||||
|
||||
# -- Parameters --------------------------------------------------------
|
||||
self.declare_parameter('scrfd_engine_path',
|
||||
'/mnt/nvme/saltybot/models/scrfd_2.5g.engine')
|
||||
self.declare_parameter('scrfd_onnx_path',
|
||||
'/mnt/nvme/saltybot/models/scrfd_2.5g_bnkps.onnx')
|
||||
self.declare_parameter('arcface_engine_path',
|
||||
'/mnt/nvme/saltybot/models/arcface_r50.engine')
|
||||
self.declare_parameter('arcface_onnx_path',
|
||||
'/mnt/nvme/saltybot/models/arcface_r50.onnx')
|
||||
self.declare_parameter('gallery_dir', '/mnt/nvme/saltybot/gallery')
|
||||
self.declare_parameter('recognition_threshold', 0.35)
|
||||
self.declare_parameter('publish_debug_image', False)
|
||||
self.declare_parameter('max_faces', 10)
|
||||
self.declare_parameter('scrfd_conf_thresh', 0.5)
|
||||
|
||||
self._recognition_threshold = self.get_parameter('recognition_threshold').value
|
||||
self._pub_debug_flag = self.get_parameter('publish_debug_image').value
|
||||
self._max_faces = self.get_parameter('max_faces').value
|
||||
|
||||
# -- Models ------------------------------------------------------------
|
||||
self._detector = SCRFDDetector(
|
||||
engine_path=self.get_parameter('scrfd_engine_path').value,
|
||||
onnx_path=self.get_parameter('scrfd_onnx_path').value,
|
||||
conf_thresh=self.get_parameter('scrfd_conf_thresh').value,
|
||||
)
|
||||
self._recognizer = ArcFaceRecognizer(
|
||||
engine_path=self.get_parameter('arcface_engine_path').value,
|
||||
onnx_path=self.get_parameter('arcface_onnx_path').value,
|
||||
)
|
||||
|
||||
# -- Gallery -----------------------------------------------------------
|
||||
gallery_dir = self.get_parameter('gallery_dir').value
|
||||
self._gallery = FaceGallery(gallery_dir)
|
||||
self._gallery.load()
|
||||
self.get_logger().info('Gallery loaded: %d persons.', len(self._gallery))
|
||||
|
||||
# -- Enrollment state --------------------------------------------------
|
||||
self._enrolling = None # {name, samples_needed, collected: [embeddings]}
|
||||
|
||||
# -- QoS profiles ------------------------------------------------------
|
||||
best_effort_qos = QoSProfile(
|
||||
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||
history=HistoryPolicy.KEEP_LAST,
|
||||
depth=1,
|
||||
)
|
||||
reliable_qos = QoSProfile(
|
||||
reliability=ReliabilityPolicy.RELIABLE,
|
||||
durability=DurabilityPolicy.TRANSIENT_LOCAL,
|
||||
history=HistoryPolicy.KEEP_LAST,
|
||||
depth=1,
|
||||
)
|
||||
|
||||
# -- Subscribers -------------------------------------------------------
|
||||
self.create_subscription(
|
||||
Image,
|
||||
'/camera/color/image_raw',
|
||||
self._on_image,
|
||||
best_effort_qos,
|
||||
)
|
||||
|
||||
# -- Publishers --------------------------------------------------------
|
||||
self._pub_detections = self.create_publisher(
|
||||
FaceDetectionArray, '/social/faces/detections', best_effort_qos)
|
||||
self._pub_embeddings = self.create_publisher(
|
||||
FaceEmbeddingArray, '/social/faces/embeddings', reliable_qos)
|
||||
if self._pub_debug_flag:
|
||||
self._pub_debug_img = self.create_publisher(
|
||||
Image, '/social/faces/debug_image', best_effort_qos)
|
||||
|
||||
# -- Services ----------------------------------------------------------
|
||||
self.create_service(EnrollPerson, '/social/enroll', self._handle_enroll)
|
||||
self.create_service(ListPersons, '/social/persons/list', self._handle_list)
|
||||
self.create_service(DeletePerson, '/social/persons/delete', self._handle_delete)
|
||||
self.create_service(UpdatePerson, '/social/persons/update', self._handle_update)
|
||||
|
||||
# Publish initial gallery state
|
||||
self._publish_gallery_embeddings()
|
||||
|
||||
self.get_logger().info('FaceRecognitionNode ready.')
|
||||
|
||||
# -- Image callback --------------------------------------------------------
|
||||
|
||||
def _on_image(self, msg: Image):
|
||||
"""Process incoming camera frame: detect, recognize, publish."""
|
||||
try:
|
||||
bgr = self._bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')
|
||||
except Exception as e:
|
||||
self.get_logger().error('Image decode error: %s', str(e),
|
||||
throttle_duration_sec=5.0)
|
||||
return
|
||||
|
||||
# Detect faces
|
||||
detections = self._detector.detect(bgr)
|
||||
|
||||
# Limit face count
|
||||
if len(detections) > self._max_faces:
|
||||
detections = sorted(detections, key=lambda d: d['score'], reverse=True)
|
||||
detections = detections[:self._max_faces]
|
||||
|
||||
# Build output message
|
||||
det_array = FaceDetectionArray()
|
||||
det_array.header = msg.header
|
||||
|
||||
for det in detections:
|
||||
# Extract embedding and match gallery
|
||||
embedding = self._recognizer.align_and_embed(bgr, det['kps'])
|
||||
pid, pname, score = self._gallery.match(
|
||||
embedding, self._recognition_threshold)
|
||||
|
||||
# Handle enrollment: collect embedding from largest face
|
||||
if self._enrolling is not None:
|
||||
self._enrollment_collect(det, embedding)
|
||||
|
||||
# Build FaceDetection message
|
||||
face_msg = FaceDetection()
|
||||
face_msg.header = msg.header
|
||||
face_msg.face_id = pid
|
||||
face_msg.person_name = pname
|
||||
face_msg.confidence = det['score']
|
||||
face_msg.recognition_score = score
|
||||
|
||||
bbox = det['bbox']
|
||||
face_msg.bbox_x = bbox[0]
|
||||
face_msg.bbox_y = bbox[1]
|
||||
face_msg.bbox_w = bbox[2] - bbox[0]
|
||||
face_msg.bbox_h = bbox[3] - bbox[1]
|
||||
|
||||
kps = det['kps']
|
||||
for i in range(10):
|
||||
face_msg.landmarks[i] = kps[i]
|
||||
|
||||
det_array.faces.append(face_msg)
|
||||
|
||||
self._pub_detections.publish(det_array)
|
||||
|
||||
# Debug image
|
||||
if self._pub_debug_flag and hasattr(self, '_pub_debug_img'):
|
||||
debug_img = self._draw_debug(bgr, detections, det_array.faces)
|
||||
self._pub_debug_img.publish(
|
||||
self._bridge.cv2_to_imgmsg(debug_img, encoding='bgr8'))
|
||||
|
||||
# FPS logging
|
||||
self._frame_count += 1
|
||||
if self._frame_count % 30 == 0:
|
||||
elapsed = time.monotonic() - self._fps_t0
|
||||
fps = 30.0 / elapsed if elapsed > 0 else 0.0
|
||||
self._fps_t0 = time.monotonic()
|
||||
self.get_logger().info(
|
||||
'FPS: %.1f | faces: %d', fps, len(detections))
|
||||
|
||||
# -- Enrollment logic ------------------------------------------------------
|
||||
|
||||
def _enrollment_collect(self, det: dict, embedding: np.ndarray):
|
||||
"""Collect an embedding sample during enrollment (largest face only)."""
|
||||
if self._enrolling is None:
|
||||
return
|
||||
|
||||
# Only collect from the largest face (by bbox area)
|
||||
bbox = det['bbox']
|
||||
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
||||
|
||||
if not hasattr(self, '_enroll_best_area'):
|
||||
self._enroll_best_area = 0.0
|
||||
self._enroll_best_embedding = None
|
||||
|
||||
if area > self._enroll_best_area:
|
||||
self._enroll_best_area = area
|
||||
self._enroll_best_embedding = embedding
|
||||
|
||||
def _enrollment_frame_end(self):
|
||||
"""Called at end of each frame to finalize enrollment sample collection."""
|
||||
if self._enrolling is None or self._enroll_best_embedding is None:
|
||||
return
|
||||
|
||||
self._enrolling['collected'].append(self._enroll_best_embedding)
|
||||
self._enroll_best_area = 0.0
|
||||
self._enroll_best_embedding = None
|
||||
|
||||
collected = len(self._enrolling['collected'])
|
||||
needed = self._enrolling['samples_needed']
|
||||
self.get_logger().info('Enrollment: %d/%d samples for "%s".',
|
||||
collected, needed, self._enrolling['name'])
|
||||
|
||||
if collected >= needed:
|
||||
# Finalize enrollment
|
||||
name = self._enrolling['name']
|
||||
embeddings = self._enrolling['collected']
|
||||
mean_emb = np.mean(embeddings, axis=0).astype(np.float32)
|
||||
norm = np.linalg.norm(mean_emb)
|
||||
if norm > 0:
|
||||
mean_emb = mean_emb / norm
|
||||
|
||||
pid = self._gallery.add_person(name, mean_emb, samples=len(embeddings))
|
||||
self._gallery.save()
|
||||
self._publish_gallery_embeddings()
|
||||
|
||||
self.get_logger().info('Enrollment complete: person %d (%s).', pid, name)
|
||||
|
||||
# Store result for the service callback
|
||||
self._enrolling['result_pid'] = pid
|
||||
self._enrolling['done'] = True
|
||||
self._enrolling = None
|
||||
|
||||
# -- Service handlers ------------------------------------------------------
|
||||
|
||||
def _handle_enroll(self, request, response):
|
||||
"""Handle EnrollPerson service: start collecting face samples."""
|
||||
name = request.name.strip()
|
||||
if not name:
|
||||
response.success = False
|
||||
response.message = 'Name cannot be empty.'
|
||||
response.person_id = -1
|
||||
return response
|
||||
|
||||
n_samples = request.n_samples if request.n_samples > 0 else 10
|
||||
|
||||
self.get_logger().info('Starting enrollment for "%s" (%d samples).',
|
||||
name, n_samples)
|
||||
|
||||
# Set enrollment state — frames will collect embeddings
|
||||
self._enrolling = {
|
||||
'name': name,
|
||||
'samples_needed': n_samples,
|
||||
'collected': [],
|
||||
'done': False,
|
||||
'result_pid': -1,
|
||||
}
|
||||
self._enroll_best_area = 0.0
|
||||
self._enroll_best_embedding = None
|
||||
|
||||
# Spin until enrollment is done (blocking service)
|
||||
rate = self.create_rate(10) # 10 Hz check
|
||||
timeout_sec = n_samples * 2.0 + 10.0 # generous timeout
|
||||
t0 = time.monotonic()
|
||||
|
||||
while not self._enrolling.get('done', False):
|
||||
# Finalize any pending frame collection
|
||||
self._enrollment_frame_end()
|
||||
|
||||
if time.monotonic() - t0 > timeout_sec:
|
||||
self._enrolling = None
|
||||
response.success = False
|
||||
response.message = f'Enrollment timed out after {timeout_sec:.0f}s.'
|
||||
response.person_id = -1
|
||||
return response
|
||||
|
||||
rclpy.spin_once(self, timeout_sec=0.1)
|
||||
|
||||
response.success = True
|
||||
response.message = f'Enrolled "{name}" with {n_samples} samples.'
|
||||
response.person_id = self._enrolling.get('result_pid', -1) if self._enrolling else -1
|
||||
|
||||
# Clean up (already set to None in _enrollment_frame_end on success)
|
||||
return response
|
||||
|
||||
def _handle_list(self, request, response):
|
||||
"""Handle ListPersons service: return all gallery entries."""
|
||||
entries = self._gallery.get_all()
|
||||
for entry in entries:
|
||||
emb_msg = FaceEmbedding()
|
||||
emb_msg.person_id = entry['person_id']
|
||||
emb_msg.person_name = entry['name']
|
||||
emb_msg.embedding = entry['embedding'].tolist()
|
||||
emb_msg.sample_count = entry['samples']
|
||||
|
||||
secs = int(entry['enrolled_at'])
|
||||
nsecs = int((entry['enrolled_at'] - secs) * 1e9)
|
||||
emb_msg.enrolled_at = Time(sec=secs, nanosec=nsecs)
|
||||
|
||||
response.persons.append(emb_msg)
|
||||
return response
|
||||
|
||||
def _handle_delete(self, request, response):
|
||||
"""Handle DeletePerson service: remove a person from the gallery."""
|
||||
if self._gallery.delete_person(request.person_id):
|
||||
self._gallery.save()
|
||||
self._publish_gallery_embeddings()
|
||||
response.success = True
|
||||
response.message = f'Deleted person {request.person_id}.'
|
||||
else:
|
||||
response.success = False
|
||||
response.message = f'Person {request.person_id} not found.'
|
||||
return response
|
||||
|
||||
def _handle_update(self, request, response):
|
||||
"""Handle UpdatePerson service: rename a person."""
|
||||
new_name = request.new_name.strip()
|
||||
if not new_name:
|
||||
response.success = False
|
||||
response.message = 'New name cannot be empty.'
|
||||
return response
|
||||
|
||||
if self._gallery.update_name(request.person_id, new_name):
|
||||
self._gallery.save()
|
||||
self._publish_gallery_embeddings()
|
||||
response.success = True
|
||||
response.message = f'Updated person {request.person_id} to "{new_name}".'
|
||||
else:
|
||||
response.success = False
|
||||
response.message = f'Person {request.person_id} not found.'
|
||||
return response
|
||||
|
||||
# -- Gallery publishing ----------------------------------------------------
|
||||
|
||||
def _publish_gallery_embeddings(self):
|
||||
"""Publish current gallery as FaceEmbeddingArray (latched-like)."""
|
||||
entries = self._gallery.get_all()
|
||||
msg = FaceEmbeddingArray()
|
||||
msg.header.stamp = self.get_clock().now().to_msg()
|
||||
|
||||
for entry in entries:
|
||||
emb_msg = FaceEmbedding()
|
||||
emb_msg.person_id = entry['person_id']
|
||||
emb_msg.person_name = entry['name']
|
||||
emb_msg.embedding = entry['embedding'].tolist()
|
||||
emb_msg.sample_count = entry['samples']
|
||||
|
||||
secs = int(entry['enrolled_at'])
|
||||
nsecs = int((entry['enrolled_at'] - secs) * 1e9)
|
||||
emb_msg.enrolled_at = Time(sec=secs, nanosec=nsecs)
|
||||
|
||||
msg.embeddings.append(emb_msg)
|
||||
|
||||
self._pub_embeddings.publish(msg)
|
||||
|
||||
# -- Debug image -----------------------------------------------------------
|
||||
|
||||
def _draw_debug(self, bgr: np.ndarray, detections: list[dict],
|
||||
face_msgs: list) -> np.ndarray:
|
||||
"""Draw bounding boxes, landmarks, and names on the image."""
|
||||
vis = bgr.copy()
|
||||
for det, face_msg in zip(detections, face_msgs):
|
||||
bbox = det['bbox']
|
||||
x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
|
||||
|
||||
# Color: green if recognized, yellow if unknown
|
||||
if face_msg.face_id >= 0:
|
||||
color = (0, 255, 0)
|
||||
label = f'{face_msg.person_name} ({face_msg.recognition_score:.2f})'
|
||||
else:
|
||||
color = (0, 255, 255)
|
||||
label = f'unknown ({face_msg.confidence:.2f})'
|
||||
|
||||
cv2.rectangle(vis, (x1, y1), (x2, y2), color, 2)
|
||||
cv2.putText(vis, label, (x1, y1 - 8),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
||||
|
||||
# Draw landmarks
|
||||
kps = det['kps']
|
||||
for k in range(5):
|
||||
px, py = int(kps[k * 2]), int(kps[k * 2 + 1])
|
||||
cv2.circle(vis, (px, py), 2, (0, 0, 255), -1)
|
||||
|
||||
return vis
|
||||
|
||||
|
||||
# -- Entry point ---------------------------------------------------------------
|
||||
|
||||
def main(args=None):
|
||||
"""ROS2 entry point for face_recognition node."""
|
||||
rclpy.init(args=args)
|
||||
node = FaceRecognitionNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -0,0 +1,350 @@
|
||||
"""
|
||||
scrfd_detector.py -- SCRFD face detection with TensorRT FP16 + ONNX fallback.
|
||||
|
||||
SCRFD (Sample and Computation Redistribution for Face Detection) produces
|
||||
9 output tensors across 3 strides (8, 16, 32), each with score, bbox, and
|
||||
keypoint branches. This module handles anchor generation, bbox/keypoint
|
||||
decoding, and NMS to produce final face detections.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_STRIDES = [8, 16, 32]
|
||||
_NUM_ANCHORS = 2 # anchors per cell per stride
|
||||
|
||||
|
||||
# -- Inference backends --------------------------------------------------------
|
||||
|
||||
class _TRTBackend:
|
||||
"""TensorRT inference engine for SCRFD."""
|
||||
|
||||
def __init__(self, engine_path: str):
|
||||
import tensorrt as trt
|
||||
import pycuda.driver as cuda
|
||||
import pycuda.autoinit # noqa: F401
|
||||
|
||||
self._cuda = cuda
|
||||
trt_logger = trt.Logger(trt.Logger.WARNING)
|
||||
with open(engine_path, 'rb') as f, trt.Runtime(trt_logger) as runtime:
|
||||
self._engine = runtime.deserialize_cuda_engine(f.read())
|
||||
self._context = self._engine.create_execution_context()
|
||||
|
||||
self._inputs = []
|
||||
self._outputs = []
|
||||
self._output_names = []
|
||||
self._bindings = []
|
||||
for i in range(self._engine.num_io_tensors):
|
||||
name = self._engine.get_tensor_name(i)
|
||||
dtype = trt.nptype(self._engine.get_tensor_dtype(name))
|
||||
shape = tuple(self._engine.get_tensor_shape(name))
|
||||
nbytes = int(np.prod(shape)) * np.dtype(dtype).itemsize
|
||||
host_mem = cuda.pagelocked_empty(shape, dtype)
|
||||
device_mem = cuda.mem_alloc(nbytes)
|
||||
self._bindings.append(int(device_mem))
|
||||
if self._engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
|
||||
self._inputs.append({'host': host_mem, 'device': device_mem})
|
||||
else:
|
||||
self._outputs.append({'host': host_mem, 'device': device_mem,
|
||||
'shape': shape})
|
||||
self._output_names.append(name)
|
||||
self._stream = cuda.Stream()
|
||||
|
||||
def infer(self, input_data: np.ndarray) -> list[np.ndarray]:
|
||||
"""Run inference and return output tensors."""
|
||||
np.copyto(self._inputs[0]['host'], input_data.ravel())
|
||||
self._cuda.memcpy_htod_async(
|
||||
self._inputs[0]['device'], self._inputs[0]['host'], self._stream)
|
||||
self._context.execute_async_v2(self._bindings, self._stream.handle)
|
||||
results = []
|
||||
for out in self._outputs:
|
||||
self._cuda.memcpy_dtoh_async(out['host'], out['device'], self._stream)
|
||||
self._stream.synchronize()
|
||||
for out in self._outputs:
|
||||
results.append(out['host'].reshape(out['shape']).copy())
|
||||
return results
|
||||
|
||||
|
||||
class _ONNXBackend:
|
||||
"""ONNX Runtime inference (CUDA EP with CPU fallback)."""
|
||||
|
||||
def __init__(self, onnx_path: str):
|
||||
import onnxruntime as ort
|
||||
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
self._session = ort.InferenceSession(onnx_path, providers=providers)
|
||||
self._input_name = self._session.get_inputs()[0].name
|
||||
self._output_names = [o.name for o in self._session.get_outputs()]
|
||||
|
||||
def infer(self, input_data: np.ndarray) -> list[np.ndarray]:
|
||||
"""Run inference and return output tensors."""
|
||||
return self._session.run(None, {self._input_name: input_data})
|
||||
|
||||
|
||||
# -- NMS ----------------------------------------------------------------------
|
||||
|
||||
def _nms(boxes: np.ndarray, scores: np.ndarray, iou_thresh: float) -> list[int]:
|
||||
"""Non-maximum suppression. boxes: [N, 4] as x1,y1,x2,y2."""
|
||||
if len(boxes) == 0:
|
||||
return []
|
||||
|
||||
x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
|
||||
areas = (x2 - x1) * (y2 - y1)
|
||||
order = scores.argsort()[::-1]
|
||||
keep = []
|
||||
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(int(i))
|
||||
if order.size == 1:
|
||||
break
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
inter = np.maximum(0.0, xx2 - xx1) * np.maximum(0.0, yy2 - yy1)
|
||||
iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-6)
|
||||
remaining = np.where(iou <= iou_thresh)[0]
|
||||
order = order[remaining + 1]
|
||||
|
||||
return keep
|
||||
|
||||
|
||||
# -- Anchor generation ---------------------------------------------------------
|
||||
|
||||
def _generate_anchors(input_h: int, input_w: int) -> dict[int, np.ndarray]:
|
||||
"""Generate anchor centers for each stride.
|
||||
|
||||
Returns dict mapping stride -> array of shape [H*W*num_anchors, 2],
|
||||
where each row is (cx, cy) in input pixel coordinates.
|
||||
"""
|
||||
anchors = {}
|
||||
for stride in _STRIDES:
|
||||
feat_h = input_h // stride
|
||||
feat_w = input_w // stride
|
||||
grid_y, grid_x = np.mgrid[:feat_h, :feat_w]
|
||||
centers = np.stack([grid_x.ravel(), grid_y.ravel()], axis=1).astype(np.float32)
|
||||
centers = (centers + 0.5) * stride
|
||||
# Repeat for num_anchors per cell
|
||||
centers = np.tile(centers, (_NUM_ANCHORS, 1)) # [H*W*2, 2]
|
||||
# Interleave properly: [anchor0_cell0, anchor1_cell0, anchor0_cell1, ...]
|
||||
centers = np.repeat(
|
||||
np.stack([grid_x.ravel(), grid_y.ravel()], axis=1).astype(np.float32),
|
||||
_NUM_ANCHORS, axis=0
|
||||
)
|
||||
centers = (centers + 0.5) * stride
|
||||
anchors[stride] = centers
|
||||
return anchors
|
||||
|
||||
|
||||
# -- Main detector class -------------------------------------------------------
|
||||
|
||||
class SCRFDDetector:
|
||||
"""SCRFD face detector with TensorRT FP16 and ONNX fallback.
|
||||
|
||||
Args:
|
||||
engine_path: Path to TensorRT engine file.
|
||||
onnx_path: Path to ONNX model file (used if engine not available).
|
||||
conf_thresh: Minimum confidence for detections.
|
||||
nms_iou: IoU threshold for NMS.
|
||||
input_size: Model input resolution (square).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_path: str = '',
|
||||
onnx_path: str = '',
|
||||
conf_thresh: float = 0.5,
|
||||
nms_iou: float = 0.4,
|
||||
input_size: int = 640,
|
||||
):
|
||||
self._conf_thresh = conf_thresh
|
||||
self._nms_iou = nms_iou
|
||||
self._input_size = input_size
|
||||
self._backend: Optional[_TRTBackend | _ONNXBackend] = None
|
||||
self._anchors = _generate_anchors(input_size, input_size)
|
||||
|
||||
# Try TRT first, then ONNX
|
||||
if engine_path and os.path.isfile(engine_path):
|
||||
try:
|
||||
self._backend = _TRTBackend(engine_path)
|
||||
logger.info('SCRFD TensorRT backend loaded: %s', engine_path)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning('SCRFD TRT load failed (%s), trying ONNX', e)
|
||||
|
||||
if onnx_path and os.path.isfile(onnx_path):
|
||||
try:
|
||||
self._backend = _ONNXBackend(onnx_path)
|
||||
logger.info('SCRFD ONNX backend loaded: %s', onnx_path)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error('SCRFD ONNX load failed: %s', e)
|
||||
|
||||
logger.error('No SCRFD model loaded. Detection will be unavailable.')
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Return True if a backend is loaded and ready."""
|
||||
return self._backend is not None
|
||||
|
||||
def detect(self, bgr: np.ndarray) -> list[dict]:
|
||||
"""Detect faces in a BGR image.
|
||||
|
||||
Args:
|
||||
bgr: Input image in BGR format, shape (H, W, 3).
|
||||
|
||||
Returns:
|
||||
List of dicts with keys:
|
||||
bbox: [x1, y1, x2, y2] in original image coordinates
|
||||
kps: [x0,y0, x1,y1, ..., x4,y4] — 10 floats, 5 landmarks
|
||||
score: detection confidence
|
||||
"""
|
||||
if self._backend is None:
|
||||
return []
|
||||
|
||||
orig_h, orig_w = bgr.shape[:2]
|
||||
tensor, scale, pad_w, pad_h = self._preprocess(bgr)
|
||||
outputs = self._backend.infer(tensor)
|
||||
detections = self._decode_outputs(outputs)
|
||||
detections = self._rescale(detections, scale, pad_w, pad_h, orig_w, orig_h)
|
||||
return detections
|
||||
|
||||
def _preprocess(self, bgr: np.ndarray) -> tuple[np.ndarray, float, int, int]:
|
||||
"""Resize to input_size x input_size with letterbox padding, normalize."""
|
||||
h, w = bgr.shape[:2]
|
||||
size = self._input_size
|
||||
scale = min(size / h, size / w)
|
||||
new_w, new_h = int(w * scale), int(h * scale)
|
||||
pad_w = (size - new_w) // 2
|
||||
pad_h = (size - new_h) // 2
|
||||
|
||||
resized = cv2.resize(bgr, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
||||
canvas = np.full((size, size, 3), 0, dtype=np.uint8)
|
||||
canvas[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = resized
|
||||
|
||||
# Normalize: subtract 127.5, divide 128.0
|
||||
blob = canvas.astype(np.float32)
|
||||
blob = (blob - 127.5) / 128.0
|
||||
# HWC -> NCHW
|
||||
blob = blob.transpose(2, 0, 1)[np.newaxis]
|
||||
blob = np.ascontiguousarray(blob)
|
||||
return blob, scale, pad_w, pad_h
|
||||
|
||||
def _decode_outputs(self, outputs: list[np.ndarray]) -> list[dict]:
|
||||
"""Decode SCRFD 9-output format into face detections.
|
||||
|
||||
SCRFD outputs 9 tensors, 3 per stride (score, bbox, kps):
|
||||
score_8, bbox_8, kps_8, score_16, bbox_16, kps_16, score_32, bbox_32, kps_32
|
||||
"""
|
||||
all_scores = []
|
||||
all_bboxes = []
|
||||
all_kps = []
|
||||
|
||||
for idx, stride in enumerate(_STRIDES):
|
||||
score_out = outputs[idx * 3].squeeze() # [H*W*num_anchors]
|
||||
bbox_out = outputs[idx * 3 + 1].squeeze() # [H*W*num_anchors, 4]
|
||||
kps_out = outputs[idx * 3 + 2].squeeze() # [H*W*num_anchors, 10]
|
||||
|
||||
if score_out.ndim == 0:
|
||||
continue
|
||||
|
||||
# Ensure proper shapes
|
||||
if score_out.ndim == 1:
|
||||
n = score_out.shape[0]
|
||||
else:
|
||||
n = score_out.shape[0]
|
||||
score_out = score_out.ravel()
|
||||
|
||||
if bbox_out.ndim == 1:
|
||||
bbox_out = bbox_out.reshape(-1, 4)
|
||||
if kps_out.ndim == 1:
|
||||
kps_out = kps_out.reshape(-1, 10)
|
||||
|
||||
# Filter by confidence
|
||||
mask = score_out > self._conf_thresh
|
||||
if not mask.any():
|
||||
continue
|
||||
|
||||
scores = score_out[mask]
|
||||
bboxes = bbox_out[mask]
|
||||
kps = kps_out[mask]
|
||||
|
||||
anchors = self._anchors[stride]
|
||||
# Trim or pad anchors to match output count
|
||||
if anchors.shape[0] > n:
|
||||
anchors = anchors[:n]
|
||||
elif anchors.shape[0] < n:
|
||||
continue
|
||||
|
||||
anchors = anchors[mask]
|
||||
|
||||
# Decode bboxes: center = anchor + pred[:2]*stride, size = exp(pred[2:])*stride
|
||||
cx = anchors[:, 0] + bboxes[:, 0] * stride
|
||||
cy = anchors[:, 1] + bboxes[:, 1] * stride
|
||||
w = np.exp(bboxes[:, 2]) * stride
|
||||
h = np.exp(bboxes[:, 3]) * stride
|
||||
x1 = cx - w / 2.0
|
||||
y1 = cy - h / 2.0
|
||||
x2 = cx + w / 2.0
|
||||
y2 = cy + h / 2.0
|
||||
decoded_bboxes = np.stack([x1, y1, x2, y2], axis=1)
|
||||
|
||||
# Decode keypoints: kp = anchor + pred * stride
|
||||
decoded_kps = np.zeros_like(kps)
|
||||
for k in range(5):
|
||||
decoded_kps[:, k * 2] = anchors[:, 0] + kps[:, k * 2] * stride
|
||||
decoded_kps[:, k * 2 + 1] = anchors[:, 1] + kps[:, k * 2 + 1] * stride
|
||||
|
||||
all_scores.append(scores)
|
||||
all_bboxes.append(decoded_bboxes)
|
||||
all_kps.append(decoded_kps)
|
||||
|
||||
if not all_scores:
|
||||
return []
|
||||
|
||||
scores = np.concatenate(all_scores)
|
||||
bboxes = np.concatenate(all_bboxes)
|
||||
kps = np.concatenate(all_kps)
|
||||
|
||||
# NMS
|
||||
keep = _nms(bboxes, scores, self._nms_iou)
|
||||
|
||||
results = []
|
||||
for i in keep:
|
||||
results.append({
|
||||
'bbox': bboxes[i].tolist(),
|
||||
'kps': kps[i].tolist(),
|
||||
'score': float(scores[i]),
|
||||
})
|
||||
return results
|
||||
|
||||
def _rescale(
|
||||
self,
|
||||
detections: list[dict],
|
||||
scale: float,
|
||||
pad_w: int,
|
||||
pad_h: int,
|
||||
orig_w: int,
|
||||
orig_h: int,
|
||||
) -> list[dict]:
|
||||
"""Rescale detections from model input space to original image space."""
|
||||
for det in detections:
|
||||
bbox = det['bbox']
|
||||
bbox[0] = max(0.0, (bbox[0] - pad_w) / scale)
|
||||
bbox[1] = max(0.0, (bbox[1] - pad_h) / scale)
|
||||
bbox[2] = min(float(orig_w), (bbox[2] - pad_w) / scale)
|
||||
bbox[3] = min(float(orig_h), (bbox[3] - pad_h) / scale)
|
||||
det['bbox'] = bbox
|
||||
|
||||
kps = det['kps']
|
||||
for k in range(5):
|
||||
kps[k * 2] = (kps[k * 2] - pad_w) / scale
|
||||
kps[k * 2 + 1] = (kps[k * 2 + 1] - pad_h) / scale
|
||||
det['kps'] = kps
|
||||
return detections
|
||||
@ -0,0 +1,112 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
build_face_trt_engines.py -- Build TensorRT FP16 engines for SCRFD and ArcFace.
|
||||
|
||||
Converts ONNX model files to optimized TensorRT engines with FP16 precision
|
||||
for fast inference on Jetson Orin Nano Super.
|
||||
|
||||
Usage:
|
||||
python3 build_face_trt_engines.py \
|
||||
--scrfd-onnx /path/to/scrfd_2.5g_bnkps.onnx \
|
||||
--arcface-onnx /path/to/arcface_r50.onnx \
|
||||
--output-dir /mnt/nvme/saltybot/models \
|
||||
--fp16 --workspace-mb 1024
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
|
||||
|
||||
def build_engine(onnx_path: str, engine_path: str, fp16: bool, workspace_mb: int):
|
||||
"""Build a TensorRT engine from an ONNX model.
|
||||
|
||||
Args:
|
||||
onnx_path: Path to the source ONNX model file.
|
||||
engine_path: Output path for the serialized TensorRT engine.
|
||||
fp16: Enable FP16 precision.
|
||||
workspace_mb: Maximum workspace size in megabytes.
|
||||
"""
|
||||
import tensorrt as trt
|
||||
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
builder = trt.Builder(logger)
|
||||
network = builder.create_network(
|
||||
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
|
||||
print(f'Parsing ONNX model: {onnx_path}')
|
||||
t0 = time.monotonic()
|
||||
|
||||
with open(onnx_path, 'rb') as f:
|
||||
if not parser.parse(f.read()):
|
||||
for i in range(parser.num_errors):
|
||||
print(f' ONNX parse error: {parser.get_error(i)}')
|
||||
raise RuntimeError(f'Failed to parse {onnx_path}')
|
||||
|
||||
parse_time = time.monotonic() - t0
|
||||
print(f' Parsed in {parse_time:.1f}s')
|
||||
|
||||
config = builder.create_builder_config()
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE,
|
||||
workspace_mb * (1 << 20))
|
||||
if fp16:
|
||||
if builder.platform_has_fast_fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
print(' FP16 enabled.')
|
||||
else:
|
||||
print(' Warning: FP16 not supported on this platform, using FP32.')
|
||||
|
||||
print(f'Building engine (this may take several minutes)...')
|
||||
t0 = time.monotonic()
|
||||
serialized = builder.build_serialized_network(network, config)
|
||||
build_time = time.monotonic() - t0
|
||||
|
||||
if serialized is None:
|
||||
raise RuntimeError('Engine build failed.')
|
||||
|
||||
os.makedirs(os.path.dirname(engine_path) or '.', exist_ok=True)
|
||||
with open(engine_path, 'wb') as f:
|
||||
f.write(serialized)
|
||||
|
||||
size_mb = os.path.getsize(engine_path) / (1 << 20)
|
||||
print(f' Engine saved: {engine_path} ({size_mb:.1f} MB, built in {build_time:.1f}s)')
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for TRT engine building."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Build TensorRT FP16 engines for SCRFD and ArcFace.')
|
||||
parser.add_argument('--scrfd-onnx', type=str, default='',
|
||||
help='Path to SCRFD ONNX model.')
|
||||
parser.add_argument('--arcface-onnx', type=str, default='',
|
||||
help='Path to ArcFace ONNX model.')
|
||||
parser.add_argument('--output-dir', type=str,
|
||||
default='/mnt/nvme/saltybot/models',
|
||||
help='Output directory for engine files.')
|
||||
parser.add_argument('--fp16', action='store_true', default=True,
|
||||
help='Enable FP16 precision (default: True).')
|
||||
parser.add_argument('--no-fp16', action='store_false', dest='fp16',
|
||||
help='Disable FP16 (use FP32 only).')
|
||||
parser.add_argument('--workspace-mb', type=int, default=1024,
|
||||
help='TRT workspace size in MB (default: 1024).')
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.scrfd_onnx and not args.arcface_onnx:
|
||||
parser.error('At least one of --scrfd-onnx or --arcface-onnx is required.')
|
||||
|
||||
if args.scrfd_onnx:
|
||||
engine_path = os.path.join(args.output_dir, 'scrfd_2.5g.engine')
|
||||
print(f'\n=== Building SCRFD engine ===')
|
||||
build_engine(args.scrfd_onnx, engine_path, args.fp16, args.workspace_mb)
|
||||
|
||||
if args.arcface_onnx:
|
||||
engine_path = os.path.join(args.output_dir, 'arcface_r50.engine')
|
||||
print(f'\n=== Building ArcFace engine ===')
|
||||
build_engine(args.arcface_onnx, engine_path, args.fp16, args.workspace_mb)
|
||||
|
||||
print('\nDone.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
4
jetson/ros2_ws/src/saltybot_social_face/setup.cfg
Normal file
4
jetson/ros2_ws/src/saltybot_social_face/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/saltybot_social_face
|
||||
[install]
|
||||
install_scripts=$base/lib/saltybot_social_face
|
||||
30
jetson/ros2_ws/src/saltybot_social_face/setup.py
Normal file
30
jetson/ros2_ws/src/saltybot_social_face/setup.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""Setup for saltybot_social_face package."""
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
package_name = 'saltybot_social_face'
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version='0.1.0',
|
||||
packages=find_packages(exclude=['test']),
|
||||
data_files=[
|
||||
('share/ament_index/resource_index/packages', ['resource/' + package_name]),
|
||||
('share/' + package_name, ['package.xml']),
|
||||
('share/' + package_name + '/launch', ['launch/face_recognition.launch.py']),
|
||||
('share/' + package_name + '/config', ['config/face_recognition_params.yaml']),
|
||||
],
|
||||
install_requires=['setuptools'],
|
||||
zip_safe=True,
|
||||
maintainer='seb',
|
||||
maintainer_email='seb@vayrette.com',
|
||||
description='SCRFD face detection and ArcFace recognition for SaltyBot social interactions',
|
||||
license='MIT',
|
||||
tests_require=['pytest'],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'face_recognition = saltybot_social_face.face_recognition_node:main',
|
||||
'enrollment_cli = saltybot_social_face.enrollment_cli:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
@ -8,6 +8,7 @@ find_package(geometry_msgs REQUIRED)
|
||||
find_package(builtin_interfaces REQUIRED)
|
||||
|
||||
rosidl_generate_interfaces(${PROJECT_NAME}
|
||||
# Issue #80 — face detection + recognition
|
||||
"msg/FaceDetection.msg"
|
||||
"msg/FaceDetectionArray.msg"
|
||||
"msg/FaceEmbedding.msg"
|
||||
@ -18,7 +19,14 @@ rosidl_generate_interfaces(${PROJECT_NAME}
|
||||
"srv/ListPersons.srv"
|
||||
"srv/DeletePerson.srv"
|
||||
"srv/UpdatePerson.srv"
|
||||
# Issue #86 — LED expression + motor attention
|
||||
"msg/Mood.msg"
|
||||
"msg/Person.msg"
|
||||
"msg/PersonArray.msg"
|
||||
# Issue #92 — multi-modal tracking fusion
|
||||
"msg/FusedTarget.msg"
|
||||
DEPENDENCIES std_msgs geometry_msgs builtin_interfaces
|
||||
)
|
||||
|
||||
ament_export_dependencies(rosidl_default_runtime)
|
||||
ament_package()
|
||||
|
||||
19
jetson/ros2_ws/src/saltybot_social_msgs/msg/FusedTarget.msg
Normal file
19
jetson/ros2_ws/src/saltybot_social_msgs/msg/FusedTarget.msg
Normal file
@ -0,0 +1,19 @@
|
||||
# FusedTarget.msg — output of the multi-modal tracking fusion node.
|
||||
#
|
||||
# Position and velocity are in the base_link frame (robot-centred,
|
||||
# +X forward, +Y left). z components are always 0.0 for ground-plane tracking.
|
||||
#
|
||||
# Confidence: 0.0 = no data / fully predicted; 1.0 = strong fused measurement.
|
||||
# active_source: "fused" | "uwb" | "camera" | "predicted"
|
||||
|
||||
std_msgs/Header header
|
||||
|
||||
geometry_msgs/Point position # filtered 2-D position (m), z=0
|
||||
geometry_msgs/Vector3 velocity # filtered 2-D velocity (m/s), z=0
|
||||
|
||||
float32 range_m # Euclidean distance from robot to fused position
|
||||
float32 bearing_rad # bearing in base_link (+ve = person to the left)
|
||||
float32 confidence # composite confidence [0.0, 1.0]
|
||||
|
||||
string active_source # "fused" | "uwb" | "camera" | "predicted"
|
||||
string tag_id # UWB tag address (empty when UWB not contributing)
|
||||
7
jetson/ros2_ws/src/saltybot_social_msgs/msg/Mood.msg
Normal file
7
jetson/ros2_ws/src/saltybot_social_msgs/msg/Mood.msg
Normal file
@ -0,0 +1,7 @@
|
||||
# Mood.msg — social expression command sent to the LED display node.
|
||||
#
|
||||
# mood : one of "happy", "curious", "annoyed", "playful", "idle"
|
||||
# intensity : 0.0 (off) to 1.0 (full brightness)
|
||||
|
||||
string mood
|
||||
float32 intensity
|
||||
17
jetson/ros2_ws/src/saltybot_social_msgs/msg/Person.msg
Normal file
17
jetson/ros2_ws/src/saltybot_social_msgs/msg/Person.msg
Normal file
@ -0,0 +1,17 @@
|
||||
# Person.msg — single tracked person for social attention.
|
||||
#
|
||||
# bearing_rad : signed bearing in base_link frame (rad)
|
||||
# positive = person to the left (CCW), negative = right (CW)
|
||||
# distance_m : estimated distance in metres; 0 if unknown
|
||||
# confidence : detection confidence 0..1
|
||||
# is_speaking : true if mic DOA or VAD identifies this person as speaking
|
||||
# source : "camera" | "mic_doa"
|
||||
|
||||
std_msgs/Header header
|
||||
|
||||
int32 track_id
|
||||
float32 bearing_rad
|
||||
float32 distance_m
|
||||
float32 confidence
|
||||
bool is_speaking
|
||||
string source
|
||||
@ -0,0 +1,9 @@
|
||||
# PersonArray.msg — all detected persons plus the current attention target.
|
||||
#
|
||||
# persons : all currently tracked persons
|
||||
# active_id : track_id of the active/speaking person; -1 if none
|
||||
|
||||
std_msgs/Header header
|
||||
|
||||
saltybot_social_msgs/Person[] persons
|
||||
int32 active_id
|
||||
@ -3,16 +3,25 @@
|
||||
<package format="3">
|
||||
<name>saltybot_social_msgs</name>
|
||||
<version>0.1.0</version>
|
||||
<description>Custom ROS2 messages and services for saltybot social capabilities</description>
|
||||
<description>
|
||||
Custom ROS2 message and service definitions for saltybot social capabilities.
|
||||
Includes social perception types (face detection, person state, enrollment)
|
||||
and multi-modal tracking fusion types (FusedTarget) from Issue #92.
|
||||
</description>
|
||||
<maintainer email="seb@vayrette.com">seb</maintainer>
|
||||
<license>MIT</license>
|
||||
|
||||
<buildtool_depend>ament_cmake</buildtool_depend>
|
||||
<build_depend>rosidl_default_generators</build_depend>
|
||||
|
||||
<depend>std_msgs</depend>
|
||||
<depend>geometry_msgs</depend>
|
||||
<depend>builtin_interfaces</depend>
|
||||
<build_depend>rosidl_default_generators</build_depend>
|
||||
|
||||
<exec_depend>rosidl_default_runtime</exec_depend>
|
||||
|
||||
<member_of_group>rosidl_interface_packages</member_of_group>
|
||||
|
||||
<export>
|
||||
<build_type>ament_cmake</build_type>
|
||||
</export>
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
sensor_msgs/Image crop
|
||||
---
|
||||
bool success
|
||||
float32[512] embedding
|
||||
5
jetson/ros2_ws/src/saltybot_social_tracking/.gitignore
vendored
Normal file
5
jetson/ros2_ws/src/saltybot_social_tracking/.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.egg-info/
|
||||
.pytest_cache/
|
||||
@ -0,0 +1,48 @@
|
||||
# tracking_params.yaml — saltybot_social_tracking / TrackingFusionNode
|
||||
#
|
||||
# Run with:
|
||||
# ros2 launch saltybot_social_tracking tracking.launch.py
|
||||
#
|
||||
# Topics consumed:
|
||||
# /uwb/target (geometry_msgs/PoseStamped) — UWB triangulated position
|
||||
# /person/target (geometry_msgs/PoseStamped) — camera-detected position
|
||||
#
|
||||
# Topic produced:
|
||||
# /social/tracking/fused_target (saltybot_social_msgs/FusedTarget)
|
||||
|
||||
# ── Source staleness timeouts ──────────────────────────────────────────────────
|
||||
# UWB driver publishes at ~10 Hz; 1.5 s = 15 missed cycles before declared stale.
|
||||
uwb_timeout: 1.5 # seconds
|
||||
|
||||
# Camera detector publishes at ~30 Hz; 1.0 s = 30 missed frames before stale.
|
||||
cam_timeout: 1.0 # seconds
|
||||
|
||||
# How long the Kalman filter may coast (dead-reckoning) with no live source
|
||||
# before the node stops publishing.
|
||||
# At 10 m/s (EUC top-speed) the robot drifts ≈30 m over 3 s — beyond the UWB
|
||||
# follow-range, so 3 s is a reasonable hard stop.
|
||||
predict_timeout: 3.0 # seconds
|
||||
|
||||
# ── Kalman filter tuning ───────────────────────────────────────────────────────
|
||||
# process_noise: acceleration noise std-dev (m/s²).
|
||||
# EUC riders can brake or accelerate at ~3–5 m/s²; 3.0 is a good starting point.
|
||||
# Increase if the filtered track lags behind fast direction changes.
|
||||
# Decrease if the track is noisy.
|
||||
process_noise: 3.0 # m/s²
|
||||
|
||||
# UWB position measurement noise (std-dev, metres).
|
||||
# DW3000 TWR accuracy ≈ ±10–20 cm; 0.20 accounts for system-level error.
|
||||
meas_noise_uwb: 0.20 # m
|
||||
|
||||
# Camera position noise (std-dev, metres).
|
||||
# Depth reprojection error with RealSense D435i at 1–3 m ≈ ±5–15 cm.
|
||||
meas_noise_cam: 0.12 # m
|
||||
|
||||
# ── Control loop ──────────────────────────────────────────────────────────────
|
||||
control_rate: 20.0 # Hz — KF predict + publish rate
|
||||
|
||||
# ── Source arbiter ────────────────────────────────────────────────────────────
|
||||
# Minimum normalised confidence for a source to be considered live.
|
||||
# Range [0, 1]; lower = more permissive; default 0.15 keeps slightly stale
|
||||
# sources active rather than dropping to "predicted" prematurely.
|
||||
confidence_threshold: 0.15
|
||||
@ -0,0 +1,44 @@
|
||||
"""tracking.launch.py — launch the TrackingFusionNode with default params."""
|
||||
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg_share = get_package_share_directory("saltybot_social_tracking")
|
||||
default_params = os.path.join(pkg_share, "config", "tracking_params.yaml")
|
||||
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument(
|
||||
"params_file",
|
||||
default_value=default_params,
|
||||
description="Path to tracking fusion parameter YAML file",
|
||||
),
|
||||
DeclareLaunchArgument(
|
||||
"control_rate",
|
||||
default_value="20.0",
|
||||
description="KF predict + publish rate (Hz)",
|
||||
),
|
||||
DeclareLaunchArgument(
|
||||
"predict_timeout",
|
||||
default_value="3.0",
|
||||
description="Max KF coast time before stopping publish (s)",
|
||||
),
|
||||
Node(
|
||||
package="saltybot_social_tracking",
|
||||
executable="tracking_fusion_node",
|
||||
name="tracking_fusion",
|
||||
output="screen",
|
||||
parameters=[
|
||||
LaunchConfiguration("params_file"),
|
||||
{
|
||||
"control_rate": LaunchConfiguration("control_rate"),
|
||||
"predict_timeout": LaunchConfiguration("predict_timeout"),
|
||||
},
|
||||
],
|
||||
),
|
||||
])
|
||||
31
jetson/ros2_ws/src/saltybot_social_tracking/package.xml
Normal file
31
jetson/ros2_ws/src/saltybot_social_tracking/package.xml
Normal file
@ -0,0 +1,31 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_social_tracking</name>
|
||||
<version>0.1.0</version>
|
||||
<description>
|
||||
Multi-modal tracking fusion for saltybot.
|
||||
Fuses UWB triangulated position (/uwb/target) and camera-detected position
|
||||
(/person/target) using a 4-state Kalman filter to produce a smooth, low-latency
|
||||
fused estimate at /social/tracking/fused_target.
|
||||
Handles EUC rider speeds (20-30 km/h), signal handoff, and predictive coasting.
|
||||
</description>
|
||||
<maintainer email="sl-controls@saltylab.local">sl-controls</maintainer>
|
||||
<license>MIT</license>
|
||||
|
||||
<depend>rclpy</depend>
|
||||
<depend>geometry_msgs</depend>
|
||||
<depend>std_msgs</depend>
|
||||
<depend>saltybot_social_msgs</depend>
|
||||
|
||||
<buildtool_depend>ament_python</buildtool_depend>
|
||||
|
||||
<test_depend>ament_copyright</test_depend>
|
||||
<test_depend>ament_flake8</test_depend>
|
||||
<test_depend>ament_pep257</test_depend>
|
||||
<test_depend>python3-pytest</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1,134 @@
|
||||
"""
|
||||
kalman_tracker.py — 4-state linear Kalman filter for 2-D position+velocity tracking.
|
||||
|
||||
State vector: [x, y, vx, vy]
|
||||
Observation: [x_meas, y_meas]
|
||||
|
||||
Process model: constant velocity with Wiener process acceleration noise.
|
||||
Tuned to handle EUC rider speeds (20–30 km/h ≈ 5.5–8.3 m/s) with fast
|
||||
acceleration transients.
|
||||
|
||||
Pure module — no ROS2 dependency; fully unit-testable.
|
||||
"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
class KalmanTracker:
|
||||
"""
|
||||
4-state Kalman filter: state = [x, y, vx, vy].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
process_noise : acceleration noise standard deviation (m/s²).
|
||||
Higher values allow the filter to track rapid velocity
|
||||
changes (EUC acceleration events). Default 3.0 m/s².
|
||||
meas_noise_uwb : UWB position measurement noise std-dev (m). Default 0.20 m.
|
||||
meas_noise_cam : Camera position measurement noise std-dev (m). Default 0.12 m.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process_noise: float = 3.0,
|
||||
meas_noise_uwb: float = 0.20,
|
||||
meas_noise_cam: float = 0.12,
|
||||
):
|
||||
self._q = float(process_noise)
|
||||
self._r_uwb = float(meas_noise_uwb)
|
||||
self._r_cam = float(meas_noise_cam)
|
||||
|
||||
# State [x, y, vx, vy]
|
||||
self._x = np.zeros(4)
|
||||
|
||||
# Covariance — large initial uncertainty (10 m², 10 (m/s)²)
|
||||
self._P = np.diag([10.0, 10.0, 10.0, 10.0])
|
||||
|
||||
# Observation matrix: H * x = [x, y]
|
||||
self._H = np.array([[1.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0, 0.0]])
|
||||
|
||||
self._initialized = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
||||
def initialize(self, x: float, y: float) -> None:
|
||||
"""Seed the filter at position (x, y) with zero velocity."""
|
||||
self._x = np.array([x, y, 0.0, 0.0])
|
||||
self._P = np.diag([1.0, 1.0, 5.0, 5.0])
|
||||
self._initialized = True
|
||||
|
||||
def predict(self, dt: float) -> None:
|
||||
"""
|
||||
Advance the filter state by dt seconds.
|
||||
|
||||
Uses a discrete Wiener process acceleration model so that positional
|
||||
uncertainty grows as O(dt^4/4) and velocity uncertainty as O(dt^2).
|
||||
This lets the filter coast accurately through short signal outages
|
||||
while still being responsive to EUC velocity changes.
|
||||
"""
|
||||
if dt <= 0.0:
|
||||
return
|
||||
|
||||
F = np.array([[1.0, 0.0, dt, 0.0],
|
||||
[0.0, 1.0, 0.0, dt],
|
||||
[0.0, 0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 1.0]])
|
||||
|
||||
q = self._q
|
||||
dt2 = dt * dt
|
||||
dt3 = dt2 * dt
|
||||
dt4 = dt3 * dt
|
||||
Q = (q * q) * np.array([
|
||||
[dt4 / 4.0, 0.0, dt3 / 2.0, 0.0 ],
|
||||
[0.0, dt4 / 4.0, 0.0, dt3 / 2.0],
|
||||
[dt3 / 2.0, 0.0, dt2, 0.0 ],
|
||||
[0.0, dt3 / 2.0, 0.0, dt2 ],
|
||||
])
|
||||
|
||||
self._x = F @ self._x
|
||||
self._P = F @ self._P @ F.T + Q
|
||||
|
||||
def update(self, x_meas: float, y_meas: float, source: str = "uwb") -> None:
|
||||
"""
|
||||
Apply a position measurement (x_meas, y_meas).
|
||||
|
||||
source : "uwb" or "camera" — selects the appropriate noise covariance.
|
||||
"""
|
||||
r = self._r_uwb if source == "uwb" else self._r_cam
|
||||
R = np.diag([r * r, r * r])
|
||||
|
||||
z = np.array([x_meas, y_meas])
|
||||
innov = z - self._H @ self._x # innovation
|
||||
S = self._H @ self._P @ self._H.T + R # innovation covariance
|
||||
K = self._P @ self._H.T @ np.linalg.inv(S) # Kalman gain
|
||||
self._x = self._x + K @ innov
|
||||
self._P = (np.eye(4) - K @ self._H) @ self._P
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# State accessors
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def position(self) -> tuple:
|
||||
"""Current filtered position (x, y) in metres."""
|
||||
return float(self._x[0]), float(self._x[1])
|
||||
|
||||
@property
|
||||
def velocity(self) -> tuple:
|
||||
"""Current filtered velocity (vx, vy) in m/s."""
|
||||
return float(self._x[2]), float(self._x[3])
|
||||
|
||||
def position_uncertainty_m(self) -> float:
|
||||
"""RMS positional uncertainty (m) from diagonal of covariance."""
|
||||
return float(math.sqrt((self._P[0, 0] + self._P[1, 1]) / 2.0))
|
||||
|
||||
def covariance_copy(self) -> np.ndarray:
|
||||
"""Return a copy of the current 4×4 covariance matrix."""
|
||||
return self._P.copy()
|
||||
@ -0,0 +1,155 @@
|
||||
"""
|
||||
source_arbiter.py — Source confidence scoring and selection for tracking fusion.
|
||||
|
||||
Two sensor sources are supported:
|
||||
UWB — geometry_msgs/PoseStamped from /uwb/target (triangulated, ~10 Hz)
|
||||
Camera — geometry_msgs/PoseStamped from /person/target (depth+YOLO, ~30 Hz)
|
||||
|
||||
Confidence model
|
||||
----------------
|
||||
Each source's confidence is its raw measurement quality multiplied by a
|
||||
linear staleness factor that drops to zero at its respective timeout:
|
||||
|
||||
conf = quality * max(0, 1 - age / timeout)
|
||||
|
||||
UWB quality is always 1.0 (the ranging hardware confidence is not exposed
|
||||
by the driver in origin/main; the UWB node already applies Kalman filtering).
|
||||
|
||||
Camera quality defaults to 1.0; callers may pass a lower value when the
|
||||
detection confidence is available.
|
||||
|
||||
Source selection
|
||||
----------------
|
||||
Both fresh → "fused" (confidence-weighted position blend)
|
||||
UWB only → "uwb"
|
||||
Camera only → "camera"
|
||||
Neither fresh → "predicted" (Kalman coasts)
|
||||
|
||||
Pure module — no ROS2 dependency; fully unit-testable.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
|
||||
def _staleness_factor(age_s: float, timeout_s: float) -> float:
|
||||
"""Linear decay: 1.0 at age=0, 0.0 at age=timeout, clamped."""
|
||||
if timeout_s <= 0.0:
|
||||
return 0.0
|
||||
return max(0.0, 1.0 - age_s / timeout_s)
|
||||
|
||||
|
||||
def uwb_confidence(age_s: float, timeout_s: float, quality: float = 1.0) -> float:
|
||||
"""
|
||||
UWB source confidence.
|
||||
|
||||
age_s : seconds since last UWB measurement (≥0; use large value if never)
|
||||
timeout_s: staleness threshold (s); confidence reaches 0 at this age
|
||||
quality : inherent measurement quality [0, 1] (default 1.0)
|
||||
"""
|
||||
return quality * _staleness_factor(age_s, timeout_s)
|
||||
|
||||
|
||||
def camera_confidence(
|
||||
age_s: float, timeout_s: float, quality: float = 1.0
|
||||
) -> float:
|
||||
"""
|
||||
Camera source confidence.
|
||||
|
||||
age_s : seconds since last camera detection (≥0; use large value if never)
|
||||
timeout_s: staleness threshold (s)
|
||||
quality : YOLO detection confidence or other quality score [0, 1]
|
||||
"""
|
||||
return quality * _staleness_factor(age_s, timeout_s)
|
||||
|
||||
|
||||
def select_source(
|
||||
uwb_conf: float,
|
||||
cam_conf: float,
|
||||
threshold: float = 0.15,
|
||||
) -> str:
|
||||
"""
|
||||
Choose the active tracking source label.
|
||||
|
||||
Returns one of: "fused", "uwb", "camera", "predicted".
|
||||
|
||||
threshold: minimum confidence for a source to be considered live.
|
||||
Sources below threshold are ignored.
|
||||
"""
|
||||
uwb_ok = uwb_conf >= threshold
|
||||
cam_ok = cam_conf >= threshold
|
||||
|
||||
if uwb_ok and cam_ok:
|
||||
return "fused"
|
||||
if uwb_ok:
|
||||
return "uwb"
|
||||
if cam_ok:
|
||||
return "camera"
|
||||
return "predicted"
|
||||
|
||||
|
||||
def fuse_positions(
|
||||
uwb_x: float,
|
||||
uwb_y: float,
|
||||
uwb_conf: float,
|
||||
cam_x: float,
|
||||
cam_y: float,
|
||||
cam_conf: float,
|
||||
) -> tuple:
|
||||
"""
|
||||
Confidence-weighted position fusion.
|
||||
|
||||
Returns (fused_x, fused_y).
|
||||
|
||||
When total confidence is zero (shouldn't happen in "fused" state, but
|
||||
guarded), returns the UWB position as fallback.
|
||||
"""
|
||||
total = uwb_conf + cam_conf
|
||||
if total <= 0.0:
|
||||
return uwb_x, uwb_y
|
||||
w = uwb_conf / total
|
||||
return (
|
||||
w * uwb_x + (1.0 - w) * cam_x,
|
||||
w * uwb_y + (1.0 - w) * cam_y,
|
||||
)
|
||||
|
||||
|
||||
def composite_confidence(
|
||||
uwb_conf: float,
|
||||
cam_conf: float,
|
||||
source: str,
|
||||
kf_uncertainty_m: float,
|
||||
max_kf_uncertainty_m: float = 3.0,
|
||||
) -> float:
|
||||
"""
|
||||
Compute a single composite confidence value [0, 1] for the fused output.
|
||||
|
||||
source : current source label (from select_source)
|
||||
kf_uncertainty_m : current KF positional RMS uncertainty
|
||||
max_kf_uncertainty_m: uncertainty at which confidence collapses to 0
|
||||
"""
|
||||
if source == "predicted":
|
||||
# Decay with growing KF uncertainty; no sensor feeds are live
|
||||
raw = max(0.0, 1.0 - kf_uncertainty_m / max_kf_uncertainty_m)
|
||||
return min(0.4, raw) # cap at 0.4 — caller should know this is dead-reckoning
|
||||
|
||||
if source == "fused":
|
||||
raw = max(uwb_conf, cam_conf)
|
||||
elif source == "uwb":
|
||||
raw = uwb_conf
|
||||
else: # "camera"
|
||||
raw = cam_conf
|
||||
|
||||
# Scale by KF health (full confidence only if KF is tight)
|
||||
kf_health = max(0.0, 1.0 - kf_uncertainty_m / max_kf_uncertainty_m)
|
||||
return raw * (0.5 + 0.5 * kf_health)
|
||||
|
||||
|
||||
def bearing_and_range(x: float, y: float) -> tuple:
|
||||
"""
|
||||
Compute bearing (rad, +ve = left) and range (m) to position (x, y).
|
||||
|
||||
Consistent with person_follower_node conventions:
|
||||
bearing = atan2(y, x) (base_link frame: +X forward, +Y left)
|
||||
range = sqrt(x² + y²)
|
||||
"""
|
||||
return math.atan2(y, x), math.sqrt(x * x + y * y)
|
||||
@ -0,0 +1,257 @@
|
||||
"""
|
||||
tracking_fusion_node.py — Multi-modal tracking fusion for saltybot.
|
||||
|
||||
Subscribes
|
||||
----------
|
||||
/uwb/target (geometry_msgs/PoseStamped) — UWB-triangulated position (~10 Hz)
|
||||
/person/target (geometry_msgs/PoseStamped) — camera-detected position (~30 Hz)
|
||||
|
||||
Publishes
|
||||
---------
|
||||
/social/tracking/fused_target (saltybot_social_msgs/FusedTarget) at control_rate Hz
|
||||
|
||||
Algorithm
|
||||
---------
|
||||
1. Each incoming measurement updates a 4-state Kalman filter [x, y, vx, vy].
|
||||
2. A 20 Hz timer runs predict+select+publish:
|
||||
a. KF predict(dt)
|
||||
b. Compute per-source confidence from measurement age + staleness model
|
||||
c. If either source is live:
|
||||
- "fused" → confidence-weighted position blend → KF update
|
||||
- "uwb" → UWB position → KF update
|
||||
- "camera" → camera position → KF update
|
||||
d. Build FusedTarget from KF state + composite confidence
|
||||
3. If all sources are lost but within predict_timeout, keep publishing with
|
||||
active_source="predicted" and degrading confidence.
|
||||
4. Beyond predict_timeout, no message is published (node stays alive).
|
||||
|
||||
Kalman tuning for EUC speeds (20–30 km/h ≈ 5.5–8.3 m/s)
|
||||
---------------------------------------------------------
|
||||
process_noise=3.0 m/s² — allows rapid acceleration events
|
||||
predict_timeout=3.0 s — coasts ≈30 m at 10 m/s; acceptable dead-reckoning
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uwb_timeout : UWB staleness threshold (s) default 1.5
|
||||
cam_timeout : Camera staleness threshold (s) default 1.0
|
||||
predict_timeout : Max KF coast before no publish (s) default 3.0
|
||||
process_noise : KF acceleration noise std-dev (m/s²) default 3.0
|
||||
meas_noise_uwb : UWB position noise std-dev (m) default 0.20
|
||||
meas_noise_cam : Camera position noise std-dev (m) default 0.12
|
||||
control_rate : Publish / KF predict rate (Hz) default 20.0
|
||||
confidence_threshold: Min source confidence to use (0–1) default 0.15
|
||||
|
||||
Usage
|
||||
-----
|
||||
ros2 launch saltybot_social_tracking tracking.launch.py
|
||||
"""
|
||||
|
||||
import math
|
||||
import time
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from geometry_msgs.msg import PoseStamped
|
||||
from std_msgs.msg import Header
|
||||
|
||||
from saltybot_social_msgs.msg import FusedTarget
|
||||
from saltybot_social_tracking.kalman_tracker import KalmanTracker
|
||||
from saltybot_social_tracking.source_arbiter import (
|
||||
uwb_confidence,
|
||||
camera_confidence,
|
||||
select_source,
|
||||
fuse_positions,
|
||||
composite_confidence,
|
||||
bearing_and_range,
|
||||
)
|
||||
|
||||
_BIG_AGE = 1e9 # sentinel: source never received
|
||||
|
||||
|
||||
class TrackingFusionNode(Node):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("tracking_fusion")
|
||||
|
||||
# ── Parameters ────────────────────────────────────────────────────────
|
||||
self.declare_parameter("uwb_timeout", 1.5)
|
||||
self.declare_parameter("cam_timeout", 1.0)
|
||||
self.declare_parameter("predict_timeout", 3.0)
|
||||
self.declare_parameter("process_noise", 3.0)
|
||||
self.declare_parameter("meas_noise_uwb", 0.20)
|
||||
self.declare_parameter("meas_noise_cam", 0.12)
|
||||
self.declare_parameter("control_rate", 20.0)
|
||||
self.declare_parameter("confidence_threshold", 0.15)
|
||||
|
||||
self._p = self._load_params()
|
||||
|
||||
# ── State ─────────────────────────────────────────────────────────────
|
||||
self._kf = KalmanTracker(
|
||||
process_noise=self._p["process_noise"],
|
||||
meas_noise_uwb=self._p["meas_noise_uwb"],
|
||||
meas_noise_cam=self._p["meas_noise_cam"],
|
||||
)
|
||||
self._last_uwb_t: float = 0.0 # monotonic; 0 = never received
|
||||
self._last_cam_t: float = 0.0
|
||||
self._uwb_x: float = 0.0
|
||||
self._uwb_y: float = 0.0
|
||||
self._cam_x: float = 0.0
|
||||
self._cam_y: float = 0.0
|
||||
self._uwb_tag_id: str = ""
|
||||
self._last_predict_t: float = 0.0 # monotonic time of last predict call
|
||||
self._last_any_t: float = 0.0 # monotonic time of last live measurement
|
||||
|
||||
# ── Subscriptions ─────────────────────────────────────────────────────
|
||||
self.create_subscription(
|
||||
PoseStamped, "/uwb/target", self._uwb_cb, 10)
|
||||
self.create_subscription(
|
||||
PoseStamped, "/person/target", self._cam_cb, 10)
|
||||
|
||||
# ── Publisher ─────────────────────────────────────────────────────────
|
||||
self._pub = self.create_publisher(FusedTarget, "/social/tracking/fused_target", 10)
|
||||
|
||||
# ── Timer ─────────────────────────────────────────────────────────────
|
||||
self._timer = self.create_timer(
|
||||
1.0 / self._p["control_rate"], self._control_cb)
|
||||
|
||||
self.get_logger().info(
|
||||
f"TrackingFusion ready "
|
||||
f"rate={self._p['control_rate']}Hz "
|
||||
f"uwb_timeout={self._p['uwb_timeout']}s "
|
||||
f"cam_timeout={self._p['cam_timeout']}s "
|
||||
f"predict_timeout={self._p['predict_timeout']}s "
|
||||
f"process_noise={self._p['process_noise']}m/s²"
|
||||
)
|
||||
|
||||
# ── Parameter helpers ──────────────────────────────────────────────────────
|
||||
|
||||
def _load_params(self) -> dict:
|
||||
return {
|
||||
"uwb_timeout": self.get_parameter("uwb_timeout").value,
|
||||
"cam_timeout": self.get_parameter("cam_timeout").value,
|
||||
"predict_timeout": self.get_parameter("predict_timeout").value,
|
||||
"process_noise": self.get_parameter("process_noise").value,
|
||||
"meas_noise_uwb": self.get_parameter("meas_noise_uwb").value,
|
||||
"meas_noise_cam": self.get_parameter("meas_noise_cam").value,
|
||||
"control_rate": self.get_parameter("control_rate").value,
|
||||
"confidence_threshold": self.get_parameter("confidence_threshold").value,
|
||||
}
|
||||
|
||||
# ── Measurement callbacks ──────────────────────────────────────────────────
|
||||
|
||||
def _uwb_cb(self, msg: PoseStamped) -> None:
|
||||
self._uwb_x = msg.pose.position.x
|
||||
self._uwb_y = msg.pose.position.y
|
||||
self._uwb_tag_id = "" # PoseStamped has no tag field; tag reported via /uwb/bearing
|
||||
t = time.monotonic()
|
||||
self._last_uwb_t = t
|
||||
self._last_any_t = t
|
||||
|
||||
# Seed KF on first measurement
|
||||
if not self._kf.initialized:
|
||||
self._kf.initialize(self._uwb_x, self._uwb_y)
|
||||
self._last_predict_t = t
|
||||
|
||||
def _cam_cb(self, msg: PoseStamped) -> None:
|
||||
self._cam_x = msg.pose.position.x
|
||||
self._cam_y = msg.pose.position.y
|
||||
t = time.monotonic()
|
||||
self._last_cam_t = t
|
||||
self._last_any_t = t
|
||||
|
||||
if not self._kf.initialized:
|
||||
self._kf.initialize(self._cam_x, self._cam_y)
|
||||
self._last_predict_t = t
|
||||
|
||||
# ── Control loop ───────────────────────────────────────────────────────────
|
||||
|
||||
def _control_cb(self) -> None:
|
||||
self._p = self._load_params()
|
||||
|
||||
if not self._kf.initialized:
|
||||
return # no data yet — nothing to publish
|
||||
|
||||
now = time.monotonic()
|
||||
dt = now - self._last_predict_t if self._last_predict_t > 0.0 else (
|
||||
1.0 / self._p["control_rate"]
|
||||
)
|
||||
self._last_predict_t = now
|
||||
|
||||
# KF predict
|
||||
self._kf.predict(dt)
|
||||
|
||||
# Source confidence
|
||||
uwb_age = (now - self._last_uwb_t) if self._last_uwb_t > 0.0 else _BIG_AGE
|
||||
cam_age = (now - self._last_cam_t) if self._last_cam_t > 0.0 else _BIG_AGE
|
||||
|
||||
u_conf = uwb_confidence(uwb_age, self._p["uwb_timeout"])
|
||||
c_conf = camera_confidence(cam_age, self._p["cam_timeout"])
|
||||
|
||||
threshold = self._p["confidence_threshold"]
|
||||
source = select_source(u_conf, c_conf, threshold)
|
||||
|
||||
if source == "predicted":
|
||||
# Check predict_timeout — stop publishing if too stale
|
||||
last_live_age = (
|
||||
(now - self._last_any_t) if self._last_any_t > 0.0 else _BIG_AGE
|
||||
)
|
||||
if last_live_age > self._p["predict_timeout"]:
|
||||
return # silently stop publishing
|
||||
|
||||
# Apply measurement update if a live source exists
|
||||
if source == "fused":
|
||||
fx, fy = fuse_positions(
|
||||
self._uwb_x, self._uwb_y, u_conf,
|
||||
self._cam_x, self._cam_y, c_conf,
|
||||
)
|
||||
self._kf.update(fx, fy, source="uwb") # use tighter noise for blended
|
||||
elif source == "uwb":
|
||||
self._kf.update(self._uwb_x, self._uwb_y, source="uwb")
|
||||
elif source == "camera":
|
||||
self._kf.update(self._cam_x, self._cam_y, source="camera")
|
||||
# "predicted" → no update; KF coasts
|
||||
|
||||
# Build and publish message
|
||||
kx, ky = self._kf.position
|
||||
vx, vy = self._kf.velocity
|
||||
kf_unc = self._kf.position_uncertainty_m()
|
||||
conf = composite_confidence(u_conf, c_conf, source, kf_unc)
|
||||
bearing, range_m = bearing_and_range(kx, ky)
|
||||
|
||||
hdr = Header()
|
||||
hdr.stamp = self.get_clock().now().to_msg()
|
||||
hdr.frame_id = "base_link"
|
||||
|
||||
msg = FusedTarget()
|
||||
msg.header = hdr
|
||||
msg.position.x = kx
|
||||
msg.position.y = ky
|
||||
msg.position.z = 0.0
|
||||
msg.velocity.x = vx
|
||||
msg.velocity.y = vy
|
||||
msg.velocity.z = 0.0
|
||||
msg.range_m = float(range_m)
|
||||
msg.bearing_rad = float(bearing)
|
||||
msg.confidence = float(conf)
|
||||
msg.active_source = source
|
||||
msg.tag_id = self._uwb_tag_id if "uwb" in source else ""
|
||||
|
||||
self._pub.publish(msg)
|
||||
|
||||
|
||||
# ── Entry point ────────────────────────────────────────────────────────────────
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = TrackingFusionNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.try_shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
5
jetson/ros2_ws/src/saltybot_social_tracking/setup.cfg
Normal file
5
jetson/ros2_ws/src/saltybot_social_tracking/setup.cfg
Normal file
@ -0,0 +1,5 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/saltybot_social_tracking
|
||||
|
||||
[install]
|
||||
install_scripts=$base/lib/saltybot_social_tracking
|
||||
32
jetson/ros2_ws/src/saltybot_social_tracking/setup.py
Normal file
32
jetson/ros2_ws/src/saltybot_social_tracking/setup.py
Normal file
@ -0,0 +1,32 @@
|
||||
from setuptools import setup, find_packages
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
package_name = "saltybot_social_tracking"
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version="0.1.0",
|
||||
packages=find_packages(exclude=["test"]),
|
||||
data_files=[
|
||||
("share/ament_index/resource_index/packages",
|
||||
[f"resource/{package_name}"]),
|
||||
(f"share/{package_name}", ["package.xml"]),
|
||||
(os.path.join("share", package_name, "config"),
|
||||
glob("config/*.yaml")),
|
||||
(os.path.join("share", package_name, "launch"),
|
||||
glob("launch/*.py")),
|
||||
],
|
||||
install_requires=["setuptools"],
|
||||
zip_safe=True,
|
||||
maintainer="sl-controls",
|
||||
maintainer_email="sl-controls@saltylab.local",
|
||||
description="Multi-modal tracking fusion (UWB + camera Kalman filter)",
|
||||
license="MIT",
|
||||
tests_require=["pytest"],
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
f"tracking_fusion_node = {package_name}.tracking_fusion_node:main",
|
||||
],
|
||||
},
|
||||
)
|
||||
@ -0,0 +1,438 @@
|
||||
"""
|
||||
test_tracking_fusion.py — Unit tests for saltybot_social_tracking pure modules.
|
||||
|
||||
Tests cover:
|
||||
- KalmanTracker: initialization, predict, update, state accessors
|
||||
- source_arbiter: confidence functions, source selection, fusion, bearing
|
||||
|
||||
No ROS2 runtime required.
|
||||
"""
|
||||
|
||||
import math
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Allow running: python -m pytest test/test_tracking_fusion.py
|
||||
# from the package root without installing the package.
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from saltybot_social_tracking.kalman_tracker import KalmanTracker
|
||||
from saltybot_social_tracking.source_arbiter import (
|
||||
uwb_confidence,
|
||||
camera_confidence,
|
||||
select_source,
|
||||
fuse_positions,
|
||||
composite_confidence,
|
||||
bearing_and_range,
|
||||
)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# KalmanTracker tests
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestKalmanTrackerInit:
|
||||
|
||||
def test_not_initialized_by_default(self):
|
||||
kf = KalmanTracker()
|
||||
assert not kf.initialized
|
||||
|
||||
def test_initialize_sets_position(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(3.0, 1.5)
|
||||
assert kf.initialized
|
||||
x, y = kf.position
|
||||
assert abs(x - 3.0) < 1e-9
|
||||
assert abs(y - 1.5) < 1e-9
|
||||
|
||||
def test_initialize_sets_zero_velocity(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(1.0, -2.0)
|
||||
vx, vy = kf.velocity
|
||||
assert abs(vx) < 1e-9
|
||||
assert abs(vy) < 1e-9
|
||||
|
||||
def test_initialize_origin(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(0.0, 0.0)
|
||||
assert kf.initialized
|
||||
x, y = kf.position
|
||||
assert x == 0.0 and y == 0.0
|
||||
|
||||
|
||||
class TestKalmanTrackerPredict:
|
||||
|
||||
def test_predict_zero_dt_no_change(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(2.0, 1.0)
|
||||
kf.predict(0.0)
|
||||
x, y = kf.position
|
||||
assert abs(x - 2.0) < 1e-9
|
||||
assert abs(y - 1.0) < 1e-9
|
||||
|
||||
def test_predict_negative_dt_no_change(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(2.0, 1.0)
|
||||
kf.predict(-0.1)
|
||||
x, y = kf.position
|
||||
assert abs(x - 2.0) < 1e-9
|
||||
|
||||
def test_predict_constant_velocity(self):
|
||||
"""After a position update gives the filter a velocity, predict should extrapolate."""
|
||||
kf = KalmanTracker(process_noise=0.001, meas_noise_uwb=0.001)
|
||||
kf.initialize(0.0, 0.0)
|
||||
# Force filter to track a moving target to build up velocity estimate
|
||||
dt = 0.05
|
||||
for i in range(40):
|
||||
t = i * dt
|
||||
kf.predict(dt)
|
||||
kf.update(2.0 * t, 0.0, "uwb") # 2 m/s in x
|
||||
|
||||
# After many updates the velocity estimate should be close to 2 m/s
|
||||
vx, vy = kf.velocity
|
||||
assert abs(vx - 2.0) < 0.3, f"vx={vx:.3f}"
|
||||
assert abs(vy) < 0.2
|
||||
|
||||
def test_predict_grows_uncertainty(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(0.0, 0.0)
|
||||
unc_before = kf.position_uncertainty_m()
|
||||
kf.predict(1.0)
|
||||
unc_after = kf.position_uncertainty_m()
|
||||
assert unc_after > unc_before
|
||||
|
||||
def test_predict_multiple_steps(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(0.0, 0.0)
|
||||
kf.predict(0.1)
|
||||
kf.predict(0.1)
|
||||
kf.predict(0.1)
|
||||
# No assertion on exact value; just verify no exception and state is finite
|
||||
x, y = kf.position
|
||||
assert math.isfinite(x) and math.isfinite(y)
|
||||
|
||||
|
||||
class TestKalmanTrackerUpdate:
|
||||
|
||||
def test_update_pulls_position_toward_measurement(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(0.0, 0.0)
|
||||
kf.update(5.0, 5.0, "uwb")
|
||||
x, y = kf.position
|
||||
assert x > 0.0 and y > 0.0
|
||||
assert x < 5.0 and y < 5.0 # blended, not jumped
|
||||
|
||||
def test_update_reduces_uncertainty(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(0.0, 0.0)
|
||||
kf.predict(1.0) # uncertainty grows
|
||||
unc_mid = kf.position_uncertainty_m()
|
||||
kf.update(0.1, 0.1, "uwb") # measurement corrects
|
||||
unc_after = kf.position_uncertainty_m()
|
||||
assert unc_after < unc_mid
|
||||
|
||||
def test_update_converges_to_true_position(self):
|
||||
"""Many updates from same point should converge to that point."""
|
||||
kf = KalmanTracker(meas_noise_uwb=0.01)
|
||||
kf.initialize(0.0, 0.0)
|
||||
for _ in range(50):
|
||||
kf.update(3.0, -1.0, "uwb")
|
||||
x, y = kf.position
|
||||
assert abs(x - 3.0) < 0.05, f"x={x:.4f}"
|
||||
assert abs(y - (-1.0)) < 0.05, f"y={y:.4f}"
|
||||
|
||||
def test_update_camera_source_different_noise(self):
|
||||
"""Camera and UWB updates should both move state (noise model differs)."""
|
||||
kf1 = KalmanTracker(meas_noise_uwb=0.20, meas_noise_cam=0.10)
|
||||
kf1.initialize(0.0, 0.0)
|
||||
kf1.update(5.0, 0.0, "uwb")
|
||||
x_uwb, _ = kf1.position
|
||||
|
||||
kf2 = KalmanTracker(meas_noise_uwb=0.20, meas_noise_cam=0.10)
|
||||
kf2.initialize(0.0, 0.0)
|
||||
kf2.update(5.0, 0.0, "camera")
|
||||
x_cam, _ = kf2.position
|
||||
|
||||
# Camera has lower noise → stronger pull toward measurement
|
||||
assert x_cam > x_uwb
|
||||
|
||||
def test_update_unknown_source_defaults_to_camera_noise(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(0.0, 0.0)
|
||||
kf.update(2.0, 0.0, "other") # unknown source — should not raise
|
||||
x, _ = kf.position
|
||||
assert x > 0.0
|
||||
|
||||
def test_position_uncertainty_finite(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(1.0, 1.0)
|
||||
kf.predict(0.05)
|
||||
kf.update(1.1, 0.9, "uwb")
|
||||
assert math.isfinite(kf.position_uncertainty_m())
|
||||
assert kf.position_uncertainty_m() >= 0.0
|
||||
|
||||
def test_covariance_copy_is_independent(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(0.0, 0.0)
|
||||
cov = kf.covariance_copy()
|
||||
cov[0, 0] = 9999.0 # mutate copy
|
||||
assert kf.covariance_copy()[0, 0] != 9999.0
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# source_arbiter tests
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestUwbConfidence:
|
||||
|
||||
def test_zero_age_returns_quality(self):
|
||||
assert abs(uwb_confidence(0.0, 1.5) - 1.0) < 1e-9
|
||||
|
||||
def test_at_timeout_returns_zero(self):
|
||||
assert uwb_confidence(1.5, 1.5) == pytest.approx(0.0)
|
||||
|
||||
def test_beyond_timeout_returns_zero(self):
|
||||
assert uwb_confidence(2.0, 1.5) == 0.0
|
||||
|
||||
def test_half_timeout_returns_half(self):
|
||||
assert uwb_confidence(0.75, 1.5) == pytest.approx(0.5)
|
||||
|
||||
def test_quality_scales_result(self):
|
||||
assert uwb_confidence(0.0, 1.5, quality=0.7) == pytest.approx(0.7)
|
||||
|
||||
def test_large_age_returns_zero(self):
|
||||
assert uwb_confidence(1e9, 1.5) == 0.0
|
||||
|
||||
def test_zero_timeout_returns_zero(self):
|
||||
assert uwb_confidence(0.0, 0.0) == 0.0
|
||||
|
||||
|
||||
class TestCameraConfidence:
|
||||
|
||||
def test_zero_age_full_quality(self):
|
||||
assert camera_confidence(0.0, 1.0, quality=1.0) == pytest.approx(1.0)
|
||||
|
||||
def test_at_timeout_zero(self):
|
||||
assert camera_confidence(1.0, 1.0) == pytest.approx(0.0)
|
||||
|
||||
def test_beyond_timeout_zero(self):
|
||||
assert camera_confidence(2.0, 1.0) == 0.0
|
||||
|
||||
def test_quality_scales(self):
|
||||
# age=0, quality=0.8
|
||||
assert camera_confidence(0.0, 1.0, quality=0.8) == pytest.approx(0.8)
|
||||
|
||||
def test_halfway(self):
|
||||
assert camera_confidence(0.5, 1.0) == pytest.approx(0.5)
|
||||
|
||||
|
||||
class TestSelectSource:
|
||||
|
||||
def test_both_above_threshold_fused(self):
|
||||
assert select_source(0.8, 0.6) == "fused"
|
||||
|
||||
def test_only_uwb_above_threshold(self):
|
||||
assert select_source(0.8, 0.0) == "uwb"
|
||||
|
||||
def test_only_cam_above_threshold(self):
|
||||
assert select_source(0.0, 0.7) == "camera"
|
||||
|
||||
def test_both_below_threshold(self):
|
||||
assert select_source(0.0, 0.0) == "predicted"
|
||||
|
||||
def test_threshold_boundary_uwb(self):
|
||||
# Exactly at threshold — should be treated as live
|
||||
assert select_source(0.15, 0.0, threshold=0.15) == "uwb"
|
||||
|
||||
def test_threshold_boundary_below(self):
|
||||
assert select_source(0.14, 0.0, threshold=0.15) == "predicted"
|
||||
|
||||
def test_custom_threshold(self):
|
||||
assert select_source(0.5, 0.0, threshold=0.6) == "predicted"
|
||||
assert select_source(0.5, 0.0, threshold=0.4) == "uwb"
|
||||
|
||||
|
||||
class TestFusePositions:
|
||||
|
||||
def test_equal_confidence_returns_midpoint(self):
|
||||
x, y = fuse_positions(0.0, 0.0, 1.0, 4.0, 4.0, 1.0)
|
||||
assert abs(x - 2.0) < 1e-9
|
||||
assert abs(y - 2.0) < 1e-9
|
||||
|
||||
def test_full_uwb_weight_returns_uwb(self):
|
||||
x, y = fuse_positions(3.0, 1.0, 1.0, 0.0, 0.0, 0.0)
|
||||
assert abs(x - 3.0) < 1e-9
|
||||
|
||||
def test_full_cam_weight_returns_cam(self):
|
||||
x, y = fuse_positions(0.0, 0.0, 0.0, -2.0, 5.0, 1.0)
|
||||
assert abs(x - (-2.0)) < 1e-9
|
||||
assert abs(y - 5.0) < 1e-9
|
||||
|
||||
def test_weighted_blend(self):
|
||||
# UWB at (0,0) conf=3, camera at (4,0) conf=1 → x = 3/4*0 + 1/4*4 = 1
|
||||
x, y = fuse_positions(0.0, 0.0, 3.0, 4.0, 0.0, 1.0)
|
||||
assert abs(x - 1.0) < 1e-9
|
||||
|
||||
def test_zero_total_returns_uwb_fallback(self):
|
||||
x, y = fuse_positions(7.0, 2.0, 0.0, 3.0, 1.0, 0.0)
|
||||
assert abs(x - 7.0) < 1e-9
|
||||
|
||||
|
||||
class TestCompositeConfidence:
|
||||
|
||||
def test_fused_source_high_confidence(self):
|
||||
conf = composite_confidence(0.9, 0.8, "fused", 0.05)
|
||||
assert conf > 0.7
|
||||
|
||||
def test_predicted_source_capped(self):
|
||||
conf = composite_confidence(0.0, 0.0, "predicted", 0.1)
|
||||
assert conf <= 0.4
|
||||
|
||||
def test_predicted_high_uncertainty_low_confidence(self):
|
||||
conf = composite_confidence(0.0, 0.0, "predicted", 3.0, max_kf_uncertainty_m=3.0)
|
||||
assert conf == pytest.approx(0.0)
|
||||
|
||||
def test_uwb_only(self):
|
||||
conf = composite_confidence(0.8, 0.0, "uwb", 0.05)
|
||||
assert conf > 0.3
|
||||
|
||||
def test_camera_only(self):
|
||||
conf = composite_confidence(0.0, 0.7, "camera", 0.05)
|
||||
assert conf > 0.2
|
||||
|
||||
def test_high_kf_uncertainty_reduces_confidence(self):
|
||||
low_unc = composite_confidence(0.9, 0.0, "uwb", 0.1)
|
||||
high_unc = composite_confidence(0.9, 0.0, "uwb", 2.9)
|
||||
assert low_unc > high_unc
|
||||
|
||||
|
||||
class TestBearingAndRange:
|
||||
|
||||
def test_straight_ahead(self):
|
||||
bearing, rng = bearing_and_range(2.0, 0.0)
|
||||
assert abs(bearing) < 1e-9
|
||||
assert abs(rng - 2.0) < 1e-9
|
||||
|
||||
def test_left_of_robot(self):
|
||||
# +Y = left in base_link frame; bearing should be positive
|
||||
bearing, rng = bearing_and_range(0.0, 1.0)
|
||||
assert abs(bearing - math.pi / 2.0) < 1e-9
|
||||
assert abs(rng - 1.0) < 1e-9
|
||||
|
||||
def test_right_of_robot(self):
|
||||
bearing, rng = bearing_and_range(0.0, -1.0)
|
||||
assert abs(bearing - (-math.pi / 2.0)) < 1e-9
|
||||
|
||||
def test_diagonal(self):
|
||||
bearing, rng = bearing_and_range(1.0, 1.0)
|
||||
assert abs(bearing - math.pi / 4.0) < 1e-9
|
||||
assert abs(rng - math.sqrt(2.0)) < 1e-9
|
||||
|
||||
def test_at_origin(self):
|
||||
bearing, rng = bearing_and_range(0.0, 0.0)
|
||||
assert rng == pytest.approx(0.0)
|
||||
assert math.isfinite(bearing) # atan2(0,0) = 0 in most implementations
|
||||
|
||||
def test_range_always_non_negative(self):
|
||||
for x, y in [(-1, 0), (0, -1), (-2, -3), (5, -5)]:
|
||||
_, rng = bearing_and_range(x, y)
|
||||
assert rng >= 0.0
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Integration scenario tests
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
|
||||
def test_euc_speed_velocity_tracking(self):
|
||||
"""Verify KF can track EUC speed (8 m/s) within 0.5 m/s after warm-up."""
|
||||
kf = KalmanTracker(process_noise=3.0, meas_noise_uwb=0.20)
|
||||
kf.initialize(0.0, 0.0)
|
||||
dt = 1.0 / 10.0 # 10 Hz UWB rate
|
||||
speed = 8.0 # m/s (≈29 km/h)
|
||||
for i in range(60):
|
||||
t = i * dt
|
||||
kf.predict(dt)
|
||||
kf.update(speed * t, 0.0, "uwb")
|
||||
vx, vy = kf.velocity
|
||||
assert abs(vx - speed) < 0.6, f"vx={vx:.2f} expected≈{speed}"
|
||||
assert abs(vy) < 0.3
|
||||
|
||||
def test_signal_loss_recovery(self):
|
||||
"""
|
||||
After 1 s of signal loss the filter should still have a reasonable
|
||||
position estimate (not diverged to infinity).
|
||||
"""
|
||||
kf = KalmanTracker(process_noise=3.0)
|
||||
kf.initialize(2.0, 0.5)
|
||||
# Warm up with 2 m/s x motion
|
||||
dt = 0.05
|
||||
for i in range(20):
|
||||
kf.predict(dt)
|
||||
kf.update(2.0 * (i + 1) * dt, 0.0, "uwb")
|
||||
# Coast for 1 second (20 × 50 ms) without measurements
|
||||
for _ in range(20):
|
||||
kf.predict(dt)
|
||||
x, y = kf.position
|
||||
assert math.isfinite(x) and math.isfinite(y)
|
||||
assert abs(x) < 20.0 # shouldn't have drifted more than 20 m
|
||||
|
||||
def test_uwb_to_camera_handoff(self):
|
||||
"""
|
||||
Simulate UWB going stale and camera taking over — Kalman should
|
||||
smoothly continue tracking without a jump.
|
||||
"""
|
||||
kf = KalmanTracker(meas_noise_uwb=0.20, meas_noise_cam=0.12)
|
||||
kf.initialize(0.0, 0.0)
|
||||
dt = 0.05
|
||||
# Phase 1: UWB active
|
||||
for i in range(20):
|
||||
kf.predict(dt)
|
||||
kf.update(float(i) * 0.1, 0.0, "uwb")
|
||||
x_at_handoff, _ = kf.position
|
||||
|
||||
# Phase 2: Camera takes over from same trajectory
|
||||
for i in range(20, 40):
|
||||
kf.predict(dt)
|
||||
kf.update(float(i) * 0.1, 0.0, "camera")
|
||||
x_after, _ = kf.position
|
||||
|
||||
# Position should have continued progressing (not stuck or reset)
|
||||
assert x_after > x_at_handoff
|
||||
|
||||
def test_confidence_degradation_during_coast(self):
|
||||
"""Composite confidence should drop as KF uncertainty grows during coast."""
|
||||
kf = KalmanTracker(process_noise=3.0)
|
||||
kf.initialize(2.0, 0.0)
|
||||
|
||||
# Fresh: tight uncertainty → high confidence
|
||||
unc_fresh = kf.position_uncertainty_m()
|
||||
conf_fresh = composite_confidence(0.0, 0.0, "predicted", unc_fresh)
|
||||
|
||||
# After 2 s coast
|
||||
for _ in range(40):
|
||||
kf.predict(0.05)
|
||||
unc_stale = kf.position_uncertainty_m()
|
||||
conf_stale = composite_confidence(0.0, 0.0, "predicted", unc_stale)
|
||||
|
||||
assert conf_fresh >= conf_stale
|
||||
|
||||
def test_fused_source_confidence_weighted_position(self):
|
||||
"""Fused position should sit between UWB and camera, closer to higher-conf source."""
|
||||
# UWB at x=0 with high conf, camera at x=10 with low conf
|
||||
uwb_c = 0.9
|
||||
cam_c = 0.1
|
||||
fx, fy = fuse_positions(0.0, 0.0, uwb_c, 10.0, 0.0, cam_c)
|
||||
# Should be much closer to UWB (0) than camera (10)
|
||||
assert fx < 3.0, f"fused_x={fx:.2f}"
|
||||
|
||||
def test_select_source_transitions(self):
|
||||
"""Verify correct source transitions as confidences change."""
|
||||
assert select_source(0.9, 0.8) == "fused"
|
||||
assert select_source(0.9, 0.0) == "uwb"
|
||||
assert select_source(0.0, 0.8) == "camera"
|
||||
assert select_source(0.0, 0.0) == "predicted"
|
||||
Loading…
x
Reference in New Issue
Block a user