Compare commits
3 Commits
84c8b6a0ae
...
fa0162fadc
| Author | SHA1 | Date | |
|---|---|---|---|
| fa0162fadc | |||
| d48edf4092 | |||
| 44771751e2 |
@ -23,6 +23,11 @@ rosidl_generate_interfaces(${PROJECT_NAME}
|
||||
"msg/Mood.msg"
|
||||
"msg/Person.msg"
|
||||
"msg/PersonArray.msg"
|
||||
# Issue #84 — personality system
|
||||
"msg/PersonalityState.msg"
|
||||
"srv/QueryMood.srv"
|
||||
# Issue #92 — multi-modal tracking fusion
|
||||
"msg/FusedTarget.msg"
|
||||
DEPENDENCIES std_msgs geometry_msgs builtin_interfaces
|
||||
)
|
||||
|
||||
|
||||
19
jetson/ros2_ws/src/saltybot_social_msgs/msg/FusedTarget.msg
Normal file
19
jetson/ros2_ws/src/saltybot_social_msgs/msg/FusedTarget.msg
Normal file
@ -0,0 +1,19 @@
|
||||
# FusedTarget.msg — output of the multi-modal tracking fusion node.
|
||||
#
|
||||
# Position and velocity are in the base_link frame (robot-centred,
|
||||
# +X forward, +Y left). z components are always 0.0 for ground-plane tracking.
|
||||
#
|
||||
# Confidence: 0.0 = no data / fully predicted; 1.0 = strong fused measurement.
|
||||
# active_source: "fused" | "uwb" | "camera" | "predicted"
|
||||
|
||||
std_msgs/Header header
|
||||
|
||||
geometry_msgs/Point position # filtered 2-D position (m), z=0
|
||||
geometry_msgs/Vector3 velocity # filtered 2-D velocity (m/s), z=0
|
||||
|
||||
float32 range_m # Euclidean distance from robot to fused position
|
||||
float32 bearing_rad # bearing in base_link (+ve = person to the left)
|
||||
float32 confidence # composite confidence [0.0, 1.0]
|
||||
|
||||
string active_source # "fused" | "uwb" | "camera" | "predicted"
|
||||
string tag_id # UWB tag address (empty when UWB not contributing)
|
||||
@ -0,0 +1,27 @@
|
||||
# PersonalityState.msg — published on /social/personality/state
|
||||
#
|
||||
# Snapshot of the personality node's current state: active mood, relationship
|
||||
# tier with the detected person, and a pre-generated greeting string.
|
||||
|
||||
std_msgs/Header header
|
||||
|
||||
# Active persona name (from SOUL.md)
|
||||
string persona_name
|
||||
|
||||
# Current mood: happy | curious | annoyed | playful
|
||||
string mood
|
||||
|
||||
# Person currently being addressed (empty if no one detected)
|
||||
string person_id
|
||||
|
||||
# Relationship tier with person_id: stranger | regular | favorite
|
||||
string relationship_tier
|
||||
|
||||
# Raw relationship score (higher = more familiar)
|
||||
float32 relationship_score
|
||||
|
||||
# How many times we have seen this person
|
||||
int32 interaction_count
|
||||
|
||||
# Ready-to-use greeting for person_id at current tier
|
||||
string greeting_text
|
||||
@ -3,16 +3,25 @@
|
||||
<package format="3">
|
||||
<name>saltybot_social_msgs</name>
|
||||
<version>0.1.0</version>
|
||||
<description>Custom ROS2 messages and services for saltybot social capabilities</description>
|
||||
<description>
|
||||
Custom ROS2 message and service definitions for saltybot social capabilities.
|
||||
Includes social perception types (face detection, person state, enrollment)
|
||||
and the personality system types (PersonalityState, QueryMood) from Issue #84.
|
||||
</description>
|
||||
<maintainer email="seb@vayrette.com">seb</maintainer>
|
||||
<license>MIT</license>
|
||||
|
||||
<buildtool_depend>ament_cmake</buildtool_depend>
|
||||
<build_depend>rosidl_default_generators</build_depend>
|
||||
|
||||
<depend>std_msgs</depend>
|
||||
<depend>geometry_msgs</depend>
|
||||
<depend>builtin_interfaces</depend>
|
||||
<build_depend>rosidl_default_generators</build_depend>
|
||||
|
||||
<exec_depend>rosidl_default_runtime</exec_depend>
|
||||
|
||||
<member_of_group>rosidl_interface_packages</member_of_group>
|
||||
|
||||
<export>
|
||||
<build_type>ament_cmake</build_type>
|
||||
</export>
|
||||
|
||||
15
jetson/ros2_ws/src/saltybot_social_msgs/srv/QueryMood.srv
Normal file
15
jetson/ros2_ws/src/saltybot_social_msgs/srv/QueryMood.srv
Normal file
@ -0,0 +1,15 @@
|
||||
# QueryMood.srv — ask the personality node for the current mood + greeting for a person
|
||||
#
|
||||
# Call with empty person_id to query the mood for whoever is currently tracked.
|
||||
|
||||
# Request
|
||||
string person_id # person to query; leave empty for current tracked person
|
||||
---
|
||||
# Response
|
||||
string mood # happy | curious | annoyed | playful
|
||||
string relationship_tier # stranger | regular | favorite
|
||||
float32 relationship_score
|
||||
int32 interaction_count
|
||||
string greeting_text # suggested greeting string
|
||||
bool success
|
||||
string message # error detail if success=false
|
||||
@ -0,0 +1,42 @@
|
||||
---
|
||||
# SOUL.md — Saltybot persona definition
|
||||
#
|
||||
# Hot-reload: personality_node watches this file and reloads on change.
|
||||
# Runtime override: ros2 param set /personality_node soul_file /path/to/SOUL.md
|
||||
|
||||
# ── Identity ──────────────────────────────────────────────────────────────────
|
||||
name: "Salty"
|
||||
speaking_style: "casual and upbeat, occasional puns"
|
||||
|
||||
# ── Personality dials (0–10) ──────────────────────────────────────────────────
|
||||
humor_level: 7 # 0 = deadpan/serious, 10 = comedian
|
||||
sass_level: 4 # 0 = pure politeness, 10 = maximum sass
|
||||
|
||||
# ── Default mood (when no person is present or score is neutral) ──────────────
|
||||
# One of: happy | curious | annoyed | playful
|
||||
base_mood: "playful"
|
||||
|
||||
# ── Relationship thresholds (interaction counts) ──────────────────────────────
|
||||
threshold_regular: 5 # interactions to graduate from stranger → regular
|
||||
threshold_favorite: 20 # interactions to graduate from regular → favorite
|
||||
|
||||
# ── Greeting templates (use {name} placeholder for person_id or display name) ─
|
||||
greeting_stranger: "Hello there! I'm Salty, nice to meet you!"
|
||||
greeting_regular: "Hey {name}! Good to see you again!"
|
||||
greeting_favorite: "Oh hey {name}!! You're literally my favorite person right now!"
|
||||
|
||||
# ── Mood-specific greeting prefixes ──────────────────────────────────────────
|
||||
mood_prefix_happy: "Great timing — "
|
||||
mood_prefix_curious: "Oh interesting, "
|
||||
mood_prefix_annoyed: "Well, "
|
||||
mood_prefix_playful: "Beep boop! "
|
||||
---
|
||||
|
||||
# Description (ignored by the YAML parser, for human reference only)
|
||||
|
||||
Salty is the personality of the saltybot social robot.
|
||||
She is curious about the world, genuinely happy to see people she knows,
|
||||
and has a good sense of humour — especially with regulars.
|
||||
|
||||
Edit this file to change her personality. The node hot-reloads within
|
||||
`reload_interval` seconds of any change.
|
||||
@ -0,0 +1,28 @@
|
||||
# personality_params.yaml — ROS2 parameter defaults for personality_node.
|
||||
#
|
||||
# Run with:
|
||||
# ros2 launch saltybot_social_personality personality.launch.py
|
||||
# Override inline:
|
||||
# ros2 launch saltybot_social_personality personality.launch.py soul_file:=/my/SOUL.md
|
||||
# Dynamic reconfigure at runtime:
|
||||
# ros2 param set /personality_node soul_file /path/to/SOUL.md
|
||||
# ros2 param set /personality_node publish_rate 5.0
|
||||
|
||||
# ── SOUL.md path ───────────────────────────────────────────────────────────────
|
||||
# Path to the SOUL.md persona file. Supports hot-reload.
|
||||
# Relative paths are resolved from the package share directory.
|
||||
soul_file: "" # empty = use bundled default config/SOUL.md
|
||||
|
||||
# ── SQLite database ────────────────────────────────────────────────────────────
|
||||
# Path for the per-person relationship memory database.
|
||||
# Created on first run; persists across restarts.
|
||||
db_path: "~/.ros/saltybot_personality.db"
|
||||
|
||||
# ── Hot-reload polling interval ────────────────────────────────────────────────
|
||||
# How often (seconds) to check if SOUL.md has been modified.
|
||||
# Lower = faster reactions to edits; higher = less disk I/O.
|
||||
reload_interval: 5.0 # seconds
|
||||
|
||||
# ── Personality state publication rate ────────────────────────────────────────
|
||||
# How often to publish /social/personality/state (PersonalityState).
|
||||
publish_rate: 2.0 # Hz
|
||||
@ -0,0 +1,99 @@
|
||||
"""
|
||||
personality.launch.py — Launch the saltybot personality node.
|
||||
|
||||
Usage
|
||||
-----
|
||||
# Defaults (bundled SOUL.md, ~/.ros/saltybot_personality.db):
|
||||
ros2 launch saltybot_social_personality personality.launch.py
|
||||
|
||||
# Custom persona file:
|
||||
ros2 launch saltybot_social_personality personality.launch.py \\
|
||||
soul_file:=/home/robot/my_persona/SOUL.md
|
||||
|
||||
# Custom DB + faster reload:
|
||||
ros2 launch saltybot_social_personality personality.launch.py \\
|
||||
db_path:=/data/saltybot.db reload_interval:=2.0
|
||||
|
||||
# Use a params file:
|
||||
ros2 launch saltybot_social_personality personality.launch.py \\
|
||||
params_file:=/my/personality_params.yaml
|
||||
|
||||
Dynamic reconfigure (no restart required)
|
||||
-----------------------------------------
|
||||
ros2 param set /personality_node soul_file /new/SOUL.md
|
||||
ros2 param set /personality_node publish_rate 5.0
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument, OpaqueFunction
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def _launch_personality(context, *args, **kwargs):
|
||||
pkg_share = get_package_share_directory("saltybot_social_personality")
|
||||
params_file = LaunchConfiguration("params_file").perform(context)
|
||||
soul_file = LaunchConfiguration("soul_file").perform(context)
|
||||
db_path = LaunchConfiguration("db_path").perform(context)
|
||||
|
||||
# Default soul_file to bundled config if not specified
|
||||
if not soul_file:
|
||||
soul_file = os.path.join(pkg_share, "config", "SOUL.md")
|
||||
|
||||
# Expand ~ in db_path
|
||||
if db_path:
|
||||
db_path = os.path.expanduser(db_path)
|
||||
|
||||
inline_params = {
|
||||
"soul_file": soul_file,
|
||||
"db_path": db_path or os.path.expanduser("~/.ros/saltybot_personality.db"),
|
||||
"reload_interval": float(LaunchConfiguration("reload_interval").perform(context)),
|
||||
"publish_rate": float(LaunchConfiguration("publish_rate").perform(context)),
|
||||
}
|
||||
|
||||
node_params = [params_file, inline_params] if params_file else [inline_params]
|
||||
|
||||
return [Node(
|
||||
package = "saltybot_social_personality",
|
||||
executable = "personality_node",
|
||||
name = "personality_node",
|
||||
output = "screen",
|
||||
parameters = node_params,
|
||||
)]
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg_share = get_package_share_directory("saltybot_social_personality")
|
||||
default_params = os.path.join(pkg_share, "config", "personality_params.yaml")
|
||||
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument(
|
||||
"params_file",
|
||||
default_value=default_params,
|
||||
description="Full path to personality_params.yaml (base config)"),
|
||||
|
||||
DeclareLaunchArgument(
|
||||
"soul_file",
|
||||
default_value="",
|
||||
description="Path to SOUL.md persona file (empty = bundled default)"),
|
||||
|
||||
DeclareLaunchArgument(
|
||||
"db_path",
|
||||
default_value="~/.ros/saltybot_personality.db",
|
||||
description="SQLite relationship memory database path"),
|
||||
|
||||
DeclareLaunchArgument(
|
||||
"reload_interval",
|
||||
default_value="5.0",
|
||||
description="SOUL.md hot-reload polling interval (s)"),
|
||||
|
||||
DeclareLaunchArgument(
|
||||
"publish_rate",
|
||||
default_value="2.0",
|
||||
description="Personality state publish rate (Hz)"),
|
||||
|
||||
OpaqueFunction(function=_launch_personality),
|
||||
])
|
||||
32
jetson/ros2_ws/src/saltybot_social_personality/package.xml
Normal file
32
jetson/ros2_ws/src/saltybot_social_personality/package.xml
Normal file
@ -0,0 +1,32 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_social_personality</name>
|
||||
<version>0.1.0</version>
|
||||
<description>
|
||||
SOUL.md-driven personality system for saltybot.
|
||||
Loads a YAML/Markdown persona file, maintains per-person relationship memory
|
||||
in SQLite, computes mood (happy/curious/annoyed/playful), personalises
|
||||
greetings by tier (stranger/regular/favorite), and publishes personality
|
||||
state on /social/personality/state. Supports SOUL.md hot-reload and full
|
||||
ROS2 dynamic reconfigure. Issue #84.
|
||||
</description>
|
||||
<maintainer email="sl-controls@saltylab.local">sl-controls</maintainer>
|
||||
<license>MIT</license>
|
||||
|
||||
<depend>rclpy</depend>
|
||||
<depend>std_msgs</depend>
|
||||
<depend>rcl_interfaces</depend>
|
||||
<depend>saltybot_social_msgs</depend>
|
||||
|
||||
<buildtool_depend>ament_python</buildtool_depend>
|
||||
|
||||
<test_depend>ament_copyright</test_depend>
|
||||
<test_depend>ament_flake8</test_depend>
|
||||
<test_depend>ament_pep257</test_depend>
|
||||
<test_depend>python3-pytest</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1,187 @@
|
||||
"""
|
||||
mood_engine.py — Pure-function mood computation for the saltybot personality system.
|
||||
|
||||
No ROS2 imports — safe to unit-test without a live ROS2 environment.
|
||||
|
||||
Public API
|
||||
----------
|
||||
compute_mood(soul, score, interaction_count, recent_events) -> str
|
||||
get_relationship_tier(soul, interaction_count) -> str
|
||||
build_greeting(soul, tier, mood, person_id) -> str
|
||||
|
||||
Mood semantics
|
||||
--------------
|
||||
happy : positive valence, comfortable familiarity
|
||||
playful : high-energy, humorous (requires humor_level >= 7)
|
||||
curious : low familiarity or novel person — inquisitive
|
||||
annoyed : recent negative events or very low score
|
||||
|
||||
Tier semantics
|
||||
--------------
|
||||
stranger : interaction_count < threshold_regular
|
||||
regular : threshold_regular <= count < threshold_favorite
|
||||
favorite : count >= threshold_favorite
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# ── Mood / tier constants ──────────────────────────────────────────────────────
|
||||
|
||||
MOOD_HAPPY = "happy"
|
||||
MOOD_PLAYFUL = "playful"
|
||||
MOOD_CURIOUS = "curious"
|
||||
MOOD_ANNOYED = "annoyed"
|
||||
|
||||
TIER_STRANGER = "stranger"
|
||||
TIER_REGULAR = "regular"
|
||||
TIER_FAVORITE = "favorite"
|
||||
|
||||
# ── Event type constants (used by relationship_db and the node) ────────────────
|
||||
|
||||
EVENT_GREETING = "greeting"
|
||||
EVENT_POSITIVE = "positive"
|
||||
EVENT_NEGATIVE = "negative"
|
||||
EVENT_DETECTION = "detection"
|
||||
|
||||
# How far back (seconds) to consider "recent" for mood computation
|
||||
_RECENT_WINDOW_S = 120.0
|
||||
|
||||
|
||||
# ── Mood computation ──────────────────────────────────────────────────────────
|
||||
|
||||
def compute_mood(
|
||||
soul: dict,
|
||||
score: float,
|
||||
interaction_count: int,
|
||||
recent_events: list,
|
||||
) -> str:
|
||||
"""Compute the current mood for a given person.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
soul : dict
|
||||
Parsed SOUL.md configuration.
|
||||
score : float
|
||||
Relationship score for the current person (higher = more familiar).
|
||||
interaction_count : int
|
||||
Total number of times we have seen this person.
|
||||
recent_events : list of dict
|
||||
Each dict: ``{"type": str, "dt": float}`` where ``dt`` is seconds ago.
|
||||
Only events with ``dt < 120.0`` are considered "recent".
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
One of: ``"happy"``, ``"playful"``, ``"curious"``, ``"annoyed"``.
|
||||
"""
|
||||
base_mood = soul.get("base_mood", MOOD_PLAYFUL)
|
||||
humor_level = float(soul.get("humor_level", 5))
|
||||
|
||||
# Count recent negative/positive events
|
||||
recent_neg = sum(
|
||||
1 for e in recent_events
|
||||
if e.get("type") == EVENT_NEGATIVE and e.get("dt", 1e9) < _RECENT_WINDOW_S
|
||||
)
|
||||
recent_pos = sum(
|
||||
1 for e in recent_events
|
||||
if e.get("type") in (EVENT_POSITIVE, EVENT_GREETING)
|
||||
and e.get("dt", 1e9) < _RECENT_WINDOW_S
|
||||
)
|
||||
|
||||
# Hard override: multiple negatives → annoyed
|
||||
if recent_neg >= 2:
|
||||
return MOOD_ANNOYED
|
||||
|
||||
# No prior interactions or brand-new person → curious
|
||||
if interaction_count == 0 or score < 1.0:
|
||||
return MOOD_CURIOUS
|
||||
|
||||
# Stranger tier (low count) → curious
|
||||
threshold_regular = int(soul.get("threshold_regular", 5))
|
||||
if interaction_count < threshold_regular:
|
||||
return MOOD_CURIOUS
|
||||
|
||||
# Familiar person: check positive events and humor level
|
||||
if recent_pos >= 1 or score >= 20.0:
|
||||
if humor_level >= 7:
|
||||
return MOOD_PLAYFUL
|
||||
return MOOD_HAPPY
|
||||
|
||||
# High score / favorite
|
||||
threshold_fav = int(soul.get("threshold_favorite", 20))
|
||||
if interaction_count >= threshold_fav:
|
||||
if humor_level >= 7:
|
||||
return MOOD_PLAYFUL
|
||||
return MOOD_HAPPY
|
||||
|
||||
return base_mood
|
||||
|
||||
|
||||
# ── Tier classification ────────────────────────────────────────────────────────
|
||||
|
||||
def get_relationship_tier(soul: dict, interaction_count: int) -> str:
|
||||
"""Return the relationship tier string for a given interaction count.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
soul : dict
|
||||
Parsed SOUL.md configuration.
|
||||
interaction_count : int
|
||||
Total number of times we have seen this person.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
One of: ``"stranger"``, ``"regular"``, ``"favorite"``.
|
||||
"""
|
||||
threshold_regular = int(soul.get("threshold_regular", 5))
|
||||
threshold_favorite = int(soul.get("threshold_favorite", 20))
|
||||
if interaction_count >= threshold_favorite:
|
||||
return TIER_FAVORITE
|
||||
if interaction_count >= threshold_regular:
|
||||
return TIER_REGULAR
|
||||
return TIER_STRANGER
|
||||
|
||||
|
||||
# ── Greeting builder ──────────────────────────────────────────────────────────
|
||||
|
||||
def build_greeting(soul: dict, tier: str, mood: str, person_id: str = "") -> str:
|
||||
"""Compose a greeting string for a person.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
soul : dict
|
||||
Parsed SOUL.md configuration.
|
||||
tier : str
|
||||
Relationship tier (``"stranger"``, ``"regular"``, ``"favorite"``).
|
||||
mood : str
|
||||
Current mood (used to prefix the greeting).
|
||||
person_id : str
|
||||
Person identifier / display name. Substituted for ``{name}``
|
||||
in the template.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
A complete, ready-to-display greeting string.
|
||||
"""
|
||||
template_key = {
|
||||
TIER_STRANGER: "greeting_stranger",
|
||||
TIER_REGULAR: "greeting_regular",
|
||||
TIER_FAVORITE: "greeting_favorite",
|
||||
}.get(tier, "greeting_stranger")
|
||||
|
||||
template = soul.get(template_key, "Hello!")
|
||||
base_greeting = template.replace("{name}", person_id or "friend")
|
||||
|
||||
prefix_key = f"mood_prefix_{mood}"
|
||||
prefix = soul.get(prefix_key, "")
|
||||
|
||||
if prefix:
|
||||
# Avoid double punctuation / duplicate capital letters
|
||||
base_first = base_greeting[0].lower() if base_greeting else ""
|
||||
greeting = f"{prefix}{base_first}{base_greeting[1:]}"
|
||||
else:
|
||||
greeting = base_greeting
|
||||
|
||||
return greeting
|
||||
@ -0,0 +1,349 @@
|
||||
"""
|
||||
personality_node.py — ROS2 personality system for saltybot.
|
||||
|
||||
Overview
|
||||
--------
|
||||
Loads a SOUL.md persona file, maintains per-person relationship memory in
|
||||
SQLite, computes mood, and publishes personality state. All tunable params
|
||||
support ROS2 dynamic reconfigure (``ros2 param set``).
|
||||
|
||||
Subscriptions
|
||||
-------------
|
||||
/social/person_detected (std_msgs/String)
|
||||
JSON payload: ``{"person_id": "alice", "event_type": "greeting",
|
||||
"delta_score": 1.0}``
|
||||
event_type defaults to "detection" if absent.
|
||||
delta_score defaults to 0.0 if absent.
|
||||
|
||||
Publications
|
||||
------------
|
||||
/social/personality/state (saltybot_social_msgs/PersonalityState)
|
||||
Published at ``publish_rate`` Hz.
|
||||
|
||||
Services
|
||||
--------
|
||||
/social/personality/query_mood (saltybot_social_msgs/QueryMood)
|
||||
Query mood + greeting for any person_id.
|
||||
|
||||
Parameters (dynamic reconfigure via ros2 param set)
|
||||
-------------------
|
||||
soul_file (str) Path to SOUL.md persona file.
|
||||
db_path (str) SQLite database file path.
|
||||
reload_interval (float) How often to poll SOUL.md for changes (s).
|
||||
publish_rate (float) State publication rate (Hz).
|
||||
|
||||
Usage
|
||||
-----
|
||||
ros2 launch saltybot_social_personality personality.launch.py
|
||||
ros2 launch saltybot_social_personality personality.launch.py soul_file:=/my/SOUL.md
|
||||
ros2 param set /personality_node soul_file /tmp/new_SOUL.md
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rcl_interfaces.msg import SetParametersResult
|
||||
from std_msgs.msg import String, Header
|
||||
|
||||
from saltybot_social_msgs.msg import PersonalityState
|
||||
from saltybot_social_msgs.srv import QueryMood
|
||||
|
||||
from .soul_loader import load_soul, SoulWatcher
|
||||
from .mood_engine import (
|
||||
compute_mood, get_relationship_tier, build_greeting,
|
||||
EVENT_GREETING, EVENT_POSITIVE, EVENT_NEGATIVE, EVENT_DETECTION,
|
||||
)
|
||||
from .relationship_db import RelationshipDB
|
||||
|
||||
_DEFAULT_SOUL = os.path.join(
|
||||
os.path.dirname(__file__), "..", "config", "SOUL.md"
|
||||
)
|
||||
_DEFAULT_DB = os.path.expanduser("~/.ros/saltybot_personality.db")
|
||||
|
||||
|
||||
class PersonalityNode(Node):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("personality_node")
|
||||
|
||||
# ── Parameters ────────────────────────────────────────────────────────
|
||||
self.declare_parameter("soul_file", _DEFAULT_SOUL)
|
||||
self.declare_parameter("db_path", _DEFAULT_DB)
|
||||
self.declare_parameter("reload_interval", 5.0)
|
||||
self.declare_parameter("publish_rate", 2.0)
|
||||
|
||||
self._p = {}
|
||||
self._reload_ros_params()
|
||||
|
||||
# ── State ─────────────────────────────────────────────────────────────
|
||||
self._soul = {}
|
||||
self._current_person = "" # person_id currently being addressed
|
||||
self._watcher = None
|
||||
|
||||
# ── Database ──────────────────────────────────────────────────────────
|
||||
self._db = RelationshipDB(self._p["db_path"])
|
||||
|
||||
# ── Load initial SOUL.md ──────────────────────────────────────────────
|
||||
self._load_soul_safe()
|
||||
self._start_watcher()
|
||||
|
||||
# ── Dynamic reconfigure callback ─────────────────────────────────────
|
||||
self.add_on_set_parameters_callback(self._on_params_changed)
|
||||
|
||||
# ── Subscriptions ─────────────────────────────────────────────────────
|
||||
self.create_subscription(
|
||||
String, "/social/person_detected", self._person_detected_cb, 10
|
||||
)
|
||||
|
||||
# ── Publishers ────────────────────────────────────────────────────────
|
||||
self._state_pub = self.create_publisher(
|
||||
PersonalityState, "/social/personality/state", 10
|
||||
)
|
||||
|
||||
# ── Services ──────────────────────────────────────────────────────────
|
||||
self.create_service(
|
||||
QueryMood,
|
||||
"/social/personality/query_mood",
|
||||
self._query_mood_cb,
|
||||
)
|
||||
|
||||
# ── Timers ────────────────────────────────────────────────────────────
|
||||
self._pub_timer = self.create_timer(
|
||||
1.0 / self._p["publish_rate"], self._publish_state
|
||||
)
|
||||
|
||||
self.get_logger().info(
|
||||
f"PersonalityNode ready "
|
||||
f"persona={self._soul.get('name', '?')!r} "
|
||||
f"mood={self._current_mood()!r} "
|
||||
f"db={self._p['db_path']!r}"
|
||||
)
|
||||
|
||||
# ── Parameter helpers ──────────────────────────────────────────────────────
|
||||
|
||||
def _reload_ros_params(self):
|
||||
self._p = {
|
||||
"soul_file": self.get_parameter("soul_file").value,
|
||||
"db_path": self.get_parameter("db_path").value,
|
||||
"reload_interval": self.get_parameter("reload_interval").value,
|
||||
"publish_rate": self.get_parameter("publish_rate").value,
|
||||
}
|
||||
|
||||
def _on_params_changed(self, params):
|
||||
"""Dynamic reconfigure — apply changed params without restarting node."""
|
||||
for param in params:
|
||||
if param.name == "soul_file":
|
||||
# Restart watcher on new soul_file
|
||||
self._stop_watcher()
|
||||
self._p["soul_file"] = param.value
|
||||
self._load_soul_safe()
|
||||
self._start_watcher()
|
||||
self.get_logger().info(f"soul_file changed → {param.value!r}")
|
||||
elif param.name in self._p:
|
||||
self._p[param.name] = param.value
|
||||
if param.name == "publish_rate" and self._pub_timer:
|
||||
self._pub_timer.cancel()
|
||||
self._pub_timer = self.create_timer(
|
||||
1.0 / max(0.1, param.value), self._publish_state
|
||||
)
|
||||
return SetParametersResult(successful=True)
|
||||
|
||||
# ── SOUL.md ────────────────────────────────────────────────────────────────
|
||||
|
||||
def _load_soul_safe(self):
|
||||
try:
|
||||
path = os.path.realpath(self._p["soul_file"])
|
||||
self._soul = load_soul(path)
|
||||
self.get_logger().info(
|
||||
f"SOUL.md loaded: {self._soul.get('name', '?')!r} "
|
||||
f"humor={self._soul.get('humor_level')} "
|
||||
f"sass={self._soul.get('sass_level')} "
|
||||
f"base_mood={self._soul.get('base_mood')!r}"
|
||||
)
|
||||
except Exception as exc:
|
||||
self.get_logger().error(f"Failed to load SOUL.md: {exc}")
|
||||
if not self._soul:
|
||||
# Fall back to minimal defaults so the node stays alive
|
||||
self._soul = {
|
||||
"name": "Salty",
|
||||
"humor_level": 5,
|
||||
"sass_level": 3,
|
||||
"base_mood": "curious",
|
||||
"threshold_regular": 5,
|
||||
"threshold_favorite": 20,
|
||||
"greeting_stranger": "Hello!",
|
||||
"greeting_regular": "Hi {name}!",
|
||||
"greeting_favorite": "Hey {name}!!",
|
||||
}
|
||||
|
||||
def _start_watcher(self):
|
||||
if not self._soul:
|
||||
return
|
||||
self._watcher = SoulWatcher(
|
||||
path=self._p["soul_file"],
|
||||
on_reload=self._on_soul_reloaded,
|
||||
interval=self._p["reload_interval"],
|
||||
on_error=lambda exc: self.get_logger().warn(
|
||||
f"SOUL.md hot-reload error: {exc}"
|
||||
),
|
||||
)
|
||||
self._watcher.start()
|
||||
|
||||
def _stop_watcher(self):
|
||||
if self._watcher:
|
||||
self._watcher.stop()
|
||||
self._watcher = None
|
||||
|
||||
def _on_soul_reloaded(self, soul: dict):
|
||||
self._soul = soul
|
||||
self.get_logger().info(
|
||||
f"SOUL.md reloaded: persona={soul.get('name')!r} "
|
||||
f"humor={soul.get('humor_level')} base_mood={soul.get('base_mood')!r}"
|
||||
)
|
||||
|
||||
# ── Mood helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _current_mood(self) -> str:
|
||||
if not self._current_person or not self._soul:
|
||||
return self._soul.get("base_mood", "curious") if self._soul else "curious"
|
||||
person = self._db.get_person(self._current_person)
|
||||
recent = self._db.get_recent_events(self._current_person, window_s=120.0)
|
||||
return compute_mood(
|
||||
soul = self._soul,
|
||||
score = person["score"],
|
||||
interaction_count = person["interaction_count"],
|
||||
recent_events = recent,
|
||||
)
|
||||
|
||||
def _state_for_person(self, person_id: str) -> dict:
|
||||
"""Build a complete state dict for a given person_id."""
|
||||
person = self._db.get_person(person_id) if person_id else {
|
||||
"score": 0.0, "interaction_count": 0
|
||||
}
|
||||
recent = self._db.get_recent_events(person_id, window_s=120.0) if person_id else []
|
||||
|
||||
mood = compute_mood(
|
||||
soul = self._soul,
|
||||
score = person["score"],
|
||||
interaction_count = person["interaction_count"],
|
||||
recent_events = recent,
|
||||
)
|
||||
tier = get_relationship_tier(self._soul, person["interaction_count"])
|
||||
greeting = build_greeting(self._soul, tier, mood, person_id)
|
||||
|
||||
return {
|
||||
"person_id": person_id,
|
||||
"mood": mood,
|
||||
"tier": tier,
|
||||
"score": person["score"],
|
||||
"interaction_count": person["interaction_count"],
|
||||
"greeting": greeting,
|
||||
}
|
||||
|
||||
# ── Callbacks ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _person_detected_cb(self, msg: String):
|
||||
"""Handle incoming person detection / interaction event.
|
||||
|
||||
Expected JSON payload::
|
||||
|
||||
{
|
||||
"person_id": "alice", # required
|
||||
"event_type": "greeting", # optional, default "detection"
|
||||
"delta_score": 1.0 # optional, default 0.0
|
||||
}
|
||||
"""
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
except json.JSONDecodeError as exc:
|
||||
self.get_logger().warn(f"Bad JSON on /social/person_detected: {exc}")
|
||||
return
|
||||
|
||||
person_id = data.get("person_id", "").strip()
|
||||
if not person_id:
|
||||
self.get_logger().warn("person_detected msg missing 'person_id'")
|
||||
return
|
||||
|
||||
event_type = data.get("event_type", EVENT_DETECTION)
|
||||
delta_score = float(data.get("delta_score", 0.0))
|
||||
|
||||
# Increment score by +1 for detection events automatically
|
||||
if event_type == EVENT_DETECTION and delta_score == 0.0:
|
||||
delta_score = 0.5
|
||||
|
||||
self._db.record_interaction(
|
||||
person_id = person_id,
|
||||
event_type = event_type,
|
||||
details = {k: v for k, v in data.items()
|
||||
if k not in ("person_id", "event_type", "delta_score")},
|
||||
delta_score = delta_score,
|
||||
)
|
||||
self._current_person = person_id
|
||||
|
||||
def _query_mood_cb(self, request: QueryMood.Request, response: QueryMood.Response):
|
||||
"""Service handler: return mood + greeting for a specific person."""
|
||||
if not self._soul:
|
||||
response.success = False
|
||||
response.message = "SOUL.md not loaded"
|
||||
return response
|
||||
|
||||
person_id = (request.person_id or self._current_person).strip()
|
||||
state = self._state_for_person(person_id)
|
||||
|
||||
response.mood = state["mood"]
|
||||
response.relationship_tier = state["tier"]
|
||||
response.relationship_score = float(state["score"])
|
||||
response.interaction_count = int(state["interaction_count"])
|
||||
response.greeting_text = state["greeting"]
|
||||
response.success = True
|
||||
response.message = ""
|
||||
return response
|
||||
|
||||
# ── Publish ────────────────────────────────────────────────────────────────
|
||||
|
||||
def _publish_state(self):
|
||||
if not self._soul:
|
||||
return
|
||||
|
||||
state = self._state_for_person(self._current_person)
|
||||
|
||||
msg = PersonalityState()
|
||||
msg.header = Header()
|
||||
msg.header.stamp = self.get_clock().now().to_msg()
|
||||
msg.header.frame_id = "personality"
|
||||
msg.persona_name = str(self._soul.get("name", "Salty"))
|
||||
msg.mood = state["mood"]
|
||||
msg.person_id = state["person_id"]
|
||||
msg.relationship_tier = state["tier"]
|
||||
msg.relationship_score = float(state["score"])
|
||||
msg.interaction_count = int(state["interaction_count"])
|
||||
msg.greeting_text = state["greeting"]
|
||||
|
||||
self._state_pub.publish(msg)
|
||||
|
||||
# ── Lifecycle ──────────────────────────────────────────────────────────────
|
||||
|
||||
def destroy_node(self):
|
||||
self._stop_watcher()
|
||||
self._db.close()
|
||||
super().destroy_node()
|
||||
|
||||
|
||||
# ── Entry point ────────────────────────────────────────────────────────────────
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = PersonalityNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.try_shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -0,0 +1,297 @@
|
||||
"""
|
||||
relationship_db.py — SQLite-backed per-person relationship memory.
|
||||
|
||||
No ROS2 imports — safe to unit-test without a live ROS2 environment.
|
||||
|
||||
Schema
|
||||
------
|
||||
people (
|
||||
person_id TEXT PRIMARY KEY,
|
||||
score REAL DEFAULT 0.0,
|
||||
interaction_count INTEGER DEFAULT 0,
|
||||
first_seen TEXT, -- ISO-8601 UTC timestamp
|
||||
last_seen TEXT, -- ISO-8601 UTC timestamp
|
||||
prefs TEXT -- JSON blob for learned preferences
|
||||
)
|
||||
|
||||
interactions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
person_id TEXT,
|
||||
ts TEXT, -- ISO-8601 UTC timestamp
|
||||
event_type TEXT, -- greeting | positive | negative | detection
|
||||
details TEXT -- free-form JSON blob
|
||||
)
|
||||
|
||||
Public API
|
||||
----------
|
||||
RelationshipDB(db_path)
|
||||
.get_person(person_id) -> dict
|
||||
.record_interaction(person_id, event_type, details, delta_score)
|
||||
.set_pref(person_id, key, value)
|
||||
.get_pref(person_id, key, default)
|
||||
.get_recent_events(person_id, window_s) -> list[dict]
|
||||
.all_people() -> list[dict]
|
||||
.close()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
class RelationshipDB:
|
||||
"""Thread-safe SQLite relationship store.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
db_path : str
|
||||
Path to the SQLite file. Created (with parent dirs) if absent.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str):
|
||||
parent = os.path.dirname(db_path)
|
||||
if parent:
|
||||
os.makedirs(parent, exist_ok=True)
|
||||
self._path = db_path
|
||||
self._lock = threading.Lock()
|
||||
self._conn = sqlite3.connect(db_path, check_same_thread=False)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._migrate()
|
||||
|
||||
# ── Schema ────────────────────────────────────────────────────────────────
|
||||
|
||||
def _migrate(self):
|
||||
with self._conn:
|
||||
self._conn.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS people (
|
||||
person_id TEXT PRIMARY KEY,
|
||||
score REAL DEFAULT 0.0,
|
||||
interaction_count INTEGER DEFAULT 0,
|
||||
first_seen TEXT,
|
||||
last_seen TEXT,
|
||||
prefs TEXT DEFAULT '{}'
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS interactions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
person_id TEXT NOT NULL,
|
||||
ts TEXT NOT NULL,
|
||||
event_type TEXT NOT NULL,
|
||||
details TEXT DEFAULT '{}'
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_interactions_person_ts
|
||||
ON interactions (person_id, ts);
|
||||
""")
|
||||
|
||||
# ── People ────────────────────────────────────────────────────────────────
|
||||
|
||||
def get_person(self, person_id: str) -> dict:
|
||||
"""Return the person record; inserts a default row if not found.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict with keys: person_id, score, interaction_count,
|
||||
first_seen, last_seen, prefs (dict)
|
||||
"""
|
||||
with self._lock:
|
||||
row = self._conn.execute(
|
||||
"SELECT * FROM people WHERE person_id = ?", (person_id,)
|
||||
).fetchone()
|
||||
|
||||
if row is None:
|
||||
now = _now_iso()
|
||||
self._conn.execute(
|
||||
"INSERT INTO people (person_id, first_seen, last_seen) VALUES (?,?,?)",
|
||||
(person_id, now, now),
|
||||
)
|
||||
self._conn.commit()
|
||||
return {
|
||||
"person_id": person_id,
|
||||
"score": 0.0,
|
||||
"interaction_count": 0,
|
||||
"first_seen": now,
|
||||
"last_seen": now,
|
||||
"prefs": {},
|
||||
}
|
||||
|
||||
prefs = {}
|
||||
try:
|
||||
prefs = json.loads(row["prefs"] or "{}")
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return {
|
||||
"person_id": row["person_id"],
|
||||
"score": float(row["score"]),
|
||||
"interaction_count": int(row["interaction_count"]),
|
||||
"first_seen": row["first_seen"],
|
||||
"last_seen": row["last_seen"],
|
||||
"prefs": prefs,
|
||||
}
|
||||
|
||||
def all_people(self) -> list:
|
||||
"""Return all person records as a list of dicts."""
|
||||
with self._lock:
|
||||
rows = self._conn.execute("SELECT * FROM people ORDER BY score DESC").fetchall()
|
||||
result = []
|
||||
for row in rows:
|
||||
prefs = {}
|
||||
try:
|
||||
prefs = json.loads(row["prefs"] or "{}")
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
result.append({
|
||||
"person_id": row["person_id"],
|
||||
"score": float(row["score"]),
|
||||
"interaction_count": int(row["interaction_count"]),
|
||||
"first_seen": row["first_seen"],
|
||||
"last_seen": row["last_seen"],
|
||||
"prefs": prefs,
|
||||
})
|
||||
return result
|
||||
|
||||
# ── Interactions ──────────────────────────────────────────────────────────
|
||||
|
||||
def record_interaction(
|
||||
self,
|
||||
person_id: str,
|
||||
event_type: str,
|
||||
details: dict | None = None,
|
||||
delta_score: float = 0.0,
|
||||
):
|
||||
"""Record an interaction event and update the person's score.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
person_id : str
|
||||
event_type : str
|
||||
One of: ``"greeting"``, ``"positive"``, ``"negative"``,
|
||||
``"detection"``.
|
||||
details : dict, optional
|
||||
Arbitrary key/value data stored as JSON.
|
||||
delta_score : float
|
||||
Amount to add to the person's score (can be negative).
|
||||
Interaction count is always incremented by 1.
|
||||
"""
|
||||
now = _now_iso()
|
||||
details_json = json.dumps(details or {})
|
||||
|
||||
with self._lock:
|
||||
# Ensure person exists
|
||||
self.get_person.__wrapped__(self, person_id) if hasattr(
|
||||
self.get_person, "__wrapped__"
|
||||
) else None
|
||||
|
||||
# Upsert person row
|
||||
self._conn.execute("""
|
||||
INSERT INTO people (person_id, first_seen, last_seen)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(person_id) DO UPDATE SET
|
||||
last_seen = excluded.last_seen
|
||||
""", (person_id, now, now))
|
||||
|
||||
# Increment count + score
|
||||
self._conn.execute("""
|
||||
UPDATE people
|
||||
SET interaction_count = interaction_count + 1,
|
||||
score = score + ?,
|
||||
last_seen = ?
|
||||
WHERE person_id = ?
|
||||
""", (delta_score, now, person_id))
|
||||
|
||||
# Insert interaction log row
|
||||
self._conn.execute("""
|
||||
INSERT INTO interactions (person_id, ts, event_type, details)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", (person_id, now, event_type, details_json))
|
||||
|
||||
self._conn.commit()
|
||||
|
||||
def get_recent_events(self, person_id: str, window_s: float = 120.0) -> list:
|
||||
"""Return interaction events for *person_id* within the last *window_s* seconds.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of dict
|
||||
Each dict: ``{"type": str, "dt": float, "ts": str, "details": dict}``
|
||||
where ``dt`` is seconds ago (positive = in the past).
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
cutoff = (
|
||||
datetime.now(timezone.utc) - timedelta(seconds=window_s)
|
||||
).isoformat()
|
||||
|
||||
with self._lock:
|
||||
rows = self._conn.execute("""
|
||||
SELECT ts, event_type, details FROM interactions
|
||||
WHERE person_id = ? AND ts >= ?
|
||||
ORDER BY ts DESC
|
||||
""", (person_id, cutoff)).fetchall()
|
||||
|
||||
now_dt = datetime.now(timezone.utc)
|
||||
result = []
|
||||
for row in rows:
|
||||
try:
|
||||
row_dt = datetime.fromisoformat(row["ts"])
|
||||
# Make timezone-aware if needed
|
||||
if row_dt.tzinfo is None:
|
||||
row_dt = row_dt.replace(tzinfo=timezone.utc)
|
||||
dt_secs = (now_dt - row_dt).total_seconds()
|
||||
except (ValueError, TypeError):
|
||||
dt_secs = window_s
|
||||
|
||||
details = {}
|
||||
try:
|
||||
details = json.loads(row["details"] or "{}")
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
result.append({
|
||||
"type": row["event_type"],
|
||||
"dt": dt_secs,
|
||||
"ts": row["ts"],
|
||||
"details": details,
|
||||
})
|
||||
return result
|
||||
|
||||
# ── Preferences ───────────────────────────────────────────────────────────
|
||||
|
||||
def set_pref(self, person_id: str, key: str, value):
|
||||
"""Set a learned preference for a person.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
person_id, key : str
|
||||
value : JSON-serialisable
|
||||
"""
|
||||
person = self.get_person(person_id)
|
||||
prefs = person["prefs"]
|
||||
prefs[key] = value
|
||||
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"UPDATE people SET prefs = ? WHERE person_id = ?",
|
||||
(json.dumps(prefs), person_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def get_pref(self, person_id: str, key: str, default=None):
|
||||
"""Return a specific learned preference for a person."""
|
||||
return self.get_person(person_id)["prefs"].get(key, default)
|
||||
|
||||
# ── Lifecycle ─────────────────────────────────────────────────────────────
|
||||
|
||||
def close(self):
|
||||
"""Close the database connection."""
|
||||
with self._lock:
|
||||
self._conn.close()
|
||||
@ -0,0 +1,196 @@
|
||||
"""
|
||||
soul_loader.py — SOUL.md persona parser and hot-reload watcher.
|
||||
|
||||
SOUL.md format
|
||||
--------------
|
||||
The file uses YAML front-matter (delimited by ``---`` lines) with an optional
|
||||
Markdown description body that is ignored by the parser. Example::
|
||||
|
||||
---
|
||||
name: "Salty"
|
||||
humor_level: 7
|
||||
sass_level: 4
|
||||
base_mood: "playful"
|
||||
...
|
||||
---
|
||||
# Optional description text (ignored)
|
||||
|
||||
Public API
|
||||
----------
|
||||
load_soul(path) -> dict (raises on parse error)
|
||||
SoulWatcher(path, cb, interval)
|
||||
.start()
|
||||
.stop()
|
||||
.reload_now() -> dict
|
||||
|
||||
Pure module — no ROS2 imports.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
|
||||
import yaml
|
||||
|
||||
# Keys that are required in every SOUL.md file
|
||||
_REQUIRED_KEYS = {
|
||||
"name",
|
||||
"humor_level",
|
||||
"sass_level",
|
||||
"base_mood",
|
||||
"threshold_regular",
|
||||
"threshold_favorite",
|
||||
"greeting_stranger",
|
||||
"greeting_regular",
|
||||
"greeting_favorite",
|
||||
}
|
||||
|
||||
_VALID_MOODS = {"happy", "curious", "annoyed", "playful"}
|
||||
|
||||
|
||||
def _extract_frontmatter(text: str) -> str:
|
||||
"""Return the YAML block between the first pair of ``---`` delimiters.
|
||||
|
||||
Raises ``ValueError`` if the file does not contain valid front-matter.
|
||||
"""
|
||||
lines = text.splitlines()
|
||||
delimiters = [i for i, l in enumerate(lines) if l.strip() == "---"]
|
||||
if len(delimiters) < 2:
|
||||
# No delimiter found — treat the whole file as plain YAML
|
||||
return text
|
||||
start = delimiters[0] + 1
|
||||
end = delimiters[1]
|
||||
return "\n".join(lines[start:end])
|
||||
|
||||
|
||||
def load_soul(path: str) -> dict:
|
||||
"""Parse a SOUL.md file and return the validated config dict.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
Absolute path to the SOUL.md file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Validated persona configuration.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the file does not exist.
|
||||
ValueError
|
||||
If the YAML is malformed or required keys are missing.
|
||||
"""
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(f"SOUL.md not found: {path}")
|
||||
|
||||
with open(path, "r", encoding="utf-8") as fh:
|
||||
raw = fh.read()
|
||||
|
||||
yaml_text = _extract_frontmatter(raw)
|
||||
|
||||
try:
|
||||
data = yaml.safe_load(yaml_text)
|
||||
except yaml.YAMLError as exc:
|
||||
raise ValueError(f"SOUL.md YAML parse error in {path}: {exc}") from exc
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"SOUL.md top level must be a YAML mapping, got {type(data)}")
|
||||
|
||||
# Validate required keys
|
||||
missing = _REQUIRED_KEYS - data.keys()
|
||||
if missing:
|
||||
raise ValueError(f"SOUL.md missing required keys: {sorted(missing)}")
|
||||
|
||||
# Validate ranges
|
||||
for key in ("humor_level", "sass_level"):
|
||||
val = data.get(key)
|
||||
if not isinstance(val, (int, float)) or not (0 <= val <= 10):
|
||||
raise ValueError(f"SOUL.md '{key}' must be a number 0–10, got {val!r}")
|
||||
|
||||
if data.get("base_mood") not in _VALID_MOODS:
|
||||
raise ValueError(
|
||||
f"SOUL.md 'base_mood' must be one of {sorted(_VALID_MOODS)}, "
|
||||
f"got {data.get('base_mood')!r}"
|
||||
)
|
||||
|
||||
for key in ("threshold_regular", "threshold_favorite"):
|
||||
val = data.get(key)
|
||||
if not isinstance(val, int) or val < 0:
|
||||
raise ValueError(f"SOUL.md '{key}' must be a non-negative integer, got {val!r}")
|
||||
|
||||
if data["threshold_regular"] > data["threshold_favorite"]:
|
||||
raise ValueError(
|
||||
"SOUL.md 'threshold_regular' must be <= 'threshold_favorite'"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class SoulWatcher:
|
||||
"""Background thread that polls SOUL.md for changes and calls a callback.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
Path to the SOUL.md file to watch.
|
||||
on_reload : callable
|
||||
``on_reload(soul_dict)`` called whenever a valid new SOUL.md is loaded.
|
||||
interval : float
|
||||
Polling interval in seconds (default 5.0).
|
||||
on_error : callable, optional
|
||||
``on_error(exception)`` called when a reload attempt fails.
|
||||
"""
|
||||
|
||||
def __init__(self, path: str, on_reload, interval: float = 5.0, on_error=None):
|
||||
self._path = path
|
||||
self._on_reload = on_reload
|
||||
self._interval = interval
|
||||
self._on_error = on_error
|
||||
self._thread = None
|
||||
self._stop_evt = threading.Event()
|
||||
self._last_mtime = 0.0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def start(self):
|
||||
"""Start the background polling thread."""
|
||||
if self._thread and self._thread.is_alive():
|
||||
return
|
||||
self._stop_evt.clear()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""Signal the watcher thread to stop and block until it exits."""
|
||||
self._stop_evt.set()
|
||||
if self._thread:
|
||||
self._thread.join(timeout=self._interval + 1.0)
|
||||
|
||||
def reload_now(self) -> dict:
|
||||
"""Force an immediate reload and return the new soul dict."""
|
||||
soul = load_soul(self._path)
|
||||
self._last_mtime = os.path.getmtime(self._path)
|
||||
self._on_reload(soul)
|
||||
return soul
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run(self):
|
||||
while not self._stop_evt.wait(self._interval):
|
||||
try:
|
||||
mtime = os.path.getmtime(self._path)
|
||||
except OSError:
|
||||
continue
|
||||
if mtime != self._last_mtime:
|
||||
try:
|
||||
soul = load_soul(self._path)
|
||||
except (FileNotFoundError, ValueError) as exc:
|
||||
if self._on_error:
|
||||
self._on_error(exc)
|
||||
continue
|
||||
self._last_mtime = mtime
|
||||
self._on_reload(soul)
|
||||
4
jetson/ros2_ws/src/saltybot_social_personality/setup.cfg
Normal file
4
jetson/ros2_ws/src/saltybot_social_personality/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/saltybot_social_personality
|
||||
[install]
|
||||
install_scripts=$base/lib/saltybot_social_personality
|
||||
28
jetson/ros2_ws/src/saltybot_social_personality/setup.py
Normal file
28
jetson/ros2_ws/src/saltybot_social_personality/setup.py
Normal file
@ -0,0 +1,28 @@
|
||||
from setuptools import setup
|
||||
|
||||
package_name = "saltybot_social_personality"
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version="0.1.0",
|
||||
packages=[package_name],
|
||||
data_files=[
|
||||
("share/ament_index/resource_index/packages", [f"resource/{package_name}"]),
|
||||
(f"share/{package_name}", ["package.xml"]),
|
||||
(f"share/{package_name}/launch", ["launch/personality.launch.py"]),
|
||||
(f"share/{package_name}/config", ["config/SOUL.md",
|
||||
"config/personality_params.yaml"]),
|
||||
],
|
||||
install_requires=["setuptools", "pyyaml"],
|
||||
zip_safe=True,
|
||||
maintainer="sl-controls",
|
||||
maintainer_email="sl-controls@saltylab.local",
|
||||
description="SOUL.md-driven personality system for saltybot social interaction",
|
||||
license="MIT",
|
||||
tests_require=["pytest"],
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"personality_node = saltybot_social_personality.personality_node:main",
|
||||
],
|
||||
},
|
||||
)
|
||||
@ -0,0 +1,475 @@
|
||||
"""
|
||||
test_personality.py — Unit tests for the saltybot personality system.
|
||||
|
||||
No ROS2 runtime required. Tests pure functions from:
|
||||
- soul_loader.py
|
||||
- mood_engine.py
|
||||
- relationship_db.py
|
||||
|
||||
Run with:
|
||||
pytest jetson/ros2_ws/src/saltybot_social_personality/test/test_personality.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import textwrap
|
||||
|
||||
import pytest
|
||||
|
||||
# ── Imports (pure modules, no ROS2) ──────────────────────────────────────────
|
||||
|
||||
import sys
|
||||
sys.path.insert(
|
||||
0,
|
||||
os.path.join(os.path.dirname(__file__), "..", "saltybot_social_personality"),
|
||||
)
|
||||
|
||||
from soul_loader import load_soul, _extract_frontmatter
|
||||
from mood_engine import (
|
||||
compute_mood, get_relationship_tier, build_greeting,
|
||||
MOOD_HAPPY, MOOD_PLAYFUL, MOOD_CURIOUS, MOOD_ANNOYED,
|
||||
TIER_STRANGER, TIER_REGULAR, TIER_FAVORITE,
|
||||
EVENT_NEGATIVE, EVENT_POSITIVE, EVENT_GREETING,
|
||||
)
|
||||
from relationship_db import RelationshipDB
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _minimal_soul(**overrides) -> dict:
|
||||
"""Return a valid minimal soul dict, optionally overriding keys."""
|
||||
base = {
|
||||
"name": "Salty",
|
||||
"humor_level": 7,
|
||||
"sass_level": 4,
|
||||
"base_mood": "playful",
|
||||
"threshold_regular": 5,
|
||||
"threshold_favorite": 20,
|
||||
"greeting_stranger": "Hello there!",
|
||||
"greeting_regular": "Hey {name}!",
|
||||
"greeting_favorite": "Oh hey {name}!!",
|
||||
}
|
||||
base.update(overrides)
|
||||
return base
|
||||
|
||||
|
||||
def _write_soul(content: str) -> str:
|
||||
"""Write a SOUL.md string to a temp file and return the path."""
|
||||
fh = tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".md", delete=False, encoding="utf-8"
|
||||
)
|
||||
fh.write(content)
|
||||
fh.close()
|
||||
return fh.name
|
||||
|
||||
|
||||
_VALID_SOUL_CONTENT = textwrap.dedent("""\
|
||||
---
|
||||
name: "TestBot"
|
||||
speaking_style: "casual"
|
||||
humor_level: 7
|
||||
sass_level: 3
|
||||
base_mood: "playful"
|
||||
threshold_regular: 5
|
||||
threshold_favorite: 20
|
||||
greeting_stranger: "Hello stranger!"
|
||||
greeting_regular: "Hey {name}!"
|
||||
greeting_favorite: "Oh hey {name}!!"
|
||||
mood_prefix_playful: "Beep boop! "
|
||||
mood_prefix_happy: "Great — "
|
||||
mood_prefix_curious: "Hmm, "
|
||||
mood_prefix_annoyed: "Ugh, "
|
||||
---
|
||||
# Description (ignored)
|
||||
This is the description body.
|
||||
""")
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# soul_loader tests
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestExtractFrontmatter:
|
||||
def test_delimited(self):
|
||||
content = "---\nkey: val\n---\n# body"
|
||||
assert _extract_frontmatter(content) == "key: val"
|
||||
|
||||
def test_no_delimiters_returns_whole(self):
|
||||
content = "key: val\nother: 123"
|
||||
assert _extract_frontmatter(content) == content
|
||||
|
||||
def test_single_delimiter_returns_whole(self):
|
||||
content = "---\nkey: val\n"
|
||||
result = _extract_frontmatter(content)
|
||||
assert "key: val" in result
|
||||
|
||||
def test_body_stripped(self):
|
||||
content = "---\nname: X\n---\n# Body text\nMore body"
|
||||
assert "Body text" not in _extract_frontmatter(content)
|
||||
assert "name: X" in _extract_frontmatter(content)
|
||||
|
||||
|
||||
class TestLoadSoul:
|
||||
def test_valid_file_loads(self):
|
||||
path = _write_soul(_VALID_SOUL_CONTENT)
|
||||
try:
|
||||
soul = load_soul(path)
|
||||
assert soul["name"] == "TestBot"
|
||||
assert soul["humor_level"] == 7
|
||||
assert soul["base_mood"] == "playful"
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_missing_file_raises(self):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_soul("/nonexistent/SOUL.md")
|
||||
|
||||
def test_missing_required_key_raises(self):
|
||||
content = "---\nname: X\nhumor_level: 5\n---" # missing many keys
|
||||
path = _write_soul(content)
|
||||
try:
|
||||
with pytest.raises(ValueError, match="missing required keys"):
|
||||
load_soul(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_humor_out_of_range_raises(self):
|
||||
soul_str = _VALID_SOUL_CONTENT.replace("humor_level: 7", "humor_level: 11")
|
||||
path = _write_soul(soul_str)
|
||||
try:
|
||||
with pytest.raises(ValueError, match="humor_level"):
|
||||
load_soul(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_invalid_mood_raises(self):
|
||||
soul_str = _VALID_SOUL_CONTENT.replace(
|
||||
'base_mood: "playful"', 'base_mood: "grumpy"'
|
||||
)
|
||||
path = _write_soul(soul_str)
|
||||
try:
|
||||
with pytest.raises(ValueError, match="base_mood"):
|
||||
load_soul(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_threshold_order_enforced(self):
|
||||
soul_str = _VALID_SOUL_CONTENT.replace(
|
||||
"threshold_regular: 5", "threshold_regular: 25"
|
||||
)
|
||||
path = _write_soul(soul_str)
|
||||
try:
|
||||
with pytest.raises(ValueError, match="threshold_regular"):
|
||||
load_soul(path)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_extra_keys_allowed(self):
|
||||
content = _VALID_SOUL_CONTENT.replace(
|
||||
"---\n# Description",
|
||||
"custom_key: 42\n---\n# Description"
|
||||
)
|
||||
path = _write_soul(content)
|
||||
try:
|
||||
soul = load_soul(path)
|
||||
assert soul.get("custom_key") == 42
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# mood_engine tests
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestGetRelationshipTier:
|
||||
def test_zero_interactions_stranger(self):
|
||||
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
|
||||
assert get_relationship_tier(soul, 0) == TIER_STRANGER
|
||||
|
||||
def test_below_regular_stranger(self):
|
||||
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
|
||||
assert get_relationship_tier(soul, 4) == TIER_STRANGER
|
||||
|
||||
def test_at_regular_threshold(self):
|
||||
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
|
||||
assert get_relationship_tier(soul, 5) == TIER_REGULAR
|
||||
|
||||
def test_above_regular_below_favorite(self):
|
||||
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
|
||||
assert get_relationship_tier(soul, 10) == TIER_REGULAR
|
||||
|
||||
def test_at_favorite_threshold(self):
|
||||
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
|
||||
assert get_relationship_tier(soul, 20) == TIER_FAVORITE
|
||||
|
||||
def test_above_favorite(self):
|
||||
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
|
||||
assert get_relationship_tier(soul, 100) == TIER_FAVORITE
|
||||
|
||||
|
||||
class TestComputeMood:
|
||||
def test_unknown_person_returns_curious(self):
|
||||
soul = _minimal_soul(humor_level=7)
|
||||
mood = compute_mood(soul, score=0.0, interaction_count=0, recent_events=[])
|
||||
assert mood == MOOD_CURIOUS
|
||||
|
||||
def test_stranger_low_count_returns_curious(self):
|
||||
soul = _minimal_soul(threshold_regular=5)
|
||||
mood = compute_mood(soul, score=2.0, interaction_count=3, recent_events=[])
|
||||
assert mood == MOOD_CURIOUS
|
||||
|
||||
def test_two_negative_events_returns_annoyed(self):
|
||||
soul = _minimal_soul(threshold_regular=5)
|
||||
events = [
|
||||
{"type": EVENT_NEGATIVE, "dt": 30.0},
|
||||
{"type": EVENT_NEGATIVE, "dt": 60.0},
|
||||
]
|
||||
mood = compute_mood(soul, score=10.0, interaction_count=10, recent_events=events)
|
||||
assert mood == MOOD_ANNOYED
|
||||
|
||||
def test_one_negative_not_annoyed(self):
|
||||
soul = _minimal_soul(humor_level=7, threshold_regular=5, threshold_favorite=20)
|
||||
events = [{"type": EVENT_NEGATIVE, "dt": 30.0}]
|
||||
# 1 negative is not enough → should still be happy/playful based on score
|
||||
mood = compute_mood(soul, score=25.0, interaction_count=25, recent_events=events)
|
||||
assert mood != MOOD_ANNOYED
|
||||
|
||||
def test_high_humor_regular_returns_playful(self):
|
||||
soul = _minimal_soul(humor_level=8, threshold_regular=5, threshold_favorite=20)
|
||||
events = [{"type": EVENT_POSITIVE, "dt": 10.0}]
|
||||
mood = compute_mood(soul, score=10.0, interaction_count=8, recent_events=events)
|
||||
assert mood == MOOD_PLAYFUL
|
||||
|
||||
def test_low_humor_regular_returns_happy(self):
|
||||
soul = _minimal_soul(humor_level=4, threshold_regular=5, threshold_favorite=20)
|
||||
events = [{"type": EVENT_POSITIVE, "dt": 10.0}]
|
||||
mood = compute_mood(soul, score=10.0, interaction_count=8, recent_events=events)
|
||||
assert mood == MOOD_HAPPY
|
||||
|
||||
def test_stale_negative_ignored(self):
|
||||
soul = _minimal_soul(humor_level=8, threshold_regular=5, threshold_favorite=20)
|
||||
# dt > 120s → outside the recent window → should not trigger annoyed
|
||||
events = [
|
||||
{"type": EVENT_NEGATIVE, "dt": 200.0},
|
||||
{"type": EVENT_NEGATIVE, "dt": 300.0},
|
||||
]
|
||||
mood = compute_mood(soul, score=15.0, interaction_count=10, recent_events=events)
|
||||
assert mood != MOOD_ANNOYED
|
||||
|
||||
def test_favorite_high_humor_playful(self):
|
||||
soul = _minimal_soul(humor_level=9, threshold_regular=5, threshold_favorite=20)
|
||||
mood = compute_mood(soul, score=50.0, interaction_count=30, recent_events=[])
|
||||
assert mood == MOOD_PLAYFUL
|
||||
|
||||
def test_favorite_low_humor_happy(self):
|
||||
soul = _minimal_soul(humor_level=3, threshold_regular=5, threshold_favorite=20)
|
||||
mood = compute_mood(soul, score=50.0, interaction_count=30, recent_events=[])
|
||||
assert mood == MOOD_HAPPY
|
||||
|
||||
|
||||
class TestBuildGreeting:
|
||||
def _soul(self, **kw):
|
||||
return _minimal_soul(
|
||||
mood_prefix_happy="Great — ",
|
||||
mood_prefix_curious="Hmm, ",
|
||||
mood_prefix_annoyed="Well, ",
|
||||
mood_prefix_playful="Beep boop! ",
|
||||
**kw,
|
||||
)
|
||||
|
||||
def test_stranger_greeting(self):
|
||||
soul = self._soul()
|
||||
g = build_greeting(soul, TIER_STRANGER, MOOD_CURIOUS, "")
|
||||
assert "hello" in g.lower()
|
||||
|
||||
def test_regular_greeting_contains_name(self):
|
||||
soul = self._soul()
|
||||
g = build_greeting(soul, TIER_REGULAR, MOOD_HAPPY, "alice")
|
||||
assert "alice" in g
|
||||
|
||||
def test_favorite_greeting_contains_name(self):
|
||||
soul = self._soul()
|
||||
g = build_greeting(soul, TIER_FAVORITE, MOOD_PLAYFUL, "bob")
|
||||
assert "bob" in g
|
||||
|
||||
def test_mood_prefix_applied(self):
|
||||
soul = self._soul()
|
||||
g = build_greeting(soul, TIER_REGULAR, MOOD_PLAYFUL, "alice")
|
||||
assert g.startswith("Beep boop!")
|
||||
|
||||
def test_no_prefix_key_no_prefix(self):
|
||||
soul = _minimal_soul() # no mood_prefix_* keys
|
||||
g = build_greeting(soul, TIER_REGULAR, MOOD_HAPPY, "alice")
|
||||
assert g.startswith("Hey")
|
||||
|
||||
def test_empty_person_id_uses_friend(self):
|
||||
soul = self._soul()
|
||||
g = build_greeting(soul, TIER_REGULAR, MOOD_HAPPY, "")
|
||||
assert "friend" in g
|
||||
|
||||
def test_happy_prefix(self):
|
||||
soul = self._soul()
|
||||
g = build_greeting(soul, TIER_REGULAR, MOOD_HAPPY, "carol")
|
||||
assert g.startswith("Great")
|
||||
|
||||
def test_annoyed_prefix(self):
|
||||
soul = self._soul()
|
||||
g = build_greeting(soul, TIER_REGULAR, MOOD_ANNOYED, "dave")
|
||||
assert g.startswith("Well")
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# relationship_db tests
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestRelationshipDB:
|
||||
@pytest.fixture
|
||||
def db(self, tmp_path):
|
||||
path = str(tmp_path / "test.db")
|
||||
d = RelationshipDB(path)
|
||||
yield d
|
||||
d.close()
|
||||
|
||||
def test_get_person_creates_default(self, db):
|
||||
p = db.get_person("alice")
|
||||
assert p["person_id"] == "alice"
|
||||
assert p["score"] == pytest.approx(0.0)
|
||||
assert p["interaction_count"] == 0
|
||||
|
||||
def test_get_person_idempotent(self, db):
|
||||
p1 = db.get_person("bob")
|
||||
p2 = db.get_person("bob")
|
||||
assert p1["person_id"] == p2["person_id"]
|
||||
|
||||
def test_record_interaction_increments_count(self, db):
|
||||
db.record_interaction("alice", "detection")
|
||||
db.record_interaction("alice", "detection")
|
||||
p = db.get_person("alice")
|
||||
assert p["interaction_count"] == 2
|
||||
|
||||
def test_record_interaction_updates_score(self, db):
|
||||
db.record_interaction("alice", "positive", delta_score=5.0)
|
||||
p = db.get_person("alice")
|
||||
assert p["score"] == pytest.approx(5.0)
|
||||
|
||||
def test_negative_delta_reduces_score(self, db):
|
||||
db.record_interaction("carol", "positive", delta_score=10.0)
|
||||
db.record_interaction("carol", "negative", delta_score=-3.0)
|
||||
p = db.get_person("carol")
|
||||
assert p["score"] == pytest.approx(7.0)
|
||||
|
||||
def test_score_zero_by_default(self, db):
|
||||
p = db.get_person("dave")
|
||||
assert p["score"] == pytest.approx(0.0)
|
||||
|
||||
def test_set_and_get_pref(self, db):
|
||||
db.set_pref("alice", "language", "en")
|
||||
assert db.get_pref("alice", "language") == "en"
|
||||
|
||||
def test_get_pref_default(self, db):
|
||||
assert db.get_pref("nobody", "language", "fr") == "fr"
|
||||
|
||||
def test_multiple_prefs_stored(self, db):
|
||||
db.set_pref("alice", "lang", "en")
|
||||
db.set_pref("alice", "name", "Alice")
|
||||
assert db.get_pref("alice", "lang") == "en"
|
||||
assert db.get_pref("alice", "name") == "Alice"
|
||||
|
||||
def test_all_people_returns_list(self, db):
|
||||
db.record_interaction("a", "detection")
|
||||
db.record_interaction("b", "detection")
|
||||
people = db.all_people()
|
||||
ids = {p["person_id"] for p in people}
|
||||
assert {"a", "b"} <= ids
|
||||
|
||||
def test_get_recent_events_returns_events(self, db):
|
||||
db.record_interaction("alice", "greeting", delta_score=1.0)
|
||||
events = db.get_recent_events("alice", window_s=60.0)
|
||||
assert len(events) == 1
|
||||
assert events[0]["type"] == "greeting"
|
||||
|
||||
def test_get_recent_events_empty_for_new_person(self, db):
|
||||
events = db.get_recent_events("nobody", window_s=60.0)
|
||||
assert events == []
|
||||
|
||||
def test_event_dt_positive(self, db):
|
||||
db.record_interaction("alice", "detection")
|
||||
events = db.get_recent_events("alice", window_s=60.0)
|
||||
assert events[0]["dt"] >= 0.0
|
||||
|
||||
def test_multiple_people_isolated(self, db):
|
||||
db.record_interaction("alice", "positive", delta_score=10.0)
|
||||
db.record_interaction("bob", "negative", delta_score=-5.0)
|
||||
assert db.get_person("alice")["score"] == pytest.approx(10.0)
|
||||
assert db.get_person("bob")["score"] == pytest.approx(-5.0)
|
||||
|
||||
def test_details_stored(self, db):
|
||||
db.record_interaction("alice", "greeting", details={"location": "lab"})
|
||||
events = db.get_recent_events("alice", window_s=60.0)
|
||||
assert events[0]["details"].get("location") == "lab"
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# Integration: soul → tier → mood → greeting pipeline
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestIntegrationPipeline:
|
||||
|
||||
def test_stranger_pipeline(self, tmp_path):
|
||||
db_path = str(tmp_path / "int.db")
|
||||
db = RelationshipDB(db_path)
|
||||
soul = _minimal_soul(
|
||||
humor_level=7, threshold_regular=5, threshold_favorite=20,
|
||||
mood_prefix_curious="Hmm, "
|
||||
)
|
||||
# No prior interactions
|
||||
person = db.get_person("stranger_001")
|
||||
events = db.get_recent_events("stranger_001")
|
||||
tier = get_relationship_tier(soul, person["interaction_count"])
|
||||
mood = compute_mood(soul, person["score"], person["interaction_count"], events)
|
||||
greeting = build_greeting(soul, tier, mood, "stranger_001")
|
||||
|
||||
assert tier == TIER_STRANGER
|
||||
assert mood == MOOD_CURIOUS
|
||||
assert "hello" in greeting.lower()
|
||||
db.close()
|
||||
|
||||
def test_regular_positive_pipeline(self, tmp_path):
|
||||
db_path = str(tmp_path / "int2.db")
|
||||
db = RelationshipDB(db_path)
|
||||
soul = _minimal_soul(
|
||||
humor_level=8, threshold_regular=5, threshold_favorite=20,
|
||||
mood_prefix_playful="Beep! "
|
||||
)
|
||||
# Simulate 6 positive interactions (> threshold_regular=5)
|
||||
for _ in range(6):
|
||||
db.record_interaction("alice", "positive", delta_score=2.0)
|
||||
|
||||
person = db.get_person("alice")
|
||||
events = db.get_recent_events("alice")
|
||||
tier = get_relationship_tier(soul, person["interaction_count"])
|
||||
mood = compute_mood(soul, person["score"], person["interaction_count"], events)
|
||||
greeting = build_greeting(soul, tier, mood, "alice")
|
||||
|
||||
assert tier == TIER_REGULAR
|
||||
assert mood == MOOD_PLAYFUL # humor_level=8, recent positive events
|
||||
assert "alice" in greeting
|
||||
assert greeting.startswith("Beep!")
|
||||
db.close()
|
||||
|
||||
def test_favorite_pipeline(self, tmp_path):
|
||||
db_path = str(tmp_path / "int3.db")
|
||||
db = RelationshipDB(db_path)
|
||||
soul = _minimal_soul(
|
||||
humor_level=5, threshold_regular=5, threshold_favorite=20
|
||||
)
|
||||
for _ in range(25):
|
||||
db.record_interaction("bob", "positive", delta_score=1.0)
|
||||
|
||||
person = db.get_person("bob")
|
||||
tier = get_relationship_tier(soul, person["interaction_count"])
|
||||
assert tier == TIER_FAVORITE
|
||||
greeting = build_greeting(soul, tier, "happy", "bob")
|
||||
assert "bob" in greeting
|
||||
assert "Oh hey" in greeting
|
||||
db.close()
|
||||
5
jetson/ros2_ws/src/saltybot_social_tracking/.gitignore
vendored
Normal file
5
jetson/ros2_ws/src/saltybot_social_tracking/.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.egg-info/
|
||||
.pytest_cache/
|
||||
@ -0,0 +1,48 @@
|
||||
# tracking_params.yaml — saltybot_social_tracking / TrackingFusionNode
|
||||
#
|
||||
# Run with:
|
||||
# ros2 launch saltybot_social_tracking tracking.launch.py
|
||||
#
|
||||
# Topics consumed:
|
||||
# /uwb/target (geometry_msgs/PoseStamped) — UWB triangulated position
|
||||
# /person/target (geometry_msgs/PoseStamped) — camera-detected position
|
||||
#
|
||||
# Topic produced:
|
||||
# /social/tracking/fused_target (saltybot_social_msgs/FusedTarget)
|
||||
|
||||
# ── Source staleness timeouts ──────────────────────────────────────────────────
|
||||
# UWB driver publishes at ~10 Hz; 1.5 s = 15 missed cycles before declared stale.
|
||||
uwb_timeout: 1.5 # seconds
|
||||
|
||||
# Camera detector publishes at ~30 Hz; 1.0 s = 30 missed frames before stale.
|
||||
cam_timeout: 1.0 # seconds
|
||||
|
||||
# How long the Kalman filter may coast (dead-reckoning) with no live source
|
||||
# before the node stops publishing.
|
||||
# At 10 m/s (EUC top-speed) the robot drifts ≈30 m over 3 s — beyond the UWB
|
||||
# follow-range, so 3 s is a reasonable hard stop.
|
||||
predict_timeout: 3.0 # seconds
|
||||
|
||||
# ── Kalman filter tuning ───────────────────────────────────────────────────────
|
||||
# process_noise: acceleration noise std-dev (m/s²).
|
||||
# EUC riders can brake or accelerate at ~3–5 m/s²; 3.0 is a good starting point.
|
||||
# Increase if the filtered track lags behind fast direction changes.
|
||||
# Decrease if the track is noisy.
|
||||
process_noise: 3.0 # m/s²
|
||||
|
||||
# UWB position measurement noise (std-dev, metres).
|
||||
# DW3000 TWR accuracy ≈ ±10–20 cm; 0.20 accounts for system-level error.
|
||||
meas_noise_uwb: 0.20 # m
|
||||
|
||||
# Camera position noise (std-dev, metres).
|
||||
# Depth reprojection error with RealSense D435i at 1–3 m ≈ ±5–15 cm.
|
||||
meas_noise_cam: 0.12 # m
|
||||
|
||||
# ── Control loop ──────────────────────────────────────────────────────────────
|
||||
control_rate: 20.0 # Hz — KF predict + publish rate
|
||||
|
||||
# ── Source arbiter ────────────────────────────────────────────────────────────
|
||||
# Minimum normalised confidence for a source to be considered live.
|
||||
# Range [0, 1]; lower = more permissive; default 0.15 keeps slightly stale
|
||||
# sources active rather than dropping to "predicted" prematurely.
|
||||
confidence_threshold: 0.15
|
||||
@ -0,0 +1,44 @@
|
||||
"""tracking.launch.py — launch the TrackingFusionNode with default params."""
|
||||
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg_share = get_package_share_directory("saltybot_social_tracking")
|
||||
default_params = os.path.join(pkg_share, "config", "tracking_params.yaml")
|
||||
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument(
|
||||
"params_file",
|
||||
default_value=default_params,
|
||||
description="Path to tracking fusion parameter YAML file",
|
||||
),
|
||||
DeclareLaunchArgument(
|
||||
"control_rate",
|
||||
default_value="20.0",
|
||||
description="KF predict + publish rate (Hz)",
|
||||
),
|
||||
DeclareLaunchArgument(
|
||||
"predict_timeout",
|
||||
default_value="3.0",
|
||||
description="Max KF coast time before stopping publish (s)",
|
||||
),
|
||||
Node(
|
||||
package="saltybot_social_tracking",
|
||||
executable="tracking_fusion_node",
|
||||
name="tracking_fusion",
|
||||
output="screen",
|
||||
parameters=[
|
||||
LaunchConfiguration("params_file"),
|
||||
{
|
||||
"control_rate": LaunchConfiguration("control_rate"),
|
||||
"predict_timeout": LaunchConfiguration("predict_timeout"),
|
||||
},
|
||||
],
|
||||
),
|
||||
])
|
||||
31
jetson/ros2_ws/src/saltybot_social_tracking/package.xml
Normal file
31
jetson/ros2_ws/src/saltybot_social_tracking/package.xml
Normal file
@ -0,0 +1,31 @@
|
||||
<?xml version="1.0"?>
|
||||
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||
<package format="3">
|
||||
<name>saltybot_social_tracking</name>
|
||||
<version>0.1.0</version>
|
||||
<description>
|
||||
Multi-modal tracking fusion for saltybot.
|
||||
Fuses UWB triangulated position (/uwb/target) and camera-detected position
|
||||
(/person/target) using a 4-state Kalman filter to produce a smooth, low-latency
|
||||
fused estimate at /social/tracking/fused_target.
|
||||
Handles EUC rider speeds (20-30 km/h), signal handoff, and predictive coasting.
|
||||
</description>
|
||||
<maintainer email="sl-controls@saltylab.local">sl-controls</maintainer>
|
||||
<license>MIT</license>
|
||||
|
||||
<depend>rclpy</depend>
|
||||
<depend>geometry_msgs</depend>
|
||||
<depend>std_msgs</depend>
|
||||
<depend>saltybot_social_msgs</depend>
|
||||
|
||||
<buildtool_depend>ament_python</buildtool_depend>
|
||||
|
||||
<test_depend>ament_copyright</test_depend>
|
||||
<test_depend>ament_flake8</test_depend>
|
||||
<test_depend>ament_pep257</test_depend>
|
||||
<test_depend>python3-pytest</test_depend>
|
||||
|
||||
<export>
|
||||
<build_type>ament_python</build_type>
|
||||
</export>
|
||||
</package>
|
||||
@ -0,0 +1,134 @@
|
||||
"""
|
||||
kalman_tracker.py — 4-state linear Kalman filter for 2-D position+velocity tracking.
|
||||
|
||||
State vector: [x, y, vx, vy]
|
||||
Observation: [x_meas, y_meas]
|
||||
|
||||
Process model: constant velocity with Wiener process acceleration noise.
|
||||
Tuned to handle EUC rider speeds (20–30 km/h ≈ 5.5–8.3 m/s) with fast
|
||||
acceleration transients.
|
||||
|
||||
Pure module — no ROS2 dependency; fully unit-testable.
|
||||
"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
class KalmanTracker:
|
||||
"""
|
||||
4-state Kalman filter: state = [x, y, vx, vy].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
process_noise : acceleration noise standard deviation (m/s²).
|
||||
Higher values allow the filter to track rapid velocity
|
||||
changes (EUC acceleration events). Default 3.0 m/s².
|
||||
meas_noise_uwb : UWB position measurement noise std-dev (m). Default 0.20 m.
|
||||
meas_noise_cam : Camera position measurement noise std-dev (m). Default 0.12 m.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process_noise: float = 3.0,
|
||||
meas_noise_uwb: float = 0.20,
|
||||
meas_noise_cam: float = 0.12,
|
||||
):
|
||||
self._q = float(process_noise)
|
||||
self._r_uwb = float(meas_noise_uwb)
|
||||
self._r_cam = float(meas_noise_cam)
|
||||
|
||||
# State [x, y, vx, vy]
|
||||
self._x = np.zeros(4)
|
||||
|
||||
# Covariance — large initial uncertainty (10 m², 10 (m/s)²)
|
||||
self._P = np.diag([10.0, 10.0, 10.0, 10.0])
|
||||
|
||||
# Observation matrix: H * x = [x, y]
|
||||
self._H = np.array([[1.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0, 0.0]])
|
||||
|
||||
self._initialized = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
||||
def initialize(self, x: float, y: float) -> None:
|
||||
"""Seed the filter at position (x, y) with zero velocity."""
|
||||
self._x = np.array([x, y, 0.0, 0.0])
|
||||
self._P = np.diag([1.0, 1.0, 5.0, 5.0])
|
||||
self._initialized = True
|
||||
|
||||
def predict(self, dt: float) -> None:
|
||||
"""
|
||||
Advance the filter state by dt seconds.
|
||||
|
||||
Uses a discrete Wiener process acceleration model so that positional
|
||||
uncertainty grows as O(dt^4/4) and velocity uncertainty as O(dt^2).
|
||||
This lets the filter coast accurately through short signal outages
|
||||
while still being responsive to EUC velocity changes.
|
||||
"""
|
||||
if dt <= 0.0:
|
||||
return
|
||||
|
||||
F = np.array([[1.0, 0.0, dt, 0.0],
|
||||
[0.0, 1.0, 0.0, dt],
|
||||
[0.0, 0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 1.0]])
|
||||
|
||||
q = self._q
|
||||
dt2 = dt * dt
|
||||
dt3 = dt2 * dt
|
||||
dt4 = dt3 * dt
|
||||
Q = (q * q) * np.array([
|
||||
[dt4 / 4.0, 0.0, dt3 / 2.0, 0.0 ],
|
||||
[0.0, dt4 / 4.0, 0.0, dt3 / 2.0],
|
||||
[dt3 / 2.0, 0.0, dt2, 0.0 ],
|
||||
[0.0, dt3 / 2.0, 0.0, dt2 ],
|
||||
])
|
||||
|
||||
self._x = F @ self._x
|
||||
self._P = F @ self._P @ F.T + Q
|
||||
|
||||
def update(self, x_meas: float, y_meas: float, source: str = "uwb") -> None:
|
||||
"""
|
||||
Apply a position measurement (x_meas, y_meas).
|
||||
|
||||
source : "uwb" or "camera" — selects the appropriate noise covariance.
|
||||
"""
|
||||
r = self._r_uwb if source == "uwb" else self._r_cam
|
||||
R = np.diag([r * r, r * r])
|
||||
|
||||
z = np.array([x_meas, y_meas])
|
||||
innov = z - self._H @ self._x # innovation
|
||||
S = self._H @ self._P @ self._H.T + R # innovation covariance
|
||||
K = self._P @ self._H.T @ np.linalg.inv(S) # Kalman gain
|
||||
self._x = self._x + K @ innov
|
||||
self._P = (np.eye(4) - K @ self._H) @ self._P
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# State accessors
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def position(self) -> tuple:
|
||||
"""Current filtered position (x, y) in metres."""
|
||||
return float(self._x[0]), float(self._x[1])
|
||||
|
||||
@property
|
||||
def velocity(self) -> tuple:
|
||||
"""Current filtered velocity (vx, vy) in m/s."""
|
||||
return float(self._x[2]), float(self._x[3])
|
||||
|
||||
def position_uncertainty_m(self) -> float:
|
||||
"""RMS positional uncertainty (m) from diagonal of covariance."""
|
||||
return float(math.sqrt((self._P[0, 0] + self._P[1, 1]) / 2.0))
|
||||
|
||||
def covariance_copy(self) -> np.ndarray:
|
||||
"""Return a copy of the current 4×4 covariance matrix."""
|
||||
return self._P.copy()
|
||||
@ -0,0 +1,155 @@
|
||||
"""
|
||||
source_arbiter.py — Source confidence scoring and selection for tracking fusion.
|
||||
|
||||
Two sensor sources are supported:
|
||||
UWB — geometry_msgs/PoseStamped from /uwb/target (triangulated, ~10 Hz)
|
||||
Camera — geometry_msgs/PoseStamped from /person/target (depth+YOLO, ~30 Hz)
|
||||
|
||||
Confidence model
|
||||
----------------
|
||||
Each source's confidence is its raw measurement quality multiplied by a
|
||||
linear staleness factor that drops to zero at its respective timeout:
|
||||
|
||||
conf = quality * max(0, 1 - age / timeout)
|
||||
|
||||
UWB quality is always 1.0 (the ranging hardware confidence is not exposed
|
||||
by the driver in origin/main; the UWB node already applies Kalman filtering).
|
||||
|
||||
Camera quality defaults to 1.0; callers may pass a lower value when the
|
||||
detection confidence is available.
|
||||
|
||||
Source selection
|
||||
----------------
|
||||
Both fresh → "fused" (confidence-weighted position blend)
|
||||
UWB only → "uwb"
|
||||
Camera only → "camera"
|
||||
Neither fresh → "predicted" (Kalman coasts)
|
||||
|
||||
Pure module — no ROS2 dependency; fully unit-testable.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
|
||||
def _staleness_factor(age_s: float, timeout_s: float) -> float:
|
||||
"""Linear decay: 1.0 at age=0, 0.0 at age=timeout, clamped."""
|
||||
if timeout_s <= 0.0:
|
||||
return 0.0
|
||||
return max(0.0, 1.0 - age_s / timeout_s)
|
||||
|
||||
|
||||
def uwb_confidence(age_s: float, timeout_s: float, quality: float = 1.0) -> float:
|
||||
"""
|
||||
UWB source confidence.
|
||||
|
||||
age_s : seconds since last UWB measurement (≥0; use large value if never)
|
||||
timeout_s: staleness threshold (s); confidence reaches 0 at this age
|
||||
quality : inherent measurement quality [0, 1] (default 1.0)
|
||||
"""
|
||||
return quality * _staleness_factor(age_s, timeout_s)
|
||||
|
||||
|
||||
def camera_confidence(
|
||||
age_s: float, timeout_s: float, quality: float = 1.0
|
||||
) -> float:
|
||||
"""
|
||||
Camera source confidence.
|
||||
|
||||
age_s : seconds since last camera detection (≥0; use large value if never)
|
||||
timeout_s: staleness threshold (s)
|
||||
quality : YOLO detection confidence or other quality score [0, 1]
|
||||
"""
|
||||
return quality * _staleness_factor(age_s, timeout_s)
|
||||
|
||||
|
||||
def select_source(
|
||||
uwb_conf: float,
|
||||
cam_conf: float,
|
||||
threshold: float = 0.15,
|
||||
) -> str:
|
||||
"""
|
||||
Choose the active tracking source label.
|
||||
|
||||
Returns one of: "fused", "uwb", "camera", "predicted".
|
||||
|
||||
threshold: minimum confidence for a source to be considered live.
|
||||
Sources below threshold are ignored.
|
||||
"""
|
||||
uwb_ok = uwb_conf >= threshold
|
||||
cam_ok = cam_conf >= threshold
|
||||
|
||||
if uwb_ok and cam_ok:
|
||||
return "fused"
|
||||
if uwb_ok:
|
||||
return "uwb"
|
||||
if cam_ok:
|
||||
return "camera"
|
||||
return "predicted"
|
||||
|
||||
|
||||
def fuse_positions(
|
||||
uwb_x: float,
|
||||
uwb_y: float,
|
||||
uwb_conf: float,
|
||||
cam_x: float,
|
||||
cam_y: float,
|
||||
cam_conf: float,
|
||||
) -> tuple:
|
||||
"""
|
||||
Confidence-weighted position fusion.
|
||||
|
||||
Returns (fused_x, fused_y).
|
||||
|
||||
When total confidence is zero (shouldn't happen in "fused" state, but
|
||||
guarded), returns the UWB position as fallback.
|
||||
"""
|
||||
total = uwb_conf + cam_conf
|
||||
if total <= 0.0:
|
||||
return uwb_x, uwb_y
|
||||
w = uwb_conf / total
|
||||
return (
|
||||
w * uwb_x + (1.0 - w) * cam_x,
|
||||
w * uwb_y + (1.0 - w) * cam_y,
|
||||
)
|
||||
|
||||
|
||||
def composite_confidence(
|
||||
uwb_conf: float,
|
||||
cam_conf: float,
|
||||
source: str,
|
||||
kf_uncertainty_m: float,
|
||||
max_kf_uncertainty_m: float = 3.0,
|
||||
) -> float:
|
||||
"""
|
||||
Compute a single composite confidence value [0, 1] for the fused output.
|
||||
|
||||
source : current source label (from select_source)
|
||||
kf_uncertainty_m : current KF positional RMS uncertainty
|
||||
max_kf_uncertainty_m: uncertainty at which confidence collapses to 0
|
||||
"""
|
||||
if source == "predicted":
|
||||
# Decay with growing KF uncertainty; no sensor feeds are live
|
||||
raw = max(0.0, 1.0 - kf_uncertainty_m / max_kf_uncertainty_m)
|
||||
return min(0.4, raw) # cap at 0.4 — caller should know this is dead-reckoning
|
||||
|
||||
if source == "fused":
|
||||
raw = max(uwb_conf, cam_conf)
|
||||
elif source == "uwb":
|
||||
raw = uwb_conf
|
||||
else: # "camera"
|
||||
raw = cam_conf
|
||||
|
||||
# Scale by KF health (full confidence only if KF is tight)
|
||||
kf_health = max(0.0, 1.0 - kf_uncertainty_m / max_kf_uncertainty_m)
|
||||
return raw * (0.5 + 0.5 * kf_health)
|
||||
|
||||
|
||||
def bearing_and_range(x: float, y: float) -> tuple:
|
||||
"""
|
||||
Compute bearing (rad, +ve = left) and range (m) to position (x, y).
|
||||
|
||||
Consistent with person_follower_node conventions:
|
||||
bearing = atan2(y, x) (base_link frame: +X forward, +Y left)
|
||||
range = sqrt(x² + y²)
|
||||
"""
|
||||
return math.atan2(y, x), math.sqrt(x * x + y * y)
|
||||
@ -0,0 +1,257 @@
|
||||
"""
|
||||
tracking_fusion_node.py — Multi-modal tracking fusion for saltybot.
|
||||
|
||||
Subscribes
|
||||
----------
|
||||
/uwb/target (geometry_msgs/PoseStamped) — UWB-triangulated position (~10 Hz)
|
||||
/person/target (geometry_msgs/PoseStamped) — camera-detected position (~30 Hz)
|
||||
|
||||
Publishes
|
||||
---------
|
||||
/social/tracking/fused_target (saltybot_social_msgs/FusedTarget) at control_rate Hz
|
||||
|
||||
Algorithm
|
||||
---------
|
||||
1. Each incoming measurement updates a 4-state Kalman filter [x, y, vx, vy].
|
||||
2. A 20 Hz timer runs predict+select+publish:
|
||||
a. KF predict(dt)
|
||||
b. Compute per-source confidence from measurement age + staleness model
|
||||
c. If either source is live:
|
||||
- "fused" → confidence-weighted position blend → KF update
|
||||
- "uwb" → UWB position → KF update
|
||||
- "camera" → camera position → KF update
|
||||
d. Build FusedTarget from KF state + composite confidence
|
||||
3. If all sources are lost but within predict_timeout, keep publishing with
|
||||
active_source="predicted" and degrading confidence.
|
||||
4. Beyond predict_timeout, no message is published (node stays alive).
|
||||
|
||||
Kalman tuning for EUC speeds (20–30 km/h ≈ 5.5–8.3 m/s)
|
||||
---------------------------------------------------------
|
||||
process_noise=3.0 m/s² — allows rapid acceleration events
|
||||
predict_timeout=3.0 s — coasts ≈30 m at 10 m/s; acceptable dead-reckoning
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uwb_timeout : UWB staleness threshold (s) default 1.5
|
||||
cam_timeout : Camera staleness threshold (s) default 1.0
|
||||
predict_timeout : Max KF coast before no publish (s) default 3.0
|
||||
process_noise : KF acceleration noise std-dev (m/s²) default 3.0
|
||||
meas_noise_uwb : UWB position noise std-dev (m) default 0.20
|
||||
meas_noise_cam : Camera position noise std-dev (m) default 0.12
|
||||
control_rate : Publish / KF predict rate (Hz) default 20.0
|
||||
confidence_threshold: Min source confidence to use (0–1) default 0.15
|
||||
|
||||
Usage
|
||||
-----
|
||||
ros2 launch saltybot_social_tracking tracking.launch.py
|
||||
"""
|
||||
|
||||
import math
|
||||
import time
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from geometry_msgs.msg import PoseStamped
|
||||
from std_msgs.msg import Header
|
||||
|
||||
from saltybot_social_msgs.msg import FusedTarget
|
||||
from saltybot_social_tracking.kalman_tracker import KalmanTracker
|
||||
from saltybot_social_tracking.source_arbiter import (
|
||||
uwb_confidence,
|
||||
camera_confidence,
|
||||
select_source,
|
||||
fuse_positions,
|
||||
composite_confidence,
|
||||
bearing_and_range,
|
||||
)
|
||||
|
||||
_BIG_AGE = 1e9 # sentinel: source never received
|
||||
|
||||
|
||||
class TrackingFusionNode(Node):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("tracking_fusion")
|
||||
|
||||
# ── Parameters ────────────────────────────────────────────────────────
|
||||
self.declare_parameter("uwb_timeout", 1.5)
|
||||
self.declare_parameter("cam_timeout", 1.0)
|
||||
self.declare_parameter("predict_timeout", 3.0)
|
||||
self.declare_parameter("process_noise", 3.0)
|
||||
self.declare_parameter("meas_noise_uwb", 0.20)
|
||||
self.declare_parameter("meas_noise_cam", 0.12)
|
||||
self.declare_parameter("control_rate", 20.0)
|
||||
self.declare_parameter("confidence_threshold", 0.15)
|
||||
|
||||
self._p = self._load_params()
|
||||
|
||||
# ── State ─────────────────────────────────────────────────────────────
|
||||
self._kf = KalmanTracker(
|
||||
process_noise=self._p["process_noise"],
|
||||
meas_noise_uwb=self._p["meas_noise_uwb"],
|
||||
meas_noise_cam=self._p["meas_noise_cam"],
|
||||
)
|
||||
self._last_uwb_t: float = 0.0 # monotonic; 0 = never received
|
||||
self._last_cam_t: float = 0.0
|
||||
self._uwb_x: float = 0.0
|
||||
self._uwb_y: float = 0.0
|
||||
self._cam_x: float = 0.0
|
||||
self._cam_y: float = 0.0
|
||||
self._uwb_tag_id: str = ""
|
||||
self._last_predict_t: float = 0.0 # monotonic time of last predict call
|
||||
self._last_any_t: float = 0.0 # monotonic time of last live measurement
|
||||
|
||||
# ── Subscriptions ─────────────────────────────────────────────────────
|
||||
self.create_subscription(
|
||||
PoseStamped, "/uwb/target", self._uwb_cb, 10)
|
||||
self.create_subscription(
|
||||
PoseStamped, "/person/target", self._cam_cb, 10)
|
||||
|
||||
# ── Publisher ─────────────────────────────────────────────────────────
|
||||
self._pub = self.create_publisher(FusedTarget, "/social/tracking/fused_target", 10)
|
||||
|
||||
# ── Timer ─────────────────────────────────────────────────────────────
|
||||
self._timer = self.create_timer(
|
||||
1.0 / self._p["control_rate"], self._control_cb)
|
||||
|
||||
self.get_logger().info(
|
||||
f"TrackingFusion ready "
|
||||
f"rate={self._p['control_rate']}Hz "
|
||||
f"uwb_timeout={self._p['uwb_timeout']}s "
|
||||
f"cam_timeout={self._p['cam_timeout']}s "
|
||||
f"predict_timeout={self._p['predict_timeout']}s "
|
||||
f"process_noise={self._p['process_noise']}m/s²"
|
||||
)
|
||||
|
||||
# ── Parameter helpers ──────────────────────────────────────────────────────
|
||||
|
||||
def _load_params(self) -> dict:
|
||||
return {
|
||||
"uwb_timeout": self.get_parameter("uwb_timeout").value,
|
||||
"cam_timeout": self.get_parameter("cam_timeout").value,
|
||||
"predict_timeout": self.get_parameter("predict_timeout").value,
|
||||
"process_noise": self.get_parameter("process_noise").value,
|
||||
"meas_noise_uwb": self.get_parameter("meas_noise_uwb").value,
|
||||
"meas_noise_cam": self.get_parameter("meas_noise_cam").value,
|
||||
"control_rate": self.get_parameter("control_rate").value,
|
||||
"confidence_threshold": self.get_parameter("confidence_threshold").value,
|
||||
}
|
||||
|
||||
# ── Measurement callbacks ──────────────────────────────────────────────────
|
||||
|
||||
def _uwb_cb(self, msg: PoseStamped) -> None:
|
||||
self._uwb_x = msg.pose.position.x
|
||||
self._uwb_y = msg.pose.position.y
|
||||
self._uwb_tag_id = "" # PoseStamped has no tag field; tag reported via /uwb/bearing
|
||||
t = time.monotonic()
|
||||
self._last_uwb_t = t
|
||||
self._last_any_t = t
|
||||
|
||||
# Seed KF on first measurement
|
||||
if not self._kf.initialized:
|
||||
self._kf.initialize(self._uwb_x, self._uwb_y)
|
||||
self._last_predict_t = t
|
||||
|
||||
def _cam_cb(self, msg: PoseStamped) -> None:
|
||||
self._cam_x = msg.pose.position.x
|
||||
self._cam_y = msg.pose.position.y
|
||||
t = time.monotonic()
|
||||
self._last_cam_t = t
|
||||
self._last_any_t = t
|
||||
|
||||
if not self._kf.initialized:
|
||||
self._kf.initialize(self._cam_x, self._cam_y)
|
||||
self._last_predict_t = t
|
||||
|
||||
# ── Control loop ───────────────────────────────────────────────────────────
|
||||
|
||||
def _control_cb(self) -> None:
|
||||
self._p = self._load_params()
|
||||
|
||||
if not self._kf.initialized:
|
||||
return # no data yet — nothing to publish
|
||||
|
||||
now = time.monotonic()
|
||||
dt = now - self._last_predict_t if self._last_predict_t > 0.0 else (
|
||||
1.0 / self._p["control_rate"]
|
||||
)
|
||||
self._last_predict_t = now
|
||||
|
||||
# KF predict
|
||||
self._kf.predict(dt)
|
||||
|
||||
# Source confidence
|
||||
uwb_age = (now - self._last_uwb_t) if self._last_uwb_t > 0.0 else _BIG_AGE
|
||||
cam_age = (now - self._last_cam_t) if self._last_cam_t > 0.0 else _BIG_AGE
|
||||
|
||||
u_conf = uwb_confidence(uwb_age, self._p["uwb_timeout"])
|
||||
c_conf = camera_confidence(cam_age, self._p["cam_timeout"])
|
||||
|
||||
threshold = self._p["confidence_threshold"]
|
||||
source = select_source(u_conf, c_conf, threshold)
|
||||
|
||||
if source == "predicted":
|
||||
# Check predict_timeout — stop publishing if too stale
|
||||
last_live_age = (
|
||||
(now - self._last_any_t) if self._last_any_t > 0.0 else _BIG_AGE
|
||||
)
|
||||
if last_live_age > self._p["predict_timeout"]:
|
||||
return # silently stop publishing
|
||||
|
||||
# Apply measurement update if a live source exists
|
||||
if source == "fused":
|
||||
fx, fy = fuse_positions(
|
||||
self._uwb_x, self._uwb_y, u_conf,
|
||||
self._cam_x, self._cam_y, c_conf,
|
||||
)
|
||||
self._kf.update(fx, fy, source="uwb") # use tighter noise for blended
|
||||
elif source == "uwb":
|
||||
self._kf.update(self._uwb_x, self._uwb_y, source="uwb")
|
||||
elif source == "camera":
|
||||
self._kf.update(self._cam_x, self._cam_y, source="camera")
|
||||
# "predicted" → no update; KF coasts
|
||||
|
||||
# Build and publish message
|
||||
kx, ky = self._kf.position
|
||||
vx, vy = self._kf.velocity
|
||||
kf_unc = self._kf.position_uncertainty_m()
|
||||
conf = composite_confidence(u_conf, c_conf, source, kf_unc)
|
||||
bearing, range_m = bearing_and_range(kx, ky)
|
||||
|
||||
hdr = Header()
|
||||
hdr.stamp = self.get_clock().now().to_msg()
|
||||
hdr.frame_id = "base_link"
|
||||
|
||||
msg = FusedTarget()
|
||||
msg.header = hdr
|
||||
msg.position.x = kx
|
||||
msg.position.y = ky
|
||||
msg.position.z = 0.0
|
||||
msg.velocity.x = vx
|
||||
msg.velocity.y = vy
|
||||
msg.velocity.z = 0.0
|
||||
msg.range_m = float(range_m)
|
||||
msg.bearing_rad = float(bearing)
|
||||
msg.confidence = float(conf)
|
||||
msg.active_source = source
|
||||
msg.tag_id = self._uwb_tag_id if "uwb" in source else ""
|
||||
|
||||
self._pub.publish(msg)
|
||||
|
||||
|
||||
# ── Entry point ────────────────────────────────────────────────────────────────
|
||||
|
||||
def main(args=None):
|
||||
rclpy.init(args=args)
|
||||
node = TrackingFusionNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.try_shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
5
jetson/ros2_ws/src/saltybot_social_tracking/setup.cfg
Normal file
5
jetson/ros2_ws/src/saltybot_social_tracking/setup.cfg
Normal file
@ -0,0 +1,5 @@
|
||||
[develop]
|
||||
script_dir=$base/lib/saltybot_social_tracking
|
||||
|
||||
[install]
|
||||
install_scripts=$base/lib/saltybot_social_tracking
|
||||
32
jetson/ros2_ws/src/saltybot_social_tracking/setup.py
Normal file
32
jetson/ros2_ws/src/saltybot_social_tracking/setup.py
Normal file
@ -0,0 +1,32 @@
|
||||
from setuptools import setup, find_packages
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
package_name = "saltybot_social_tracking"
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version="0.1.0",
|
||||
packages=find_packages(exclude=["test"]),
|
||||
data_files=[
|
||||
("share/ament_index/resource_index/packages",
|
||||
[f"resource/{package_name}"]),
|
||||
(f"share/{package_name}", ["package.xml"]),
|
||||
(os.path.join("share", package_name, "config"),
|
||||
glob("config/*.yaml")),
|
||||
(os.path.join("share", package_name, "launch"),
|
||||
glob("launch/*.py")),
|
||||
],
|
||||
install_requires=["setuptools"],
|
||||
zip_safe=True,
|
||||
maintainer="sl-controls",
|
||||
maintainer_email="sl-controls@saltylab.local",
|
||||
description="Multi-modal tracking fusion (UWB + camera Kalman filter)",
|
||||
license="MIT",
|
||||
tests_require=["pytest"],
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
f"tracking_fusion_node = {package_name}.tracking_fusion_node:main",
|
||||
],
|
||||
},
|
||||
)
|
||||
@ -0,0 +1,438 @@
|
||||
"""
|
||||
test_tracking_fusion.py — Unit tests for saltybot_social_tracking pure modules.
|
||||
|
||||
Tests cover:
|
||||
- KalmanTracker: initialization, predict, update, state accessors
|
||||
- source_arbiter: confidence functions, source selection, fusion, bearing
|
||||
|
||||
No ROS2 runtime required.
|
||||
"""
|
||||
|
||||
import math
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Allow running: python -m pytest test/test_tracking_fusion.py
|
||||
# from the package root without installing the package.
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from saltybot_social_tracking.kalman_tracker import KalmanTracker
|
||||
from saltybot_social_tracking.source_arbiter import (
|
||||
uwb_confidence,
|
||||
camera_confidence,
|
||||
select_source,
|
||||
fuse_positions,
|
||||
composite_confidence,
|
||||
bearing_and_range,
|
||||
)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# KalmanTracker tests
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestKalmanTrackerInit:
|
||||
|
||||
def test_not_initialized_by_default(self):
|
||||
kf = KalmanTracker()
|
||||
assert not kf.initialized
|
||||
|
||||
def test_initialize_sets_position(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(3.0, 1.5)
|
||||
assert kf.initialized
|
||||
x, y = kf.position
|
||||
assert abs(x - 3.0) < 1e-9
|
||||
assert abs(y - 1.5) < 1e-9
|
||||
|
||||
def test_initialize_sets_zero_velocity(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(1.0, -2.0)
|
||||
vx, vy = kf.velocity
|
||||
assert abs(vx) < 1e-9
|
||||
assert abs(vy) < 1e-9
|
||||
|
||||
def test_initialize_origin(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(0.0, 0.0)
|
||||
assert kf.initialized
|
||||
x, y = kf.position
|
||||
assert x == 0.0 and y == 0.0
|
||||
|
||||
|
||||
class TestKalmanTrackerPredict:
|
||||
|
||||
def test_predict_zero_dt_no_change(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(2.0, 1.0)
|
||||
kf.predict(0.0)
|
||||
x, y = kf.position
|
||||
assert abs(x - 2.0) < 1e-9
|
||||
assert abs(y - 1.0) < 1e-9
|
||||
|
||||
def test_predict_negative_dt_no_change(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(2.0, 1.0)
|
||||
kf.predict(-0.1)
|
||||
x, y = kf.position
|
||||
assert abs(x - 2.0) < 1e-9
|
||||
|
||||
def test_predict_constant_velocity(self):
|
||||
"""After a position update gives the filter a velocity, predict should extrapolate."""
|
||||
kf = KalmanTracker(process_noise=0.001, meas_noise_uwb=0.001)
|
||||
kf.initialize(0.0, 0.0)
|
||||
# Force filter to track a moving target to build up velocity estimate
|
||||
dt = 0.05
|
||||
for i in range(40):
|
||||
t = i * dt
|
||||
kf.predict(dt)
|
||||
kf.update(2.0 * t, 0.0, "uwb") # 2 m/s in x
|
||||
|
||||
# After many updates the velocity estimate should be close to 2 m/s
|
||||
vx, vy = kf.velocity
|
||||
assert abs(vx - 2.0) < 0.3, f"vx={vx:.3f}"
|
||||
assert abs(vy) < 0.2
|
||||
|
||||
def test_predict_grows_uncertainty(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(0.0, 0.0)
|
||||
unc_before = kf.position_uncertainty_m()
|
||||
kf.predict(1.0)
|
||||
unc_after = kf.position_uncertainty_m()
|
||||
assert unc_after > unc_before
|
||||
|
||||
def test_predict_multiple_steps(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(0.0, 0.0)
|
||||
kf.predict(0.1)
|
||||
kf.predict(0.1)
|
||||
kf.predict(0.1)
|
||||
# No assertion on exact value; just verify no exception and state is finite
|
||||
x, y = kf.position
|
||||
assert math.isfinite(x) and math.isfinite(y)
|
||||
|
||||
|
||||
class TestKalmanTrackerUpdate:
|
||||
|
||||
def test_update_pulls_position_toward_measurement(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(0.0, 0.0)
|
||||
kf.update(5.0, 5.0, "uwb")
|
||||
x, y = kf.position
|
||||
assert x > 0.0 and y > 0.0
|
||||
assert x < 5.0 and y < 5.0 # blended, not jumped
|
||||
|
||||
def test_update_reduces_uncertainty(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(0.0, 0.0)
|
||||
kf.predict(1.0) # uncertainty grows
|
||||
unc_mid = kf.position_uncertainty_m()
|
||||
kf.update(0.1, 0.1, "uwb") # measurement corrects
|
||||
unc_after = kf.position_uncertainty_m()
|
||||
assert unc_after < unc_mid
|
||||
|
||||
def test_update_converges_to_true_position(self):
|
||||
"""Many updates from same point should converge to that point."""
|
||||
kf = KalmanTracker(meas_noise_uwb=0.01)
|
||||
kf.initialize(0.0, 0.0)
|
||||
for _ in range(50):
|
||||
kf.update(3.0, -1.0, "uwb")
|
||||
x, y = kf.position
|
||||
assert abs(x - 3.0) < 0.05, f"x={x:.4f}"
|
||||
assert abs(y - (-1.0)) < 0.05, f"y={y:.4f}"
|
||||
|
||||
def test_update_camera_source_different_noise(self):
|
||||
"""Camera and UWB updates should both move state (noise model differs)."""
|
||||
kf1 = KalmanTracker(meas_noise_uwb=0.20, meas_noise_cam=0.10)
|
||||
kf1.initialize(0.0, 0.0)
|
||||
kf1.update(5.0, 0.0, "uwb")
|
||||
x_uwb, _ = kf1.position
|
||||
|
||||
kf2 = KalmanTracker(meas_noise_uwb=0.20, meas_noise_cam=0.10)
|
||||
kf2.initialize(0.0, 0.0)
|
||||
kf2.update(5.0, 0.0, "camera")
|
||||
x_cam, _ = kf2.position
|
||||
|
||||
# Camera has lower noise → stronger pull toward measurement
|
||||
assert x_cam > x_uwb
|
||||
|
||||
def test_update_unknown_source_defaults_to_camera_noise(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(0.0, 0.0)
|
||||
kf.update(2.0, 0.0, "other") # unknown source — should not raise
|
||||
x, _ = kf.position
|
||||
assert x > 0.0
|
||||
|
||||
def test_position_uncertainty_finite(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(1.0, 1.0)
|
||||
kf.predict(0.05)
|
||||
kf.update(1.1, 0.9, "uwb")
|
||||
assert math.isfinite(kf.position_uncertainty_m())
|
||||
assert kf.position_uncertainty_m() >= 0.0
|
||||
|
||||
def test_covariance_copy_is_independent(self):
|
||||
kf = KalmanTracker()
|
||||
kf.initialize(0.0, 0.0)
|
||||
cov = kf.covariance_copy()
|
||||
cov[0, 0] = 9999.0 # mutate copy
|
||||
assert kf.covariance_copy()[0, 0] != 9999.0
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# source_arbiter tests
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestUwbConfidence:
|
||||
|
||||
def test_zero_age_returns_quality(self):
|
||||
assert abs(uwb_confidence(0.0, 1.5) - 1.0) < 1e-9
|
||||
|
||||
def test_at_timeout_returns_zero(self):
|
||||
assert uwb_confidence(1.5, 1.5) == pytest.approx(0.0)
|
||||
|
||||
def test_beyond_timeout_returns_zero(self):
|
||||
assert uwb_confidence(2.0, 1.5) == 0.0
|
||||
|
||||
def test_half_timeout_returns_half(self):
|
||||
assert uwb_confidence(0.75, 1.5) == pytest.approx(0.5)
|
||||
|
||||
def test_quality_scales_result(self):
|
||||
assert uwb_confidence(0.0, 1.5, quality=0.7) == pytest.approx(0.7)
|
||||
|
||||
def test_large_age_returns_zero(self):
|
||||
assert uwb_confidence(1e9, 1.5) == 0.0
|
||||
|
||||
def test_zero_timeout_returns_zero(self):
|
||||
assert uwb_confidence(0.0, 0.0) == 0.0
|
||||
|
||||
|
||||
class TestCameraConfidence:
|
||||
|
||||
def test_zero_age_full_quality(self):
|
||||
assert camera_confidence(0.0, 1.0, quality=1.0) == pytest.approx(1.0)
|
||||
|
||||
def test_at_timeout_zero(self):
|
||||
assert camera_confidence(1.0, 1.0) == pytest.approx(0.0)
|
||||
|
||||
def test_beyond_timeout_zero(self):
|
||||
assert camera_confidence(2.0, 1.0) == 0.0
|
||||
|
||||
def test_quality_scales(self):
|
||||
# age=0, quality=0.8
|
||||
assert camera_confidence(0.0, 1.0, quality=0.8) == pytest.approx(0.8)
|
||||
|
||||
def test_halfway(self):
|
||||
assert camera_confidence(0.5, 1.0) == pytest.approx(0.5)
|
||||
|
||||
|
||||
class TestSelectSource:
|
||||
|
||||
def test_both_above_threshold_fused(self):
|
||||
assert select_source(0.8, 0.6) == "fused"
|
||||
|
||||
def test_only_uwb_above_threshold(self):
|
||||
assert select_source(0.8, 0.0) == "uwb"
|
||||
|
||||
def test_only_cam_above_threshold(self):
|
||||
assert select_source(0.0, 0.7) == "camera"
|
||||
|
||||
def test_both_below_threshold(self):
|
||||
assert select_source(0.0, 0.0) == "predicted"
|
||||
|
||||
def test_threshold_boundary_uwb(self):
|
||||
# Exactly at threshold — should be treated as live
|
||||
assert select_source(0.15, 0.0, threshold=0.15) == "uwb"
|
||||
|
||||
def test_threshold_boundary_below(self):
|
||||
assert select_source(0.14, 0.0, threshold=0.15) == "predicted"
|
||||
|
||||
def test_custom_threshold(self):
|
||||
assert select_source(0.5, 0.0, threshold=0.6) == "predicted"
|
||||
assert select_source(0.5, 0.0, threshold=0.4) == "uwb"
|
||||
|
||||
|
||||
class TestFusePositions:
|
||||
|
||||
def test_equal_confidence_returns_midpoint(self):
|
||||
x, y = fuse_positions(0.0, 0.0, 1.0, 4.0, 4.0, 1.0)
|
||||
assert abs(x - 2.0) < 1e-9
|
||||
assert abs(y - 2.0) < 1e-9
|
||||
|
||||
def test_full_uwb_weight_returns_uwb(self):
|
||||
x, y = fuse_positions(3.0, 1.0, 1.0, 0.0, 0.0, 0.0)
|
||||
assert abs(x - 3.0) < 1e-9
|
||||
|
||||
def test_full_cam_weight_returns_cam(self):
|
||||
x, y = fuse_positions(0.0, 0.0, 0.0, -2.0, 5.0, 1.0)
|
||||
assert abs(x - (-2.0)) < 1e-9
|
||||
assert abs(y - 5.0) < 1e-9
|
||||
|
||||
def test_weighted_blend(self):
|
||||
# UWB at (0,0) conf=3, camera at (4,0) conf=1 → x = 3/4*0 + 1/4*4 = 1
|
||||
x, y = fuse_positions(0.0, 0.0, 3.0, 4.0, 0.0, 1.0)
|
||||
assert abs(x - 1.0) < 1e-9
|
||||
|
||||
def test_zero_total_returns_uwb_fallback(self):
|
||||
x, y = fuse_positions(7.0, 2.0, 0.0, 3.0, 1.0, 0.0)
|
||||
assert abs(x - 7.0) < 1e-9
|
||||
|
||||
|
||||
class TestCompositeConfidence:
|
||||
|
||||
def test_fused_source_high_confidence(self):
|
||||
conf = composite_confidence(0.9, 0.8, "fused", 0.05)
|
||||
assert conf > 0.7
|
||||
|
||||
def test_predicted_source_capped(self):
|
||||
conf = composite_confidence(0.0, 0.0, "predicted", 0.1)
|
||||
assert conf <= 0.4
|
||||
|
||||
def test_predicted_high_uncertainty_low_confidence(self):
|
||||
conf = composite_confidence(0.0, 0.0, "predicted", 3.0, max_kf_uncertainty_m=3.0)
|
||||
assert conf == pytest.approx(0.0)
|
||||
|
||||
def test_uwb_only(self):
|
||||
conf = composite_confidence(0.8, 0.0, "uwb", 0.05)
|
||||
assert conf > 0.3
|
||||
|
||||
def test_camera_only(self):
|
||||
conf = composite_confidence(0.0, 0.7, "camera", 0.05)
|
||||
assert conf > 0.2
|
||||
|
||||
def test_high_kf_uncertainty_reduces_confidence(self):
|
||||
low_unc = composite_confidence(0.9, 0.0, "uwb", 0.1)
|
||||
high_unc = composite_confidence(0.9, 0.0, "uwb", 2.9)
|
||||
assert low_unc > high_unc
|
||||
|
||||
|
||||
class TestBearingAndRange:
|
||||
|
||||
def test_straight_ahead(self):
|
||||
bearing, rng = bearing_and_range(2.0, 0.0)
|
||||
assert abs(bearing) < 1e-9
|
||||
assert abs(rng - 2.0) < 1e-9
|
||||
|
||||
def test_left_of_robot(self):
|
||||
# +Y = left in base_link frame; bearing should be positive
|
||||
bearing, rng = bearing_and_range(0.0, 1.0)
|
||||
assert abs(bearing - math.pi / 2.0) < 1e-9
|
||||
assert abs(rng - 1.0) < 1e-9
|
||||
|
||||
def test_right_of_robot(self):
|
||||
bearing, rng = bearing_and_range(0.0, -1.0)
|
||||
assert abs(bearing - (-math.pi / 2.0)) < 1e-9
|
||||
|
||||
def test_diagonal(self):
|
||||
bearing, rng = bearing_and_range(1.0, 1.0)
|
||||
assert abs(bearing - math.pi / 4.0) < 1e-9
|
||||
assert abs(rng - math.sqrt(2.0)) < 1e-9
|
||||
|
||||
def test_at_origin(self):
|
||||
bearing, rng = bearing_and_range(0.0, 0.0)
|
||||
assert rng == pytest.approx(0.0)
|
||||
assert math.isfinite(bearing) # atan2(0,0) = 0 in most implementations
|
||||
|
||||
def test_range_always_non_negative(self):
|
||||
for x, y in [(-1, 0), (0, -1), (-2, -3), (5, -5)]:
|
||||
_, rng = bearing_and_range(x, y)
|
||||
assert rng >= 0.0
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Integration scenario tests
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
|
||||
def test_euc_speed_velocity_tracking(self):
|
||||
"""Verify KF can track EUC speed (8 m/s) within 0.5 m/s after warm-up."""
|
||||
kf = KalmanTracker(process_noise=3.0, meas_noise_uwb=0.20)
|
||||
kf.initialize(0.0, 0.0)
|
||||
dt = 1.0 / 10.0 # 10 Hz UWB rate
|
||||
speed = 8.0 # m/s (≈29 km/h)
|
||||
for i in range(60):
|
||||
t = i * dt
|
||||
kf.predict(dt)
|
||||
kf.update(speed * t, 0.0, "uwb")
|
||||
vx, vy = kf.velocity
|
||||
assert abs(vx - speed) < 0.6, f"vx={vx:.2f} expected≈{speed}"
|
||||
assert abs(vy) < 0.3
|
||||
|
||||
def test_signal_loss_recovery(self):
|
||||
"""
|
||||
After 1 s of signal loss the filter should still have a reasonable
|
||||
position estimate (not diverged to infinity).
|
||||
"""
|
||||
kf = KalmanTracker(process_noise=3.0)
|
||||
kf.initialize(2.0, 0.5)
|
||||
# Warm up with 2 m/s x motion
|
||||
dt = 0.05
|
||||
for i in range(20):
|
||||
kf.predict(dt)
|
||||
kf.update(2.0 * (i + 1) * dt, 0.0, "uwb")
|
||||
# Coast for 1 second (20 × 50 ms) without measurements
|
||||
for _ in range(20):
|
||||
kf.predict(dt)
|
||||
x, y = kf.position
|
||||
assert math.isfinite(x) and math.isfinite(y)
|
||||
assert abs(x) < 20.0 # shouldn't have drifted more than 20 m
|
||||
|
||||
def test_uwb_to_camera_handoff(self):
|
||||
"""
|
||||
Simulate UWB going stale and camera taking over — Kalman should
|
||||
smoothly continue tracking without a jump.
|
||||
"""
|
||||
kf = KalmanTracker(meas_noise_uwb=0.20, meas_noise_cam=0.12)
|
||||
kf.initialize(0.0, 0.0)
|
||||
dt = 0.05
|
||||
# Phase 1: UWB active
|
||||
for i in range(20):
|
||||
kf.predict(dt)
|
||||
kf.update(float(i) * 0.1, 0.0, "uwb")
|
||||
x_at_handoff, _ = kf.position
|
||||
|
||||
# Phase 2: Camera takes over from same trajectory
|
||||
for i in range(20, 40):
|
||||
kf.predict(dt)
|
||||
kf.update(float(i) * 0.1, 0.0, "camera")
|
||||
x_after, _ = kf.position
|
||||
|
||||
# Position should have continued progressing (not stuck or reset)
|
||||
assert x_after > x_at_handoff
|
||||
|
||||
def test_confidence_degradation_during_coast(self):
|
||||
"""Composite confidence should drop as KF uncertainty grows during coast."""
|
||||
kf = KalmanTracker(process_noise=3.0)
|
||||
kf.initialize(2.0, 0.0)
|
||||
|
||||
# Fresh: tight uncertainty → high confidence
|
||||
unc_fresh = kf.position_uncertainty_m()
|
||||
conf_fresh = composite_confidence(0.0, 0.0, "predicted", unc_fresh)
|
||||
|
||||
# After 2 s coast
|
||||
for _ in range(40):
|
||||
kf.predict(0.05)
|
||||
unc_stale = kf.position_uncertainty_m()
|
||||
conf_stale = composite_confidence(0.0, 0.0, "predicted", unc_stale)
|
||||
|
||||
assert conf_fresh >= conf_stale
|
||||
|
||||
def test_fused_source_confidence_weighted_position(self):
|
||||
"""Fused position should sit between UWB and camera, closer to higher-conf source."""
|
||||
# UWB at x=0 with high conf, camera at x=10 with low conf
|
||||
uwb_c = 0.9
|
||||
cam_c = 0.1
|
||||
fx, fy = fuse_positions(0.0, 0.0, uwb_c, 10.0, 0.0, cam_c)
|
||||
# Should be much closer to UWB (0) than camera (10)
|
||||
assert fx < 3.0, f"fused_x={fx:.2f}"
|
||||
|
||||
def test_select_source_transitions(self):
|
||||
"""Verify correct source transitions as confidences change."""
|
||||
assert select_source(0.9, 0.8) == "fused"
|
||||
assert select_source(0.9, 0.0) == "uwb"
|
||||
assert select_source(0.0, 0.8) == "camera"
|
||||
assert select_source(0.0, 0.0) == "predicted"
|
||||
Loading…
x
Reference in New Issue
Block a user