Merge pull request 'feat(social): conversation topic memory (Issue #299)' (#304) from sl-jetson/issue-299-topic-memory into main
Some checks failed
Some checks failed
This commit is contained in:
commit
a8838cfbbd
@ -0,0 +1,14 @@
|
||||
topic_memory_node:
|
||||
ros__parameters:
|
||||
conversation_topic: "/social/conversation_text" # Input: JSON String {person_id, text}
|
||||
output_topic: "/saltybot/conversation_topics" # Output: JSON String per utterance
|
||||
|
||||
# Keyword extraction
|
||||
min_word_length: 3 # Skip words shorter than this
|
||||
max_keywords_per_msg: 10 # Extract at most this many keywords per utterance
|
||||
|
||||
# Per-person rolling window
|
||||
max_topics_per_person: 30 # Keep last N unique topics per person
|
||||
|
||||
# Stale-person pruning (0 = disabled)
|
||||
prune_after_s: 1800.0 # Forget person after 30 min of inactivity
|
||||
@ -0,0 +1,42 @@
|
||||
"""topic_memory.launch.py — Launch conversation topic memory node (Issue #299).
|
||||
|
||||
Usage:
|
||||
ros2 launch saltybot_social topic_memory.launch.py
|
||||
ros2 launch saltybot_social topic_memory.launch.py max_topics_per_person:=50
|
||||
"""
|
||||
|
||||
import os
|
||||
from ament_index_python.packages import get_package_share_directory
|
||||
from launch import LaunchDescription
|
||||
from launch.actions import DeclareLaunchArgument
|
||||
from launch.substitutions import LaunchConfiguration
|
||||
from launch_ros.actions import Node
|
||||
|
||||
|
||||
def generate_launch_description():
|
||||
pkg = get_package_share_directory("saltybot_social")
|
||||
cfg = os.path.join(pkg, "config", "topic_memory_params.yaml")
|
||||
|
||||
return LaunchDescription([
|
||||
DeclareLaunchArgument("max_topics_per_person", default_value="30",
|
||||
description="Rolling topic window per person"),
|
||||
DeclareLaunchArgument("max_keywords_per_msg", default_value="10",
|
||||
description="Max keywords extracted per utterance"),
|
||||
DeclareLaunchArgument("prune_after_s", default_value="1800.0",
|
||||
description="Forget persons idle this long (0=off)"),
|
||||
|
||||
Node(
|
||||
package="saltybot_social",
|
||||
executable="topic_memory_node",
|
||||
name="topic_memory_node",
|
||||
output="screen",
|
||||
parameters=[
|
||||
cfg,
|
||||
{
|
||||
"max_topics_per_person": LaunchConfiguration("max_topics_per_person"),
|
||||
"max_keywords_per_msg": LaunchConfiguration("max_keywords_per_msg"),
|
||||
"prune_after_s": LaunchConfiguration("prune_after_s"),
|
||||
},
|
||||
],
|
||||
),
|
||||
])
|
||||
@ -0,0 +1,268 @@
|
||||
"""topic_memory_node.py — Conversation topic memory.
|
||||
Issue #299
|
||||
|
||||
Subscribes to /social/conversation_text (std_msgs/String, JSON payload
|
||||
{"person_id": "...", "text": "..."}), extracts key topics via stop-word
|
||||
filtered keyword extraction, and maintains a per-person rolling topic
|
||||
history.
|
||||
|
||||
On every message that yields at least one new keyword the node publishes
|
||||
an updated topic snapshot on /saltybot/conversation_topics (std_msgs/String,
|
||||
JSON) — enabling recall like "last time you mentioned coffee" or "you talked
|
||||
about the weather with alice".
|
||||
|
||||
Published JSON format
|
||||
─────────────────────
|
||||
{
|
||||
"person_id": "alice",
|
||||
"recent_topics": ["coffee", "weather", "robot"], // most-recent first
|
||||
"new_topics": ["coffee"], // keywords added this turn
|
||||
"ts": 1234567890.123
|
||||
}
|
||||
|
||||
Keyword extraction pipeline
|
||||
────────────────────────────
|
||||
1. Lowercase + tokenise on non-word characters
|
||||
2. Filter: length >= min_word_length, alphabetic only
|
||||
3. Remove stop words (built-in English list)
|
||||
4. Deduplicate within the utterance
|
||||
5. Take first max_keywords_per_msg tokens
|
||||
|
||||
Per-person storage
|
||||
──────────────────
|
||||
Ordered list (insertion order), capped at max_topics_per_person.
|
||||
Duplicate keywords are promoted to the front (most-recent position).
|
||||
Persons not seen for prune_after_s seconds are pruned on next publish
|
||||
(set to 0 to disable pruning).
|
||||
|
||||
Parameters
|
||||
──────────
|
||||
conversation_topic (str, "/social/conversation_text")
|
||||
output_topic (str, "/saltybot/conversation_topics")
|
||||
min_word_length (int, 3) minimum character length to keep
|
||||
max_keywords_per_msg (int, 10) max keywords extracted per utterance
|
||||
max_topics_per_person(int, 30) rolling window per person
|
||||
prune_after_s (float, 1800.0) forget persons idle this long (0=off)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import rclpy
|
||||
from rclpy.node import Node
|
||||
from rclpy.qos import QoSProfile
|
||||
from std_msgs.msg import String
|
||||
|
||||
|
||||
# ── Stop-word list (English) ───────────────────────────────────────────────────
|
||||
|
||||
STOP_WORDS: frozenset = frozenset({
|
||||
"the", "a", "an", "and", "or", "but", "nor", "not", "so", "yet",
|
||||
"for", "with", "from", "into", "onto", "upon", "about", "above",
|
||||
"after", "before", "between", "during", "through", "under", "over",
|
||||
"at", "by", "in", "of", "on", "to", "up", "as",
|
||||
"is", "are", "was", "were", "be", "been", "being",
|
||||
"have", "has", "had", "do", "does", "did",
|
||||
"will", "would", "could", "should", "may", "might", "shall", "can",
|
||||
"it", "its", "this", "that", "these", "those",
|
||||
"i", "me", "my", "myself", "we", "our", "ours",
|
||||
"you", "your", "yours", "he", "him", "his", "she", "her", "hers",
|
||||
"they", "them", "their", "theirs",
|
||||
"what", "which", "who", "whom", "when", "where", "why", "how",
|
||||
"all", "each", "every", "both", "few", "more", "most",
|
||||
"other", "some", "such", "no", "only", "own", "same",
|
||||
"than", "then", "too", "very", "just", "also",
|
||||
"get", "got", "say", "said", "know", "think", "go", "going", "come",
|
||||
"like", "want", "see", "take", "make", "give", "look",
|
||||
"yes", "yeah", "okay", "ok", "oh", "ah", "um", "uh", "well",
|
||||
"now", "here", "there", "hi", "hey", "hello",
|
||||
})
|
||||
|
||||
_PUNCT_RE = re.compile(r"[^\w\s]")
|
||||
|
||||
|
||||
# ── Keyword extraction ────────────────────────────────────────────────────────
|
||||
|
||||
def extract_keywords(text: str,
|
||||
min_length: int = 3,
|
||||
max_keywords: int = 10) -> List[str]:
|
||||
"""Return a deduplicated list of meaningful keywords from *text*.
|
||||
|
||||
Steps: lowercase -> strip punctuation -> split -> filter stop words &
|
||||
length -> deduplicate -> cap at max_keywords.
|
||||
"""
|
||||
cleaned = _PUNCT_RE.sub(" ", text.lower())
|
||||
seen: dict = {} # ordered-set via insertion-order dict
|
||||
for tok in cleaned.split():
|
||||
tok = tok.strip(string.punctuation + "_")
|
||||
if (len(tok) >= min_length
|
||||
and tok.isalpha()
|
||||
and tok not in STOP_WORDS
|
||||
and tok not in seen):
|
||||
seen[tok] = None
|
||||
if len(seen) >= max_keywords:
|
||||
break
|
||||
return list(seen)
|
||||
|
||||
|
||||
# ── Per-person topic memory ───────────────────────────────────────────────────
|
||||
|
||||
class PersonTopicMemory:
|
||||
"""Rolling, deduplicated topic list for one person."""
|
||||
|
||||
def __init__(self, max_topics: int = 30) -> None:
|
||||
self._max = max_topics
|
||||
self._topics: List[str] = [] # oldest -> newest order
|
||||
self._topic_set: set = set()
|
||||
self.last_updated: float = 0.0
|
||||
|
||||
def add(self, keywords: List[str]) -> List[str]:
|
||||
"""Add *keywords*; promote duplicates to front, evict oldest over cap.
|
||||
|
||||
Returns the list of newly added (previously unseen) keywords.
|
||||
"""
|
||||
added: List[str] = []
|
||||
for kw in keywords:
|
||||
if kw in self._topic_set:
|
||||
# Promote to most-recent position
|
||||
self._topics.remove(kw)
|
||||
self._topics.append(kw)
|
||||
else:
|
||||
self._topics.append(kw)
|
||||
self._topic_set.add(kw)
|
||||
added.append(kw)
|
||||
# Trim oldest if over cap
|
||||
while len(self._topics) > self._max:
|
||||
evicted = self._topics.pop(0)
|
||||
self._topic_set.discard(evicted)
|
||||
self.last_updated = time.monotonic()
|
||||
return added
|
||||
|
||||
@property
|
||||
def recent_topics(self) -> List[str]:
|
||||
"""Most-recent topics first."""
|
||||
return list(reversed(self._topics))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._topics)
|
||||
|
||||
|
||||
# ── ROS2 node ─────────────────────────────────────────────────────────────────
|
||||
|
||||
class TopicMemoryNode(Node):
|
||||
"""Extracts and remembers conversation topics per person."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("topic_memory_node")
|
||||
|
||||
self.declare_parameter("conversation_topic", "/social/conversation_text")
|
||||
self.declare_parameter("output_topic", "/saltybot/conversation_topics")
|
||||
self.declare_parameter("min_word_length", 3)
|
||||
self.declare_parameter("max_keywords_per_msg", 10)
|
||||
self.declare_parameter("max_topics_per_person", 30)
|
||||
self.declare_parameter("prune_after_s", 1800.0)
|
||||
|
||||
conv_topic = self.get_parameter("conversation_topic").value
|
||||
out_topic = self.get_parameter("output_topic").value
|
||||
self._min_len = self.get_parameter("min_word_length").value
|
||||
self._max_kw = self.get_parameter("max_keywords_per_msg").value
|
||||
self._max_tp = self.get_parameter("max_topics_per_person").value
|
||||
self._prune_s = self.get_parameter("prune_after_s").value
|
||||
|
||||
# person_id -> PersonTopicMemory
|
||||
self._memory: Dict[str, PersonTopicMemory] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
qos = QoSProfile(depth=10)
|
||||
self._pub = self.create_publisher(String, out_topic, qos)
|
||||
self._sub = self.create_subscription(
|
||||
String, conv_topic, self._on_conversation, qos
|
||||
)
|
||||
|
||||
self.get_logger().info(
|
||||
f"TopicMemoryNode ready "
|
||||
f"(min_len={self._min_len}, max_kw={self._max_kw}, "
|
||||
f"max_topics={self._max_tp}, prune_after={self._prune_s}s)"
|
||||
)
|
||||
|
||||
# ── Subscription ───────────────────────────────────────────────────────
|
||||
|
||||
def _on_conversation(self, msg: String) -> None:
|
||||
try:
|
||||
payload = json.loads(msg.data)
|
||||
person_id: str = str(payload.get("person_id", "unknown"))
|
||||
text: str = str(payload.get("text", ""))
|
||||
except (json.JSONDecodeError, AttributeError) as exc:
|
||||
self.get_logger().warn(f"Bad conversation_text payload: {exc}")
|
||||
return
|
||||
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
keywords = extract_keywords(text, self._min_len, self._max_kw)
|
||||
if not keywords:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
self._prune_stale()
|
||||
if person_id not in self._memory:
|
||||
self._memory[person_id] = PersonTopicMemory(self._max_tp)
|
||||
mem = self._memory[person_id]
|
||||
new_topics = mem.add(keywords)
|
||||
recent = mem.recent_topics
|
||||
|
||||
out = String()
|
||||
out.data = json.dumps({
|
||||
"person_id": person_id,
|
||||
"recent_topics": recent,
|
||||
"new_topics": new_topics,
|
||||
"ts": time.time(),
|
||||
})
|
||||
self._pub.publish(out)
|
||||
|
||||
if new_topics:
|
||||
self.get_logger().info(
|
||||
f"[{person_id}] new topics: {new_topics} | "
|
||||
f"memory: {recent[:5]}{'...' if len(recent) > 5 else ''}"
|
||||
)
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
def _prune_stale(self) -> None:
|
||||
"""Remove persons not seen for prune_after_s seconds (call under lock)."""
|
||||
if self._prune_s <= 0:
|
||||
return
|
||||
now = time.monotonic()
|
||||
stale = [pid for pid, m in self._memory.items()
|
||||
if (now - m.last_updated) > self._prune_s]
|
||||
for pid in stale:
|
||||
del self._memory[pid]
|
||||
self.get_logger().info(f"Pruned stale person: {pid}")
|
||||
|
||||
def get_memory(self, person_id: str) -> Optional[PersonTopicMemory]:
|
||||
"""Return topic memory for a person (None if not seen)."""
|
||||
with self._lock:
|
||||
return self._memory.get(person_id)
|
||||
|
||||
def all_persons(self) -> Dict[str, List[str]]:
|
||||
"""Return {person_id: recent_topics} snapshot for all known persons."""
|
||||
with self._lock:
|
||||
return {pid: m.recent_topics for pid, m in self._memory.items()}
|
||||
|
||||
|
||||
def main(args=None) -> None:
|
||||
rclpy.init(args=args)
|
||||
node = TopicMemoryNode()
|
||||
try:
|
||||
rclpy.spin(node)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
node.destroy_node()
|
||||
rclpy.shutdown()
|
||||
@ -53,6 +53,8 @@ setup(
|
||||
'face_track_servo_node = saltybot_social.face_track_servo_node:main',
|
||||
# Speech volume auto-adjuster (Issue #289)
|
||||
'volume_adjust_node = saltybot_social.volume_adjust_node:main',
|
||||
# Conversation topic memory (Issue #299)
|
||||
'topic_memory_node = saltybot_social.topic_memory_node:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
522
jetson/ros2_ws/src/saltybot_social/test/test_topic_memory.py
Normal file
522
jetson/ros2_ws/src/saltybot_social/test/test_topic_memory.py
Normal file
@ -0,0 +1,522 @@
|
||||
"""test_topic_memory.py — Offline tests for topic_memory_node (Issue #299).
|
||||
|
||||
Stubs out rclpy so tests run without a ROS install.
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import types
|
||||
import unittest
|
||||
|
||||
|
||||
# ── ROS2 stubs ────────────────────────────────────────────────────────────────
|
||||
|
||||
def _make_ros_stubs():
|
||||
for mod_name in ("rclpy", "rclpy.node", "rclpy.qos",
|
||||
"std_msgs", "std_msgs.msg"):
|
||||
if mod_name not in sys.modules:
|
||||
sys.modules[mod_name] = types.ModuleType(mod_name)
|
||||
|
||||
class _Node:
|
||||
def __init__(self, name="node"):
|
||||
self._name = name
|
||||
if not hasattr(self, "_params"):
|
||||
self._params = {}
|
||||
self._pubs = {}
|
||||
self._subs = {}
|
||||
self._logs = []
|
||||
|
||||
def declare_parameter(self, name, default):
|
||||
if name not in self._params:
|
||||
self._params[name] = default
|
||||
|
||||
def get_parameter(self, name):
|
||||
class _P:
|
||||
def __init__(self, v): self.value = v
|
||||
return _P(self._params.get(name))
|
||||
|
||||
def create_publisher(self, msg_type, topic, qos):
|
||||
pub = _FakePub()
|
||||
self._pubs[topic] = pub
|
||||
return pub
|
||||
|
||||
def create_subscription(self, msg_type, topic, cb, qos):
|
||||
self._subs[topic] = cb
|
||||
return object()
|
||||
|
||||
def get_logger(self):
|
||||
node = self
|
||||
class _L:
|
||||
def info(self, m): node._logs.append(("INFO", m))
|
||||
def warn(self, m): node._logs.append(("WARN", m))
|
||||
def error(self, m): node._logs.append(("ERROR", m))
|
||||
return _L()
|
||||
|
||||
def destroy_node(self): pass
|
||||
|
||||
class _FakePub:
|
||||
def __init__(self):
|
||||
self.msgs = []
|
||||
def publish(self, msg):
|
||||
self.msgs.append(msg)
|
||||
|
||||
class _QoSProfile:
|
||||
def __init__(self, depth=10): self.depth = depth
|
||||
|
||||
class _String:
|
||||
def __init__(self): self.data = ""
|
||||
|
||||
rclpy_mod = sys.modules["rclpy"]
|
||||
rclpy_mod.init = lambda args=None: None
|
||||
rclpy_mod.spin = lambda node: None
|
||||
rclpy_mod.shutdown = lambda: None
|
||||
|
||||
sys.modules["rclpy.node"].Node = _Node
|
||||
sys.modules["rclpy.qos"].QoSProfile = _QoSProfile
|
||||
sys.modules["std_msgs.msg"].String = _String
|
||||
|
||||
return _Node, _FakePub, _String
|
||||
|
||||
|
||||
_Node, _FakePub, _String = _make_ros_stubs()
|
||||
|
||||
|
||||
# ── Module loader ─────────────────────────────────────────────────────────────
|
||||
|
||||
_SRC = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/saltybot_social/topic_memory_node.py"
|
||||
)
|
||||
|
||||
|
||||
def _load_mod():
|
||||
spec = importlib.util.spec_from_file_location("topic_memory_testmod", _SRC)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def _make_node(mod, **kwargs):
|
||||
node = mod.TopicMemoryNode.__new__(mod.TopicMemoryNode)
|
||||
defaults = {
|
||||
"conversation_topic": "/social/conversation_text",
|
||||
"output_topic": "/saltybot/conversation_topics",
|
||||
"min_word_length": 3,
|
||||
"max_keywords_per_msg": 10,
|
||||
"max_topics_per_person": 30,
|
||||
"prune_after_s": 1800.0,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
node._params = dict(defaults)
|
||||
mod.TopicMemoryNode.__init__(node)
|
||||
return node
|
||||
|
||||
|
||||
def _msg(person_id, text):
|
||||
m = _String()
|
||||
m.data = json.dumps({"person_id": person_id, "text": text})
|
||||
return m
|
||||
|
||||
|
||||
def _send(node, person_id, text):
|
||||
"""Deliver a conversation message to the node."""
|
||||
cb = node._subs["/social/conversation_text"]
|
||||
cb(_msg(person_id, text))
|
||||
|
||||
|
||||
# ── Tests: extract_keywords ───────────────────────────────────────────────────
|
||||
|
||||
class TestExtractKeywords(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def _kw(self, text, min_len=3, max_kw=10):
|
||||
return self.mod.extract_keywords(text, min_len, max_kw)
|
||||
|
||||
def test_basic(self):
|
||||
kws = self._kw("I love drinking coffee")
|
||||
self.assertIn("love", kws)
|
||||
self.assertIn("drinking", kws)
|
||||
self.assertIn("coffee", kws)
|
||||
|
||||
def test_stop_words_removed(self):
|
||||
kws = self._kw("the quick brown fox")
|
||||
self.assertNotIn("the", kws)
|
||||
|
||||
def test_short_words_removed(self):
|
||||
kws = self._kw("go to the museum now", min_len=4)
|
||||
self.assertNotIn("go", kws)
|
||||
self.assertNotIn("to", kws)
|
||||
self.assertNotIn("the", kws)
|
||||
|
||||
def test_deduplication(self):
|
||||
kws = self._kw("coffee coffee coffee")
|
||||
self.assertEqual(kws.count("coffee"), 1)
|
||||
|
||||
def test_max_keywords_cap(self):
|
||||
text = " ".join(f"word{i}" for i in range(20))
|
||||
kws = self._kw(text, max_kw=5)
|
||||
self.assertLessEqual(len(kws), 5)
|
||||
|
||||
def test_punctuation_stripped(self):
|
||||
kws = self._kw("Hello, world! How's weather?")
|
||||
# "world" and "weather" should be found, punctuation removed
|
||||
self.assertIn("world", kws)
|
||||
self.assertIn("weather", kws)
|
||||
|
||||
def test_case_insensitive(self):
|
||||
kws = self._kw("Robot ROBOT robot")
|
||||
self.assertEqual(kws.count("robot"), 1)
|
||||
|
||||
def test_empty_text(self):
|
||||
self.assertEqual(self._kw(""), [])
|
||||
|
||||
def test_all_stop_words(self):
|
||||
self.assertEqual(self._kw("the is a and"), [])
|
||||
|
||||
def test_non_alpha_excluded(self):
|
||||
kws = self._kw("model42 123 price500")
|
||||
# alphanumeric tokens like "model42" contain digits → excluded
|
||||
self.assertEqual(kws, [])
|
||||
|
||||
def test_preserves_order(self):
|
||||
kws = self._kw("zebra apple mango")
|
||||
self.assertEqual(kws, ["zebra", "apple", "mango"])
|
||||
|
||||
def test_min_length_three(self):
|
||||
kws = self._kw("cat dog elephant", min_len=3)
|
||||
self.assertIn("cat", kws)
|
||||
self.assertIn("dog", kws)
|
||||
self.assertIn("elephant", kws)
|
||||
|
||||
def test_min_length_four_excludes_short(self):
|
||||
kws = self._kw("cat dog elephant", min_len=4)
|
||||
self.assertNotIn("cat", kws)
|
||||
self.assertNotIn("dog", kws)
|
||||
self.assertIn("elephant", kws)
|
||||
|
||||
|
||||
# ── Tests: PersonTopicMemory ──────────────────────────────────────────────────
|
||||
|
||||
class TestPersonTopicMemory(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def _mem(self, max_topics=10):
|
||||
return self.mod.PersonTopicMemory(max_topics)
|
||||
|
||||
def test_empty_initially(self):
|
||||
self.assertEqual(len(self._mem()), 0)
|
||||
|
||||
def test_add_returns_new(self):
|
||||
m = self._mem()
|
||||
added = m.add(["coffee", "weather"])
|
||||
self.assertEqual(added, ["coffee", "weather"])
|
||||
|
||||
def test_duplicate_not_in_added(self):
|
||||
m = self._mem()
|
||||
m.add(["coffee"])
|
||||
added = m.add(["coffee", "weather"])
|
||||
self.assertNotIn("coffee", added)
|
||||
self.assertIn("weather", added)
|
||||
|
||||
def test_recent_topics_most_recent_first(self):
|
||||
m = self._mem()
|
||||
m.add(["coffee"])
|
||||
m.add(["weather"])
|
||||
topics = m.recent_topics
|
||||
self.assertEqual(topics[0], "weather")
|
||||
self.assertEqual(topics[1], "coffee")
|
||||
|
||||
def test_duplicate_promoted_to_front(self):
|
||||
m = self._mem()
|
||||
m.add(["coffee", "weather", "robot"])
|
||||
m.add(["coffee"]) # promote coffee to front
|
||||
topics = m.recent_topics
|
||||
self.assertEqual(topics[0], "coffee")
|
||||
|
||||
def test_cap_evicts_oldest(self):
|
||||
m = self._mem(max_topics=3)
|
||||
m.add(["alpha", "beta", "gamma"])
|
||||
m.add(["delta"]) # alpha should be evicted
|
||||
topics = m.recent_topics
|
||||
self.assertNotIn("alpha", topics)
|
||||
self.assertIn("delta", topics)
|
||||
self.assertEqual(len(topics), 3)
|
||||
|
||||
def test_len(self):
|
||||
m = self._mem()
|
||||
m.add(["coffee", "weather", "robot"])
|
||||
self.assertEqual(len(m), 3)
|
||||
|
||||
def test_last_updated_set(self):
|
||||
m = self._mem()
|
||||
before = time.monotonic()
|
||||
m.add(["test"])
|
||||
self.assertGreaterEqual(m.last_updated, before)
|
||||
|
||||
def test_empty_add(self):
|
||||
m = self._mem()
|
||||
added = m.add([])
|
||||
self.assertEqual(added, [])
|
||||
self.assertEqual(len(m), 0)
|
||||
|
||||
def test_many_duplicates_stay_within_cap(self):
|
||||
m = self._mem(max_topics=5)
|
||||
for _ in range(10):
|
||||
m.add(["coffee"])
|
||||
self.assertEqual(len(m), 1)
|
||||
self.assertIn("coffee", m.recent_topics)
|
||||
|
||||
|
||||
# ── Tests: node init ──────────────────────────────────────────────────────────
|
||||
|
||||
class TestNodeInit(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def test_instantiates(self):
|
||||
self.assertIsNotNone(_make_node(self.mod))
|
||||
|
||||
def test_pub_registered(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIn("/saltybot/conversation_topics", node._pubs)
|
||||
|
||||
def test_sub_registered(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertIn("/social/conversation_text", node._subs)
|
||||
|
||||
def test_memory_empty(self):
|
||||
node = _make_node(self.mod)
|
||||
self.assertEqual(node.all_persons(), {})
|
||||
|
||||
def test_custom_topics(self):
|
||||
node = _make_node(self.mod,
|
||||
conversation_topic="/my/conv",
|
||||
output_topic="/my/topics")
|
||||
self.assertIn("/my/conv", node._subs)
|
||||
self.assertIn("/my/topics", node._pubs)
|
||||
|
||||
|
||||
# ── Tests: on_conversation callback ──────────────────────────────────────────
|
||||
|
||||
class TestOnConversation(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def setUp(self):
|
||||
self.node = _make_node(self.mod)
|
||||
self.pub = self.node._pubs["/saltybot/conversation_topics"]
|
||||
|
||||
def test_publishes_on_keyword(self):
|
||||
_send(self.node, "alice", "I love drinking coffee")
|
||||
self.assertEqual(len(self.pub.msgs), 1)
|
||||
|
||||
def test_payload_is_json(self):
|
||||
_send(self.node, "alice", "I love drinking coffee")
|
||||
payload = json.loads(self.pub.msgs[-1].data)
|
||||
self.assertIsInstance(payload, dict)
|
||||
|
||||
def test_payload_person_id(self):
|
||||
_send(self.node, "bob", "I enjoy hiking mountains")
|
||||
payload = json.loads(self.pub.msgs[-1].data)
|
||||
self.assertEqual(payload["person_id"], "bob")
|
||||
|
||||
def test_payload_recent_topics(self):
|
||||
_send(self.node, "alice", "I love coffee weather robots")
|
||||
payload = json.loads(self.pub.msgs[-1].data)
|
||||
self.assertIsInstance(payload["recent_topics"], list)
|
||||
self.assertGreater(len(payload["recent_topics"]), 0)
|
||||
|
||||
def test_payload_new_topics(self):
|
||||
_send(self.node, "alice", "I love coffee")
|
||||
payload = json.loads(self.pub.msgs[-1].data)
|
||||
self.assertIn("coffee", payload["new_topics"])
|
||||
|
||||
def test_payload_has_ts(self):
|
||||
_send(self.node, "alice", "I love coffee")
|
||||
payload = json.loads(self.pub.msgs[-1].data)
|
||||
self.assertIn("ts", payload)
|
||||
|
||||
def test_all_stop_words_no_publish(self):
|
||||
_send(self.node, "alice", "the is and or")
|
||||
self.assertEqual(len(self.pub.msgs), 0)
|
||||
|
||||
def test_empty_text_no_publish(self):
|
||||
_send(self.node, "alice", "")
|
||||
self.assertEqual(len(self.pub.msgs), 0)
|
||||
|
||||
def test_bad_json_no_crash(self):
|
||||
m = _String(); m.data = "not json at all"
|
||||
self.node._subs["/social/conversation_text"](m)
|
||||
self.assertEqual(len(self.pub.msgs), 0)
|
||||
warns = [l for l in self.node._logs if l[0] == "WARN"]
|
||||
self.assertEqual(len(warns), 1)
|
||||
|
||||
def test_missing_person_id_defaults_unknown(self):
|
||||
m = _String(); m.data = json.dumps({"text": "coffee weather hiking"})
|
||||
self.node._subs["/social/conversation_text"](m)
|
||||
payload = json.loads(self.pub.msgs[-1].data)
|
||||
self.assertEqual(payload["person_id"], "unknown")
|
||||
|
||||
def test_duplicate_topic_not_in_new(self):
|
||||
_send(self.node, "alice", "coffee mountains weather")
|
||||
_send(self.node, "alice", "coffee") # coffee already known
|
||||
payload = json.loads(self.pub.msgs[-1].data)
|
||||
self.assertNotIn("coffee", payload["new_topics"])
|
||||
|
||||
def test_topics_accumulate_across_turns(self):
|
||||
_send(self.node, "alice", "coffee weather")
|
||||
_send(self.node, "alice", "mountains hiking")
|
||||
mem = self.node.get_memory("alice")
|
||||
topics = mem.recent_topics
|
||||
self.assertIn("coffee", topics)
|
||||
self.assertIn("mountains", topics)
|
||||
|
||||
def test_separate_persons_independent(self):
|
||||
_send(self.node, "alice", "coffee weather")
|
||||
_send(self.node, "bob", "robots motors")
|
||||
alice = self.node.get_memory("alice").recent_topics
|
||||
bob = self.node.get_memory("bob").recent_topics
|
||||
self.assertIn("coffee", alice)
|
||||
self.assertNotIn("robots", alice)
|
||||
self.assertIn("robots", bob)
|
||||
self.assertNotIn("coffee", bob)
|
||||
|
||||
def test_all_persons_snapshot(self):
|
||||
_send(self.node, "alice", "coffee weather")
|
||||
_send(self.node, "bob", "robots motors")
|
||||
persons = self.node.all_persons()
|
||||
self.assertIn("alice", persons)
|
||||
self.assertIn("bob", persons)
|
||||
|
||||
def test_recent_topics_most_recent_first(self):
|
||||
_send(self.node, "alice", "alpha")
|
||||
_send(self.node, "alice", "beta")
|
||||
_send(self.node, "alice", "gamma")
|
||||
topics = self.node.get_memory("alice").recent_topics
|
||||
self.assertEqual(topics[0], "gamma")
|
||||
|
||||
def test_stop_words_not_stored(self):
|
||||
_send(self.node, "alice", "the weather and coffee")
|
||||
topics = self.node.get_memory("alice").recent_topics
|
||||
self.assertNotIn("the", topics)
|
||||
self.assertNotIn("and", topics)
|
||||
|
||||
def test_max_keywords_respected(self):
|
||||
node = _make_node(self.mod, max_keywords_per_msg=3)
|
||||
_send(node, "alice",
|
||||
"coffee weather hiking mountains ocean desert forest lake")
|
||||
mem = node.get_memory("alice")
|
||||
# Only 3 keywords should have been extracted per message
|
||||
self.assertLessEqual(len(mem), 3)
|
||||
|
||||
def test_max_topics_cap(self):
|
||||
node = _make_node(self.mod, max_topics_per_person=5, max_keywords_per_msg=20)
|
||||
words = ["alpha", "beta", "gamma", "delta", "epsilon",
|
||||
"zeta", "eta", "theta"]
|
||||
for i, w in enumerate(words):
|
||||
_send(node, "alice", w)
|
||||
mem = node.get_memory("alice")
|
||||
self.assertLessEqual(len(mem), 5)
|
||||
|
||||
|
||||
# ── Tests: prune ──────────────────────────────────────────────────────────────
|
||||
|
||||
class TestPrune(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): cls.mod = _load_mod()
|
||||
|
||||
def test_no_prune_when_disabled(self):
|
||||
node = _make_node(self.mod, prune_after_s=0.0)
|
||||
_send(node, "alice", "coffee weather")
|
||||
# Manually expire timestamp
|
||||
node._memory["alice"].last_updated = time.monotonic() - 9999
|
||||
_send(node, "bob", "robots motors") # triggers prune check
|
||||
# alice should still be present (prune disabled)
|
||||
self.assertIn("alice", node.all_persons())
|
||||
|
||||
def test_prune_stale_person(self):
|
||||
node = _make_node(self.mod, prune_after_s=1.0)
|
||||
_send(node, "alice", "coffee weather")
|
||||
node._memory["alice"].last_updated = time.monotonic() - 10.0 # stale
|
||||
_send(node, "bob", "robots motors") # triggers prune
|
||||
self.assertNotIn("alice", node.all_persons())
|
||||
|
||||
def test_fresh_person_not_pruned(self):
|
||||
node = _make_node(self.mod, prune_after_s=1.0)
|
||||
_send(node, "alice", "coffee weather")
|
||||
# alice is fresh (just added)
|
||||
_send(node, "bob", "robots motors")
|
||||
self.assertIn("alice", node.all_persons())
|
||||
|
||||
|
||||
# ── Tests: source and config ──────────────────────────────────────────────────
|
||||
|
||||
class TestNodeSrc(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
with open(_SRC) as f: cls.src = f.read()
|
||||
|
||||
def test_issue_tag(self): self.assertIn("#299", self.src)
|
||||
def test_input_topic(self): self.assertIn("/social/conversation_text", self.src)
|
||||
def test_output_topic(self): self.assertIn("/saltybot/conversation_topics", self.src)
|
||||
def test_extract_keywords(self): self.assertIn("extract_keywords", self.src)
|
||||
def test_person_topic_memory(self):self.assertIn("PersonTopicMemory", self.src)
|
||||
def test_stop_words(self): self.assertIn("STOP_WORDS", self.src)
|
||||
def test_json_output(self): self.assertIn("json.dumps", self.src)
|
||||
def test_person_id_in_output(self):self.assertIn("person_id", self.src)
|
||||
def test_recent_topics_key(self): self.assertIn("recent_topics", self.src)
|
||||
def test_new_topics_key(self): self.assertIn("new_topics", self.src)
|
||||
def test_threading_lock(self): self.assertIn("threading.Lock", self.src)
|
||||
def test_prune_method(self): self.assertIn("_prune_stale", self.src)
|
||||
def test_main_defined(self): self.assertIn("def main", self.src)
|
||||
def test_min_word_length_param(self):self.assertIn("min_word_length", self.src)
|
||||
def test_max_topics_param(self): self.assertIn("max_topics_per_person", self.src)
|
||||
|
||||
|
||||
class TestConfig(unittest.TestCase):
|
||||
_CONFIG = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/config/topic_memory_params.yaml"
|
||||
)
|
||||
_LAUNCH = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/launch/topic_memory.launch.py"
|
||||
)
|
||||
_SETUP = (
|
||||
"/Users/seb/AI/saltylab-firmware/jetson/ros2_ws/src/"
|
||||
"saltybot_social/setup.py"
|
||||
)
|
||||
|
||||
def test_config_exists(self):
|
||||
import os; self.assertTrue(os.path.exists(self._CONFIG))
|
||||
|
||||
def test_config_min_word_length(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("min_word_length", c)
|
||||
|
||||
def test_config_max_topics(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("max_topics_per_person", c)
|
||||
|
||||
def test_config_prune(self):
|
||||
with open(self._CONFIG) as f: c = f.read()
|
||||
self.assertIn("prune_after_s", c)
|
||||
|
||||
def test_launch_exists(self):
|
||||
import os; self.assertTrue(os.path.exists(self._LAUNCH))
|
||||
|
||||
def test_launch_max_topics_arg(self):
|
||||
with open(self._LAUNCH) as f: c = f.read()
|
||||
self.assertIn("max_topics_per_person", c)
|
||||
|
||||
def test_entry_point(self):
|
||||
with open(self._SETUP) as f: c = f.read()
|
||||
self.assertIn("topic_memory_node", c)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
x
Reference in New Issue
Block a user