feat(social): conversation topic memory (Issue #299) #304
@ -0,0 +1,14 @@
|
|||||||
|
topic_memory_node:
|
||||||
|
ros__parameters:
|
||||||
|
conversation_topic: "/social/conversation_text" # Input: JSON String {person_id, text}
|
||||||
|
output_topic: "/saltybot/conversation_topics" # Output: JSON String per utterance
|
||||||
|
|
||||||
|
# Keyword extraction
|
||||||
|
min_word_length: 3 # Skip words shorter than this
|
||||||
|
max_keywords_per_msg: 10 # Extract at most this many keywords per utterance
|
||||||
|
|
||||||
|
# Per-person rolling window
|
||||||
|
max_topics_per_person: 30 # Keep last N unique topics per person
|
||||||
|
|
||||||
|
# Stale-person pruning (0 = disabled)
|
||||||
|
prune_after_s: 1800.0 # Forget person after 30 min of inactivity
|
||||||
@ -0,0 +1,42 @@
|
|||||||
|
"""topic_memory.launch.py — Launch conversation topic memory node (Issue #299).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ros2 launch saltybot_social topic_memory.launch.py
|
||||||
|
ros2 launch saltybot_social topic_memory.launch.py max_topics_per_person:=50
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from ament_index_python.packages import get_package_share_directory
|
||||||
|
from launch import LaunchDescription
|
||||||
|
from launch.actions import DeclareLaunchArgument
|
||||||
|
from launch.substitutions import LaunchConfiguration
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launch_description():
|
||||||
|
pkg = get_package_share_directory("saltybot_social")
|
||||||
|
cfg = os.path.join(pkg, "config", "topic_memory_params.yaml")
|
||||||
|
|
||||||
|
return LaunchDescription([
|
||||||
|
DeclareLaunchArgument("max_topics_per_person", default_value="30",
|
||||||
|
description="Rolling topic window per person"),
|
||||||
|
DeclareLaunchArgument("max_keywords_per_msg", default_value="10",
|
||||||
|
description="Max keywords extracted per utterance"),
|
||||||
|
DeclareLaunchArgument("prune_after_s", default_value="1800.0",
|
||||||
|
description="Forget persons idle this long (0=off)"),
|
||||||
|
|
||||||
|
Node(
|
||||||
|
package="saltybot_social",
|
||||||
|
executable="topic_memory_node",
|
||||||
|
name="topic_memory_node",
|
||||||
|
output="screen",
|
||||||
|
parameters=[
|
||||||
|
cfg,
|
||||||
|
{
|
||||||
|
"max_topics_per_person": LaunchConfiguration("max_topics_per_person"),
|
||||||
|
"max_keywords_per_msg": LaunchConfiguration("max_keywords_per_msg"),
|
||||||
|
"prune_after_s": LaunchConfiguration("prune_after_s"),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
])
|
||||||
@ -0,0 +1,268 @@
|
|||||||
|
"""topic_memory_node.py — Conversation topic memory.
|
||||||
|
Issue #299
|
||||||
|
|
||||||
|
Subscribes to /social/conversation_text (std_msgs/String, JSON payload
|
||||||
|
{"person_id": "...", "text": "..."}), extracts key topics via stop-word
|
||||||
|
filtered keyword extraction, and maintains a per-person rolling topic
|
||||||
|
history.
|
||||||
|
|
||||||
|
On every message that yields at least one new keyword the node publishes
|
||||||
|
an updated topic snapshot on /saltybot/conversation_topics (std_msgs/String,
|
||||||
|
JSON) — enabling recall like "last time you mentioned coffee" or "you talked
|
||||||
|
about the weather with alice".
|
||||||
|
|
||||||
|
Published JSON format
|
||||||
|
─────────────────────
|
||||||
|
{
|
||||||
|
"person_id": "alice",
|
||||||
|
"recent_topics": ["coffee", "weather", "robot"], // most-recent first
|
||||||
|
"new_topics": ["coffee"], // keywords added this turn
|
||||||
|
"ts": 1234567890.123
|
||||||
|
}
|
||||||
|
|
||||||
|
Keyword extraction pipeline
|
||||||
|
────────────────────────────
|
||||||
|
1. Lowercase + tokenise on non-word characters
|
||||||
|
2. Filter: length >= min_word_length, alphabetic only
|
||||||
|
3. Remove stop words (built-in English list)
|
||||||
|
4. Deduplicate within the utterance
|
||||||
|
5. Take first max_keywords_per_msg tokens
|
||||||
|
|
||||||
|
Per-person storage
|
||||||
|
──────────────────
|
||||||
|
Ordered list (insertion order), capped at max_topics_per_person.
|
||||||
|
Duplicate keywords are promoted to the front (most-recent position).
|
||||||
|
Persons not seen for prune_after_s seconds are pruned on next publish
|
||||||
|
(set to 0 to disable pruning).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
──────────
|
||||||
|
conversation_topic (str, "/social/conversation_text")
|
||||||
|
output_topic (str, "/saltybot/conversation_topics")
|
||||||
|
min_word_length (int, 3) minimum character length to keep
|
||||||
|
max_keywords_per_msg (int, 10) max keywords extracted per utterance
|
||||||
|
max_topics_per_person(int, 30) rolling window per person
|
||||||
|
prune_after_s (float, 1800.0) forget persons idle this long (0=off)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import string
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile
|
||||||
|
from std_msgs.msg import String
|
||||||
|
|
||||||
|
|
||||||
|
# ── Stop-word list (English) ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
STOP_WORDS: frozenset = frozenset({
|
||||||
|
"the", "a", "an", "and", "or", "but", "nor", "not", "so", "yet",
|
||||||
|
"for", "with", "from", "into", "onto", "upon", "about", "above",
|
||||||
|
"after", "before", "between", "during", "through", "under", "over",
|
||||||
|
"at", "by", "in", "of", "on", "to", "up", "as",
|
||||||
|
"is", "are", "was", "were", "be", "been", "being",
|
||||||
|
"have", "has", "had", "do", "does", "did",
|
||||||
|
"will", "would", "could", "should", "may", "might", "shall", "can",
|
||||||
|
"it", "its", "this", "that", "these", "those",
|
||||||
|
"i", "me", "my", "myself", "we", "our", "ours",
|
||||||
|
"you", "your", "yours", "he", "him", "his", "she", "her", "hers",
|
||||||
|
"they", "them", "their", "theirs",
|
||||||
|
"what", "which", "who", "whom", "when", "where", "why", "how",
|
||||||
|
"all", "each", "every", "both", "few", "more", "most",
|
||||||
|
"other", "some", "such", "no", "only", "own", "same",
|
||||||
|
"than", "then", "too", "very", "just", "also",
|
||||||
|
"get", "got", "say", "said", "know", "think", "go", "going", "come",
|
||||||
|
"like", "want", "see", "take", "make", "give", "look",
|
||||||
|
"yes", "yeah", "okay", "ok", "oh", "ah", "um", "uh", "well",
|
||||||
|
"now", "here", "there", "hi", "hey", "hello",
|
||||||
|
})
|
||||||
|
|
||||||
|
_PUNCT_RE = re.compile(r"[^\w\s]")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Keyword extraction ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def extract_keywords(text: str,
|
||||||
|
min_length: int = 3,
|
||||||
|
max_keywords: int = 10) -> List[str]:
|
||||||
|
"""Return a deduplicated list of meaningful keywords from *text*.
|
||||||
|
|
||||||
|
Steps: lowercase -> strip punctuation -> split -> filter stop words &
|
||||||
|
length -> deduplicate -> cap at max_keywords.
|
||||||
|
"""
|
||||||
|
cleaned = _PUNCT_RE.sub(" ", text.lower())
|
||||||
|
seen: dict = {} # ordered-set via insertion-order dict
|
||||||
|
for tok in cleaned.split():
|
||||||
|
tok = tok.strip(string.punctuation + "_")
|
||||||
|
if (len(tok) >= min_length
|
||||||
|
and tok.isalpha()
|
||||||
|
and tok not in STOP_WORDS
|
||||||
|
and tok not in seen):
|
||||||
|
seen[tok] = None
|
||||||
|
if len(seen) >= max_keywords:
|
||||||
|
break
|
||||||
|
return list(seen)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Per-person topic memory ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class PersonTopicMemory:
|
||||||
|
"""Rolling, deduplicated topic list for one person."""
|
||||||
|
|
||||||
|
def __init__(self, max_topics: int = 30) -> None:
|
||||||
|
self._max = max_topics
|
||||||
|
self._topics: List[str] = [] # oldest -> newest order
|
||||||
|
self._topic_set: set = set()
|
||||||
|
self.last_updated: float = 0.0
|
||||||
|
|
||||||
|
def add(self, keywords: List[str]) -> List[str]:
|
||||||
|
"""Add *keywords*; promote duplicates to front, evict oldest over cap.
|
||||||
|
|
||||||
|
Returns the list of newly added (previously unseen) keywords.
|
||||||
|
"""
|
||||||
|
added: List[str] = []
|
||||||
|
for kw in keywords:
|
||||||
|
if kw in self._topic_set:
|
||||||
|
# Promote to most-recent position
|
||||||
|
self._topics.remove(kw)
|
||||||
|
self._topics.append(kw)
|
||||||
|
else:
|
||||||
|
self._topics.append(kw)
|
||||||
|
self._topic_set.add(kw)
|
||||||
|
added.append(kw)
|
||||||
|
# Trim oldest if over cap
|
||||||
|
while len(self._topics) > self._max:
|
||||||
|
evicted = self._topics.pop(0)
|
||||||
|
self._topic_set.discard(evicted)
|
||||||
|
self.last_updated = time.monotonic()
|
||||||
|
return added
|
||||||
|
|
||||||
|
@property
|
||||||
|
def recent_topics(self) -> List[str]:
|
||||||
|
"""Most-recent topics first."""
|
||||||
|
return list(reversed(self._topics))
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._topics)
|
||||||
|
|
||||||
|
|
||||||
|
# ── ROS2 node ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TopicMemoryNode(Node):
|
||||||
|
"""Extracts and remembers conversation topics per person."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__("topic_memory_node")
|
||||||
|
|
||||||
|
self.declare_parameter("conversation_topic", "/social/conversation_text")
|
||||||
|
self.declare_parameter("output_topic", "/saltybot/conversation_topics")
|
||||||
|
self.declare_parameter("min_word_length", 3)
|
||||||
|
self.declare_parameter("max_keywords_per_msg", 10)
|
||||||
|
self.declare_parameter("max_topics_per_person", 30)
|
||||||
|
self.declare_parameter("prune_after_s", 1800.0)
|
||||||
|
|
||||||
|
conv_topic = self.get_parameter("conversation_topic").value
|
||||||
|
out_topic = self.get_parameter("output_topic").value
|
||||||
|
self._min_len = self.get_parameter("min_word_length").value
|
||||||
|
self._max_kw = self.get_parameter("max_keywords_per_msg").value
|
||||||
|
self._max_tp = self.get_parameter("max_topics_per_person").value
|
||||||
|
self._prune_s = self.get_parameter("prune_after_s").value
|
||||||
|
|
||||||
|
# person_id -> PersonTopicMemory
|
||||||
|
self._memory: Dict[str, PersonTopicMemory] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
qos = QoSProfile(depth=10)
|
||||||
|
self._pub = self.create_publisher(String, out_topic, qos)
|
||||||
|
self._sub = self.create_subscription(
|
||||||
|
String, conv_topic, self._on_conversation, qos
|
||||||
|
)
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f"TopicMemoryNode ready "
|
||||||
|
f"(min_len={self._min_len}, max_kw={self._max_kw}, "
|
||||||
|
f"max_topics={self._max_tp}, prune_after={self._prune_s}s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Subscription ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _on_conversation(self, msg: String) -> None:
|
||||||
|
try:
|
||||||
|
payload = json.loads(msg.data)
|
||||||
|
person_id: str = str(payload.get("person_id", "unknown"))
|
||||||
|
text: str = str(payload.get("text", ""))
|
||||||
|
except (json.JSONDecodeError, AttributeError) as exc:
|
||||||
|
self.get_logger().warn(f"Bad conversation_text payload: {exc}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not text.strip():
|
||||||
|
return
|
||||||
|
|
||||||
|
keywords = extract_keywords(text, self._min_len, self._max_kw)
|
||||||
|
if not keywords:
|
||||||
|
return
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self._prune_stale()
|
||||||
|
if person_id not in self._memory:
|
||||||
|
self._memory[person_id] = PersonTopicMemory(self._max_tp)
|
||||||
|
mem = self._memory[person_id]
|
||||||
|
new_topics = mem.add(keywords)
|
||||||
|
recent = mem.recent_topics
|
||||||
|
|
||||||
|
out = String()
|
||||||
|
out.data = json.dumps({
|
||||||
|
"person_id": person_id,
|
||||||
|
"recent_topics": recent,
|
||||||
|
"new_topics": new_topics,
|
||||||
|
"ts": time.time(),
|
||||||
|
})
|
||||||
|
self._pub.publish(out)
|
||||||
|
|
||||||
|
if new_topics:
|
||||||
|
self.get_logger().info(
|
||||||
|
f"[{person_id}] new topics: {new_topics} | "
|
||||||
|
f"memory: {recent[:5]}{'...' if len(recent) > 5 else ''}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Helpers ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _prune_stale(self) -> None:
|
||||||
|
"""Remove persons not seen for prune_after_s seconds (call under lock)."""
|
||||||
|
if self._prune_s <= 0:
|
||||||
|
return
|
||||||
|
now = time.monotonic()
|
||||||
|
stale = [pid for pid, m in self._memory.items()
|
||||||
|
if (now - m.last_updated) > self._prune_s]
|
||||||
|
for pid in stale:
|
||||||
|
del self._memory[pid]
|
||||||
|
self.get_logger().info(f"Pruned stale person: {pid}")
|
||||||
|
|
||||||
|
def get_memory(self, person_id: str) -> Optional[PersonTopicMemory]:
|
||||||
|
"""Return topic memory for a person (None if not seen)."""
|
||||||
|
with self._lock:
|
||||||
|
return self._memory.get(person_id)
|
||||||
|
|
||||||
|
def all_persons(self) -> Dict[str, List[str]]:
|
||||||
|
"""Return {person_id: recent_topics} snapshot for all known persons."""
|
||||||
|
with self._lock:
|
||||||
|
return {pid: m.recent_topics for pid, m in self._memory.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = TopicMemoryNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
@ -53,6 +53,8 @@ setup(
|
|||||||
'face_track_servo_node = saltybot_social.face_track_servo_node:main',
|
'face_track_servo_node = saltybot_social.face_track_servo_node:main',
|
||||||
# Speech volume auto-adjuster (Issue #289)
|
# Speech volume auto-adjuster (Issue #289)
|
||||||
'volume_adjust_node = saltybot_social.volume_adjust_node:main',
|
'volume_adjust_node = saltybot_social.volume_adjust_node:main',
|
||||||
|
# Conversation topic memory (Issue #299)
|
||||||
|
'topic_memory_node = saltybot_social.topic_memory_node:main',
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
522
jetson/ros2_ws/src/saltybot_social/test/test_topic_memory.py
Normal file
522
jetson/ros2_ws/src/saltybot_social/test/test_topic_memory.py
Normal file
@ -0,0 +1,522 @@
|
|||||||
|
"""test_topic_memory.py — Offline tests for topic_memory_node (Issue #299).
|
||||||
|
|
||||||
|
Stubs out rclpy so tests run without a ROS install.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
# ── ROS2 stubs ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _make_ros_stubs():
|
||||||
|
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
|
||||||
|
"std_msgs", "std_msgs.msg"):
|
||||||
|
if mod_name not in sys.modules:
|
||||||
|
sys.modules[mod_name] = types.ModuleType(mod_name)
|
||||||
|
|
||||||
|
class _Node:
|
||||||
|
def __init__(self, name="node"):
|
||||||
|
self._name = name
|
||||||
|
if not hasattr(self, "_params"):
|
||||||
|
self._params = {}
|
||||||
|
self._pubs = {}
|
||||||
|
self._subs = {}
|
||||||
|
self._logs = []
|
||||||
|
|
||||||
|
def declare_parameter(self, name, default):
|
||||||
|
if name not in self._params:
|
||||||
|
self._params[name] = default
|
||||||
|
|
||||||
|
def get_parameter(self, name):
|
||||||
|
class _P:
|
||||||
|
def __init__(self, v): self.value = v
|
||||||
|
return _P(self._params.get(name))
|
||||||
|
|
||||||
|
def create_publisher(self, msg_type, topic, qos):
|
||||||
|
pub = _FakePub()
|
||||||
|
self._pubs[topic] = pub
|
||||||
|
return pub
|
||||||
|
|
||||||
|
def create_subscription(self, msg_type, topic, cb, qos):
|
||||||
|
self._subs[topic] = cb
|
||||||
|
return object()
|
||||||
|
|
||||||
|
def get_logger(self):
|
||||||
|
node = self
|
||||||
|
class _L:
|
||||||
|
def info(self, m): node._logs.append(("INFO", m))
|
||||||
|
def warn(self, m): node._logs.append(("WARN", m))
|
||||||
|
def error(self, m): node._logs.append(("ERROR", m))
|
||||||
|
return _L()
|
||||||
|
|
||||||
|
def destroy_node(self): pass
|
||||||
|
|
||||||
|
class _FakePub:
|
||||||
|
def __init__(self):
|
||||||
|
self.msgs = []
|
||||||
|
def publish(self, msg):
|
||||||
|
self.msgs.append(msg)
|
||||||
|
|
||||||
|
class _QoSProfile:
|
||||||
|
def __init__(self, depth=10): self.depth = depth
|
||||||
|
|
||||||
|
class _String:
|
||||||
|
def __init__(self): self.data = ""
|
||||||
|
|
||||||
|
rclpy_mod = sys.modules["rclpy"]
|
||||||
|
rclpy_mod.init = lambda args=None: None
|
||||||
|
rclpy_mod.spin = lambda node: None
|
||||||
|
rclpy_mod.shutdown = lambda: None
|
||||||
|
|
||||||
|
sys.modules["rclpy.node"].Node = _Node
|
||||||
|
sys.modules["rclpy.qos"].QoSProfile = _QoSProfile
|
||||||
|
sys.modules["std_msgs.msg"].String = _String
|
||||||
|
|
||||||
|
return _Node, _FakePub, _String
|
||||||
|
|
||||||
|
|
||||||
|
_Node, _FakePub, _String = _make_ros_stubs()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Module loader ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_SRC = (
|
||||||
|
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||||
|
"saltybot_social/saltybot_social/topic_memory_node.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_mod():
|
||||||
|
spec = importlib.util.spec_from_file_location("topic_memory_testmod", _SRC)
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(mod)
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
def _make_node(mod, **kwargs):
|
||||||
|
node = mod.TopicMemoryNode.__new__(mod.TopicMemoryNode)
|
||||||
|
defaults = {
|
||||||
|
"conversation_topic": "/social/conversation_text",
|
||||||
|
"output_topic": "/saltybot/conversation_topics",
|
||||||
|
"min_word_length": 3,
|
||||||
|
"max_keywords_per_msg": 10,
|
||||||
|
"max_topics_per_person": 30,
|
||||||
|
"prune_after_s": 1800.0,
|
||||||
|
}
|
||||||
|
defaults.update(kwargs)
|
||||||
|
node._params = dict(defaults)
|
||||||
|
mod.TopicMemoryNode.__init__(node)
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
def _msg(person_id, text):
|
||||||
|
m = _String()
|
||||||
|
m.data = json.dumps({"person_id": person_id, "text": text})
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def _send(node, person_id, text):
|
||||||
|
"""Deliver a conversation message to the node."""
|
||||||
|
cb = node._subs["/social/conversation_text"]
|
||||||
|
cb(_msg(person_id, text))
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests: extract_keywords ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestExtractKeywords(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls): cls.mod = _load_mod()
|
||||||
|
|
||||||
|
def _kw(self, text, min_len=3, max_kw=10):
|
||||||
|
return self.mod.extract_keywords(text, min_len, max_kw)
|
||||||
|
|
||||||
|
def test_basic(self):
|
||||||
|
kws = self._kw("I love drinking coffee")
|
||||||
|
self.assertIn("love", kws)
|
||||||
|
self.assertIn("drinking", kws)
|
||||||
|
self.assertIn("coffee", kws)
|
||||||
|
|
||||||
|
def test_stop_words_removed(self):
|
||||||
|
kws = self._kw("the quick brown fox")
|
||||||
|
self.assertNotIn("the", kws)
|
||||||
|
|
||||||
|
def test_short_words_removed(self):
|
||||||
|
kws = self._kw("go to the museum now", min_len=4)
|
||||||
|
self.assertNotIn("go", kws)
|
||||||
|
self.assertNotIn("to", kws)
|
||||||
|
self.assertNotIn("the", kws)
|
||||||
|
|
||||||
|
def test_deduplication(self):
|
||||||
|
kws = self._kw("coffee coffee coffee")
|
||||||
|
self.assertEqual(kws.count("coffee"), 1)
|
||||||
|
|
||||||
|
def test_max_keywords_cap(self):
|
||||||
|
text = " ".join(f"word{i}" for i in range(20))
|
||||||
|
kws = self._kw(text, max_kw=5)
|
||||||
|
self.assertLessEqual(len(kws), 5)
|
||||||
|
|
||||||
|
def test_punctuation_stripped(self):
|
||||||
|
kws = self._kw("Hello, world! How's weather?")
|
||||||
|
# "world" and "weather" should be found, punctuation removed
|
||||||
|
self.assertIn("world", kws)
|
||||||
|
self.assertIn("weather", kws)
|
||||||
|
|
||||||
|
def test_case_insensitive(self):
|
||||||
|
kws = self._kw("Robot ROBOT robot")
|
||||||
|
self.assertEqual(kws.count("robot"), 1)
|
||||||
|
|
||||||
|
def test_empty_text(self):
|
||||||
|
self.assertEqual(self._kw(""), [])
|
||||||
|
|
||||||
|
def test_all_stop_words(self):
|
||||||
|
self.assertEqual(self._kw("the is a and"), [])
|
||||||
|
|
||||||
|
def test_non_alpha_excluded(self):
|
||||||
|
kws = self._kw("model42 123 price500")
|
||||||
|
# alphanumeric tokens like "model42" contain digits → excluded
|
||||||
|
self.assertEqual(kws, [])
|
||||||
|
|
||||||
|
def test_preserves_order(self):
|
||||||
|
kws = self._kw("zebra apple mango")
|
||||||
|
self.assertEqual(kws, ["zebra", "apple", "mango"])
|
||||||
|
|
||||||
|
def test_min_length_three(self):
|
||||||
|
kws = self._kw("cat dog elephant", min_len=3)
|
||||||
|
self.assertIn("cat", kws)
|
||||||
|
self.assertIn("dog", kws)
|
||||||
|
self.assertIn("elephant", kws)
|
||||||
|
|
||||||
|
def test_min_length_four_excludes_short(self):
|
||||||
|
kws = self._kw("cat dog elephant", min_len=4)
|
||||||
|
self.assertNotIn("cat", kws)
|
||||||
|
self.assertNotIn("dog", kws)
|
||||||
|
self.assertIn("elephant", kws)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests: PersonTopicMemory ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestPersonTopicMemory(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls): cls.mod = _load_mod()
|
||||||
|
|
||||||
|
def _mem(self, max_topics=10):
|
||||||
|
return self.mod.PersonTopicMemory(max_topics)
|
||||||
|
|
||||||
|
def test_empty_initially(self):
|
||||||
|
self.assertEqual(len(self._mem()), 0)
|
||||||
|
|
||||||
|
def test_add_returns_new(self):
|
||||||
|
m = self._mem()
|
||||||
|
added = m.add(["coffee", "weather"])
|
||||||
|
self.assertEqual(added, ["coffee", "weather"])
|
||||||
|
|
||||||
|
def test_duplicate_not_in_added(self):
|
||||||
|
m = self._mem()
|
||||||
|
m.add(["coffee"])
|
||||||
|
added = m.add(["coffee", "weather"])
|
||||||
|
self.assertNotIn("coffee", added)
|
||||||
|
self.assertIn("weather", added)
|
||||||
|
|
||||||
|
def test_recent_topics_most_recent_first(self):
|
||||||
|
m = self._mem()
|
||||||
|
m.add(["coffee"])
|
||||||
|
m.add(["weather"])
|
||||||
|
topics = m.recent_topics
|
||||||
|
self.assertEqual(topics[0], "weather")
|
||||||
|
self.assertEqual(topics[1], "coffee")
|
||||||
|
|
||||||
|
def test_duplicate_promoted_to_front(self):
|
||||||
|
m = self._mem()
|
||||||
|
m.add(["coffee", "weather", "robot"])
|
||||||
|
m.add(["coffee"]) # promote coffee to front
|
||||||
|
topics = m.recent_topics
|
||||||
|
self.assertEqual(topics[0], "coffee")
|
||||||
|
|
||||||
|
def test_cap_evicts_oldest(self):
|
||||||
|
m = self._mem(max_topics=3)
|
||||||
|
m.add(["alpha", "beta", "gamma"])
|
||||||
|
m.add(["delta"]) # alpha should be evicted
|
||||||
|
topics = m.recent_topics
|
||||||
|
self.assertNotIn("alpha", topics)
|
||||||
|
self.assertIn("delta", topics)
|
||||||
|
self.assertEqual(len(topics), 3)
|
||||||
|
|
||||||
|
def test_len(self):
|
||||||
|
m = self._mem()
|
||||||
|
m.add(["coffee", "weather", "robot"])
|
||||||
|
self.assertEqual(len(m), 3)
|
||||||
|
|
||||||
|
def test_last_updated_set(self):
|
||||||
|
m = self._mem()
|
||||||
|
before = time.monotonic()
|
||||||
|
m.add(["test"])
|
||||||
|
self.assertGreaterEqual(m.last_updated, before)
|
||||||
|
|
||||||
|
def test_empty_add(self):
|
||||||
|
m = self._mem()
|
||||||
|
added = m.add([])
|
||||||
|
self.assertEqual(added, [])
|
||||||
|
self.assertEqual(len(m), 0)
|
||||||
|
|
||||||
|
def test_many_duplicates_stay_within_cap(self):
|
||||||
|
m = self._mem(max_topics=5)
|
||||||
|
for _ in range(10):
|
||||||
|
m.add(["coffee"])
|
||||||
|
self.assertEqual(len(m), 1)
|
||||||
|
self.assertIn("coffee", m.recent_topics)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests: node init ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestNodeInit(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls): cls.mod = _load_mod()
|
||||||
|
|
||||||
|
def test_instantiates(self):
|
||||||
|
self.assertIsNotNone(_make_node(self.mod))
|
||||||
|
|
||||||
|
def test_pub_registered(self):
|
||||||
|
node = _make_node(self.mod)
|
||||||
|
self.assertIn("/saltybot/conversation_topics", node._pubs)
|
||||||
|
|
||||||
|
def test_sub_registered(self):
|
||||||
|
node = _make_node(self.mod)
|
||||||
|
self.assertIn("/social/conversation_text", node._subs)
|
||||||
|
|
||||||
|
def test_memory_empty(self):
|
||||||
|
node = _make_node(self.mod)
|
||||||
|
self.assertEqual(node.all_persons(), {})
|
||||||
|
|
||||||
|
def test_custom_topics(self):
|
||||||
|
node = _make_node(self.mod,
|
||||||
|
conversation_topic="/my/conv",
|
||||||
|
output_topic="/my/topics")
|
||||||
|
self.assertIn("/my/conv", node._subs)
|
||||||
|
self.assertIn("/my/topics", node._pubs)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests: on_conversation callback ──────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestOnConversation(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls): cls.mod = _load_mod()
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.node = _make_node(self.mod)
|
||||||
|
self.pub = self.node._pubs["/saltybot/conversation_topics"]
|
||||||
|
|
||||||
|
def test_publishes_on_keyword(self):
|
||||||
|
_send(self.node, "alice", "I love drinking coffee")
|
||||||
|
self.assertEqual(len(self.pub.msgs), 1)
|
||||||
|
|
||||||
|
def test_payload_is_json(self):
|
||||||
|
_send(self.node, "alice", "I love drinking coffee")
|
||||||
|
payload = json.loads(self.pub.msgs[-1].data)
|
||||||
|
self.assertIsInstance(payload, dict)
|
||||||
|
|
||||||
|
def test_payload_person_id(self):
|
||||||
|
_send(self.node, "bob", "I enjoy hiking mountains")
|
||||||
|
payload = json.loads(self.pub.msgs[-1].data)
|
||||||
|
self.assertEqual(payload["person_id"], "bob")
|
||||||
|
|
||||||
|
def test_payload_recent_topics(self):
|
||||||
|
_send(self.node, "alice", "I love coffee weather robots")
|
||||||
|
payload = json.loads(self.pub.msgs[-1].data)
|
||||||
|
self.assertIsInstance(payload["recent_topics"], list)
|
||||||
|
self.assertGreater(len(payload["recent_topics"]), 0)
|
||||||
|
|
||||||
|
def test_payload_new_topics(self):
|
||||||
|
_send(self.node, "alice", "I love coffee")
|
||||||
|
payload = json.loads(self.pub.msgs[-1].data)
|
||||||
|
self.assertIn("coffee", payload["new_topics"])
|
||||||
|
|
||||||
|
def test_payload_has_ts(self):
|
||||||
|
_send(self.node, "alice", "I love coffee")
|
||||||
|
payload = json.loads(self.pub.msgs[-1].data)
|
||||||
|
self.assertIn("ts", payload)
|
||||||
|
|
||||||
|
def test_all_stop_words_no_publish(self):
|
||||||
|
_send(self.node, "alice", "the is and or")
|
||||||
|
self.assertEqual(len(self.pub.msgs), 0)
|
||||||
|
|
||||||
|
def test_empty_text_no_publish(self):
|
||||||
|
_send(self.node, "alice", "")
|
||||||
|
self.assertEqual(len(self.pub.msgs), 0)
|
||||||
|
|
||||||
|
def test_bad_json_no_crash(self):
|
||||||
|
m = _String(); m.data = "not json at all"
|
||||||
|
self.node._subs["/social/conversation_text"](m)
|
||||||
|
self.assertEqual(len(self.pub.msgs), 0)
|
||||||
|
warns = [l for l in self.node._logs if l[0] == "WARN"]
|
||||||
|
self.assertEqual(len(warns), 1)
|
||||||
|
|
||||||
|
def test_missing_person_id_defaults_unknown(self):
|
||||||
|
m = _String(); m.data = json.dumps({"text": "coffee weather hiking"})
|
||||||
|
self.node._subs["/social/conversation_text"](m)
|
||||||
|
payload = json.loads(self.pub.msgs[-1].data)
|
||||||
|
self.assertEqual(payload["person_id"], "unknown")
|
||||||
|
|
||||||
|
def test_duplicate_topic_not_in_new(self):
|
||||||
|
_send(self.node, "alice", "coffee mountains weather")
|
||||||
|
_send(self.node, "alice", "coffee") # coffee already known
|
||||||
|
payload = json.loads(self.pub.msgs[-1].data)
|
||||||
|
self.assertNotIn("coffee", payload["new_topics"])
|
||||||
|
|
||||||
|
def test_topics_accumulate_across_turns(self):
|
||||||
|
_send(self.node, "alice", "coffee weather")
|
||||||
|
_send(self.node, "alice", "mountains hiking")
|
||||||
|
mem = self.node.get_memory("alice")
|
||||||
|
topics = mem.recent_topics
|
||||||
|
self.assertIn("coffee", topics)
|
||||||
|
self.assertIn("mountains", topics)
|
||||||
|
|
||||||
|
def test_separate_persons_independent(self):
|
||||||
|
_send(self.node, "alice", "coffee weather")
|
||||||
|
_send(self.node, "bob", "robots motors")
|
||||||
|
alice = self.node.get_memory("alice").recent_topics
|
||||||
|
bob = self.node.get_memory("bob").recent_topics
|
||||||
|
self.assertIn("coffee", alice)
|
||||||
|
self.assertNotIn("robots", alice)
|
||||||
|
self.assertIn("robots", bob)
|
||||||
|
self.assertNotIn("coffee", bob)
|
||||||
|
|
||||||
|
def test_all_persons_snapshot(self):
|
||||||
|
_send(self.node, "alice", "coffee weather")
|
||||||
|
_send(self.node, "bob", "robots motors")
|
||||||
|
persons = self.node.all_persons()
|
||||||
|
self.assertIn("alice", persons)
|
||||||
|
self.assertIn("bob", persons)
|
||||||
|
|
||||||
|
def test_recent_topics_most_recent_first(self):
|
||||||
|
_send(self.node, "alice", "alpha")
|
||||||
|
_send(self.node, "alice", "beta")
|
||||||
|
_send(self.node, "alice", "gamma")
|
||||||
|
topics = self.node.get_memory("alice").recent_topics
|
||||||
|
self.assertEqual(topics[0], "gamma")
|
||||||
|
|
||||||
|
def test_stop_words_not_stored(self):
|
||||||
|
_send(self.node, "alice", "the weather and coffee")
|
||||||
|
topics = self.node.get_memory("alice").recent_topics
|
||||||
|
self.assertNotIn("the", topics)
|
||||||
|
self.assertNotIn("and", topics)
|
||||||
|
|
||||||
|
def test_max_keywords_respected(self):
|
||||||
|
node = _make_node(self.mod, max_keywords_per_msg=3)
|
||||||
|
_send(node, "alice",
|
||||||
|
"coffee weather hiking mountains ocean desert forest lake")
|
||||||
|
mem = node.get_memory("alice")
|
||||||
|
# Only 3 keywords should have been extracted per message
|
||||||
|
self.assertLessEqual(len(mem), 3)
|
||||||
|
|
||||||
|
def test_max_topics_cap(self):
|
||||||
|
node = _make_node(self.mod, max_topics_per_person=5, max_keywords_per_msg=20)
|
||||||
|
words = ["alpha", "beta", "gamma", "delta", "epsilon",
|
||||||
|
"zeta", "eta", "theta"]
|
||||||
|
for i, w in enumerate(words):
|
||||||
|
_send(node, "alice", w)
|
||||||
|
mem = node.get_memory("alice")
|
||||||
|
self.assertLessEqual(len(mem), 5)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests: prune ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestPrune(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls): cls.mod = _load_mod()
|
||||||
|
|
||||||
|
def test_no_prune_when_disabled(self):
|
||||||
|
node = _make_node(self.mod, prune_after_s=0.0)
|
||||||
|
_send(node, "alice", "coffee weather")
|
||||||
|
# Manually expire timestamp
|
||||||
|
node._memory["alice"].last_updated = time.monotonic() - 9999
|
||||||
|
_send(node, "bob", "robots motors") # triggers prune check
|
||||||
|
# alice should still be present (prune disabled)
|
||||||
|
self.assertIn("alice", node.all_persons())
|
||||||
|
|
||||||
|
def test_prune_stale_person(self):
|
||||||
|
node = _make_node(self.mod, prune_after_s=1.0)
|
||||||
|
_send(node, "alice", "coffee weather")
|
||||||
|
node._memory["alice"].last_updated = time.monotonic() - 10.0 # stale
|
||||||
|
_send(node, "bob", "robots motors") # triggers prune
|
||||||
|
self.assertNotIn("alice", node.all_persons())
|
||||||
|
|
||||||
|
def test_fresh_person_not_pruned(self):
|
||||||
|
node = _make_node(self.mod, prune_after_s=1.0)
|
||||||
|
_send(node, "alice", "coffee weather")
|
||||||
|
# alice is fresh (just added)
|
||||||
|
_send(node, "bob", "robots motors")
|
||||||
|
self.assertIn("alice", node.all_persons())
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests: source and config ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestNodeSrc(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
with open(_SRC) as f: cls.src = f.read()
|
||||||
|
|
||||||
|
def test_issue_tag(self): self.assertIn("#299", self.src)
|
||||||
|
def test_input_topic(self): self.assertIn("/social/conversation_text", self.src)
|
||||||
|
def test_output_topic(self): self.assertIn("/saltybot/conversation_topics", self.src)
|
||||||
|
def test_extract_keywords(self): self.assertIn("extract_keywords", self.src)
|
||||||
|
def test_person_topic_memory(self):self.assertIn("PersonTopicMemory", self.src)
|
||||||
|
def test_stop_words(self): self.assertIn("STOP_WORDS", self.src)
|
||||||
|
def test_json_output(self): self.assertIn("json.dumps", self.src)
|
||||||
|
def test_person_id_in_output(self):self.assertIn("person_id", self.src)
|
||||||
|
def test_recent_topics_key(self): self.assertIn("recent_topics", self.src)
|
||||||
|
def test_new_topics_key(self): self.assertIn("new_topics", self.src)
|
||||||
|
def test_threading_lock(self): self.assertIn("threading.Lock", self.src)
|
||||||
|
def test_prune_method(self): self.assertIn("_prune_stale", self.src)
|
||||||
|
def test_main_defined(self): self.assertIn("def main", self.src)
|
||||||
|
def test_min_word_length_param(self):self.assertIn("min_word_length", self.src)
|
||||||
|
def test_max_topics_param(self): self.assertIn("max_topics_per_person", self.src)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfig(unittest.TestCase):
|
||||||
|
_CONFIG = (
|
||||||
|
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||||
|
"saltybot_social/config/topic_memory_params.yaml"
|
||||||
|
)
|
||||||
|
_LAUNCH = (
|
||||||
|
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||||
|
"saltybot_social/launch/topic_memory.launch.py"
|
||||||
|
)
|
||||||
|
_SETUP = (
|
||||||
|
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||||
|
"saltybot_social/setup.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_config_exists(self):
|
||||||
|
import os; self.assertTrue(os.path.exists(self._CONFIG))
|
||||||
|
|
||||||
|
def test_config_min_word_length(self):
|
||||||
|
with open(self._CONFIG) as f: c = f.read()
|
||||||
|
self.assertIn("min_word_length", c)
|
||||||
|
|
||||||
|
def test_config_max_topics(self):
|
||||||
|
with open(self._CONFIG) as f: c = f.read()
|
||||||
|
self.assertIn("max_topics_per_person", c)
|
||||||
|
|
||||||
|
def test_config_prune(self):
|
||||||
|
with open(self._CONFIG) as f: c = f.read()
|
||||||
|
self.assertIn("prune_after_s", c)
|
||||||
|
|
||||||
|
def test_launch_exists(self):
|
||||||
|
import os; self.assertTrue(os.path.exists(self._LAUNCH))
|
||||||
|
|
||||||
|
def test_launch_max_topics_arg(self):
|
||||||
|
with open(self._LAUNCH) as f: c = f.read()
|
||||||
|
self.assertIn("max_topics_per_person", c)
|
||||||
|
|
||||||
|
def test_entry_point(self):
|
||||||
|
with open(self._SETUP) as f: c = f.read()
|
||||||
|
self.assertIn("topic_memory_node", c)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
x
Reference in New Issue
Block a user