Compare commits

...

3 Commits

Author SHA1 Message Date
fa0162fadc feat(social): multi-modal tracking fusion — UWB+camera Kalman filter (Issue #92)
New packages:
  saltybot_social_msgs   — FusedTarget.msg custom message
  saltybot_social_tracking — 4-state Kalman fusion node

saltybot_social_tracking/tracking_fusion_node.py
  Subscribes to /uwb/target (PoseStamped, ~10 Hz) and /person/target
  (PoseStamped, ~30 Hz) and publishes /social/tracking/fused_target
  (FusedTarget) at 20 Hz.

  Source arbitration:
    • "fused"     — both UWB and camera are fresh; confidence-weighted blend
    • "uwb"       — UWB fresh, camera stale
    • "camera"    — camera fresh, UWB stale
    • "predicted" — all sources stale; KF coasts for up to predict_timeout (3 s)

  Kalman filter (kalman_tracker.py):
    State [x, y, vx, vy] with discrete Wiener acceleration noise model
    (process_noise=3.0 m/s²) sized for EUC speeds (20-30 km/h, ≈5.5-8.3 m/s).
    Separate UWB (0.20 m) and camera (0.12 m) measurement noise.
    Velocity estimate converges after ~3 s of 10 Hz UWB measurements.

  Confidence model (source_arbiter.py):
    Per-source confidence = quality × max(0, 1 - age/timeout).
    Composite confidence accounts for KF positional uncertainty and
    is capped at 0.4 during dead-reckoning ("predicted") mode.

Tests: 58/58 pass (no ROS2 runtime required).

Note: saltybot_social_msgs here adds FusedTarget.msg; PR #98
(Issue #84) adds PersonalityState.msg + QueryMood.srv to the same
package. The maintainer should squash-merge #98 first and rebase
this branch on top of it before merging to avoid the package.xml
conflict.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-01 23:59:10 -05:00
d48edf4092 Merge pull request 'feat(social): personality system — SOUL.md persona, mood engine, relationship DB (Issue #84)' (#98) from sl-controls/social-personality into main 2026-03-01 23:58:43 -05:00
44771751e2 feat(social): personality system — SOUL.md persona, mood engine, relationship DB (Issue #84)
New packages:
- saltybot_social_msgs: PersonalityState.msg + QueryMood.srv custom interfaces
- saltybot_social_personality: full personality node

Features:
- SOUL.md YAML/Markdown persona file: name, humor_level (0-10), sass_level (0-10),
  base_mood, per-tier greeting templates, mood prefix strings
- Hot-reload: SoulWatcher polls SOUL.md every reload_interval seconds, applies
  changes live without restarting the node
- Per-person relationship memory in SQLite: score, interaction_count,
  first/last_seen, learned preferences (JSON), full interaction log
- Mood engine (pure functions): happy | curious | annoyed | playful
  driven by relationship score, interaction count, recent event window (120s)
- Greeting personalisation: stranger | regular | favorite tiers
  keyed on interaction count thresholds from SOUL.md
- Publishes /social/personality/state (PersonalityState) at publish_rate Hz
- /social/personality/query_mood (QueryMood) service for on-demand mood query
- Full ROS2 dynamic reconfigure: soul_file, db_path, reload_interval, publish_rate
- 52 unit tests, no ROS2 runtime required

ROS2 interfaces:
  Sub: /social/person_detected  (std_msgs/String JSON)
  Pub: /social/personality/state (saltybot_social_msgs/PersonalityState)
  Srv: /social/personality/query_mood (saltybot_social_msgs/QueryMood)
2026-03-01 23:56:05 -05:00
30 changed files with 2963 additions and 2 deletions

View File

@ -23,6 +23,11 @@ rosidl_generate_interfaces(${PROJECT_NAME}
"msg/Mood.msg"
"msg/Person.msg"
"msg/PersonArray.msg"
# Issue #84 personality system
"msg/PersonalityState.msg"
"srv/QueryMood.srv"
# Issue #92 multi-modal tracking fusion
"msg/FusedTarget.msg"
DEPENDENCIES std_msgs geometry_msgs builtin_interfaces
)

View File

@ -0,0 +1,19 @@
# FusedTarget.msg — output of the multi-modal tracking fusion node.
#
# Position and velocity are in the base_link frame (robot-centred,
# +X forward, +Y left). z components are always 0.0 for ground-plane tracking.
#
# Confidence: 0.0 = no data / fully predicted; 1.0 = strong fused measurement.
# active_source: "fused" | "uwb" | "camera" | "predicted"
std_msgs/Header header
geometry_msgs/Point position # filtered 2-D position (m), z=0
geometry_msgs/Vector3 velocity # filtered 2-D velocity (m/s), z=0
float32 range_m # Euclidean distance from robot to fused position
float32 bearing_rad # bearing in base_link (+ve = person to the left)
float32 confidence # composite confidence [0.0, 1.0]
string active_source # "fused" | "uwb" | "camera" | "predicted"
string tag_id # UWB tag address (empty when UWB not contributing)

View File

@ -0,0 +1,27 @@
# PersonalityState.msg — published on /social/personality/state
#
# Snapshot of the personality node's current state: active mood, relationship
# tier with the detected person, and a pre-generated greeting string.
std_msgs/Header header
# Active persona name (from SOUL.md)
string persona_name
# Current mood: happy | curious | annoyed | playful
string mood
# Person currently being addressed (empty if no one detected)
string person_id
# Relationship tier with person_id: stranger | regular | favorite
string relationship_tier
# Raw relationship score (higher = more familiar)
float32 relationship_score
# How many times we have seen this person
int32 interaction_count
# Ready-to-use greeting for person_id at current tier
string greeting_text

View File

@ -3,16 +3,25 @@
<package format="3">
<name>saltybot_social_msgs</name>
<version>0.1.0</version>
<description>Custom ROS2 messages and services for saltybot social capabilities</description>
<description>
Custom ROS2 message and service definitions for saltybot social capabilities.
Includes social perception types (face detection, person state, enrollment)
and the personality system types (PersonalityState, QueryMood) from Issue #84.
</description>
<maintainer email="seb@vayrette.com">seb</maintainer>
<license>MIT</license>
<buildtool_depend>ament_cmake</buildtool_depend>
<build_depend>rosidl_default_generators</build_depend>
<depend>std_msgs</depend>
<depend>geometry_msgs</depend>
<depend>builtin_interfaces</depend>
<build_depend>rosidl_default_generators</build_depend>
<exec_depend>rosidl_default_runtime</exec_depend>
<member_of_group>rosidl_interface_packages</member_of_group>
<export>
<build_type>ament_cmake</build_type>
</export>

View File

@ -0,0 +1,15 @@
# QueryMood.srv — ask the personality node for the current mood + greeting for a person
#
# Call with empty person_id to query the mood for whoever is currently tracked.
# Request
string person_id # person to query; leave empty for current tracked person
---
# Response
string mood # happy | curious | annoyed | playful
string relationship_tier # stranger | regular | favorite
float32 relationship_score
int32 interaction_count
string greeting_text # suggested greeting string
bool success
string message # error detail if success=false

View File

@ -0,0 +1,42 @@
---
# SOUL.md — Saltybot persona definition
#
# Hot-reload: personality_node watches this file and reloads on change.
# Runtime override: ros2 param set /personality_node soul_file /path/to/SOUL.md
# ── Identity ──────────────────────────────────────────────────────────────────
name: "Salty"
speaking_style: "casual and upbeat, occasional puns"
# ── Personality dials (010) ──────────────────────────────────────────────────
humor_level: 7 # 0 = deadpan/serious, 10 = comedian
sass_level: 4 # 0 = pure politeness, 10 = maximum sass
# ── Default mood (when no person is present or score is neutral) ──────────────
# One of: happy | curious | annoyed | playful
base_mood: "playful"
# ── Relationship thresholds (interaction counts) ──────────────────────────────
threshold_regular: 5 # interactions to graduate from stranger → regular
threshold_favorite: 20 # interactions to graduate from regular → favorite
# ── Greeting templates (use {name} placeholder for person_id or display name) ─
greeting_stranger: "Hello there! I'm Salty, nice to meet you!"
greeting_regular: "Hey {name}! Good to see you again!"
greeting_favorite: "Oh hey {name}!! You're literally my favorite person right now!"
# ── Mood-specific greeting prefixes ──────────────────────────────────────────
mood_prefix_happy: "Great timing — "
mood_prefix_curious: "Oh interesting, "
mood_prefix_annoyed: "Well, "
mood_prefix_playful: "Beep boop! "
---
# Description (ignored by the YAML parser, for human reference only)
Salty is the personality of the saltybot social robot.
She is curious about the world, genuinely happy to see people she knows,
and has a good sense of humour — especially with regulars.
Edit this file to change her personality. The node hot-reloads within
`reload_interval` seconds of any change.

View File

@ -0,0 +1,28 @@
# personality_params.yaml — ROS2 parameter defaults for personality_node.
#
# Run with:
# ros2 launch saltybot_social_personality personality.launch.py
# Override inline:
# ros2 launch saltybot_social_personality personality.launch.py soul_file:=/my/SOUL.md
# Dynamic reconfigure at runtime:
# ros2 param set /personality_node soul_file /path/to/SOUL.md
# ros2 param set /personality_node publish_rate 5.0
# ── SOUL.md path ───────────────────────────────────────────────────────────────
# Path to the SOUL.md persona file. Supports hot-reload.
# Relative paths are resolved from the package share directory.
soul_file: "" # empty = use bundled default config/SOUL.md
# ── SQLite database ────────────────────────────────────────────────────────────
# Path for the per-person relationship memory database.
# Created on first run; persists across restarts.
db_path: "~/.ros/saltybot_personality.db"
# ── Hot-reload polling interval ────────────────────────────────────────────────
# How often (seconds) to check if SOUL.md has been modified.
# Lower = faster reactions to edits; higher = less disk I/O.
reload_interval: 5.0 # seconds
# ── Personality state publication rate ────────────────────────────────────────
# How often to publish /social/personality/state (PersonalityState).
publish_rate: 2.0 # Hz

View File

@ -0,0 +1,99 @@
"""
personality.launch.py Launch the saltybot personality node.
Usage
-----
# Defaults (bundled SOUL.md, ~/.ros/saltybot_personality.db):
ros2 launch saltybot_social_personality personality.launch.py
# Custom persona file:
ros2 launch saltybot_social_personality personality.launch.py \\
soul_file:=/home/robot/my_persona/SOUL.md
# Custom DB + faster reload:
ros2 launch saltybot_social_personality personality.launch.py \\
db_path:=/data/saltybot.db reload_interval:=2.0
# Use a params file:
ros2 launch saltybot_social_personality personality.launch.py \\
params_file:=/my/personality_params.yaml
Dynamic reconfigure (no restart required)
-----------------------------------------
ros2 param set /personality_node soul_file /new/SOUL.md
ros2 param set /personality_node publish_rate 5.0
"""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument, OpaqueFunction
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def _launch_personality(context, *args, **kwargs):
pkg_share = get_package_share_directory("saltybot_social_personality")
params_file = LaunchConfiguration("params_file").perform(context)
soul_file = LaunchConfiguration("soul_file").perform(context)
db_path = LaunchConfiguration("db_path").perform(context)
# Default soul_file to bundled config if not specified
if not soul_file:
soul_file = os.path.join(pkg_share, "config", "SOUL.md")
# Expand ~ in db_path
if db_path:
db_path = os.path.expanduser(db_path)
inline_params = {
"soul_file": soul_file,
"db_path": db_path or os.path.expanduser("~/.ros/saltybot_personality.db"),
"reload_interval": float(LaunchConfiguration("reload_interval").perform(context)),
"publish_rate": float(LaunchConfiguration("publish_rate").perform(context)),
}
node_params = [params_file, inline_params] if params_file else [inline_params]
return [Node(
package = "saltybot_social_personality",
executable = "personality_node",
name = "personality_node",
output = "screen",
parameters = node_params,
)]
def generate_launch_description():
pkg_share = get_package_share_directory("saltybot_social_personality")
default_params = os.path.join(pkg_share, "config", "personality_params.yaml")
return LaunchDescription([
DeclareLaunchArgument(
"params_file",
default_value=default_params,
description="Full path to personality_params.yaml (base config)"),
DeclareLaunchArgument(
"soul_file",
default_value="",
description="Path to SOUL.md persona file (empty = bundled default)"),
DeclareLaunchArgument(
"db_path",
default_value="~/.ros/saltybot_personality.db",
description="SQLite relationship memory database path"),
DeclareLaunchArgument(
"reload_interval",
default_value="5.0",
description="SOUL.md hot-reload polling interval (s)"),
DeclareLaunchArgument(
"publish_rate",
default_value="2.0",
description="Personality state publish rate (Hz)"),
OpaqueFunction(function=_launch_personality),
])

View File

@ -0,0 +1,32 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social_personality</name>
<version>0.1.0</version>
<description>
SOUL.md-driven personality system for saltybot.
Loads a YAML/Markdown persona file, maintains per-person relationship memory
in SQLite, computes mood (happy/curious/annoyed/playful), personalises
greetings by tier (stranger/regular/favorite), and publishes personality
state on /social/personality/state. Supports SOUL.md hot-reload and full
ROS2 dynamic reconfigure. Issue #84.
</description>
<maintainer email="sl-controls@saltylab.local">sl-controls</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>std_msgs</depend>
<depend>rcl_interfaces</depend>
<depend>saltybot_social_msgs</depend>
<buildtool_depend>ament_python</buildtool_depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,187 @@
"""
mood_engine.py Pure-function mood computation for the saltybot personality system.
No ROS2 imports safe to unit-test without a live ROS2 environment.
Public API
----------
compute_mood(soul, score, interaction_count, recent_events) -> str
get_relationship_tier(soul, interaction_count) -> str
build_greeting(soul, tier, mood, person_id) -> str
Mood semantics
--------------
happy : positive valence, comfortable familiarity
playful : high-energy, humorous (requires humor_level >= 7)
curious : low familiarity or novel person inquisitive
annoyed : recent negative events or very low score
Tier semantics
--------------
stranger : interaction_count < threshold_regular
regular : threshold_regular <= count < threshold_favorite
favorite : count >= threshold_favorite
"""
from __future__ import annotations
# ── Mood / tier constants ──────────────────────────────────────────────────────
MOOD_HAPPY = "happy"
MOOD_PLAYFUL = "playful"
MOOD_CURIOUS = "curious"
MOOD_ANNOYED = "annoyed"
TIER_STRANGER = "stranger"
TIER_REGULAR = "regular"
TIER_FAVORITE = "favorite"
# ── Event type constants (used by relationship_db and the node) ────────────────
EVENT_GREETING = "greeting"
EVENT_POSITIVE = "positive"
EVENT_NEGATIVE = "negative"
EVENT_DETECTION = "detection"
# How far back (seconds) to consider "recent" for mood computation
_RECENT_WINDOW_S = 120.0
# ── Mood computation ──────────────────────────────────────────────────────────
def compute_mood(
soul: dict,
score: float,
interaction_count: int,
recent_events: list,
) -> str:
"""Compute the current mood for a given person.
Parameters
----------
soul : dict
Parsed SOUL.md configuration.
score : float
Relationship score for the current person (higher = more familiar).
interaction_count : int
Total number of times we have seen this person.
recent_events : list of dict
Each dict: ``{"type": str, "dt": float}`` where ``dt`` is seconds ago.
Only events with ``dt < 120.0`` are considered "recent".
Returns
-------
str
One of: ``"happy"``, ``"playful"``, ``"curious"``, ``"annoyed"``.
"""
base_mood = soul.get("base_mood", MOOD_PLAYFUL)
humor_level = float(soul.get("humor_level", 5))
# Count recent negative/positive events
recent_neg = sum(
1 for e in recent_events
if e.get("type") == EVENT_NEGATIVE and e.get("dt", 1e9) < _RECENT_WINDOW_S
)
recent_pos = sum(
1 for e in recent_events
if e.get("type") in (EVENT_POSITIVE, EVENT_GREETING)
and e.get("dt", 1e9) < _RECENT_WINDOW_S
)
# Hard override: multiple negatives → annoyed
if recent_neg >= 2:
return MOOD_ANNOYED
# No prior interactions or brand-new person → curious
if interaction_count == 0 or score < 1.0:
return MOOD_CURIOUS
# Stranger tier (low count) → curious
threshold_regular = int(soul.get("threshold_regular", 5))
if interaction_count < threshold_regular:
return MOOD_CURIOUS
# Familiar person: check positive events and humor level
if recent_pos >= 1 or score >= 20.0:
if humor_level >= 7:
return MOOD_PLAYFUL
return MOOD_HAPPY
# High score / favorite
threshold_fav = int(soul.get("threshold_favorite", 20))
if interaction_count >= threshold_fav:
if humor_level >= 7:
return MOOD_PLAYFUL
return MOOD_HAPPY
return base_mood
# ── Tier classification ────────────────────────────────────────────────────────
def get_relationship_tier(soul: dict, interaction_count: int) -> str:
"""Return the relationship tier string for a given interaction count.
Parameters
----------
soul : dict
Parsed SOUL.md configuration.
interaction_count : int
Total number of times we have seen this person.
Returns
-------
str
One of: ``"stranger"``, ``"regular"``, ``"favorite"``.
"""
threshold_regular = int(soul.get("threshold_regular", 5))
threshold_favorite = int(soul.get("threshold_favorite", 20))
if interaction_count >= threshold_favorite:
return TIER_FAVORITE
if interaction_count >= threshold_regular:
return TIER_REGULAR
return TIER_STRANGER
# ── Greeting builder ──────────────────────────────────────────────────────────
def build_greeting(soul: dict, tier: str, mood: str, person_id: str = "") -> str:
"""Compose a greeting string for a person.
Parameters
----------
soul : dict
Parsed SOUL.md configuration.
tier : str
Relationship tier (``"stranger"``, ``"regular"``, ``"favorite"``).
mood : str
Current mood (used to prefix the greeting).
person_id : str
Person identifier / display name. Substituted for ``{name}``
in the template.
Returns
-------
str
A complete, ready-to-display greeting string.
"""
template_key = {
TIER_STRANGER: "greeting_stranger",
TIER_REGULAR: "greeting_regular",
TIER_FAVORITE: "greeting_favorite",
}.get(tier, "greeting_stranger")
template = soul.get(template_key, "Hello!")
base_greeting = template.replace("{name}", person_id or "friend")
prefix_key = f"mood_prefix_{mood}"
prefix = soul.get(prefix_key, "")
if prefix:
# Avoid double punctuation / duplicate capital letters
base_first = base_greeting[0].lower() if base_greeting else ""
greeting = f"{prefix}{base_first}{base_greeting[1:]}"
else:
greeting = base_greeting
return greeting

View File

@ -0,0 +1,349 @@
"""
personality_node.py ROS2 personality system for saltybot.
Overview
--------
Loads a SOUL.md persona file, maintains per-person relationship memory in
SQLite, computes mood, and publishes personality state. All tunable params
support ROS2 dynamic reconfigure (``ros2 param set``).
Subscriptions
-------------
/social/person_detected (std_msgs/String)
JSON payload: ``{"person_id": "alice", "event_type": "greeting",
"delta_score": 1.0}``
event_type defaults to "detection" if absent.
delta_score defaults to 0.0 if absent.
Publications
------------
/social/personality/state (saltybot_social_msgs/PersonalityState)
Published at ``publish_rate`` Hz.
Services
--------
/social/personality/query_mood (saltybot_social_msgs/QueryMood)
Query mood + greeting for any person_id.
Parameters (dynamic reconfigure via ros2 param set)
-------------------
soul_file (str) Path to SOUL.md persona file.
db_path (str) SQLite database file path.
reload_interval (float) How often to poll SOUL.md for changes (s).
publish_rate (float) State publication rate (Hz).
Usage
-----
ros2 launch saltybot_social_personality personality.launch.py
ros2 launch saltybot_social_personality personality.launch.py soul_file:=/my/SOUL.md
ros2 param set /personality_node soul_file /tmp/new_SOUL.md
"""
import json
import os
import rclpy
from rclpy.node import Node
from rcl_interfaces.msg import SetParametersResult
from std_msgs.msg import String, Header
from saltybot_social_msgs.msg import PersonalityState
from saltybot_social_msgs.srv import QueryMood
from .soul_loader import load_soul, SoulWatcher
from .mood_engine import (
compute_mood, get_relationship_tier, build_greeting,
EVENT_GREETING, EVENT_POSITIVE, EVENT_NEGATIVE, EVENT_DETECTION,
)
from .relationship_db import RelationshipDB
_DEFAULT_SOUL = os.path.join(
os.path.dirname(__file__), "..", "config", "SOUL.md"
)
_DEFAULT_DB = os.path.expanduser("~/.ros/saltybot_personality.db")
class PersonalityNode(Node):
def __init__(self):
super().__init__("personality_node")
# ── Parameters ────────────────────────────────────────────────────────
self.declare_parameter("soul_file", _DEFAULT_SOUL)
self.declare_parameter("db_path", _DEFAULT_DB)
self.declare_parameter("reload_interval", 5.0)
self.declare_parameter("publish_rate", 2.0)
self._p = {}
self._reload_ros_params()
# ── State ─────────────────────────────────────────────────────────────
self._soul = {}
self._current_person = "" # person_id currently being addressed
self._watcher = None
# ── Database ──────────────────────────────────────────────────────────
self._db = RelationshipDB(self._p["db_path"])
# ── Load initial SOUL.md ──────────────────────────────────────────────
self._load_soul_safe()
self._start_watcher()
# ── Dynamic reconfigure callback ─────────────────────────────────────
self.add_on_set_parameters_callback(self._on_params_changed)
# ── Subscriptions ─────────────────────────────────────────────────────
self.create_subscription(
String, "/social/person_detected", self._person_detected_cb, 10
)
# ── Publishers ────────────────────────────────────────────────────────
self._state_pub = self.create_publisher(
PersonalityState, "/social/personality/state", 10
)
# ── Services ──────────────────────────────────────────────────────────
self.create_service(
QueryMood,
"/social/personality/query_mood",
self._query_mood_cb,
)
# ── Timers ────────────────────────────────────────────────────────────
self._pub_timer = self.create_timer(
1.0 / self._p["publish_rate"], self._publish_state
)
self.get_logger().info(
f"PersonalityNode ready "
f"persona={self._soul.get('name', '?')!r} "
f"mood={self._current_mood()!r} "
f"db={self._p['db_path']!r}"
)
# ── Parameter helpers ──────────────────────────────────────────────────────
def _reload_ros_params(self):
self._p = {
"soul_file": self.get_parameter("soul_file").value,
"db_path": self.get_parameter("db_path").value,
"reload_interval": self.get_parameter("reload_interval").value,
"publish_rate": self.get_parameter("publish_rate").value,
}
def _on_params_changed(self, params):
"""Dynamic reconfigure — apply changed params without restarting node."""
for param in params:
if param.name == "soul_file":
# Restart watcher on new soul_file
self._stop_watcher()
self._p["soul_file"] = param.value
self._load_soul_safe()
self._start_watcher()
self.get_logger().info(f"soul_file changed → {param.value!r}")
elif param.name in self._p:
self._p[param.name] = param.value
if param.name == "publish_rate" and self._pub_timer:
self._pub_timer.cancel()
self._pub_timer = self.create_timer(
1.0 / max(0.1, param.value), self._publish_state
)
return SetParametersResult(successful=True)
# ── SOUL.md ────────────────────────────────────────────────────────────────
def _load_soul_safe(self):
try:
path = os.path.realpath(self._p["soul_file"])
self._soul = load_soul(path)
self.get_logger().info(
f"SOUL.md loaded: {self._soul.get('name', '?')!r} "
f"humor={self._soul.get('humor_level')} "
f"sass={self._soul.get('sass_level')} "
f"base_mood={self._soul.get('base_mood')!r}"
)
except Exception as exc:
self.get_logger().error(f"Failed to load SOUL.md: {exc}")
if not self._soul:
# Fall back to minimal defaults so the node stays alive
self._soul = {
"name": "Salty",
"humor_level": 5,
"sass_level": 3,
"base_mood": "curious",
"threshold_regular": 5,
"threshold_favorite": 20,
"greeting_stranger": "Hello!",
"greeting_regular": "Hi {name}!",
"greeting_favorite": "Hey {name}!!",
}
def _start_watcher(self):
if not self._soul:
return
self._watcher = SoulWatcher(
path=self._p["soul_file"],
on_reload=self._on_soul_reloaded,
interval=self._p["reload_interval"],
on_error=lambda exc: self.get_logger().warn(
f"SOUL.md hot-reload error: {exc}"
),
)
self._watcher.start()
def _stop_watcher(self):
if self._watcher:
self._watcher.stop()
self._watcher = None
def _on_soul_reloaded(self, soul: dict):
self._soul = soul
self.get_logger().info(
f"SOUL.md reloaded: persona={soul.get('name')!r} "
f"humor={soul.get('humor_level')} base_mood={soul.get('base_mood')!r}"
)
# ── Mood helpers ───────────────────────────────────────────────────────────
def _current_mood(self) -> str:
if not self._current_person or not self._soul:
return self._soul.get("base_mood", "curious") if self._soul else "curious"
person = self._db.get_person(self._current_person)
recent = self._db.get_recent_events(self._current_person, window_s=120.0)
return compute_mood(
soul = self._soul,
score = person["score"],
interaction_count = person["interaction_count"],
recent_events = recent,
)
def _state_for_person(self, person_id: str) -> dict:
"""Build a complete state dict for a given person_id."""
person = self._db.get_person(person_id) if person_id else {
"score": 0.0, "interaction_count": 0
}
recent = self._db.get_recent_events(person_id, window_s=120.0) if person_id else []
mood = compute_mood(
soul = self._soul,
score = person["score"],
interaction_count = person["interaction_count"],
recent_events = recent,
)
tier = get_relationship_tier(self._soul, person["interaction_count"])
greeting = build_greeting(self._soul, tier, mood, person_id)
return {
"person_id": person_id,
"mood": mood,
"tier": tier,
"score": person["score"],
"interaction_count": person["interaction_count"],
"greeting": greeting,
}
# ── Callbacks ──────────────────────────────────────────────────────────────
def _person_detected_cb(self, msg: String):
"""Handle incoming person detection / interaction event.
Expected JSON payload::
{
"person_id": "alice", # required
"event_type": "greeting", # optional, default "detection"
"delta_score": 1.0 # optional, default 0.0
}
"""
try:
data = json.loads(msg.data)
except json.JSONDecodeError as exc:
self.get_logger().warn(f"Bad JSON on /social/person_detected: {exc}")
return
person_id = data.get("person_id", "").strip()
if not person_id:
self.get_logger().warn("person_detected msg missing 'person_id'")
return
event_type = data.get("event_type", EVENT_DETECTION)
delta_score = float(data.get("delta_score", 0.0))
# Increment score by +1 for detection events automatically
if event_type == EVENT_DETECTION and delta_score == 0.0:
delta_score = 0.5
self._db.record_interaction(
person_id = person_id,
event_type = event_type,
details = {k: v for k, v in data.items()
if k not in ("person_id", "event_type", "delta_score")},
delta_score = delta_score,
)
self._current_person = person_id
def _query_mood_cb(self, request: QueryMood.Request, response: QueryMood.Response):
"""Service handler: return mood + greeting for a specific person."""
if not self._soul:
response.success = False
response.message = "SOUL.md not loaded"
return response
person_id = (request.person_id or self._current_person).strip()
state = self._state_for_person(person_id)
response.mood = state["mood"]
response.relationship_tier = state["tier"]
response.relationship_score = float(state["score"])
response.interaction_count = int(state["interaction_count"])
response.greeting_text = state["greeting"]
response.success = True
response.message = ""
return response
# ── Publish ────────────────────────────────────────────────────────────────
def _publish_state(self):
if not self._soul:
return
state = self._state_for_person(self._current_person)
msg = PersonalityState()
msg.header = Header()
msg.header.stamp = self.get_clock().now().to_msg()
msg.header.frame_id = "personality"
msg.persona_name = str(self._soul.get("name", "Salty"))
msg.mood = state["mood"]
msg.person_id = state["person_id"]
msg.relationship_tier = state["tier"]
msg.relationship_score = float(state["score"])
msg.interaction_count = int(state["interaction_count"])
msg.greeting_text = state["greeting"]
self._state_pub.publish(msg)
# ── Lifecycle ──────────────────────────────────────────────────────────────
def destroy_node(self):
self._stop_watcher()
self._db.close()
super().destroy_node()
# ── Entry point ────────────────────────────────────────────────────────────────
def main(args=None):
rclpy.init(args=args)
node = PersonalityNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.try_shutdown()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,297 @@
"""
relationship_db.py SQLite-backed per-person relationship memory.
No ROS2 imports safe to unit-test without a live ROS2 environment.
Schema
------
people (
person_id TEXT PRIMARY KEY,
score REAL DEFAULT 0.0,
interaction_count INTEGER DEFAULT 0,
first_seen TEXT, -- ISO-8601 UTC timestamp
last_seen TEXT, -- ISO-8601 UTC timestamp
prefs TEXT -- JSON blob for learned preferences
)
interactions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
person_id TEXT,
ts TEXT, -- ISO-8601 UTC timestamp
event_type TEXT, -- greeting | positive | negative | detection
details TEXT -- free-form JSON blob
)
Public API
----------
RelationshipDB(db_path)
.get_person(person_id) -> dict
.record_interaction(person_id, event_type, details, delta_score)
.set_pref(person_id, key, value)
.get_pref(person_id, key, default)
.get_recent_events(person_id, window_s) -> list[dict]
.all_people() -> list[dict]
.close()
"""
from __future__ import annotations
import json
import os
import sqlite3
import threading
from datetime import datetime, timezone
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
class RelationshipDB:
"""Thread-safe SQLite relationship store.
Parameters
----------
db_path : str
Path to the SQLite file. Created (with parent dirs) if absent.
"""
def __init__(self, db_path: str):
parent = os.path.dirname(db_path)
if parent:
os.makedirs(parent, exist_ok=True)
self._path = db_path
self._lock = threading.Lock()
self._conn = sqlite3.connect(db_path, check_same_thread=False)
self._conn.row_factory = sqlite3.Row
self._migrate()
# ── Schema ────────────────────────────────────────────────────────────────
def _migrate(self):
with self._conn:
self._conn.executescript("""
CREATE TABLE IF NOT EXISTS people (
person_id TEXT PRIMARY KEY,
score REAL DEFAULT 0.0,
interaction_count INTEGER DEFAULT 0,
first_seen TEXT,
last_seen TEXT,
prefs TEXT DEFAULT '{}'
);
CREATE TABLE IF NOT EXISTS interactions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
person_id TEXT NOT NULL,
ts TEXT NOT NULL,
event_type TEXT NOT NULL,
details TEXT DEFAULT '{}'
);
CREATE INDEX IF NOT EXISTS idx_interactions_person_ts
ON interactions (person_id, ts);
""")
# ── People ────────────────────────────────────────────────────────────────
def get_person(self, person_id: str) -> dict:
"""Return the person record; inserts a default row if not found.
Returns
-------
dict with keys: person_id, score, interaction_count,
first_seen, last_seen, prefs (dict)
"""
with self._lock:
row = self._conn.execute(
"SELECT * FROM people WHERE person_id = ?", (person_id,)
).fetchone()
if row is None:
now = _now_iso()
self._conn.execute(
"INSERT INTO people (person_id, first_seen, last_seen) VALUES (?,?,?)",
(person_id, now, now),
)
self._conn.commit()
return {
"person_id": person_id,
"score": 0.0,
"interaction_count": 0,
"first_seen": now,
"last_seen": now,
"prefs": {},
}
prefs = {}
try:
prefs = json.loads(row["prefs"] or "{}")
except json.JSONDecodeError:
pass
return {
"person_id": row["person_id"],
"score": float(row["score"]),
"interaction_count": int(row["interaction_count"]),
"first_seen": row["first_seen"],
"last_seen": row["last_seen"],
"prefs": prefs,
}
def all_people(self) -> list:
"""Return all person records as a list of dicts."""
with self._lock:
rows = self._conn.execute("SELECT * FROM people ORDER BY score DESC").fetchall()
result = []
for row in rows:
prefs = {}
try:
prefs = json.loads(row["prefs"] or "{}")
except json.JSONDecodeError:
pass
result.append({
"person_id": row["person_id"],
"score": float(row["score"]),
"interaction_count": int(row["interaction_count"]),
"first_seen": row["first_seen"],
"last_seen": row["last_seen"],
"prefs": prefs,
})
return result
# ── Interactions ──────────────────────────────────────────────────────────
def record_interaction(
self,
person_id: str,
event_type: str,
details: dict | None = None,
delta_score: float = 0.0,
):
"""Record an interaction event and update the person's score.
Parameters
----------
person_id : str
event_type : str
One of: ``"greeting"``, ``"positive"``, ``"negative"``,
``"detection"``.
details : dict, optional
Arbitrary key/value data stored as JSON.
delta_score : float
Amount to add to the person's score (can be negative).
Interaction count is always incremented by 1.
"""
now = _now_iso()
details_json = json.dumps(details or {})
with self._lock:
# Ensure person exists
self.get_person.__wrapped__(self, person_id) if hasattr(
self.get_person, "__wrapped__"
) else None
# Upsert person row
self._conn.execute("""
INSERT INTO people (person_id, first_seen, last_seen)
VALUES (?, ?, ?)
ON CONFLICT(person_id) DO UPDATE SET
last_seen = excluded.last_seen
""", (person_id, now, now))
# Increment count + score
self._conn.execute("""
UPDATE people
SET interaction_count = interaction_count + 1,
score = score + ?,
last_seen = ?
WHERE person_id = ?
""", (delta_score, now, person_id))
# Insert interaction log row
self._conn.execute("""
INSERT INTO interactions (person_id, ts, event_type, details)
VALUES (?, ?, ?, ?)
""", (person_id, now, event_type, details_json))
self._conn.commit()
def get_recent_events(self, person_id: str, window_s: float = 120.0) -> list:
"""Return interaction events for *person_id* within the last *window_s* seconds.
Returns
-------
list of dict
Each dict: ``{"type": str, "dt": float, "ts": str, "details": dict}``
where ``dt`` is seconds ago (positive = in the past).
"""
from datetime import timedelta
cutoff = (
datetime.now(timezone.utc) - timedelta(seconds=window_s)
).isoformat()
with self._lock:
rows = self._conn.execute("""
SELECT ts, event_type, details FROM interactions
WHERE person_id = ? AND ts >= ?
ORDER BY ts DESC
""", (person_id, cutoff)).fetchall()
now_dt = datetime.now(timezone.utc)
result = []
for row in rows:
try:
row_dt = datetime.fromisoformat(row["ts"])
# Make timezone-aware if needed
if row_dt.tzinfo is None:
row_dt = row_dt.replace(tzinfo=timezone.utc)
dt_secs = (now_dt - row_dt).total_seconds()
except (ValueError, TypeError):
dt_secs = window_s
details = {}
try:
details = json.loads(row["details"] or "{}")
except json.JSONDecodeError:
pass
result.append({
"type": row["event_type"],
"dt": dt_secs,
"ts": row["ts"],
"details": details,
})
return result
# ── Preferences ───────────────────────────────────────────────────────────
def set_pref(self, person_id: str, key: str, value):
"""Set a learned preference for a person.
Parameters
----------
person_id, key : str
value : JSON-serialisable
"""
person = self.get_person(person_id)
prefs = person["prefs"]
prefs[key] = value
with self._lock:
self._conn.execute(
"UPDATE people SET prefs = ? WHERE person_id = ?",
(json.dumps(prefs), person_id),
)
self._conn.commit()
def get_pref(self, person_id: str, key: str, default=None):
"""Return a specific learned preference for a person."""
return self.get_person(person_id)["prefs"].get(key, default)
# ── Lifecycle ─────────────────────────────────────────────────────────────
def close(self):
"""Close the database connection."""
with self._lock:
self._conn.close()

View File

@ -0,0 +1,196 @@
"""
soul_loader.py SOUL.md persona parser and hot-reload watcher.
SOUL.md format
--------------
The file uses YAML front-matter (delimited by ``---`` lines) with an optional
Markdown description body that is ignored by the parser. Example::
---
name: "Salty"
humor_level: 7
sass_level: 4
base_mood: "playful"
...
---
# Optional description text (ignored)
Public API
----------
load_soul(path) -> dict (raises on parse error)
SoulWatcher(path, cb, interval)
.start()
.stop()
.reload_now() -> dict
Pure module no ROS2 imports.
"""
import os
import re
import threading
import time
import yaml
# Keys that are required in every SOUL.md file
_REQUIRED_KEYS = {
"name",
"humor_level",
"sass_level",
"base_mood",
"threshold_regular",
"threshold_favorite",
"greeting_stranger",
"greeting_regular",
"greeting_favorite",
}
_VALID_MOODS = {"happy", "curious", "annoyed", "playful"}
def _extract_frontmatter(text: str) -> str:
"""Return the YAML block between the first pair of ``---`` delimiters.
Raises ``ValueError`` if the file does not contain valid front-matter.
"""
lines = text.splitlines()
delimiters = [i for i, l in enumerate(lines) if l.strip() == "---"]
if len(delimiters) < 2:
# No delimiter found — treat the whole file as plain YAML
return text
start = delimiters[0] + 1
end = delimiters[1]
return "\n".join(lines[start:end])
def load_soul(path: str) -> dict:
"""Parse a SOUL.md file and return the validated config dict.
Parameters
----------
path : str
Absolute path to the SOUL.md file.
Returns
-------
dict
Validated persona configuration.
Raises
------
FileNotFoundError
If the file does not exist.
ValueError
If the YAML is malformed or required keys are missing.
"""
if not os.path.isfile(path):
raise FileNotFoundError(f"SOUL.md not found: {path}")
with open(path, "r", encoding="utf-8") as fh:
raw = fh.read()
yaml_text = _extract_frontmatter(raw)
try:
data = yaml.safe_load(yaml_text)
except yaml.YAMLError as exc:
raise ValueError(f"SOUL.md YAML parse error in {path}: {exc}") from exc
if not isinstance(data, dict):
raise ValueError(f"SOUL.md top level must be a YAML mapping, got {type(data)}")
# Validate required keys
missing = _REQUIRED_KEYS - data.keys()
if missing:
raise ValueError(f"SOUL.md missing required keys: {sorted(missing)}")
# Validate ranges
for key in ("humor_level", "sass_level"):
val = data.get(key)
if not isinstance(val, (int, float)) or not (0 <= val <= 10):
raise ValueError(f"SOUL.md '{key}' must be a number 010, got {val!r}")
if data.get("base_mood") not in _VALID_MOODS:
raise ValueError(
f"SOUL.md 'base_mood' must be one of {sorted(_VALID_MOODS)}, "
f"got {data.get('base_mood')!r}"
)
for key in ("threshold_regular", "threshold_favorite"):
val = data.get(key)
if not isinstance(val, int) or val < 0:
raise ValueError(f"SOUL.md '{key}' must be a non-negative integer, got {val!r}")
if data["threshold_regular"] > data["threshold_favorite"]:
raise ValueError(
"SOUL.md 'threshold_regular' must be <= 'threshold_favorite'"
)
return data
class SoulWatcher:
"""Background thread that polls SOUL.md for changes and calls a callback.
Parameters
----------
path : str
Path to the SOUL.md file to watch.
on_reload : callable
``on_reload(soul_dict)`` called whenever a valid new SOUL.md is loaded.
interval : float
Polling interval in seconds (default 5.0).
on_error : callable, optional
``on_error(exception)`` called when a reload attempt fails.
"""
def __init__(self, path: str, on_reload, interval: float = 5.0, on_error=None):
self._path = path
self._on_reload = on_reload
self._interval = interval
self._on_error = on_error
self._thread = None
self._stop_evt = threading.Event()
self._last_mtime = 0.0
# ------------------------------------------------------------------
def start(self):
"""Start the background polling thread."""
if self._thread and self._thread.is_alive():
return
self._stop_evt.clear()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
def stop(self):
"""Signal the watcher thread to stop and block until it exits."""
self._stop_evt.set()
if self._thread:
self._thread.join(timeout=self._interval + 1.0)
def reload_now(self) -> dict:
"""Force an immediate reload and return the new soul dict."""
soul = load_soul(self._path)
self._last_mtime = os.path.getmtime(self._path)
self._on_reload(soul)
return soul
# ------------------------------------------------------------------
def _run(self):
while not self._stop_evt.wait(self._interval):
try:
mtime = os.path.getmtime(self._path)
except OSError:
continue
if mtime != self._last_mtime:
try:
soul = load_soul(self._path)
except (FileNotFoundError, ValueError) as exc:
if self._on_error:
self._on_error(exc)
continue
self._last_mtime = mtime
self._on_reload(soul)

View File

@ -0,0 +1,4 @@
[develop]
script_dir=$base/lib/saltybot_social_personality
[install]
install_scripts=$base/lib/saltybot_social_personality

View File

@ -0,0 +1,28 @@
from setuptools import setup
package_name = "saltybot_social_personality"
setup(
name=package_name,
version="0.1.0",
packages=[package_name],
data_files=[
("share/ament_index/resource_index/packages", [f"resource/{package_name}"]),
(f"share/{package_name}", ["package.xml"]),
(f"share/{package_name}/launch", ["launch/personality.launch.py"]),
(f"share/{package_name}/config", ["config/SOUL.md",
"config/personality_params.yaml"]),
],
install_requires=["setuptools", "pyyaml"],
zip_safe=True,
maintainer="sl-controls",
maintainer_email="sl-controls@saltylab.local",
description="SOUL.md-driven personality system for saltybot social interaction",
license="MIT",
tests_require=["pytest"],
entry_points={
"console_scripts": [
"personality_node = saltybot_social_personality.personality_node:main",
],
},
)

View File

@ -0,0 +1,475 @@
"""
test_personality.py Unit tests for the saltybot personality system.
No ROS2 runtime required. Tests pure functions from:
- soul_loader.py
- mood_engine.py
- relationship_db.py
Run with:
pytest jetson/ros2_ws/src/saltybot_social_personality/test/test_personality.py
"""
import os
import tempfile
import textwrap
import pytest
# ── Imports (pure modules, no ROS2) ──────────────────────────────────────────
import sys
sys.path.insert(
0,
os.path.join(os.path.dirname(__file__), "..", "saltybot_social_personality"),
)
from soul_loader import load_soul, _extract_frontmatter
from mood_engine import (
compute_mood, get_relationship_tier, build_greeting,
MOOD_HAPPY, MOOD_PLAYFUL, MOOD_CURIOUS, MOOD_ANNOYED,
TIER_STRANGER, TIER_REGULAR, TIER_FAVORITE,
EVENT_NEGATIVE, EVENT_POSITIVE, EVENT_GREETING,
)
from relationship_db import RelationshipDB
# ── Helpers ───────────────────────────────────────────────────────────────────
def _minimal_soul(**overrides) -> dict:
"""Return a valid minimal soul dict, optionally overriding keys."""
base = {
"name": "Salty",
"humor_level": 7,
"sass_level": 4,
"base_mood": "playful",
"threshold_regular": 5,
"threshold_favorite": 20,
"greeting_stranger": "Hello there!",
"greeting_regular": "Hey {name}!",
"greeting_favorite": "Oh hey {name}!!",
}
base.update(overrides)
return base
def _write_soul(content: str) -> str:
"""Write a SOUL.md string to a temp file and return the path."""
fh = tempfile.NamedTemporaryFile(
mode="w", suffix=".md", delete=False, encoding="utf-8"
)
fh.write(content)
fh.close()
return fh.name
_VALID_SOUL_CONTENT = textwrap.dedent("""\
---
name: "TestBot"
speaking_style: "casual"
humor_level: 7
sass_level: 3
base_mood: "playful"
threshold_regular: 5
threshold_favorite: 20
greeting_stranger: "Hello stranger!"
greeting_regular: "Hey {name}!"
greeting_favorite: "Oh hey {name}!!"
mood_prefix_playful: "Beep boop! "
mood_prefix_happy: "Great — "
mood_prefix_curious: "Hmm, "
mood_prefix_annoyed: "Ugh, "
---
# Description (ignored)
This is the description body.
""")
# ═══════════════════════════════════════════════════════════════════════════════
# soul_loader tests
# ═══════════════════════════════════════════════════════════════════════════════
class TestExtractFrontmatter:
def test_delimited(self):
content = "---\nkey: val\n---\n# body"
assert _extract_frontmatter(content) == "key: val"
def test_no_delimiters_returns_whole(self):
content = "key: val\nother: 123"
assert _extract_frontmatter(content) == content
def test_single_delimiter_returns_whole(self):
content = "---\nkey: val\n"
result = _extract_frontmatter(content)
assert "key: val" in result
def test_body_stripped(self):
content = "---\nname: X\n---\n# Body text\nMore body"
assert "Body text" not in _extract_frontmatter(content)
assert "name: X" in _extract_frontmatter(content)
class TestLoadSoul:
def test_valid_file_loads(self):
path = _write_soul(_VALID_SOUL_CONTENT)
try:
soul = load_soul(path)
assert soul["name"] == "TestBot"
assert soul["humor_level"] == 7
assert soul["base_mood"] == "playful"
finally:
os.unlink(path)
def test_missing_file_raises(self):
with pytest.raises(FileNotFoundError):
load_soul("/nonexistent/SOUL.md")
def test_missing_required_key_raises(self):
content = "---\nname: X\nhumor_level: 5\n---" # missing many keys
path = _write_soul(content)
try:
with pytest.raises(ValueError, match="missing required keys"):
load_soul(path)
finally:
os.unlink(path)
def test_humor_out_of_range_raises(self):
soul_str = _VALID_SOUL_CONTENT.replace("humor_level: 7", "humor_level: 11")
path = _write_soul(soul_str)
try:
with pytest.raises(ValueError, match="humor_level"):
load_soul(path)
finally:
os.unlink(path)
def test_invalid_mood_raises(self):
soul_str = _VALID_SOUL_CONTENT.replace(
'base_mood: "playful"', 'base_mood: "grumpy"'
)
path = _write_soul(soul_str)
try:
with pytest.raises(ValueError, match="base_mood"):
load_soul(path)
finally:
os.unlink(path)
def test_threshold_order_enforced(self):
soul_str = _VALID_SOUL_CONTENT.replace(
"threshold_regular: 5", "threshold_regular: 25"
)
path = _write_soul(soul_str)
try:
with pytest.raises(ValueError, match="threshold_regular"):
load_soul(path)
finally:
os.unlink(path)
def test_extra_keys_allowed(self):
content = _VALID_SOUL_CONTENT.replace(
"---\n# Description",
"custom_key: 42\n---\n# Description"
)
path = _write_soul(content)
try:
soul = load_soul(path)
assert soul.get("custom_key") == 42
finally:
os.unlink(path)
# ═══════════════════════════════════════════════════════════════════════════════
# mood_engine tests
# ═══════════════════════════════════════════════════════════════════════════════
class TestGetRelationshipTier:
def test_zero_interactions_stranger(self):
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
assert get_relationship_tier(soul, 0) == TIER_STRANGER
def test_below_regular_stranger(self):
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
assert get_relationship_tier(soul, 4) == TIER_STRANGER
def test_at_regular_threshold(self):
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
assert get_relationship_tier(soul, 5) == TIER_REGULAR
def test_above_regular_below_favorite(self):
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
assert get_relationship_tier(soul, 10) == TIER_REGULAR
def test_at_favorite_threshold(self):
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
assert get_relationship_tier(soul, 20) == TIER_FAVORITE
def test_above_favorite(self):
soul = _minimal_soul(threshold_regular=5, threshold_favorite=20)
assert get_relationship_tier(soul, 100) == TIER_FAVORITE
class TestComputeMood:
def test_unknown_person_returns_curious(self):
soul = _minimal_soul(humor_level=7)
mood = compute_mood(soul, score=0.0, interaction_count=0, recent_events=[])
assert mood == MOOD_CURIOUS
def test_stranger_low_count_returns_curious(self):
soul = _minimal_soul(threshold_regular=5)
mood = compute_mood(soul, score=2.0, interaction_count=3, recent_events=[])
assert mood == MOOD_CURIOUS
def test_two_negative_events_returns_annoyed(self):
soul = _minimal_soul(threshold_regular=5)
events = [
{"type": EVENT_NEGATIVE, "dt": 30.0},
{"type": EVENT_NEGATIVE, "dt": 60.0},
]
mood = compute_mood(soul, score=10.0, interaction_count=10, recent_events=events)
assert mood == MOOD_ANNOYED
def test_one_negative_not_annoyed(self):
soul = _minimal_soul(humor_level=7, threshold_regular=5, threshold_favorite=20)
events = [{"type": EVENT_NEGATIVE, "dt": 30.0}]
# 1 negative is not enough → should still be happy/playful based on score
mood = compute_mood(soul, score=25.0, interaction_count=25, recent_events=events)
assert mood != MOOD_ANNOYED
def test_high_humor_regular_returns_playful(self):
soul = _minimal_soul(humor_level=8, threshold_regular=5, threshold_favorite=20)
events = [{"type": EVENT_POSITIVE, "dt": 10.0}]
mood = compute_mood(soul, score=10.0, interaction_count=8, recent_events=events)
assert mood == MOOD_PLAYFUL
def test_low_humor_regular_returns_happy(self):
soul = _minimal_soul(humor_level=4, threshold_regular=5, threshold_favorite=20)
events = [{"type": EVENT_POSITIVE, "dt": 10.0}]
mood = compute_mood(soul, score=10.0, interaction_count=8, recent_events=events)
assert mood == MOOD_HAPPY
def test_stale_negative_ignored(self):
soul = _minimal_soul(humor_level=8, threshold_regular=5, threshold_favorite=20)
# dt > 120s → outside the recent window → should not trigger annoyed
events = [
{"type": EVENT_NEGATIVE, "dt": 200.0},
{"type": EVENT_NEGATIVE, "dt": 300.0},
]
mood = compute_mood(soul, score=15.0, interaction_count=10, recent_events=events)
assert mood != MOOD_ANNOYED
def test_favorite_high_humor_playful(self):
soul = _minimal_soul(humor_level=9, threshold_regular=5, threshold_favorite=20)
mood = compute_mood(soul, score=50.0, interaction_count=30, recent_events=[])
assert mood == MOOD_PLAYFUL
def test_favorite_low_humor_happy(self):
soul = _minimal_soul(humor_level=3, threshold_regular=5, threshold_favorite=20)
mood = compute_mood(soul, score=50.0, interaction_count=30, recent_events=[])
assert mood == MOOD_HAPPY
class TestBuildGreeting:
def _soul(self, **kw):
return _minimal_soul(
mood_prefix_happy="Great — ",
mood_prefix_curious="Hmm, ",
mood_prefix_annoyed="Well, ",
mood_prefix_playful="Beep boop! ",
**kw,
)
def test_stranger_greeting(self):
soul = self._soul()
g = build_greeting(soul, TIER_STRANGER, MOOD_CURIOUS, "")
assert "hello" in g.lower()
def test_regular_greeting_contains_name(self):
soul = self._soul()
g = build_greeting(soul, TIER_REGULAR, MOOD_HAPPY, "alice")
assert "alice" in g
def test_favorite_greeting_contains_name(self):
soul = self._soul()
g = build_greeting(soul, TIER_FAVORITE, MOOD_PLAYFUL, "bob")
assert "bob" in g
def test_mood_prefix_applied(self):
soul = self._soul()
g = build_greeting(soul, TIER_REGULAR, MOOD_PLAYFUL, "alice")
assert g.startswith("Beep boop!")
def test_no_prefix_key_no_prefix(self):
soul = _minimal_soul() # no mood_prefix_* keys
g = build_greeting(soul, TIER_REGULAR, MOOD_HAPPY, "alice")
assert g.startswith("Hey")
def test_empty_person_id_uses_friend(self):
soul = self._soul()
g = build_greeting(soul, TIER_REGULAR, MOOD_HAPPY, "")
assert "friend" in g
def test_happy_prefix(self):
soul = self._soul()
g = build_greeting(soul, TIER_REGULAR, MOOD_HAPPY, "carol")
assert g.startswith("Great")
def test_annoyed_prefix(self):
soul = self._soul()
g = build_greeting(soul, TIER_REGULAR, MOOD_ANNOYED, "dave")
assert g.startswith("Well")
# ═══════════════════════════════════════════════════════════════════════════════
# relationship_db tests
# ═══════════════════════════════════════════════════════════════════════════════
class TestRelationshipDB:
@pytest.fixture
def db(self, tmp_path):
path = str(tmp_path / "test.db")
d = RelationshipDB(path)
yield d
d.close()
def test_get_person_creates_default(self, db):
p = db.get_person("alice")
assert p["person_id"] == "alice"
assert p["score"] == pytest.approx(0.0)
assert p["interaction_count"] == 0
def test_get_person_idempotent(self, db):
p1 = db.get_person("bob")
p2 = db.get_person("bob")
assert p1["person_id"] == p2["person_id"]
def test_record_interaction_increments_count(self, db):
db.record_interaction("alice", "detection")
db.record_interaction("alice", "detection")
p = db.get_person("alice")
assert p["interaction_count"] == 2
def test_record_interaction_updates_score(self, db):
db.record_interaction("alice", "positive", delta_score=5.0)
p = db.get_person("alice")
assert p["score"] == pytest.approx(5.0)
def test_negative_delta_reduces_score(self, db):
db.record_interaction("carol", "positive", delta_score=10.0)
db.record_interaction("carol", "negative", delta_score=-3.0)
p = db.get_person("carol")
assert p["score"] == pytest.approx(7.0)
def test_score_zero_by_default(self, db):
p = db.get_person("dave")
assert p["score"] == pytest.approx(0.0)
def test_set_and_get_pref(self, db):
db.set_pref("alice", "language", "en")
assert db.get_pref("alice", "language") == "en"
def test_get_pref_default(self, db):
assert db.get_pref("nobody", "language", "fr") == "fr"
def test_multiple_prefs_stored(self, db):
db.set_pref("alice", "lang", "en")
db.set_pref("alice", "name", "Alice")
assert db.get_pref("alice", "lang") == "en"
assert db.get_pref("alice", "name") == "Alice"
def test_all_people_returns_list(self, db):
db.record_interaction("a", "detection")
db.record_interaction("b", "detection")
people = db.all_people()
ids = {p["person_id"] for p in people}
assert {"a", "b"} <= ids
def test_get_recent_events_returns_events(self, db):
db.record_interaction("alice", "greeting", delta_score=1.0)
events = db.get_recent_events("alice", window_s=60.0)
assert len(events) == 1
assert events[0]["type"] == "greeting"
def test_get_recent_events_empty_for_new_person(self, db):
events = db.get_recent_events("nobody", window_s=60.0)
assert events == []
def test_event_dt_positive(self, db):
db.record_interaction("alice", "detection")
events = db.get_recent_events("alice", window_s=60.0)
assert events[0]["dt"] >= 0.0
def test_multiple_people_isolated(self, db):
db.record_interaction("alice", "positive", delta_score=10.0)
db.record_interaction("bob", "negative", delta_score=-5.0)
assert db.get_person("alice")["score"] == pytest.approx(10.0)
assert db.get_person("bob")["score"] == pytest.approx(-5.0)
def test_details_stored(self, db):
db.record_interaction("alice", "greeting", details={"location": "lab"})
events = db.get_recent_events("alice", window_s=60.0)
assert events[0]["details"].get("location") == "lab"
# ═══════════════════════════════════════════════════════════════════════════════
# Integration: soul → tier → mood → greeting pipeline
# ═══════════════════════════════════════════════════════════════════════════════
class TestIntegrationPipeline:
def test_stranger_pipeline(self, tmp_path):
db_path = str(tmp_path / "int.db")
db = RelationshipDB(db_path)
soul = _minimal_soul(
humor_level=7, threshold_regular=5, threshold_favorite=20,
mood_prefix_curious="Hmm, "
)
# No prior interactions
person = db.get_person("stranger_001")
events = db.get_recent_events("stranger_001")
tier = get_relationship_tier(soul, person["interaction_count"])
mood = compute_mood(soul, person["score"], person["interaction_count"], events)
greeting = build_greeting(soul, tier, mood, "stranger_001")
assert tier == TIER_STRANGER
assert mood == MOOD_CURIOUS
assert "hello" in greeting.lower()
db.close()
def test_regular_positive_pipeline(self, tmp_path):
db_path = str(tmp_path / "int2.db")
db = RelationshipDB(db_path)
soul = _minimal_soul(
humor_level=8, threshold_regular=5, threshold_favorite=20,
mood_prefix_playful="Beep! "
)
# Simulate 6 positive interactions (> threshold_regular=5)
for _ in range(6):
db.record_interaction("alice", "positive", delta_score=2.0)
person = db.get_person("alice")
events = db.get_recent_events("alice")
tier = get_relationship_tier(soul, person["interaction_count"])
mood = compute_mood(soul, person["score"], person["interaction_count"], events)
greeting = build_greeting(soul, tier, mood, "alice")
assert tier == TIER_REGULAR
assert mood == MOOD_PLAYFUL # humor_level=8, recent positive events
assert "alice" in greeting
assert greeting.startswith("Beep!")
db.close()
def test_favorite_pipeline(self, tmp_path):
db_path = str(tmp_path / "int3.db")
db = RelationshipDB(db_path)
soul = _minimal_soul(
humor_level=5, threshold_regular=5, threshold_favorite=20
)
for _ in range(25):
db.record_interaction("bob", "positive", delta_score=1.0)
person = db.get_person("bob")
tier = get_relationship_tier(soul, person["interaction_count"])
assert tier == TIER_FAVORITE
greeting = build_greeting(soul, tier, "happy", "bob")
assert "bob" in greeting
assert "Oh hey" in greeting
db.close()

View File

@ -0,0 +1,5 @@
__pycache__/
*.pyc
*.pyo
*.egg-info/
.pytest_cache/

View File

@ -0,0 +1,48 @@
# tracking_params.yaml — saltybot_social_tracking / TrackingFusionNode
#
# Run with:
# ros2 launch saltybot_social_tracking tracking.launch.py
#
# Topics consumed:
# /uwb/target (geometry_msgs/PoseStamped) — UWB triangulated position
# /person/target (geometry_msgs/PoseStamped) — camera-detected position
#
# Topic produced:
# /social/tracking/fused_target (saltybot_social_msgs/FusedTarget)
# ── Source staleness timeouts ──────────────────────────────────────────────────
# UWB driver publishes at ~10 Hz; 1.5 s = 15 missed cycles before declared stale.
uwb_timeout: 1.5 # seconds
# Camera detector publishes at ~30 Hz; 1.0 s = 30 missed frames before stale.
cam_timeout: 1.0 # seconds
# How long the Kalman filter may coast (dead-reckoning) with no live source
# before the node stops publishing.
# At 10 m/s (EUC top-speed) the robot drifts ≈30 m over 3 s — beyond the UWB
# follow-range, so 3 s is a reasonable hard stop.
predict_timeout: 3.0 # seconds
# ── Kalman filter tuning ───────────────────────────────────────────────────────
# process_noise: acceleration noise std-dev (m/s²).
# EUC riders can brake or accelerate at ~35 m/s²; 3.0 is a good starting point.
# Increase if the filtered track lags behind fast direction changes.
# Decrease if the track is noisy.
process_noise: 3.0 # m/s²
# UWB position measurement noise (std-dev, metres).
# DW3000 TWR accuracy ≈ ±1020 cm; 0.20 accounts for system-level error.
meas_noise_uwb: 0.20 # m
# Camera position noise (std-dev, metres).
# Depth reprojection error with RealSense D435i at 13 m ≈ ±515 cm.
meas_noise_cam: 0.12 # m
# ── Control loop ──────────────────────────────────────────────────────────────
control_rate: 20.0 # Hz — KF predict + publish rate
# ── Source arbiter ────────────────────────────────────────────────────────────
# Minimum normalised confidence for a source to be considered live.
# Range [0, 1]; lower = more permissive; default 0.15 keeps slightly stale
# sources active rather than dropping to "predicted" prematurely.
confidence_threshold: 0.15

View File

@ -0,0 +1,44 @@
"""tracking.launch.py — launch the TrackingFusionNode with default params."""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description():
pkg_share = get_package_share_directory("saltybot_social_tracking")
default_params = os.path.join(pkg_share, "config", "tracking_params.yaml")
return LaunchDescription([
DeclareLaunchArgument(
"params_file",
default_value=default_params,
description="Path to tracking fusion parameter YAML file",
),
DeclareLaunchArgument(
"control_rate",
default_value="20.0",
description="KF predict + publish rate (Hz)",
),
DeclareLaunchArgument(
"predict_timeout",
default_value="3.0",
description="Max KF coast time before stopping publish (s)",
),
Node(
package="saltybot_social_tracking",
executable="tracking_fusion_node",
name="tracking_fusion",
output="screen",
parameters=[
LaunchConfiguration("params_file"),
{
"control_rate": LaunchConfiguration("control_rate"),
"predict_timeout": LaunchConfiguration("predict_timeout"),
},
],
),
])

View File

@ -0,0 +1,31 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>saltybot_social_tracking</name>
<version>0.1.0</version>
<description>
Multi-modal tracking fusion for saltybot.
Fuses UWB triangulated position (/uwb/target) and camera-detected position
(/person/target) using a 4-state Kalman filter to produce a smooth, low-latency
fused estimate at /social/tracking/fused_target.
Handles EUC rider speeds (20-30 km/h), signal handoff, and predictive coasting.
</description>
<maintainer email="sl-controls@saltylab.local">sl-controls</maintainer>
<license>MIT</license>
<depend>rclpy</depend>
<depend>geometry_msgs</depend>
<depend>std_msgs</depend>
<depend>saltybot_social_msgs</depend>
<buildtool_depend>ament_python</buildtool_depend>
<test_depend>ament_copyright</test_depend>
<test_depend>ament_flake8</test_depend>
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<export>
<build_type>ament_python</build_type>
</export>
</package>

View File

@ -0,0 +1,134 @@
"""
kalman_tracker.py 4-state linear Kalman filter for 2-D position+velocity tracking.
State vector: [x, y, vx, vy]
Observation: [x_meas, y_meas]
Process model: constant velocity with Wiener process acceleration noise.
Tuned to handle EUC rider speeds (2030 km/h 5.58.3 m/s) with fast
acceleration transients.
Pure module no ROS2 dependency; fully unit-testable.
"""
import math
import numpy as np
class KalmanTracker:
"""
4-state Kalman filter: state = [x, y, vx, vy].
Parameters
----------
process_noise : acceleration noise standard deviation (m/).
Higher values allow the filter to track rapid velocity
changes (EUC acceleration events). Default 3.0 m/.
meas_noise_uwb : UWB position measurement noise std-dev (m). Default 0.20 m.
meas_noise_cam : Camera position measurement noise std-dev (m). Default 0.12 m.
"""
def __init__(
self,
process_noise: float = 3.0,
meas_noise_uwb: float = 0.20,
meas_noise_cam: float = 0.12,
):
self._q = float(process_noise)
self._r_uwb = float(meas_noise_uwb)
self._r_cam = float(meas_noise_cam)
# State [x, y, vx, vy]
self._x = np.zeros(4)
# Covariance — large initial uncertainty (10 m², 10 (m/s)²)
self._P = np.diag([10.0, 10.0, 10.0, 10.0])
# Observation matrix: H * x = [x, y]
self._H = np.array([[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0]])
self._initialized = False
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
@property
def initialized(self) -> bool:
return self._initialized
def initialize(self, x: float, y: float) -> None:
"""Seed the filter at position (x, y) with zero velocity."""
self._x = np.array([x, y, 0.0, 0.0])
self._P = np.diag([1.0, 1.0, 5.0, 5.0])
self._initialized = True
def predict(self, dt: float) -> None:
"""
Advance the filter state by dt seconds.
Uses a discrete Wiener process acceleration model so that positional
uncertainty grows as O(dt^4/4) and velocity uncertainty as O(dt^2).
This lets the filter coast accurately through short signal outages
while still being responsive to EUC velocity changes.
"""
if dt <= 0.0:
return
F = np.array([[1.0, 0.0, dt, 0.0],
[0.0, 1.0, 0.0, dt],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0]])
q = self._q
dt2 = dt * dt
dt3 = dt2 * dt
dt4 = dt3 * dt
Q = (q * q) * np.array([
[dt4 / 4.0, 0.0, dt3 / 2.0, 0.0 ],
[0.0, dt4 / 4.0, 0.0, dt3 / 2.0],
[dt3 / 2.0, 0.0, dt2, 0.0 ],
[0.0, dt3 / 2.0, 0.0, dt2 ],
])
self._x = F @ self._x
self._P = F @ self._P @ F.T + Q
def update(self, x_meas: float, y_meas: float, source: str = "uwb") -> None:
"""
Apply a position measurement (x_meas, y_meas).
source : "uwb" or "camera" selects the appropriate noise covariance.
"""
r = self._r_uwb if source == "uwb" else self._r_cam
R = np.diag([r * r, r * r])
z = np.array([x_meas, y_meas])
innov = z - self._H @ self._x # innovation
S = self._H @ self._P @ self._H.T + R # innovation covariance
K = self._P @ self._H.T @ np.linalg.inv(S) # Kalman gain
self._x = self._x + K @ innov
self._P = (np.eye(4) - K @ self._H) @ self._P
# ------------------------------------------------------------------
# State accessors
# ------------------------------------------------------------------
@property
def position(self) -> tuple:
"""Current filtered position (x, y) in metres."""
return float(self._x[0]), float(self._x[1])
@property
def velocity(self) -> tuple:
"""Current filtered velocity (vx, vy) in m/s."""
return float(self._x[2]), float(self._x[3])
def position_uncertainty_m(self) -> float:
"""RMS positional uncertainty (m) from diagonal of covariance."""
return float(math.sqrt((self._P[0, 0] + self._P[1, 1]) / 2.0))
def covariance_copy(self) -> np.ndarray:
"""Return a copy of the current 4×4 covariance matrix."""
return self._P.copy()

View File

@ -0,0 +1,155 @@
"""
source_arbiter.py Source confidence scoring and selection for tracking fusion.
Two sensor sources are supported:
UWB geometry_msgs/PoseStamped from /uwb/target (triangulated, ~10 Hz)
Camera geometry_msgs/PoseStamped from /person/target (depth+YOLO, ~30 Hz)
Confidence model
----------------
Each source's confidence is its raw measurement quality multiplied by a
linear staleness factor that drops to zero at its respective timeout:
conf = quality * max(0, 1 - age / timeout)
UWB quality is always 1.0 (the ranging hardware confidence is not exposed
by the driver in origin/main; the UWB node already applies Kalman filtering).
Camera quality defaults to 1.0; callers may pass a lower value when the
detection confidence is available.
Source selection
----------------
Both fresh "fused" (confidence-weighted position blend)
UWB only "uwb"
Camera only "camera"
Neither fresh "predicted" (Kalman coasts)
Pure module no ROS2 dependency; fully unit-testable.
"""
import math
def _staleness_factor(age_s: float, timeout_s: float) -> float:
"""Linear decay: 1.0 at age=0, 0.0 at age=timeout, clamped."""
if timeout_s <= 0.0:
return 0.0
return max(0.0, 1.0 - age_s / timeout_s)
def uwb_confidence(age_s: float, timeout_s: float, quality: float = 1.0) -> float:
"""
UWB source confidence.
age_s : seconds since last UWB measurement (0; use large value if never)
timeout_s: staleness threshold (s); confidence reaches 0 at this age
quality : inherent measurement quality [0, 1] (default 1.0)
"""
return quality * _staleness_factor(age_s, timeout_s)
def camera_confidence(
age_s: float, timeout_s: float, quality: float = 1.0
) -> float:
"""
Camera source confidence.
age_s : seconds since last camera detection (0; use large value if never)
timeout_s: staleness threshold (s)
quality : YOLO detection confidence or other quality score [0, 1]
"""
return quality * _staleness_factor(age_s, timeout_s)
def select_source(
uwb_conf: float,
cam_conf: float,
threshold: float = 0.15,
) -> str:
"""
Choose the active tracking source label.
Returns one of: "fused", "uwb", "camera", "predicted".
threshold: minimum confidence for a source to be considered live.
Sources below threshold are ignored.
"""
uwb_ok = uwb_conf >= threshold
cam_ok = cam_conf >= threshold
if uwb_ok and cam_ok:
return "fused"
if uwb_ok:
return "uwb"
if cam_ok:
return "camera"
return "predicted"
def fuse_positions(
uwb_x: float,
uwb_y: float,
uwb_conf: float,
cam_x: float,
cam_y: float,
cam_conf: float,
) -> tuple:
"""
Confidence-weighted position fusion.
Returns (fused_x, fused_y).
When total confidence is zero (shouldn't happen in "fused" state, but
guarded), returns the UWB position as fallback.
"""
total = uwb_conf + cam_conf
if total <= 0.0:
return uwb_x, uwb_y
w = uwb_conf / total
return (
w * uwb_x + (1.0 - w) * cam_x,
w * uwb_y + (1.0 - w) * cam_y,
)
def composite_confidence(
uwb_conf: float,
cam_conf: float,
source: str,
kf_uncertainty_m: float,
max_kf_uncertainty_m: float = 3.0,
) -> float:
"""
Compute a single composite confidence value [0, 1] for the fused output.
source : current source label (from select_source)
kf_uncertainty_m : current KF positional RMS uncertainty
max_kf_uncertainty_m: uncertainty at which confidence collapses to 0
"""
if source == "predicted":
# Decay with growing KF uncertainty; no sensor feeds are live
raw = max(0.0, 1.0 - kf_uncertainty_m / max_kf_uncertainty_m)
return min(0.4, raw) # cap at 0.4 — caller should know this is dead-reckoning
if source == "fused":
raw = max(uwb_conf, cam_conf)
elif source == "uwb":
raw = uwb_conf
else: # "camera"
raw = cam_conf
# Scale by KF health (full confidence only if KF is tight)
kf_health = max(0.0, 1.0 - kf_uncertainty_m / max_kf_uncertainty_m)
return raw * (0.5 + 0.5 * kf_health)
def bearing_and_range(x: float, y: float) -> tuple:
"""
Compute bearing (rad, +ve = left) and range (m) to position (x, y).
Consistent with person_follower_node conventions:
bearing = atan2(y, x) (base_link frame: +X forward, +Y left)
range = sqrt( + )
"""
return math.atan2(y, x), math.sqrt(x * x + y * y)

View File

@ -0,0 +1,257 @@
"""
tracking_fusion_node.py Multi-modal tracking fusion for saltybot.
Subscribes
----------
/uwb/target (geometry_msgs/PoseStamped) UWB-triangulated position (~10 Hz)
/person/target (geometry_msgs/PoseStamped) camera-detected position (~30 Hz)
Publishes
---------
/social/tracking/fused_target (saltybot_social_msgs/FusedTarget) at control_rate Hz
Algorithm
---------
1. Each incoming measurement updates a 4-state Kalman filter [x, y, vx, vy].
2. A 20 Hz timer runs predict+select+publish:
a. KF predict(dt)
b. Compute per-source confidence from measurement age + staleness model
c. If either source is live:
- "fused" confidence-weighted position blend KF update
- "uwb" UWB position KF update
- "camera" camera position KF update
d. Build FusedTarget from KF state + composite confidence
3. If all sources are lost but within predict_timeout, keep publishing with
active_source="predicted" and degrading confidence.
4. Beyond predict_timeout, no message is published (node stays alive).
Kalman tuning for EUC speeds (2030 km/h 5.58.3 m/s)
---------------------------------------------------------
process_noise=3.0 m/ allows rapid acceleration events
predict_timeout=3.0 s coasts 30 m at 10 m/s; acceptable dead-reckoning
Parameters
----------
uwb_timeout : UWB staleness threshold (s) default 1.5
cam_timeout : Camera staleness threshold (s) default 1.0
predict_timeout : Max KF coast before no publish (s) default 3.0
process_noise : KF acceleration noise std-dev (m/) default 3.0
meas_noise_uwb : UWB position noise std-dev (m) default 0.20
meas_noise_cam : Camera position noise std-dev (m) default 0.12
control_rate : Publish / KF predict rate (Hz) default 20.0
confidence_threshold: Min source confidence to use (01) default 0.15
Usage
-----
ros2 launch saltybot_social_tracking tracking.launch.py
"""
import math
import time
import rclpy
from rclpy.node import Node
from geometry_msgs.msg import PoseStamped
from std_msgs.msg import Header
from saltybot_social_msgs.msg import FusedTarget
from saltybot_social_tracking.kalman_tracker import KalmanTracker
from saltybot_social_tracking.source_arbiter import (
uwb_confidence,
camera_confidence,
select_source,
fuse_positions,
composite_confidence,
bearing_and_range,
)
_BIG_AGE = 1e9 # sentinel: source never received
class TrackingFusionNode(Node):
def __init__(self):
super().__init__("tracking_fusion")
# ── Parameters ────────────────────────────────────────────────────────
self.declare_parameter("uwb_timeout", 1.5)
self.declare_parameter("cam_timeout", 1.0)
self.declare_parameter("predict_timeout", 3.0)
self.declare_parameter("process_noise", 3.0)
self.declare_parameter("meas_noise_uwb", 0.20)
self.declare_parameter("meas_noise_cam", 0.12)
self.declare_parameter("control_rate", 20.0)
self.declare_parameter("confidence_threshold", 0.15)
self._p = self._load_params()
# ── State ─────────────────────────────────────────────────────────────
self._kf = KalmanTracker(
process_noise=self._p["process_noise"],
meas_noise_uwb=self._p["meas_noise_uwb"],
meas_noise_cam=self._p["meas_noise_cam"],
)
self._last_uwb_t: float = 0.0 # monotonic; 0 = never received
self._last_cam_t: float = 0.0
self._uwb_x: float = 0.0
self._uwb_y: float = 0.0
self._cam_x: float = 0.0
self._cam_y: float = 0.0
self._uwb_tag_id: str = ""
self._last_predict_t: float = 0.0 # monotonic time of last predict call
self._last_any_t: float = 0.0 # monotonic time of last live measurement
# ── Subscriptions ─────────────────────────────────────────────────────
self.create_subscription(
PoseStamped, "/uwb/target", self._uwb_cb, 10)
self.create_subscription(
PoseStamped, "/person/target", self._cam_cb, 10)
# ── Publisher ─────────────────────────────────────────────────────────
self._pub = self.create_publisher(FusedTarget, "/social/tracking/fused_target", 10)
# ── Timer ─────────────────────────────────────────────────────────────
self._timer = self.create_timer(
1.0 / self._p["control_rate"], self._control_cb)
self.get_logger().info(
f"TrackingFusion ready "
f"rate={self._p['control_rate']}Hz "
f"uwb_timeout={self._p['uwb_timeout']}s "
f"cam_timeout={self._p['cam_timeout']}s "
f"predict_timeout={self._p['predict_timeout']}s "
f"process_noise={self._p['process_noise']}m/s²"
)
# ── Parameter helpers ──────────────────────────────────────────────────────
def _load_params(self) -> dict:
return {
"uwb_timeout": self.get_parameter("uwb_timeout").value,
"cam_timeout": self.get_parameter("cam_timeout").value,
"predict_timeout": self.get_parameter("predict_timeout").value,
"process_noise": self.get_parameter("process_noise").value,
"meas_noise_uwb": self.get_parameter("meas_noise_uwb").value,
"meas_noise_cam": self.get_parameter("meas_noise_cam").value,
"control_rate": self.get_parameter("control_rate").value,
"confidence_threshold": self.get_parameter("confidence_threshold").value,
}
# ── Measurement callbacks ──────────────────────────────────────────────────
def _uwb_cb(self, msg: PoseStamped) -> None:
self._uwb_x = msg.pose.position.x
self._uwb_y = msg.pose.position.y
self._uwb_tag_id = "" # PoseStamped has no tag field; tag reported via /uwb/bearing
t = time.monotonic()
self._last_uwb_t = t
self._last_any_t = t
# Seed KF on first measurement
if not self._kf.initialized:
self._kf.initialize(self._uwb_x, self._uwb_y)
self._last_predict_t = t
def _cam_cb(self, msg: PoseStamped) -> None:
self._cam_x = msg.pose.position.x
self._cam_y = msg.pose.position.y
t = time.monotonic()
self._last_cam_t = t
self._last_any_t = t
if not self._kf.initialized:
self._kf.initialize(self._cam_x, self._cam_y)
self._last_predict_t = t
# ── Control loop ───────────────────────────────────────────────────────────
def _control_cb(self) -> None:
self._p = self._load_params()
if not self._kf.initialized:
return # no data yet — nothing to publish
now = time.monotonic()
dt = now - self._last_predict_t if self._last_predict_t > 0.0 else (
1.0 / self._p["control_rate"]
)
self._last_predict_t = now
# KF predict
self._kf.predict(dt)
# Source confidence
uwb_age = (now - self._last_uwb_t) if self._last_uwb_t > 0.0 else _BIG_AGE
cam_age = (now - self._last_cam_t) if self._last_cam_t > 0.0 else _BIG_AGE
u_conf = uwb_confidence(uwb_age, self._p["uwb_timeout"])
c_conf = camera_confidence(cam_age, self._p["cam_timeout"])
threshold = self._p["confidence_threshold"]
source = select_source(u_conf, c_conf, threshold)
if source == "predicted":
# Check predict_timeout — stop publishing if too stale
last_live_age = (
(now - self._last_any_t) if self._last_any_t > 0.0 else _BIG_AGE
)
if last_live_age > self._p["predict_timeout"]:
return # silently stop publishing
# Apply measurement update if a live source exists
if source == "fused":
fx, fy = fuse_positions(
self._uwb_x, self._uwb_y, u_conf,
self._cam_x, self._cam_y, c_conf,
)
self._kf.update(fx, fy, source="uwb") # use tighter noise for blended
elif source == "uwb":
self._kf.update(self._uwb_x, self._uwb_y, source="uwb")
elif source == "camera":
self._kf.update(self._cam_x, self._cam_y, source="camera")
# "predicted" → no update; KF coasts
# Build and publish message
kx, ky = self._kf.position
vx, vy = self._kf.velocity
kf_unc = self._kf.position_uncertainty_m()
conf = composite_confidence(u_conf, c_conf, source, kf_unc)
bearing, range_m = bearing_and_range(kx, ky)
hdr = Header()
hdr.stamp = self.get_clock().now().to_msg()
hdr.frame_id = "base_link"
msg = FusedTarget()
msg.header = hdr
msg.position.x = kx
msg.position.y = ky
msg.position.z = 0.0
msg.velocity.x = vx
msg.velocity.y = vy
msg.velocity.z = 0.0
msg.range_m = float(range_m)
msg.bearing_rad = float(bearing)
msg.confidence = float(conf)
msg.active_source = source
msg.tag_id = self._uwb_tag_id if "uwb" in source else ""
self._pub.publish(msg)
# ── Entry point ────────────────────────────────────────────────────────────────
def main(args=None):
rclpy.init(args=args)
node = TrackingFusionNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.try_shutdown()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,5 @@
[develop]
script_dir=$base/lib/saltybot_social_tracking
[install]
install_scripts=$base/lib/saltybot_social_tracking

View File

@ -0,0 +1,32 @@
from setuptools import setup, find_packages
import os
from glob import glob
package_name = "saltybot_social_tracking"
setup(
name=package_name,
version="0.1.0",
packages=find_packages(exclude=["test"]),
data_files=[
("share/ament_index/resource_index/packages",
[f"resource/{package_name}"]),
(f"share/{package_name}", ["package.xml"]),
(os.path.join("share", package_name, "config"),
glob("config/*.yaml")),
(os.path.join("share", package_name, "launch"),
glob("launch/*.py")),
],
install_requires=["setuptools"],
zip_safe=True,
maintainer="sl-controls",
maintainer_email="sl-controls@saltylab.local",
description="Multi-modal tracking fusion (UWB + camera Kalman filter)",
license="MIT",
tests_require=["pytest"],
entry_points={
"console_scripts": [
f"tracking_fusion_node = {package_name}.tracking_fusion_node:main",
],
},
)

View File

@ -0,0 +1,438 @@
"""
test_tracking_fusion.py Unit tests for saltybot_social_tracking pure modules.
Tests cover:
- KalmanTracker: initialization, predict, update, state accessors
- source_arbiter: confidence functions, source selection, fusion, bearing
No ROS2 runtime required.
"""
import math
import sys
import os
# Allow running: python -m pytest test/test_tracking_fusion.py
# from the package root without installing the package.
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import pytest
import numpy as np
from saltybot_social_tracking.kalman_tracker import KalmanTracker
from saltybot_social_tracking.source_arbiter import (
uwb_confidence,
camera_confidence,
select_source,
fuse_positions,
composite_confidence,
bearing_and_range,
)
# ─────────────────────────────────────────────────────────────────────────────
# KalmanTracker tests
# ─────────────────────────────────────────────────────────────────────────────
class TestKalmanTrackerInit:
def test_not_initialized_by_default(self):
kf = KalmanTracker()
assert not kf.initialized
def test_initialize_sets_position(self):
kf = KalmanTracker()
kf.initialize(3.0, 1.5)
assert kf.initialized
x, y = kf.position
assert abs(x - 3.0) < 1e-9
assert abs(y - 1.5) < 1e-9
def test_initialize_sets_zero_velocity(self):
kf = KalmanTracker()
kf.initialize(1.0, -2.0)
vx, vy = kf.velocity
assert abs(vx) < 1e-9
assert abs(vy) < 1e-9
def test_initialize_origin(self):
kf = KalmanTracker()
kf.initialize(0.0, 0.0)
assert kf.initialized
x, y = kf.position
assert x == 0.0 and y == 0.0
class TestKalmanTrackerPredict:
def test_predict_zero_dt_no_change(self):
kf = KalmanTracker()
kf.initialize(2.0, 1.0)
kf.predict(0.0)
x, y = kf.position
assert abs(x - 2.0) < 1e-9
assert abs(y - 1.0) < 1e-9
def test_predict_negative_dt_no_change(self):
kf = KalmanTracker()
kf.initialize(2.0, 1.0)
kf.predict(-0.1)
x, y = kf.position
assert abs(x - 2.0) < 1e-9
def test_predict_constant_velocity(self):
"""After a position update gives the filter a velocity, predict should extrapolate."""
kf = KalmanTracker(process_noise=0.001, meas_noise_uwb=0.001)
kf.initialize(0.0, 0.0)
# Force filter to track a moving target to build up velocity estimate
dt = 0.05
for i in range(40):
t = i * dt
kf.predict(dt)
kf.update(2.0 * t, 0.0, "uwb") # 2 m/s in x
# After many updates the velocity estimate should be close to 2 m/s
vx, vy = kf.velocity
assert abs(vx - 2.0) < 0.3, f"vx={vx:.3f}"
assert abs(vy) < 0.2
def test_predict_grows_uncertainty(self):
kf = KalmanTracker()
kf.initialize(0.0, 0.0)
unc_before = kf.position_uncertainty_m()
kf.predict(1.0)
unc_after = kf.position_uncertainty_m()
assert unc_after > unc_before
def test_predict_multiple_steps(self):
kf = KalmanTracker()
kf.initialize(0.0, 0.0)
kf.predict(0.1)
kf.predict(0.1)
kf.predict(0.1)
# No assertion on exact value; just verify no exception and state is finite
x, y = kf.position
assert math.isfinite(x) and math.isfinite(y)
class TestKalmanTrackerUpdate:
def test_update_pulls_position_toward_measurement(self):
kf = KalmanTracker()
kf.initialize(0.0, 0.0)
kf.update(5.0, 5.0, "uwb")
x, y = kf.position
assert x > 0.0 and y > 0.0
assert x < 5.0 and y < 5.0 # blended, not jumped
def test_update_reduces_uncertainty(self):
kf = KalmanTracker()
kf.initialize(0.0, 0.0)
kf.predict(1.0) # uncertainty grows
unc_mid = kf.position_uncertainty_m()
kf.update(0.1, 0.1, "uwb") # measurement corrects
unc_after = kf.position_uncertainty_m()
assert unc_after < unc_mid
def test_update_converges_to_true_position(self):
"""Many updates from same point should converge to that point."""
kf = KalmanTracker(meas_noise_uwb=0.01)
kf.initialize(0.0, 0.0)
for _ in range(50):
kf.update(3.0, -1.0, "uwb")
x, y = kf.position
assert abs(x - 3.0) < 0.05, f"x={x:.4f}"
assert abs(y - (-1.0)) < 0.05, f"y={y:.4f}"
def test_update_camera_source_different_noise(self):
"""Camera and UWB updates should both move state (noise model differs)."""
kf1 = KalmanTracker(meas_noise_uwb=0.20, meas_noise_cam=0.10)
kf1.initialize(0.0, 0.0)
kf1.update(5.0, 0.0, "uwb")
x_uwb, _ = kf1.position
kf2 = KalmanTracker(meas_noise_uwb=0.20, meas_noise_cam=0.10)
kf2.initialize(0.0, 0.0)
kf2.update(5.0, 0.0, "camera")
x_cam, _ = kf2.position
# Camera has lower noise → stronger pull toward measurement
assert x_cam > x_uwb
def test_update_unknown_source_defaults_to_camera_noise(self):
kf = KalmanTracker()
kf.initialize(0.0, 0.0)
kf.update(2.0, 0.0, "other") # unknown source — should not raise
x, _ = kf.position
assert x > 0.0
def test_position_uncertainty_finite(self):
kf = KalmanTracker()
kf.initialize(1.0, 1.0)
kf.predict(0.05)
kf.update(1.1, 0.9, "uwb")
assert math.isfinite(kf.position_uncertainty_m())
assert kf.position_uncertainty_m() >= 0.0
def test_covariance_copy_is_independent(self):
kf = KalmanTracker()
kf.initialize(0.0, 0.0)
cov = kf.covariance_copy()
cov[0, 0] = 9999.0 # mutate copy
assert kf.covariance_copy()[0, 0] != 9999.0
# ─────────────────────────────────────────────────────────────────────────────
# source_arbiter tests
# ─────────────────────────────────────────────────────────────────────────────
class TestUwbConfidence:
def test_zero_age_returns_quality(self):
assert abs(uwb_confidence(0.0, 1.5) - 1.0) < 1e-9
def test_at_timeout_returns_zero(self):
assert uwb_confidence(1.5, 1.5) == pytest.approx(0.0)
def test_beyond_timeout_returns_zero(self):
assert uwb_confidence(2.0, 1.5) == 0.0
def test_half_timeout_returns_half(self):
assert uwb_confidence(0.75, 1.5) == pytest.approx(0.5)
def test_quality_scales_result(self):
assert uwb_confidence(0.0, 1.5, quality=0.7) == pytest.approx(0.7)
def test_large_age_returns_zero(self):
assert uwb_confidence(1e9, 1.5) == 0.0
def test_zero_timeout_returns_zero(self):
assert uwb_confidence(0.0, 0.0) == 0.0
class TestCameraConfidence:
def test_zero_age_full_quality(self):
assert camera_confidence(0.0, 1.0, quality=1.0) == pytest.approx(1.0)
def test_at_timeout_zero(self):
assert camera_confidence(1.0, 1.0) == pytest.approx(0.0)
def test_beyond_timeout_zero(self):
assert camera_confidence(2.0, 1.0) == 0.0
def test_quality_scales(self):
# age=0, quality=0.8
assert camera_confidence(0.0, 1.0, quality=0.8) == pytest.approx(0.8)
def test_halfway(self):
assert camera_confidence(0.5, 1.0) == pytest.approx(0.5)
class TestSelectSource:
def test_both_above_threshold_fused(self):
assert select_source(0.8, 0.6) == "fused"
def test_only_uwb_above_threshold(self):
assert select_source(0.8, 0.0) == "uwb"
def test_only_cam_above_threshold(self):
assert select_source(0.0, 0.7) == "camera"
def test_both_below_threshold(self):
assert select_source(0.0, 0.0) == "predicted"
def test_threshold_boundary_uwb(self):
# Exactly at threshold — should be treated as live
assert select_source(0.15, 0.0, threshold=0.15) == "uwb"
def test_threshold_boundary_below(self):
assert select_source(0.14, 0.0, threshold=0.15) == "predicted"
def test_custom_threshold(self):
assert select_source(0.5, 0.0, threshold=0.6) == "predicted"
assert select_source(0.5, 0.0, threshold=0.4) == "uwb"
class TestFusePositions:
def test_equal_confidence_returns_midpoint(self):
x, y = fuse_positions(0.0, 0.0, 1.0, 4.0, 4.0, 1.0)
assert abs(x - 2.0) < 1e-9
assert abs(y - 2.0) < 1e-9
def test_full_uwb_weight_returns_uwb(self):
x, y = fuse_positions(3.0, 1.0, 1.0, 0.0, 0.0, 0.0)
assert abs(x - 3.0) < 1e-9
def test_full_cam_weight_returns_cam(self):
x, y = fuse_positions(0.0, 0.0, 0.0, -2.0, 5.0, 1.0)
assert abs(x - (-2.0)) < 1e-9
assert abs(y - 5.0) < 1e-9
def test_weighted_blend(self):
# UWB at (0,0) conf=3, camera at (4,0) conf=1 → x = 3/4*0 + 1/4*4 = 1
x, y = fuse_positions(0.0, 0.0, 3.0, 4.0, 0.0, 1.0)
assert abs(x - 1.0) < 1e-9
def test_zero_total_returns_uwb_fallback(self):
x, y = fuse_positions(7.0, 2.0, 0.0, 3.0, 1.0, 0.0)
assert abs(x - 7.0) < 1e-9
class TestCompositeConfidence:
def test_fused_source_high_confidence(self):
conf = composite_confidence(0.9, 0.8, "fused", 0.05)
assert conf > 0.7
def test_predicted_source_capped(self):
conf = composite_confidence(0.0, 0.0, "predicted", 0.1)
assert conf <= 0.4
def test_predicted_high_uncertainty_low_confidence(self):
conf = composite_confidence(0.0, 0.0, "predicted", 3.0, max_kf_uncertainty_m=3.0)
assert conf == pytest.approx(0.0)
def test_uwb_only(self):
conf = composite_confidence(0.8, 0.0, "uwb", 0.05)
assert conf > 0.3
def test_camera_only(self):
conf = composite_confidence(0.0, 0.7, "camera", 0.05)
assert conf > 0.2
def test_high_kf_uncertainty_reduces_confidence(self):
low_unc = composite_confidence(0.9, 0.0, "uwb", 0.1)
high_unc = composite_confidence(0.9, 0.0, "uwb", 2.9)
assert low_unc > high_unc
class TestBearingAndRange:
def test_straight_ahead(self):
bearing, rng = bearing_and_range(2.0, 0.0)
assert abs(bearing) < 1e-9
assert abs(rng - 2.0) < 1e-9
def test_left_of_robot(self):
# +Y = left in base_link frame; bearing should be positive
bearing, rng = bearing_and_range(0.0, 1.0)
assert abs(bearing - math.pi / 2.0) < 1e-9
assert abs(rng - 1.0) < 1e-9
def test_right_of_robot(self):
bearing, rng = bearing_and_range(0.0, -1.0)
assert abs(bearing - (-math.pi / 2.0)) < 1e-9
def test_diagonal(self):
bearing, rng = bearing_and_range(1.0, 1.0)
assert abs(bearing - math.pi / 4.0) < 1e-9
assert abs(rng - math.sqrt(2.0)) < 1e-9
def test_at_origin(self):
bearing, rng = bearing_and_range(0.0, 0.0)
assert rng == pytest.approx(0.0)
assert math.isfinite(bearing) # atan2(0,0) = 0 in most implementations
def test_range_always_non_negative(self):
for x, y in [(-1, 0), (0, -1), (-2, -3), (5, -5)]:
_, rng = bearing_and_range(x, y)
assert rng >= 0.0
# ─────────────────────────────────────────────────────────────────────────────
# Integration scenario tests
# ─────────────────────────────────────────────────────────────────────────────
class TestIntegrationScenarios:
def test_euc_speed_velocity_tracking(self):
"""Verify KF can track EUC speed (8 m/s) within 0.5 m/s after warm-up."""
kf = KalmanTracker(process_noise=3.0, meas_noise_uwb=0.20)
kf.initialize(0.0, 0.0)
dt = 1.0 / 10.0 # 10 Hz UWB rate
speed = 8.0 # m/s (≈29 km/h)
for i in range(60):
t = i * dt
kf.predict(dt)
kf.update(speed * t, 0.0, "uwb")
vx, vy = kf.velocity
assert abs(vx - speed) < 0.6, f"vx={vx:.2f} expected≈{speed}"
assert abs(vy) < 0.3
def test_signal_loss_recovery(self):
"""
After 1 s of signal loss the filter should still have a reasonable
position estimate (not diverged to infinity).
"""
kf = KalmanTracker(process_noise=3.0)
kf.initialize(2.0, 0.5)
# Warm up with 2 m/s x motion
dt = 0.05
for i in range(20):
kf.predict(dt)
kf.update(2.0 * (i + 1) * dt, 0.0, "uwb")
# Coast for 1 second (20 × 50 ms) without measurements
for _ in range(20):
kf.predict(dt)
x, y = kf.position
assert math.isfinite(x) and math.isfinite(y)
assert abs(x) < 20.0 # shouldn't have drifted more than 20 m
def test_uwb_to_camera_handoff(self):
"""
Simulate UWB going stale and camera taking over Kalman should
smoothly continue tracking without a jump.
"""
kf = KalmanTracker(meas_noise_uwb=0.20, meas_noise_cam=0.12)
kf.initialize(0.0, 0.0)
dt = 0.05
# Phase 1: UWB active
for i in range(20):
kf.predict(dt)
kf.update(float(i) * 0.1, 0.0, "uwb")
x_at_handoff, _ = kf.position
# Phase 2: Camera takes over from same trajectory
for i in range(20, 40):
kf.predict(dt)
kf.update(float(i) * 0.1, 0.0, "camera")
x_after, _ = kf.position
# Position should have continued progressing (not stuck or reset)
assert x_after > x_at_handoff
def test_confidence_degradation_during_coast(self):
"""Composite confidence should drop as KF uncertainty grows during coast."""
kf = KalmanTracker(process_noise=3.0)
kf.initialize(2.0, 0.0)
# Fresh: tight uncertainty → high confidence
unc_fresh = kf.position_uncertainty_m()
conf_fresh = composite_confidence(0.0, 0.0, "predicted", unc_fresh)
# After 2 s coast
for _ in range(40):
kf.predict(0.05)
unc_stale = kf.position_uncertainty_m()
conf_stale = composite_confidence(0.0, 0.0, "predicted", unc_stale)
assert conf_fresh >= conf_stale
def test_fused_source_confidence_weighted_position(self):
"""Fused position should sit between UWB and camera, closer to higher-conf source."""
# UWB at x=0 with high conf, camera at x=10 with low conf
uwb_c = 0.9
cam_c = 0.1
fx, fy = fuse_positions(0.0, 0.0, uwb_c, 10.0, 0.0, cam_c)
# Should be much closer to UWB (0) than camera (10)
assert fx < 3.0, f"fused_x={fx:.2f}"
def test_select_source_transitions(self):
"""Verify correct source transitions as confidences change."""
assert select_source(0.9, 0.8) == "fused"
assert select_source(0.9, 0.0) == "uwb"
assert select_source(0.0, 0.8) == "camera"
assert select_source(0.0, 0.0) == "predicted"