Compare commits
8 Commits
41040f8bbd
...
532edb835b
| Author | SHA1 | Date | |
|---|---|---|---|
| 532edb835b | |||
| d1a4008451 | |||
| d143a6d156 | |||
| 0d07b09949 | |||
| 03e7995e66 | |||
| 1600691ec5 | |||
| 58bb5ef18e | |||
| e247389b07 |
@ -53,6 +53,21 @@
|
|||||||
#define LED_STRIP_NUM_LEDS 8u // 8-LED ring
|
#define LED_STRIP_NUM_LEDS 8u // 8-LED ring
|
||||||
#define LED_STRIP_FREQ_HZ 800000u // 800 kHz PWM for NeoPixel (1.25 µs per bit)
|
#define LED_STRIP_FREQ_HZ 800000u // 800 kHz PWM for NeoPixel (1.25 µs per bit)
|
||||||
|
|
||||||
|
// --- Servo Pan-Tilt (Issue #206) ---
|
||||||
|
// TIM4_CH1 (PB6) for pan servo, TIM4_CH2 (PB7) for tilt servo
|
||||||
|
#define SERVO_TIM TIM4
|
||||||
|
#define SERVO_PAN_PORT GPIOB
|
||||||
|
#define SERVO_PAN_PIN GPIO_PIN_6 // TIM4_CH1
|
||||||
|
#define SERVO_PAN_CHANNEL TIM_CHANNEL_1
|
||||||
|
#define SERVO_TILT_PORT GPIOB
|
||||||
|
#define SERVO_TILT_PIN GPIO_PIN_7 // TIM4_CH2
|
||||||
|
#define SERVO_TILT_CHANNEL TIM_CHANNEL_2
|
||||||
|
#define SERVO_AF GPIO_AF2_TIM4 // Alternate function
|
||||||
|
#define SERVO_FREQ_HZ 50u // 50 Hz (20ms period, standard servo)
|
||||||
|
#define SERVO_MIN_US 500u // 500µs = 0°
|
||||||
|
#define SERVO_MAX_US 2500u // 2500µs = 180°
|
||||||
|
#define SERVO_CENTER_US 1500u // 1500µs = 90°
|
||||||
|
|
||||||
// --- OSD: MAX7456 (SPI2) ---
|
// --- OSD: MAX7456 (SPI2) ---
|
||||||
#define OSD_SPI SPI2
|
#define OSD_SPI SPI2
|
||||||
#define OSD_CS_PORT GPIOB
|
#define OSD_CS_PORT GPIOB
|
||||||
|
|||||||
110
include/servo.h
Normal file
110
include/servo.h
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
#ifndef SERVO_H
|
||||||
|
#define SERVO_H
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
|
||||||
|
/*
|
||||||
|
* servo.h — Pan-tilt servo driver for camera head (Issue #206)
|
||||||
|
*
|
||||||
|
* Hardware: TIM4 PWM at 50 Hz (20 ms period)
|
||||||
|
* - CH1 (PB6): Pan servo (0-180°)
|
||||||
|
* - CH2 (PB7): Tilt servo (0-180°)
|
||||||
|
*
|
||||||
|
* Servo pulse mapping:
|
||||||
|
* - 500 µs → 0° (full left/down)
|
||||||
|
* - 1500 µs → 90° (center)
|
||||||
|
* - 2500 µs → 180° (full right/up)
|
||||||
|
*
|
||||||
|
* Smooth sweeping via servo_sweep() for camera motion.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/* Servo channels */
|
||||||
|
typedef enum {
|
||||||
|
SERVO_PAN = 0, /* CH1 (PB6) */
|
||||||
|
SERVO_TILT = 1, /* CH2 (PB7) */
|
||||||
|
SERVO_COUNT
|
||||||
|
} ServoChannel;
|
||||||
|
|
||||||
|
/* Servo state */
|
||||||
|
typedef struct {
|
||||||
|
uint16_t current_angle_deg[SERVO_COUNT]; /* Current angle in degrees (0-180) */
|
||||||
|
uint16_t target_angle_deg[SERVO_COUNT]; /* Target angle in degrees */
|
||||||
|
uint16_t pulse_us[SERVO_COUNT]; /* Pulse width in microseconds (500-2500) */
|
||||||
|
uint32_t sweep_start_ms;
|
||||||
|
uint32_t sweep_duration_ms;
|
||||||
|
bool is_sweeping;
|
||||||
|
} ServoState;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* servo_init()
|
||||||
|
*
|
||||||
|
* Initialize TIM4 PWM on PB6 (CH1, pan) and PB7 (CH2, tilt) at 50 Hz.
|
||||||
|
* Centers both servos at 90° (1500 µs). Call once at startup.
|
||||||
|
*/
|
||||||
|
void servo_init(void);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* servo_set_angle(channel, degrees)
|
||||||
|
*
|
||||||
|
* Set target angle for a servo (0-180°).
|
||||||
|
* Immediately updates PWM without motion ramping.
|
||||||
|
* Valid channels: SERVO_PAN, SERVO_TILT
|
||||||
|
*
|
||||||
|
* Examples:
|
||||||
|
* servo_set_angle(SERVO_PAN, 0); // Pan left
|
||||||
|
* servo_set_angle(SERVO_PAN, 90); // Pan center
|
||||||
|
* servo_set_angle(SERVO_TILT, 180); // Tilt up
|
||||||
|
*/
|
||||||
|
void servo_set_angle(ServoChannel channel, uint16_t degrees);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* servo_get_angle(channel)
|
||||||
|
*
|
||||||
|
* Return current servo angle in degrees (0-180).
|
||||||
|
*/
|
||||||
|
uint16_t servo_get_angle(ServoChannel channel);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* servo_set_pulse_us(channel, pulse_us)
|
||||||
|
*
|
||||||
|
* Set servo pulse width directly in microseconds (500-2500).
|
||||||
|
* Used for fine-tuning or direct control.
|
||||||
|
*/
|
||||||
|
void servo_set_pulse_us(ServoChannel channel, uint16_t pulse_us);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* servo_sweep(channel, start_deg, end_deg, duration_ms)
|
||||||
|
*
|
||||||
|
* Smooth linear sweep from start to end angle over duration_ms.
|
||||||
|
* Non-blocking: must call servo_tick() every ~10 ms to update PWM.
|
||||||
|
*
|
||||||
|
* Examples:
|
||||||
|
* servo_sweep(SERVO_PAN, 0, 180, 2000); // Pan left-to-right in 2 seconds
|
||||||
|
* servo_sweep(SERVO_TILT, 45, 135, 1000); // Tilt up-down in 1 second
|
||||||
|
*/
|
||||||
|
void servo_sweep(ServoChannel channel, uint16_t start_deg, uint16_t end_deg, uint32_t duration_ms);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* servo_tick(now_ms)
|
||||||
|
*
|
||||||
|
* Update servo sweep animation (if active). Call every ~10 ms from main loop.
|
||||||
|
* No-op if not currently sweeping.
|
||||||
|
*/
|
||||||
|
void servo_tick(uint32_t now_ms);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* servo_is_sweeping()
|
||||||
|
*
|
||||||
|
* Returns true if any servo is currently sweeping.
|
||||||
|
*/
|
||||||
|
bool servo_is_sweeping(void);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* servo_stop_sweep(channel)
|
||||||
|
*
|
||||||
|
* Stop sweep immediately, hold current position.
|
||||||
|
*/
|
||||||
|
void servo_stop_sweep(ServoChannel channel);
|
||||||
|
|
||||||
|
#endif /* SERVO_H */
|
||||||
Binary file not shown.
Binary file not shown.
@ -0,0 +1,6 @@
|
|||||||
|
person_reid:
|
||||||
|
ros__parameters:
|
||||||
|
model_path: '' # path to MobileNetV2+projection ONNX file (empty = histogram fallback)
|
||||||
|
match_threshold: 0.75 # cosine similarity threshold for re-ID match
|
||||||
|
max_identity_age_s: 300.0 # seconds before unseen identity is pruned
|
||||||
|
publish_hz: 5.0 # publication rate (Hz)
|
||||||
28
jetson/ros2_ws/src/saltybot_person_reid/package.xml
Normal file
28
jetson/ros2_ws/src/saltybot_person_reid/package.xml
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
<?xml version="1.0"?>
|
||||||
|
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||||
|
<package format="3">
|
||||||
|
<name>saltybot_person_reid</name>
|
||||||
|
<version>0.1.0</version>
|
||||||
|
<description>
|
||||||
|
Person re-identification node — cross-camera appearance matching using
|
||||||
|
MobileNetV2 ONNX embeddings (128-dim, cosine similarity gallery).
|
||||||
|
</description>
|
||||||
|
<maintainer email="robot@saltylab.local">SaltyLab</maintainer>
|
||||||
|
<license>MIT</license>
|
||||||
|
|
||||||
|
<depend>rclpy</depend>
|
||||||
|
<depend>sensor_msgs</depend>
|
||||||
|
<depend>vision_msgs</depend>
|
||||||
|
<depend>cv_bridge</depend>
|
||||||
|
<depend>message_filters</depend>
|
||||||
|
<depend>saltybot_person_reid_msgs</depend>
|
||||||
|
|
||||||
|
<exec_depend>python3-numpy</exec_depend>
|
||||||
|
<exec_depend>python3-opencv</exec_depend>
|
||||||
|
|
||||||
|
<test_depend>pytest</test_depend>
|
||||||
|
|
||||||
|
<export>
|
||||||
|
<build_type>ament_python</build_type>
|
||||||
|
</export>
|
||||||
|
</package>
|
||||||
@ -0,0 +1,95 @@
|
|||||||
|
"""
|
||||||
|
_embedding_model.py — Appearance embedding extractor (no ROS2 deps).
|
||||||
|
|
||||||
|
Primary: MobileNetV2 ONNX torso crop → 128-dim L2-normalised embedding.
|
||||||
|
Fallback: 128-bin HSV histogram (H:16 × S:8) when no model file is available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
# Top fraction of the bounding box height used as torso crop
|
||||||
|
_INPUT_SIZE = (128, 256) # (W, H) fed to MobileNetV2
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingModel:
|
||||||
|
"""
|
||||||
|
Extract a 128-dim L2-normalised appearance embedding from a BGR crop.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_path : str or None
|
||||||
|
Path to a MobileNetV2+projection ONNX file. When None (or file
|
||||||
|
not found), falls back to a 128-bin HSV colour histogram.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_path: str | None = None):
|
||||||
|
self._net = None
|
||||||
|
if model_path:
|
||||||
|
try:
|
||||||
|
self._net = cv2.dnn.readNetFromONNX(model_path)
|
||||||
|
except Exception:
|
||||||
|
pass # histogram fallback
|
||||||
|
|
||||||
|
def embed(self, bgr_crop: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
bgr_crop : np.ndarray shape (H, W, 3) uint8
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray shape (128,) float32, L2-normalised
|
||||||
|
"""
|
||||||
|
if bgr_crop.size == 0:
|
||||||
|
return np.zeros(128, dtype=np.float32)
|
||||||
|
|
||||||
|
if self._net is not None:
|
||||||
|
return self._mobilenet_embed(bgr_crop)
|
||||||
|
return self._histogram_embed(bgr_crop)
|
||||||
|
|
||||||
|
# ── MobileNetV2 path ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _mobilenet_embed(self, bgr: np.ndarray) -> np.ndarray:
|
||||||
|
resized = cv2.resize(bgr, _INPUT_SIZE)
|
||||||
|
blob = cv2.dnn.blobFromImage(
|
||||||
|
resized,
|
||||||
|
scalefactor=1.0 / 255.0,
|
||||||
|
size=_INPUT_SIZE,
|
||||||
|
mean=(0.485 * 255, 0.456 * 255, 0.406 * 255),
|
||||||
|
swapRB=True,
|
||||||
|
crop=False,
|
||||||
|
)
|
||||||
|
# Std normalisation: divide channel-wise
|
||||||
|
blob[:, 0] /= 0.229
|
||||||
|
blob[:, 1] /= 0.224
|
||||||
|
blob[:, 2] /= 0.225
|
||||||
|
|
||||||
|
self._net.setInput(blob)
|
||||||
|
feat = self._net.forward().flatten().astype(np.float32)
|
||||||
|
|
||||||
|
# Ensure 128-dim — average-pool if model output differs
|
||||||
|
if feat.shape[0] != 128:
|
||||||
|
n = feat.shape[0]
|
||||||
|
block = max(1, n // 128)
|
||||||
|
feat = feat[: block * 128].reshape(128, block).mean(axis=1)
|
||||||
|
|
||||||
|
return _l2_norm(feat)
|
||||||
|
|
||||||
|
# ── HSV histogram fallback ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _histogram_embed(self, bgr: np.ndarray) -> np.ndarray:
|
||||||
|
"""128-bin HSV histogram: 16 H-bins × 8 S-bins, concatenated."""
|
||||||
|
hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
|
||||||
|
hist = cv2.calcHist(
|
||||||
|
[hsv], [0, 1], None,
|
||||||
|
[16, 8], [0, 180, 0, 256],
|
||||||
|
).flatten().astype(np.float32)
|
||||||
|
return _l2_norm(hist)
|
||||||
|
|
||||||
|
|
||||||
|
def _l2_norm(v: np.ndarray) -> np.ndarray:
|
||||||
|
n = float(np.linalg.norm(v))
|
||||||
|
return v / n if n > 1e-6 else v
|
||||||
@ -0,0 +1,105 @@
|
|||||||
|
"""
|
||||||
|
_reid_gallery.py — Appearance gallery for person re-identification (no ROS2 deps).
|
||||||
|
|
||||||
|
Matches an incoming embedding against stored identities using cosine similarity.
|
||||||
|
New identities are created when the best match falls below the threshold.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Identity:
|
||||||
|
identity_id: int
|
||||||
|
embedding: np.ndarray # shape (D,) L2-normalised
|
||||||
|
last_seen: float = field(default_factory=time.monotonic)
|
||||||
|
hit_count: int = 1
|
||||||
|
|
||||||
|
def update(self, new_embedding: np.ndarray, alpha: float = 0.1) -> None:
|
||||||
|
"""EMA update of the stored embedding, re-normalised after blending."""
|
||||||
|
merged = (1.0 - alpha) * self.embedding + alpha * new_embedding
|
||||||
|
n = float(np.linalg.norm(merged))
|
||||||
|
self.embedding = merged / n if n > 1e-6 else merged
|
||||||
|
self.last_seen = time.monotonic()
|
||||||
|
self.hit_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
class ReidGallery:
|
||||||
|
"""
|
||||||
|
Lightweight cosine-similarity re-ID gallery.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
match_threshold : float
|
||||||
|
Cosine similarity (dot product of unit vectors) required to accept a
|
||||||
|
match. Range [0, 1]; 0 = always new identity, 1 = perfect match only.
|
||||||
|
max_age_s : float
|
||||||
|
Identities not seen for this many seconds are pruned.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
match_threshold: float = 0.75,
|
||||||
|
max_age_s: float = 300.0,
|
||||||
|
):
|
||||||
|
self._threshold = match_threshold
|
||||||
|
self._max_age_s = max_age_s
|
||||||
|
self._identities: List[Identity] = []
|
||||||
|
self._next_id = 1
|
||||||
|
|
||||||
|
def match(self, embedding: np.ndarray) -> Tuple[int, float, bool]:
|
||||||
|
"""
|
||||||
|
Match embedding against the gallery.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
(identity_id, match_score, is_new)
|
||||||
|
identity_id : assigned ID (new or existing)
|
||||||
|
match_score : cosine similarity to best match (0.0 if new)
|
||||||
|
is_new : True if a new identity was created
|
||||||
|
"""
|
||||||
|
self._prune()
|
||||||
|
|
||||||
|
if not self._identities:
|
||||||
|
return self._add_identity(embedding)
|
||||||
|
|
||||||
|
scores = np.array(
|
||||||
|
[float(np.dot(embedding, ident.embedding)) for ident in self._identities]
|
||||||
|
)
|
||||||
|
best_idx = int(np.argmax(scores))
|
||||||
|
best_score = float(scores[best_idx])
|
||||||
|
|
||||||
|
if best_score >= self._threshold:
|
||||||
|
ident = self._identities[best_idx]
|
||||||
|
ident.update(embedding)
|
||||||
|
return ident.identity_id, best_score, False
|
||||||
|
|
||||||
|
return self._add_identity(embedding)
|
||||||
|
|
||||||
|
# ── Internal helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _add_identity(self, embedding: np.ndarray) -> Tuple[int, float, bool]:
|
||||||
|
new_id = self._next_id
|
||||||
|
self._next_id += 1
|
||||||
|
self._identities.append(
|
||||||
|
Identity(identity_id=new_id, embedding=embedding.copy())
|
||||||
|
)
|
||||||
|
return new_id, 0.0, True
|
||||||
|
|
||||||
|
def _prune(self) -> None:
|
||||||
|
now = time.monotonic()
|
||||||
|
self._identities = [
|
||||||
|
ident
|
||||||
|
for ident in self._identities
|
||||||
|
if now - ident.last_seen < self._max_age_s
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self) -> int:
|
||||||
|
return len(self._identities)
|
||||||
@ -0,0 +1,174 @@
|
|||||||
|
"""
|
||||||
|
person_reid_node.py — Person re-identification for cross-camera tracking.
|
||||||
|
|
||||||
|
Subscribes to:
|
||||||
|
/person/detections vision_msgs/Detection2DArray (person bounding boxes)
|
||||||
|
/camera/color/image_raw sensor_msgs/Image (colour frame for crops)
|
||||||
|
|
||||||
|
Publishes:
|
||||||
|
/saltybot/person_reid saltybot_person_reid_msgs/PersonAppearanceArray (5 Hz)
|
||||||
|
|
||||||
|
For each detected person the node:
|
||||||
|
1. Crops the torso region (top 65 % of the bounding box height).
|
||||||
|
2. Extracts a 128-dim L2-normalised embedding via MobileNetV2 ONNX (if the
|
||||||
|
model file is provided) or a 128-bin HSV colour histogram (fallback).
|
||||||
|
3. Matches against a cosine-similarity gallery.
|
||||||
|
4. Assigns a persistent identity_id (new or existing).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model_path str '' Path to MobileNetV2+projection ONNX file
|
||||||
|
match_threshold float 0.75 Cosine similarity threshold for matching
|
||||||
|
max_identity_age_s float 300.0 Seconds before an unseen identity is pruned
|
||||||
|
publish_hz float 5.0 Publication rate (Hz)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile, ReliabilityPolicy, HistoryPolicy
|
||||||
|
|
||||||
|
import message_filters
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from cv_bridge import CvBridge
|
||||||
|
|
||||||
|
from sensor_msgs.msg import Image
|
||||||
|
from vision_msgs.msg import Detection2DArray
|
||||||
|
|
||||||
|
from saltybot_person_reid_msgs.msg import PersonAppearance, PersonAppearanceArray
|
||||||
|
|
||||||
|
from ._embedding_model import EmbeddingModel
|
||||||
|
from ._reid_gallery import ReidGallery
|
||||||
|
|
||||||
|
# Fraction of bbox height kept as torso crop (top portion)
|
||||||
|
_TORSO_FRAC = 0.65
|
||||||
|
|
||||||
|
_BEST_EFFORT_QOS = QoSProfile(
|
||||||
|
reliability=ReliabilityPolicy.BEST_EFFORT,
|
||||||
|
history=HistoryPolicy.KEEP_LAST,
|
||||||
|
depth=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PersonReidNode(Node):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__('person_reid')
|
||||||
|
|
||||||
|
self.declare_parameter('model_path', '')
|
||||||
|
self.declare_parameter('match_threshold', 0.75)
|
||||||
|
self.declare_parameter('max_identity_age_s', 300.0)
|
||||||
|
self.declare_parameter('publish_hz', 5.0)
|
||||||
|
|
||||||
|
model_path = self.get_parameter('model_path').value
|
||||||
|
match_thr = self.get_parameter('match_threshold').value
|
||||||
|
max_age = self.get_parameter('max_identity_age_s').value
|
||||||
|
publish_hz = self.get_parameter('publish_hz').value
|
||||||
|
|
||||||
|
self._bridge = CvBridge()
|
||||||
|
self._embedder = EmbeddingModel(model_path or None)
|
||||||
|
self._gallery = ReidGallery(match_threshold=match_thr, max_age_s=max_age)
|
||||||
|
|
||||||
|
# Buffer: updated by frame callback, drained by timer
|
||||||
|
self._pending: List[PersonAppearance] = []
|
||||||
|
self._pending_header = None
|
||||||
|
|
||||||
|
# Synchronized subscribers
|
||||||
|
det_sub = message_filters.Subscriber(
|
||||||
|
self, Detection2DArray, '/person/detections',
|
||||||
|
qos_profile=_BEST_EFFORT_QOS)
|
||||||
|
img_sub = message_filters.Subscriber(
|
||||||
|
self, Image, '/camera/color/image_raw',
|
||||||
|
qos_profile=_BEST_EFFORT_QOS)
|
||||||
|
self._sync = message_filters.ApproximateTimeSynchronizer(
|
||||||
|
[det_sub, img_sub], queue_size=4, slop=0.1)
|
||||||
|
self._sync.registerCallback(self._on_frame)
|
||||||
|
|
||||||
|
self._pub = self.create_publisher(
|
||||||
|
PersonAppearanceArray, '/saltybot/person_reid', 10)
|
||||||
|
|
||||||
|
self.create_timer(1.0 / publish_hz, self._tick)
|
||||||
|
|
||||||
|
backend = 'ONNX' if self._embedder._net else 'histogram'
|
||||||
|
self.get_logger().info(
|
||||||
|
f'person_reid ready — backend={backend} '
|
||||||
|
f'threshold={match_thr} max_age={max_age}s'
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Frame callback ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _on_frame(self, det_msg: Detection2DArray, img_msg: Image) -> None:
|
||||||
|
if not det_msg.detections:
|
||||||
|
self._pending = []
|
||||||
|
self._pending_header = det_msg.header
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
bgr = self._bridge.imgmsg_to_cv2(img_msg, desired_encoding='bgr8')
|
||||||
|
except Exception as exc:
|
||||||
|
self.get_logger().error(
|
||||||
|
f'imgmsg_to_cv2 failed: {exc}', throttle_duration_sec=5.0)
|
||||||
|
return
|
||||||
|
|
||||||
|
h_img, w_img = bgr.shape[:2]
|
||||||
|
appearances: List[PersonAppearance] = []
|
||||||
|
|
||||||
|
for det in det_msg.detections:
|
||||||
|
cx = det.bbox.center.position.x
|
||||||
|
cy = det.bbox.center.position.y
|
||||||
|
bw = det.bbox.size_x
|
||||||
|
bh = det.bbox.size_y
|
||||||
|
conf = det.results[0].hypothesis.score if det.results else 0.0
|
||||||
|
|
||||||
|
# Torso crop: top TORSO_FRAC of bounding box
|
||||||
|
x1 = max(0, int(cx - bw / 2.0))
|
||||||
|
y1 = max(0, int(cy - bh / 2.0))
|
||||||
|
x2 = min(w_img, int(cx + bw / 2.0))
|
||||||
|
y2 = min(h_img, int(cy - bh / 2.0 + bh * _TORSO_FRAC))
|
||||||
|
|
||||||
|
if x2 - x1 < 8 or y2 - y1 < 8:
|
||||||
|
continue
|
||||||
|
|
||||||
|
crop = bgr[y1:y2, x1:x2]
|
||||||
|
emb = self._embedder.embed(crop)
|
||||||
|
identity_id, match_score, is_new = self._gallery.match(emb)
|
||||||
|
|
||||||
|
app = PersonAppearance()
|
||||||
|
app.header = det_msg.header
|
||||||
|
app.track_id = identity_id
|
||||||
|
app.embedding = emb.tolist()
|
||||||
|
app.bbox = det.bbox
|
||||||
|
app.confidence = float(conf)
|
||||||
|
app.match_score = float(match_score)
|
||||||
|
app.is_new_identity = is_new
|
||||||
|
appearances.append(app)
|
||||||
|
|
||||||
|
self._pending = appearances
|
||||||
|
self._pending_header = det_msg.header
|
||||||
|
|
||||||
|
# ── 5 Hz publish tick ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _tick(self) -> None:
|
||||||
|
if self._pending_header is None:
|
||||||
|
return
|
||||||
|
msg = PersonAppearanceArray()
|
||||||
|
msg.header = self._pending_header
|
||||||
|
msg.appearances = self._pending
|
||||||
|
self._pub.publish(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None):
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = PersonReidNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
4
jetson/ros2_ws/src/saltybot_person_reid/setup.cfg
Normal file
4
jetson/ros2_ws/src/saltybot_person_reid/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
[develop]
|
||||||
|
script_dir=$base/lib/saltybot_person_reid
|
||||||
|
[install]
|
||||||
|
install_scripts=$base/lib/saltybot_person_reid
|
||||||
29
jetson/ros2_ws/src/saltybot_person_reid/setup.py
Normal file
29
jetson/ros2_ws/src/saltybot_person_reid/setup.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from setuptools import setup, find_packages
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
package_name = 'saltybot_person_reid'
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name=package_name,
|
||||||
|
version='0.1.0',
|
||||||
|
packages=find_packages(exclude=['test']),
|
||||||
|
data_files=[
|
||||||
|
('share/ament_index/resource_index/packages',
|
||||||
|
['resource/' + package_name]),
|
||||||
|
('share/' + package_name, ['package.xml']),
|
||||||
|
('share/' + package_name + '/config',
|
||||||
|
glob('config/*.yaml')),
|
||||||
|
],
|
||||||
|
install_requires=['setuptools'],
|
||||||
|
zip_safe=True,
|
||||||
|
maintainer='SaltyLab',
|
||||||
|
maintainer_email='robot@saltylab.local',
|
||||||
|
description='Person re-identification — cross-camera appearance matching',
|
||||||
|
license='MIT',
|
||||||
|
tests_require=['pytest'],
|
||||||
|
entry_points={
|
||||||
|
'console_scripts': [
|
||||||
|
'person_reid = saltybot_person_reid.person_reid_node:main',
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
163
jetson/ros2_ws/src/saltybot_person_reid/test/test_person_reid.py
Normal file
163
jetson/ros2_ws/src/saltybot_person_reid/test/test_person_reid.py
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
"""
|
||||||
|
test_person_reid.py — Unit tests for person re-ID helpers (no ROS2 required).
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- _l2_norm helper
|
||||||
|
- EmbeddingModel (histogram fallback — no model file needed)
|
||||||
|
- ReidGallery cosine-similarity matching
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||||
|
|
||||||
|
from saltybot_person_reid._embedding_model import EmbeddingModel, _l2_norm
|
||||||
|
from saltybot_person_reid._reid_gallery import ReidGallery
|
||||||
|
|
||||||
|
|
||||||
|
# ── _l2_norm ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestL2Norm:
|
||||||
|
|
||||||
|
def test_unit_vector_unchanged(self):
|
||||||
|
v = np.array([1.0, 0.0, 0.0], dtype=np.float32)
|
||||||
|
assert np.allclose(_l2_norm(v), v)
|
||||||
|
|
||||||
|
def test_normalised_to_unit_norm(self):
|
||||||
|
v = np.array([3.0, 4.0], dtype=np.float32)
|
||||||
|
assert abs(np.linalg.norm(_l2_norm(v)) - 1.0) < 1e-6
|
||||||
|
|
||||||
|
def test_zero_vector_does_not_crash(self):
|
||||||
|
v = np.zeros(4, dtype=np.float32)
|
||||||
|
result = _l2_norm(v)
|
||||||
|
assert result.shape == (4,)
|
||||||
|
|
||||||
|
|
||||||
|
# ── EmbeddingModel ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestEmbeddingModel:
|
||||||
|
|
||||||
|
def test_histogram_fallback_shape(self):
|
||||||
|
m = EmbeddingModel(model_path=None)
|
||||||
|
bgr = np.random.randint(0, 255, (100, 50, 3), dtype=np.uint8)
|
||||||
|
emb = m.embed(bgr)
|
||||||
|
assert emb.shape == (128,)
|
||||||
|
|
||||||
|
def test_embedding_is_unit_norm(self):
|
||||||
|
m = EmbeddingModel(model_path=None)
|
||||||
|
bgr = np.random.randint(0, 255, (80, 40, 3), dtype=np.uint8)
|
||||||
|
emb = m.embed(bgr)
|
||||||
|
assert abs(np.linalg.norm(emb) - 1.0) < 1e-5
|
||||||
|
|
||||||
|
def test_empty_crop_returns_zero_vector(self):
|
||||||
|
m = EmbeddingModel(model_path=None)
|
||||||
|
emb = m.embed(np.zeros((0, 0, 3), dtype=np.uint8))
|
||||||
|
assert emb.shape == (128,)
|
||||||
|
assert np.all(emb == 0.0)
|
||||||
|
|
||||||
|
def test_red_and_blue_crops_differ(self):
|
||||||
|
m = EmbeddingModel(model_path=None)
|
||||||
|
red = np.full((80, 40, 3), (0, 0, 200), dtype=np.uint8)
|
||||||
|
blue = np.full((80, 40, 3), (200, 0, 0), dtype=np.uint8)
|
||||||
|
sim = float(np.dot(m.embed(red), m.embed(blue)))
|
||||||
|
assert sim < 0.99
|
||||||
|
|
||||||
|
def test_same_crop_deterministic(self):
|
||||||
|
m = EmbeddingModel(model_path=None)
|
||||||
|
bgr = np.random.randint(0, 255, (80, 40, 3), dtype=np.uint8)
|
||||||
|
assert np.allclose(m.embed(bgr), m.embed(bgr))
|
||||||
|
|
||||||
|
def test_embedding_float32(self):
|
||||||
|
m = EmbeddingModel(model_path=None)
|
||||||
|
bgr = np.random.randint(0, 255, (60, 30, 3), dtype=np.uint8)
|
||||||
|
emb = m.embed(bgr)
|
||||||
|
assert emb.dtype == np.float32
|
||||||
|
|
||||||
|
|
||||||
|
# ── ReidGallery ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _unit(dim: int = 128, seed: int | None = None) -> np.ndarray:
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
v = rng.standard_normal(dim).astype(np.float32)
|
||||||
|
return v / np.linalg.norm(v)
|
||||||
|
|
||||||
|
|
||||||
|
class TestReidGallery:
|
||||||
|
|
||||||
|
def test_first_match_creates_identity(self):
|
||||||
|
g = ReidGallery(match_threshold=0.75)
|
||||||
|
uid, score, is_new = g.match(_unit(seed=0))
|
||||||
|
assert uid == 1
|
||||||
|
assert is_new is True
|
||||||
|
assert score == pytest.approx(0.0)
|
||||||
|
|
||||||
|
def test_identical_embedding_matches(self):
|
||||||
|
g = ReidGallery(match_threshold=0.75)
|
||||||
|
emb = _unit(seed=1)
|
||||||
|
g.match(emb)
|
||||||
|
uid2, score2, is_new2 = g.match(emb)
|
||||||
|
assert uid2 == 1
|
||||||
|
assert is_new2 is False
|
||||||
|
assert score2 > 0.99
|
||||||
|
|
||||||
|
def test_orthogonal_embeddings_create_new_id(self):
|
||||||
|
g = ReidGallery(match_threshold=0.75)
|
||||||
|
e1 = np.zeros(128, dtype=np.float32); e1[0] = 1.0
|
||||||
|
e2 = np.zeros(128, dtype=np.float32); e2[64] = 1.0
|
||||||
|
uid1, _, new1 = g.match(e1)
|
||||||
|
uid2, _, new2 = g.match(e2)
|
||||||
|
assert uid1 != uid2
|
||||||
|
assert new2 is True
|
||||||
|
|
||||||
|
def test_ids_are_monotonically_increasing(self):
|
||||||
|
# threshold > 1.0 is unreachable → every embedding creates a new identity
|
||||||
|
g = ReidGallery(match_threshold=2.0)
|
||||||
|
ids = [g.match(_unit(seed=i))[0] for i in range(5)]
|
||||||
|
assert ids == list(range(1, 6))
|
||||||
|
|
||||||
|
def test_gallery_size_increments_for_new_ids(self):
|
||||||
|
g = ReidGallery(match_threshold=2.0)
|
||||||
|
for i in range(4):
|
||||||
|
g.match(_unit(seed=i))
|
||||||
|
assert g.size == 4
|
||||||
|
|
||||||
|
def test_prune_removes_stale_identities(self):
|
||||||
|
g = ReidGallery(match_threshold=0.75, max_age_s=0.01)
|
||||||
|
g.match(_unit(seed=0))
|
||||||
|
time.sleep(0.05)
|
||||||
|
g._prune()
|
||||||
|
assert g.size == 0
|
||||||
|
|
||||||
|
def test_empty_gallery_prune_is_safe(self):
|
||||||
|
g = ReidGallery()
|
||||||
|
g._prune()
|
||||||
|
assert g.size == 0
|
||||||
|
|
||||||
|
def test_match_below_threshold_increments_id(self):
|
||||||
|
g = ReidGallery(match_threshold=0.99)
|
||||||
|
# Two random unit vectors are almost certainly < 0.99 similar
|
||||||
|
e1, e2 = _unit(seed=10), _unit(seed=20)
|
||||||
|
uid1, _, _ = g.match(e1)
|
||||||
|
uid2, _, _ = g.match(e2)
|
||||||
|
# uid2 may or may not equal uid1 depending on random similarity,
|
||||||
|
# but both must be valid positive integers
|
||||||
|
assert uid1 >= 1
|
||||||
|
assert uid2 >= 1
|
||||||
|
|
||||||
|
def test_identity_update_does_not_change_id(self):
|
||||||
|
g = ReidGallery(match_threshold=0.5)
|
||||||
|
emb = _unit(seed=5)
|
||||||
|
uid_first, _, _ = g.match(emb)
|
||||||
|
for _ in range(10):
|
||||||
|
g.match(emb)
|
||||||
|
uid_last, _, _ = g.match(emb)
|
||||||
|
assert uid_last == uid_first
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main([__file__, '-v'])
|
||||||
16
jetson/ros2_ws/src/saltybot_person_reid_msgs/CMakeLists.txt
Normal file
16
jetson/ros2_ws/src/saltybot_person_reid_msgs/CMakeLists.txt
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.8)
|
||||||
|
project(saltybot_person_reid_msgs)
|
||||||
|
|
||||||
|
find_package(ament_cmake REQUIRED)
|
||||||
|
find_package(rosidl_default_generators REQUIRED)
|
||||||
|
find_package(std_msgs REQUIRED)
|
||||||
|
find_package(vision_msgs REQUIRED)
|
||||||
|
|
||||||
|
rosidl_generate_interfaces(${PROJECT_NAME}
|
||||||
|
"msg/PersonAppearance.msg"
|
||||||
|
"msg/PersonAppearanceArray.msg"
|
||||||
|
DEPENDENCIES std_msgs vision_msgs
|
||||||
|
)
|
||||||
|
|
||||||
|
ament_export_dependencies(rosidl_default_runtime)
|
||||||
|
ament_package()
|
||||||
@ -0,0 +1,7 @@
|
|||||||
|
std_msgs/Header header
|
||||||
|
uint32 track_id
|
||||||
|
float32[] embedding
|
||||||
|
vision_msgs/BoundingBox2D bbox
|
||||||
|
float32 confidence
|
||||||
|
float32 match_score
|
||||||
|
bool is_new_identity
|
||||||
@ -0,0 +1,2 @@
|
|||||||
|
std_msgs/Header header
|
||||||
|
PersonAppearance[] appearances
|
||||||
22
jetson/ros2_ws/src/saltybot_person_reid_msgs/package.xml
Normal file
22
jetson/ros2_ws/src/saltybot_person_reid_msgs/package.xml
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
<?xml version="1.0"?>
|
||||||
|
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||||
|
<package format="3">
|
||||||
|
<name>saltybot_person_reid_msgs</name>
|
||||||
|
<version>0.1.0</version>
|
||||||
|
<description>Message types for person re-identification.</description>
|
||||||
|
<maintainer email="robot@saltylab.local">SaltyLab</maintainer>
|
||||||
|
<license>MIT</license>
|
||||||
|
|
||||||
|
<buildtool_depend>ament_cmake</buildtool_depend>
|
||||||
|
<buildtool_depend>rosidl_default_generators</buildtool_depend>
|
||||||
|
|
||||||
|
<depend>std_msgs</depend>
|
||||||
|
<depend>vision_msgs</depend>
|
||||||
|
|
||||||
|
<exec_depend>rosidl_default_runtime</exec_depend>
|
||||||
|
<member_of_group>rosidl_interface_packages</member_of_group>
|
||||||
|
|
||||||
|
<export>
|
||||||
|
<build_type>ament_cmake</build_type>
|
||||||
|
</export>
|
||||||
|
</package>
|
||||||
@ -0,0 +1,6 @@
|
|||||||
|
thermal_node:
|
||||||
|
ros__parameters:
|
||||||
|
publish_rate_hz: 1.0 # Hz — publish rate for /saltybot/thermal
|
||||||
|
warn_temp_c: 75.0 # Log WARN above this temperature (°C)
|
||||||
|
throttle_temp_c: 85.0 # Log ERROR + set throttled=true above this (°C)
|
||||||
|
thermal_root: "/sys/class/thermal" # Sysfs thermal root; override for tests
|
||||||
42
jetson/ros2_ws/src/saltybot_thermal/launch/thermal.launch.py
Normal file
42
jetson/ros2_ws/src/saltybot_thermal/launch/thermal.launch.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
"""thermal.launch.py — Launch the Jetson thermal monitor (Issue #205).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
ros2 launch saltybot_thermal thermal.launch.py
|
||||||
|
ros2 launch saltybot_thermal thermal.launch.py warn_temp_c:=70.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from ament_index_python.packages import get_package_share_directory
|
||||||
|
from launch import LaunchDescription
|
||||||
|
from launch.actions import DeclareLaunchArgument
|
||||||
|
from launch.substitutions import LaunchConfiguration
|
||||||
|
from launch_ros.actions import Node
|
||||||
|
|
||||||
|
|
||||||
|
def generate_launch_description():
|
||||||
|
pkg = get_package_share_directory("saltybot_thermal")
|
||||||
|
cfg = os.path.join(pkg, "config", "thermal_params.yaml")
|
||||||
|
|
||||||
|
return LaunchDescription([
|
||||||
|
DeclareLaunchArgument("publish_rate_hz", default_value="1.0",
|
||||||
|
description="Publish rate (Hz)"),
|
||||||
|
DeclareLaunchArgument("warn_temp_c", default_value="75.0",
|
||||||
|
description="WARN threshold (°C)"),
|
||||||
|
DeclareLaunchArgument("throttle_temp_c", default_value="85.0",
|
||||||
|
description="THROTTLE threshold (°C)"),
|
||||||
|
|
||||||
|
Node(
|
||||||
|
package="saltybot_thermal",
|
||||||
|
executable="thermal_node",
|
||||||
|
name="thermal_node",
|
||||||
|
output="screen",
|
||||||
|
parameters=[
|
||||||
|
cfg,
|
||||||
|
{
|
||||||
|
"publish_rate_hz": LaunchConfiguration("publish_rate_hz"),
|
||||||
|
"warn_temp_c": LaunchConfiguration("warn_temp_c"),
|
||||||
|
"throttle_temp_c": LaunchConfiguration("throttle_temp_c"),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
])
|
||||||
26
jetson/ros2_ws/src/saltybot_thermal/package.xml
Normal file
26
jetson/ros2_ws/src/saltybot_thermal/package.xml
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
<?xml version="1.0"?>
|
||||||
|
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
|
||||||
|
<package format="3">
|
||||||
|
<name>saltybot_thermal</name>
|
||||||
|
<version>0.1.0</version>
|
||||||
|
<description>
|
||||||
|
Jetson thermal monitor (Issue #205). Reads /sys/class/thermal/thermal_zone*,
|
||||||
|
publishes /saltybot/thermal JSON at 1 Hz, warns at 75 °C, throttles at 85 °C.
|
||||||
|
</description>
|
||||||
|
<maintainer email="sl-jetson@saltylab.local">sl-jetson</maintainer>
|
||||||
|
<license>MIT</license>
|
||||||
|
|
||||||
|
<depend>rclpy</depend>
|
||||||
|
<depend>std_msgs</depend>
|
||||||
|
|
||||||
|
<buildtool_depend>ament_python</buildtool_depend>
|
||||||
|
|
||||||
|
<test_depend>ament_copyright</test_depend>
|
||||||
|
<test_depend>ament_flake8</test_depend>
|
||||||
|
<test_depend>ament_pep257</test_depend>
|
||||||
|
<test_depend>python3-pytest</test_depend>
|
||||||
|
|
||||||
|
<export>
|
||||||
|
<build_type>ament_python</build_type>
|
||||||
|
</export>
|
||||||
|
</package>
|
||||||
@ -0,0 +1,139 @@
|
|||||||
|
"""thermal_node.py — Jetson CPU/GPU thermal monitor.
|
||||||
|
Issue #205
|
||||||
|
|
||||||
|
Reads every /sys/class/thermal/thermal_zone* sysfs entry, publishes a JSON
|
||||||
|
blob on /saltybot/thermal at a configurable rate (default 1 Hz), and logs
|
||||||
|
ROS2 WARN / ERROR when zone temperatures exceed configurable thresholds.
|
||||||
|
|
||||||
|
Published topic:
|
||||||
|
/saltybot/thermal (std_msgs/String, JSON)
|
||||||
|
|
||||||
|
JSON schema:
|
||||||
|
{
|
||||||
|
"ts": <float unix seconds>,
|
||||||
|
"zones": [
|
||||||
|
{"zone": "CPU-therm", "index": 0, "temp_c": 42.5},
|
||||||
|
...
|
||||||
|
],
|
||||||
|
"max_temp_c": 55.0,
|
||||||
|
"throttled": false,
|
||||||
|
"warn": false
|
||||||
|
}
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
publish_rate_hz (float, 1.0) — publish rate
|
||||||
|
warn_temp_c (float, 75.0) — log WARN above this temperature
|
||||||
|
throttle_temp_c (float, 85.0) — log ERROR and set throttled=true above this
|
||||||
|
thermal_root (str, "/sys/class/thermal") — sysfs thermal root (override for tests)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import rclpy
|
||||||
|
from rclpy.node import Node
|
||||||
|
from rclpy.qos import QoSProfile
|
||||||
|
from std_msgs.msg import String
|
||||||
|
|
||||||
|
|
||||||
|
def read_thermal_zones(root: str) -> List[dict]:
|
||||||
|
"""Return a list of {zone, index, temp_c} dicts from sysfs."""
|
||||||
|
zones = []
|
||||||
|
try:
|
||||||
|
entries = sorted(os.listdir(root))
|
||||||
|
except OSError:
|
||||||
|
return zones
|
||||||
|
for entry in entries:
|
||||||
|
if not entry.startswith("thermal_zone"):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
idx = int(entry[len("thermal_zone"):])
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
zone_dir = os.path.join(root, entry)
|
||||||
|
try:
|
||||||
|
with open(os.path.join(zone_dir, "type")) as f:
|
||||||
|
zone_type = f.read().strip()
|
||||||
|
except OSError:
|
||||||
|
zone_type = entry
|
||||||
|
try:
|
||||||
|
with open(os.path.join(zone_dir, "temp")) as f:
|
||||||
|
temp_mc = int(f.read().strip()) # millidegrees Celsius
|
||||||
|
temp_c = round(temp_mc / 1000.0, 1)
|
||||||
|
except (OSError, ValueError):
|
||||||
|
continue
|
||||||
|
zones.append({"zone": zone_type, "index": idx, "temp_c": temp_c})
|
||||||
|
return zones
|
||||||
|
|
||||||
|
|
||||||
|
class ThermalNode(Node):
|
||||||
|
"""Reads Jetson thermal zones and publishes /saltybot/thermal at 1 Hz."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__("thermal_node")
|
||||||
|
|
||||||
|
self.declare_parameter("publish_rate_hz", 1.0)
|
||||||
|
self.declare_parameter("warn_temp_c", 75.0)
|
||||||
|
self.declare_parameter("throttle_temp_c", 85.0)
|
||||||
|
self.declare_parameter("thermal_root", "/sys/class/thermal")
|
||||||
|
|
||||||
|
self._rate = self.get_parameter("publish_rate_hz").value
|
||||||
|
self._warn_t = self.get_parameter("warn_temp_c").value
|
||||||
|
self._throttle_t = self.get_parameter("throttle_temp_c").value
|
||||||
|
self._root = self.get_parameter("thermal_root").value
|
||||||
|
|
||||||
|
qos = QoSProfile(depth=10)
|
||||||
|
self._pub = self.create_publisher(String, "/saltybot/thermal", qos)
|
||||||
|
self._timer = self.create_timer(1.0 / self._rate, self._publish)
|
||||||
|
|
||||||
|
self.get_logger().info(
|
||||||
|
f"ThermalNode ready (rate={self._rate} Hz, "
|
||||||
|
f"warn={self._warn_t}°C, throttle={self._throttle_t}°C, "
|
||||||
|
f"root={self._root})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _publish(self) -> None:
|
||||||
|
zones = read_thermal_zones(self._root)
|
||||||
|
if not zones:
|
||||||
|
self.get_logger().warn("No thermal zones found — check thermal_root param")
|
||||||
|
return
|
||||||
|
|
||||||
|
max_temp = max(z["temp_c"] for z in zones)
|
||||||
|
throttled = max_temp >= self._throttle_t
|
||||||
|
warn = max_temp >= self._warn_t
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"ts": time.time(),
|
||||||
|
"zones": zones,
|
||||||
|
"max_temp_c": max_temp,
|
||||||
|
"throttled": throttled,
|
||||||
|
"warn": warn,
|
||||||
|
}
|
||||||
|
msg = String()
|
||||||
|
msg.data = json.dumps(payload)
|
||||||
|
self._pub.publish(msg)
|
||||||
|
|
||||||
|
if throttled:
|
||||||
|
self.get_logger().error(
|
||||||
|
f"THERMAL THROTTLE: {max_temp}°C >= {self._throttle_t}°C"
|
||||||
|
)
|
||||||
|
elif warn:
|
||||||
|
self.get_logger().warn(
|
||||||
|
f"Thermal warning: {max_temp}°C >= {self._warn_t}°C"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: Optional[list] = None) -> None:
|
||||||
|
rclpy.init(args=args)
|
||||||
|
node = ThermalNode()
|
||||||
|
try:
|
||||||
|
rclpy.spin(node)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
node.destroy_node()
|
||||||
|
rclpy.shutdown()
|
||||||
4
jetson/ros2_ws/src/saltybot_thermal/setup.cfg
Normal file
4
jetson/ros2_ws/src/saltybot_thermal/setup.cfg
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
[develop]
|
||||||
|
script_dir=$base/lib/saltybot_thermal
|
||||||
|
[egg_info]
|
||||||
|
tag_date = 0
|
||||||
27
jetson/ros2_ws/src/saltybot_thermal/setup.py
Normal file
27
jetson/ros2_ws/src/saltybot_thermal/setup.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from setuptools import setup
|
||||||
|
|
||||||
|
package_name = "saltybot_thermal"
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name=package_name,
|
||||||
|
version="0.1.0",
|
||||||
|
packages=[package_name],
|
||||||
|
data_files=[
|
||||||
|
("share/ament_index/resource_index/packages", [f"resource/{package_name}"]),
|
||||||
|
(f"share/{package_name}", ["package.xml"]),
|
||||||
|
(f"share/{package_name}/launch", ["launch/thermal.launch.py"]),
|
||||||
|
(f"share/{package_name}/config", ["config/thermal_params.yaml"]),
|
||||||
|
],
|
||||||
|
install_requires=["setuptools"],
|
||||||
|
zip_safe=True,
|
||||||
|
maintainer="sl-jetson",
|
||||||
|
maintainer_email="sl-jetson@saltylab.local",
|
||||||
|
description="Jetson thermal monitor — /saltybot/thermal JSON at 1 Hz",
|
||||||
|
license="MIT",
|
||||||
|
tests_require=["pytest"],
|
||||||
|
entry_points={
|
||||||
|
"console_scripts": [
|
||||||
|
"thermal_node = saltybot_thermal.thermal_node:main",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
303
jetson/ros2_ws/src/saltybot_thermal/test/test_thermal.py
Normal file
303
jetson/ros2_ws/src/saltybot_thermal/test/test_thermal.py
Normal file
@ -0,0 +1,303 @@
|
|||||||
|
"""test_thermal.py -- Unit tests for Issue #205 Jetson thermal monitor."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import json, os, time
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def _pkg_root():
|
||||||
|
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
|
||||||
|
def _read_src(rel_path):
|
||||||
|
with open(os.path.join(_pkg_root(), rel_path)) as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Import the sysfs reader (no ROS required) ─────────────────────────────────
|
||||||
|
|
||||||
|
def _import_reader():
|
||||||
|
import importlib.util, sys, types
|
||||||
|
|
||||||
|
# Build minimal ROS2 stubs so thermal_node.py imports without a ROS install
|
||||||
|
def _stub(name):
|
||||||
|
m = types.ModuleType(name)
|
||||||
|
sys.modules[name] = m
|
||||||
|
return m
|
||||||
|
|
||||||
|
rclpy_mod = _stub("rclpy")
|
||||||
|
rclpy_node_mod = _stub("rclpy.node")
|
||||||
|
rclpy_qos_mod = _stub("rclpy.qos")
|
||||||
|
std_msgs_mod = _stub("std_msgs")
|
||||||
|
std_msg_mod = _stub("std_msgs.msg")
|
||||||
|
|
||||||
|
class _Node:
|
||||||
|
def __init__(self, *a, **kw): pass
|
||||||
|
def declare_parameter(self, *a, **kw): pass
|
||||||
|
def get_parameter(self, name):
|
||||||
|
class _P:
|
||||||
|
value = None
|
||||||
|
return _P()
|
||||||
|
def create_publisher(self, *a, **kw): return None
|
||||||
|
def create_timer(self, *a, **kw): return None
|
||||||
|
def get_logger(self):
|
||||||
|
class _L:
|
||||||
|
def info(self, *a): pass
|
||||||
|
def warn(self, *a): pass
|
||||||
|
def error(self, *a): pass
|
||||||
|
return _L()
|
||||||
|
def destroy_node(self): pass
|
||||||
|
|
||||||
|
class _QoSProfile:
|
||||||
|
def __init__(self, **kw): pass
|
||||||
|
|
||||||
|
class _String:
|
||||||
|
data = ""
|
||||||
|
|
||||||
|
rclpy_node_mod.Node = _Node
|
||||||
|
rclpy_qos_mod.QoSProfile = _QoSProfile
|
||||||
|
std_msg_mod.String = _String
|
||||||
|
rclpy_mod.init = lambda *a, **kw: None
|
||||||
|
rclpy_mod.spin = lambda node: None
|
||||||
|
rclpy_mod.ok = lambda: True
|
||||||
|
rclpy_mod.shutdown = lambda: None
|
||||||
|
|
||||||
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
"thermal_node_testmod",
|
||||||
|
os.path.join(_pkg_root(), "saltybot_thermal", "thermal_node.py"),
|
||||||
|
)
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(mod)
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
# ── Sysfs fixture helpers ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _make_zone(root, idx, zone_type, temp_mc):
|
||||||
|
"""Create a fake thermal_zone<idx> directory under root."""
|
||||||
|
zdir = os.path.join(str(root), "thermal_zone{}".format(idx))
|
||||||
|
os.makedirs(zdir, exist_ok=True)
|
||||||
|
with open(os.path.join(zdir, "type"), "w") as f:
|
||||||
|
f.write(zone_type + "\n")
|
||||||
|
with open(os.path.join(zdir, "temp"), "w") as f:
|
||||||
|
f.write(str(temp_mc) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
# ── read_thermal_zones ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestReadThermalZones:
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def mod(self):
|
||||||
|
return _import_reader()
|
||||||
|
|
||||||
|
def test_empty_dir(self, mod, tmp_path):
|
||||||
|
assert mod.read_thermal_zones(str(tmp_path)) == []
|
||||||
|
|
||||||
|
def test_missing_dir(self, mod):
|
||||||
|
assert mod.read_thermal_zones("/nonexistent/path/xyz") == []
|
||||||
|
|
||||||
|
def test_single_zone(self, mod, tmp_path):
|
||||||
|
_make_zone(tmp_path, 0, "CPU-therm", 45000)
|
||||||
|
zones = mod.read_thermal_zones(str(tmp_path))
|
||||||
|
assert len(zones) == 1
|
||||||
|
assert zones[0]["zone"] == "CPU-therm"
|
||||||
|
assert zones[0]["temp_c"] == 45.0
|
||||||
|
assert zones[0]["index"] == 0
|
||||||
|
|
||||||
|
def test_temp_millidegrees_conversion(self, mod, tmp_path):
|
||||||
|
_make_zone(tmp_path, 0, "GPU-therm", 72500)
|
||||||
|
zones = mod.read_thermal_zones(str(tmp_path))
|
||||||
|
assert zones[0]["temp_c"] == 72.5
|
||||||
|
|
||||||
|
def test_multiple_zones(self, mod, tmp_path):
|
||||||
|
_make_zone(tmp_path, 0, "CPU-therm", 40000)
|
||||||
|
_make_zone(tmp_path, 1, "GPU-therm", 55000)
|
||||||
|
_make_zone(tmp_path, 2, "PMIC-Die", 38000)
|
||||||
|
zones = mod.read_thermal_zones(str(tmp_path))
|
||||||
|
assert len(zones) == 3
|
||||||
|
|
||||||
|
def test_sorted_by_index(self, mod, tmp_path):
|
||||||
|
_make_zone(tmp_path, 2, "Z2", 20000)
|
||||||
|
_make_zone(tmp_path, 0, "Z0", 10000)
|
||||||
|
_make_zone(tmp_path, 1, "Z1", 15000)
|
||||||
|
zones = mod.read_thermal_zones(str(tmp_path))
|
||||||
|
indices = [z["index"] for z in zones]
|
||||||
|
assert indices == sorted(indices)
|
||||||
|
|
||||||
|
def test_skips_non_zone_entries(self, mod, tmp_path):
|
||||||
|
os.makedirs(os.path.join(str(tmp_path), "cooling_device0"))
|
||||||
|
_make_zone(tmp_path, 0, "CPU-therm", 40000)
|
||||||
|
zones = mod.read_thermal_zones(str(tmp_path))
|
||||||
|
assert len(zones) == 1
|
||||||
|
|
||||||
|
def test_skips_zone_without_temp(self, mod, tmp_path):
|
||||||
|
zdir = os.path.join(str(tmp_path), "thermal_zone0")
|
||||||
|
os.makedirs(zdir)
|
||||||
|
with open(os.path.join(zdir, "type"), "w") as f:
|
||||||
|
f.write("CPU-therm\n")
|
||||||
|
# No temp file — should be skipped
|
||||||
|
zones = mod.read_thermal_zones(str(tmp_path))
|
||||||
|
assert zones == []
|
||||||
|
|
||||||
|
def test_zone_type_fallback(self, mod, tmp_path):
|
||||||
|
"""Zone without type file falls back to directory name."""
|
||||||
|
zdir = os.path.join(str(tmp_path), "thermal_zone0")
|
||||||
|
os.makedirs(zdir)
|
||||||
|
with open(os.path.join(zdir, "temp"), "w") as f:
|
||||||
|
f.write("40000\n")
|
||||||
|
zones = mod.read_thermal_zones(str(tmp_path))
|
||||||
|
assert len(zones) == 1
|
||||||
|
assert zones[0]["zone"] == "thermal_zone0"
|
||||||
|
|
||||||
|
def test_temp_rounding(self, mod, tmp_path):
|
||||||
|
_make_zone(tmp_path, 0, "CPU-therm", 72333)
|
||||||
|
zones = mod.read_thermal_zones(str(tmp_path))
|
||||||
|
assert zones[0]["temp_c"] == 72.3
|
||||||
|
|
||||||
|
|
||||||
|
# ── Threshold logic (pure Python) ────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestThresholds:
|
||||||
|
def _classify(self, temp_c, warn_t=75.0, throttle_t=85.0):
|
||||||
|
throttled = temp_c >= throttle_t
|
||||||
|
warn = temp_c >= warn_t
|
||||||
|
return throttled, warn
|
||||||
|
|
||||||
|
def test_normal(self):
|
||||||
|
t, w = self._classify(50.0)
|
||||||
|
assert not t and not w
|
||||||
|
|
||||||
|
def test_warn_boundary(self):
|
||||||
|
t, w = self._classify(75.0)
|
||||||
|
assert not t and w
|
||||||
|
|
||||||
|
def test_below_warn(self):
|
||||||
|
t, w = self._classify(74.9)
|
||||||
|
assert not t and not w
|
||||||
|
|
||||||
|
def test_throttle_boundary(self):
|
||||||
|
t, w = self._classify(85.0)
|
||||||
|
assert t and w
|
||||||
|
|
||||||
|
def test_above_throttle(self):
|
||||||
|
t, w = self._classify(90.0)
|
||||||
|
assert t and w
|
||||||
|
|
||||||
|
def test_custom_thresholds(self):
|
||||||
|
t, w = self._classify(70.0, warn_t=70.0, throttle_t=80.0)
|
||||||
|
assert not t and w
|
||||||
|
|
||||||
|
def test_max_temp_drives_status(self):
|
||||||
|
zones = [{"temp_c": 40.0}, {"temp_c": 86.0}, {"temp_c": 55.0}]
|
||||||
|
max_t = max(z["temp_c"] for z in zones)
|
||||||
|
assert max_t == 86.0
|
||||||
|
t, w = self._classify(max_t)
|
||||||
|
assert t and w
|
||||||
|
|
||||||
|
|
||||||
|
# ── JSON payload schema ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestJsonPayload:
|
||||||
|
def _make_payload(self, zones, warn_t=75.0, throttle_t=85.0):
|
||||||
|
max_temp = max(z["temp_c"] for z in zones) if zones else 0.0
|
||||||
|
return {
|
||||||
|
"ts": time.time(),
|
||||||
|
"zones": zones,
|
||||||
|
"max_temp_c": max_temp,
|
||||||
|
"throttled": max_temp >= throttle_t,
|
||||||
|
"warn": max_temp >= warn_t,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_has_ts(self):
|
||||||
|
p = self._make_payload([{"zone": "CPU", "index": 0, "temp_c": 40.0}])
|
||||||
|
assert "ts" in p and isinstance(p["ts"], float)
|
||||||
|
|
||||||
|
def test_has_zones(self):
|
||||||
|
p = self._make_payload([{"zone": "CPU", "index": 0, "temp_c": 40.0}])
|
||||||
|
assert "zones" in p and len(p["zones"]) == 1
|
||||||
|
|
||||||
|
def test_has_max_temp(self):
|
||||||
|
p = self._make_payload([{"zone": "CPU", "index": 0, "temp_c": 55.0}])
|
||||||
|
assert p["max_temp_c"] == 55.0
|
||||||
|
|
||||||
|
def test_throttled_false_below(self):
|
||||||
|
p = self._make_payload([{"zone": "CPU", "index": 0, "temp_c": 60.0}])
|
||||||
|
assert p["throttled"] is False
|
||||||
|
|
||||||
|
def test_warn_true_at_threshold(self):
|
||||||
|
p = self._make_payload([{"zone": "CPU", "index": 0, "temp_c": 75.0}])
|
||||||
|
assert p["warn"] is True and p["throttled"] is False
|
||||||
|
|
||||||
|
def test_throttled_true_above(self):
|
||||||
|
p = self._make_payload([{"zone": "CPU", "index": 0, "temp_c": 90.0}])
|
||||||
|
assert p["throttled"] is True
|
||||||
|
|
||||||
|
def test_json_serializable(self):
|
||||||
|
zones = [{"zone": "CPU", "index": 0, "temp_c": 50.0}]
|
||||||
|
p = self._make_payload(zones)
|
||||||
|
blob = json.dumps(p)
|
||||||
|
parsed = json.loads(blob)
|
||||||
|
assert parsed["max_temp_c"] == 50.0
|
||||||
|
|
||||||
|
def test_multi_zone_max(self):
|
||||||
|
zones = [
|
||||||
|
{"zone": "CPU-therm", "index": 0, "temp_c": 55.0},
|
||||||
|
{"zone": "GPU-therm", "index": 1, "temp_c": 78.0},
|
||||||
|
{"zone": "PMIC-Die", "index": 2, "temp_c": 38.0},
|
||||||
|
]
|
||||||
|
p = self._make_payload(zones)
|
||||||
|
assert p["max_temp_c"] == 78.0
|
||||||
|
assert p["warn"] is True
|
||||||
|
assert p["throttled"] is False
|
||||||
|
|
||||||
|
|
||||||
|
# ── Node source checks ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestNodeSrc:
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def src(self):
|
||||||
|
return _read_src("saltybot_thermal/thermal_node.py")
|
||||||
|
|
||||||
|
def test_class_defined(self, src): assert "class ThermalNode" in src
|
||||||
|
def test_publish_rate_param(self, src): assert '"publish_rate_hz"' in src
|
||||||
|
def test_warn_param(self, src): assert '"warn_temp_c"' in src
|
||||||
|
def test_throttle_param(self, src): assert '"throttle_temp_c"' in src
|
||||||
|
def test_thermal_root_param(self, src): assert '"thermal_root"' in src
|
||||||
|
def test_topic(self, src): assert '"/saltybot/thermal"' in src
|
||||||
|
def test_read_fn(self, src): assert "read_thermal_zones" in src
|
||||||
|
def test_warn_log(self, src): assert "warn" in src.lower()
|
||||||
|
def test_error_log(self, src): assert "error" in src.lower()
|
||||||
|
def test_throttled_flag(self, src): assert '"throttled"' in src
|
||||||
|
def test_warn_flag(self, src): assert '"warn"' in src
|
||||||
|
def test_max_temp(self, src): assert '"max_temp_c"' in src
|
||||||
|
def test_millidegrees(self, src): assert "1000" in src
|
||||||
|
def test_json_dumps(self, src): assert "json.dumps" in src
|
||||||
|
def test_issue_tag(self, src): assert "205" in src
|
||||||
|
def test_main(self, src): assert "def main" in src
|
||||||
|
def test_sysfs_path(self, src): assert "/sys/class/thermal" in src
|
||||||
|
|
||||||
|
|
||||||
|
# ── Package metadata ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestPackageMeta:
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def pkg_xml(self):
|
||||||
|
return _read_src("package.xml")
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def setup_py(self):
|
||||||
|
return _read_src("setup.py")
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def cfg(self):
|
||||||
|
return _read_src("config/thermal_params.yaml")
|
||||||
|
|
||||||
|
def test_pkg_name(self, pkg_xml): assert "saltybot_thermal" in pkg_xml
|
||||||
|
def test_issue_tag(self, pkg_xml): assert "205" in pkg_xml
|
||||||
|
def test_entry_point(self, setup_py): assert "thermal_node = saltybot_thermal.thermal_node:main" in setup_py
|
||||||
|
def test_cfg_node_name(self, cfg): assert "thermal_node:" in cfg
|
||||||
|
def test_cfg_warn(self, cfg): assert "warn_temp_c" in cfg
|
||||||
|
def test_cfg_throttle(self, cfg): assert "throttle_temp_c" in cfg
|
||||||
|
def test_cfg_rate(self, cfg): assert "publish_rate_hz" in cfg
|
||||||
|
def test_cfg_defaults(self, cfg):
|
||||||
|
assert "75.0" in cfg and "85.0" in cfg and "1.0" in cfg
|
||||||
@ -21,6 +21,7 @@
|
|||||||
#include "audio.h"
|
#include "audio.h"
|
||||||
#include "buzzer.h"
|
#include "buzzer.h"
|
||||||
#include "led.h"
|
#include "led.h"
|
||||||
|
#include "servo.h"
|
||||||
#include "power_mgmt.h"
|
#include "power_mgmt.h"
|
||||||
#include "battery.h"
|
#include "battery.h"
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
@ -163,6 +164,9 @@ int main(void) {
|
|||||||
/* Init power management — STOP-mode sleep/wake, wake EXTIs configured */
|
/* Init power management — STOP-mode sleep/wake, wake EXTIs configured */
|
||||||
power_mgmt_init();
|
power_mgmt_init();
|
||||||
|
|
||||||
|
/* Init servo pan-tilt driver for camera head (TIM4 PWM on PB6/PB7, Issue #206) */
|
||||||
|
servo_init();
|
||||||
|
|
||||||
/* Init mode manager (RC/autonomous blend; CH6 mode switch) */
|
/* Init mode manager (RC/autonomous blend; CH6 mode switch) */
|
||||||
mode_manager_t mode;
|
mode_manager_t mode;
|
||||||
mode_manager_init(&mode);
|
mode_manager_init(&mode);
|
||||||
@ -218,6 +222,9 @@ int main(void) {
|
|||||||
/* Advance LED animation sequencer (non-blocking, call every tick) */
|
/* Advance LED animation sequencer (non-blocking, call every tick) */
|
||||||
led_tick(now);
|
led_tick(now);
|
||||||
|
|
||||||
|
/* Servo pan-tilt animation tick — updates smooth sweeps */
|
||||||
|
servo_tick(now);
|
||||||
|
|
||||||
/* Sleep LED: software PWM on LED1 (active-low PC15) driven by PM brightness.
|
/* Sleep LED: software PWM on LED1 (active-low PC15) driven by PM brightness.
|
||||||
* pm_pwm_phase rolls over each ms; brightness sets duty cycle 0-255. */
|
* pm_pwm_phase rolls over each ms; brightness sets duty cycle 0-255. */
|
||||||
pm_pwm_phase++;
|
pm_pwm_phase++;
|
||||||
|
|||||||
242
src/servo.c
Normal file
242
src/servo.c
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
#include "servo.h"
|
||||||
|
#include "config.h"
|
||||||
|
#include "stm32f7xx_hal.h"
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
/* ================================================================
|
||||||
|
* Servo PWM Protocol
|
||||||
|
* ================================================================
|
||||||
|
* TIM4 at 50 Hz (20 ms period)
|
||||||
|
* APB1 clock: 54 MHz
|
||||||
|
* Prescaler: 53 (54 MHz / 54 = 1 MHz)
|
||||||
|
* ARR: 19999 (1 MHz / 20000 = 50 Hz)
|
||||||
|
* CCR: 500-2500 (0.5-2.5 ms out of 20 ms)
|
||||||
|
*
|
||||||
|
* Servo pulse mapping:
|
||||||
|
* 500 µs → 0° (full left/down)
|
||||||
|
* 1500 µs → 90° (center)
|
||||||
|
* 2500 µs → 180° (full right/up)
|
||||||
|
*/
|
||||||
|
|
||||||
|
#define SERVO_PWM_FREQ 50u /* 50 Hz */
|
||||||
|
#define SERVO_PERIOD_MS 20u /* 20 ms = 1/50 Hz */
|
||||||
|
#define SERVO_CLOCK_HZ 1000000u /* 1 MHz timer clock */
|
||||||
|
#define SERVO_PRESCALER 53u /* APB1 54 MHz / 54 = 1 MHz */
|
||||||
|
#define SERVO_ARR 19999u /* 1 MHz / 20000 = 50 Hz */
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
uint16_t current_angle_deg[SERVO_COUNT];
|
||||||
|
uint16_t target_angle_deg[SERVO_COUNT];
|
||||||
|
uint16_t pulse_us[SERVO_COUNT];
|
||||||
|
|
||||||
|
/* Sweep state */
|
||||||
|
uint32_t sweep_start_ms[SERVO_COUNT];
|
||||||
|
uint32_t sweep_duration_ms[SERVO_COUNT];
|
||||||
|
uint16_t sweep_start_deg[SERVO_COUNT];
|
||||||
|
uint16_t sweep_end_deg[SERVO_COUNT];
|
||||||
|
bool is_sweeping[SERVO_COUNT];
|
||||||
|
} ServoState;
|
||||||
|
|
||||||
|
static ServoState s_servo = {0};
|
||||||
|
static TIM_HandleTypeDef s_tim_handle = {0};
|
||||||
|
|
||||||
|
/* ================================================================
|
||||||
|
* Helper functions
|
||||||
|
* ================================================================
|
||||||
|
*/
|
||||||
|
|
||||||
|
static uint16_t angle_to_pulse_us(uint16_t degrees)
|
||||||
|
{
|
||||||
|
/* Linear interpolation: 0° → 500µs, 180° → 2500µs */
|
||||||
|
if (degrees > 180) degrees = 180;
|
||||||
|
|
||||||
|
uint32_t pulse = SERVO_MIN_US + (uint32_t)degrees * (SERVO_MAX_US - SERVO_MIN_US) / 180;
|
||||||
|
return (uint16_t)pulse;
|
||||||
|
}
|
||||||
|
|
||||||
|
static uint16_t pulse_us_to_angle(uint16_t pulse_us)
|
||||||
|
{
|
||||||
|
/* Inverse mapping: 500µs → 0°, 2500µs → 180° */
|
||||||
|
if (pulse_us < SERVO_MIN_US) pulse_us = SERVO_MIN_US;
|
||||||
|
if (pulse_us > SERVO_MAX_US) pulse_us = SERVO_MAX_US;
|
||||||
|
|
||||||
|
uint32_t angle = (uint32_t)(pulse_us - SERVO_MIN_US) * 180 / (SERVO_MAX_US - SERVO_MIN_US);
|
||||||
|
return (uint16_t)angle;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void update_pwm(ServoChannel channel)
|
||||||
|
{
|
||||||
|
/* Convert pulse width (500-2500 µs) to CCR value */
|
||||||
|
/* At 1 MHz timer clock: 1 µs = 1 count */
|
||||||
|
uint32_t ccr_value = s_servo.pulse_us[channel];
|
||||||
|
|
||||||
|
if (channel == SERVO_PAN) {
|
||||||
|
__HAL_TIM_SET_COMPARE(&s_tim_handle, SERVO_PAN_CHANNEL, ccr_value);
|
||||||
|
} else {
|
||||||
|
__HAL_TIM_SET_COMPARE(&s_tim_handle, SERVO_TILT_CHANNEL, ccr_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ================================================================
|
||||||
|
* Public API
|
||||||
|
* ================================================================
|
||||||
|
*/
|
||||||
|
|
||||||
|
void servo_init(void)
|
||||||
|
{
|
||||||
|
/* Initialize state */
|
||||||
|
memset(&s_servo, 0, sizeof(s_servo));
|
||||||
|
|
||||||
|
/* Center both servos at 90° */
|
||||||
|
for (int i = 0; i < SERVO_COUNT; i++) {
|
||||||
|
s_servo.current_angle_deg[i] = 90;
|
||||||
|
s_servo.target_angle_deg[i] = 90;
|
||||||
|
s_servo.pulse_us[i] = SERVO_CENTER_US;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Configure GPIO PB6 (CH1) and PB7 (CH2) as TIM4 PWM */
|
||||||
|
__HAL_RCC_GPIOB_CLK_ENABLE();
|
||||||
|
|
||||||
|
GPIO_InitTypeDef gpio_init = {0};
|
||||||
|
gpio_init.Mode = GPIO_MODE_AF_PP;
|
||||||
|
gpio_init.Pull = GPIO_NOPULL;
|
||||||
|
gpio_init.Speed = GPIO_SPEED_FREQ_HIGH;
|
||||||
|
gpio_init.Alternate = SERVO_AF;
|
||||||
|
|
||||||
|
/* Configure PB6 (pan) */
|
||||||
|
gpio_init.Pin = SERVO_PAN_PIN;
|
||||||
|
HAL_GPIO_Init(SERVO_PAN_PORT, &gpio_init);
|
||||||
|
|
||||||
|
/* Configure PB7 (tilt) */
|
||||||
|
gpio_init.Pin = SERVO_TILT_PIN;
|
||||||
|
HAL_GPIO_Init(SERVO_TILT_PORT, &gpio_init);
|
||||||
|
|
||||||
|
/* Configure TIM4: 50 Hz PWM */
|
||||||
|
__HAL_RCC_TIM4_CLK_ENABLE();
|
||||||
|
|
||||||
|
s_tim_handle.Instance = SERVO_TIM;
|
||||||
|
s_tim_handle.Init.Prescaler = SERVO_PRESCALER;
|
||||||
|
s_tim_handle.Init.CounterMode = TIM_COUNTERMODE_UP;
|
||||||
|
s_tim_handle.Init.Period = SERVO_ARR;
|
||||||
|
s_tim_handle.Init.ClockDivision = TIM_CLOCKDIVISION_DIV1;
|
||||||
|
s_tim_handle.Init.RepetitionCounter = 0;
|
||||||
|
|
||||||
|
HAL_TIM_PWM_Init(&s_tim_handle);
|
||||||
|
|
||||||
|
/* Configure TIM4_CH1 (pan) for PWM */
|
||||||
|
TIM_OC_InitTypeDef oc_init = {0};
|
||||||
|
oc_init.OCMode = TIM_OCMODE_PWM1;
|
||||||
|
oc_init.Pulse = SERVO_CENTER_US;
|
||||||
|
oc_init.OCPolarity = TIM_OCPOLARITY_HIGH;
|
||||||
|
oc_init.OCFastMode = TIM_OCFAST_DISABLE;
|
||||||
|
|
||||||
|
HAL_TIM_PWM_ConfigChannel(&s_tim_handle, &oc_init, SERVO_PAN_CHANNEL);
|
||||||
|
HAL_TIM_PWM_Start(&s_tim_handle, SERVO_PAN_CHANNEL);
|
||||||
|
|
||||||
|
/* Configure TIM4_CH2 (tilt) for PWM */
|
||||||
|
oc_init.Pulse = SERVO_CENTER_US;
|
||||||
|
HAL_TIM_PWM_ConfigChannel(&s_tim_handle, &oc_init, SERVO_TILT_CHANNEL);
|
||||||
|
HAL_TIM_PWM_Start(&s_tim_handle, SERVO_TILT_CHANNEL);
|
||||||
|
}
|
||||||
|
|
||||||
|
void servo_set_angle(ServoChannel channel, uint16_t degrees)
|
||||||
|
{
|
||||||
|
if (channel >= SERVO_COUNT) return;
|
||||||
|
if (degrees > 180) degrees = 180;
|
||||||
|
|
||||||
|
s_servo.current_angle_deg[channel] = degrees;
|
||||||
|
s_servo.target_angle_deg[channel] = degrees;
|
||||||
|
s_servo.pulse_us[channel] = angle_to_pulse_us(degrees);
|
||||||
|
|
||||||
|
/* Stop any sweep in progress */
|
||||||
|
s_servo.is_sweeping[channel] = false;
|
||||||
|
|
||||||
|
/* Update PWM immediately */
|
||||||
|
update_pwm(channel);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint16_t servo_get_angle(ServoChannel channel)
|
||||||
|
{
|
||||||
|
if (channel >= SERVO_COUNT) return 0;
|
||||||
|
return s_servo.current_angle_deg[channel];
|
||||||
|
}
|
||||||
|
|
||||||
|
void servo_set_pulse_us(ServoChannel channel, uint16_t pulse_us)
|
||||||
|
{
|
||||||
|
if (channel >= SERVO_COUNT) return;
|
||||||
|
if (pulse_us < SERVO_MIN_US) pulse_us = SERVO_MIN_US;
|
||||||
|
if (pulse_us > SERVO_MAX_US) pulse_us = SERVO_MAX_US;
|
||||||
|
|
||||||
|
s_servo.pulse_us[channel] = pulse_us;
|
||||||
|
s_servo.current_angle_deg[channel] = pulse_us_to_angle(pulse_us);
|
||||||
|
s_servo.target_angle_deg[channel] = s_servo.current_angle_deg[channel];
|
||||||
|
|
||||||
|
/* Stop any sweep in progress */
|
||||||
|
s_servo.is_sweeping[channel] = false;
|
||||||
|
|
||||||
|
/* Update PWM immediately */
|
||||||
|
update_pwm(channel);
|
||||||
|
}
|
||||||
|
|
||||||
|
void servo_sweep(ServoChannel channel, uint16_t start_deg, uint16_t end_deg, uint32_t duration_ms)
|
||||||
|
{
|
||||||
|
if (channel >= SERVO_COUNT) return;
|
||||||
|
if (duration_ms == 0) return;
|
||||||
|
if (start_deg > 180) start_deg = 180;
|
||||||
|
if (end_deg > 180) end_deg = 180;
|
||||||
|
|
||||||
|
s_servo.sweep_start_deg[channel] = start_deg;
|
||||||
|
s_servo.sweep_end_deg[channel] = end_deg;
|
||||||
|
s_servo.sweep_duration_ms[channel] = duration_ms;
|
||||||
|
s_servo.sweep_start_ms[channel] = 0; /* Will be set on first tick */
|
||||||
|
s_servo.is_sweeping[channel] = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void servo_tick(uint32_t now_ms)
|
||||||
|
{
|
||||||
|
for (int ch = 0; ch < SERVO_COUNT; ch++) {
|
||||||
|
if (!s_servo.is_sweeping[ch]) continue;
|
||||||
|
|
||||||
|
/* Initialize start time on first call */
|
||||||
|
if (s_servo.sweep_start_ms[ch] == 0) {
|
||||||
|
s_servo.sweep_start_ms[ch] = now_ms;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t elapsed = now_ms - s_servo.sweep_start_ms[ch];
|
||||||
|
uint32_t duration = s_servo.sweep_duration_ms[ch];
|
||||||
|
|
||||||
|
if (elapsed >= duration) {
|
||||||
|
/* Sweep complete */
|
||||||
|
s_servo.is_sweeping[ch] = false;
|
||||||
|
s_servo.current_angle_deg[ch] = s_servo.sweep_end_deg[ch];
|
||||||
|
s_servo.pulse_us[ch] = angle_to_pulse_us(s_servo.sweep_end_deg[ch]);
|
||||||
|
} else {
|
||||||
|
/* Linear interpolation */
|
||||||
|
int16_t start = (int16_t)s_servo.sweep_start_deg[ch];
|
||||||
|
int16_t end = (int16_t)s_servo.sweep_end_deg[ch];
|
||||||
|
int32_t delta = end - start;
|
||||||
|
|
||||||
|
/* angle = start + (delta * elapsed / duration) */
|
||||||
|
int32_t angle_i32 = start + (delta * (int32_t)elapsed / (int32_t)duration);
|
||||||
|
s_servo.current_angle_deg[ch] = (uint16_t)angle_i32;
|
||||||
|
s_servo.pulse_us[ch] = angle_to_pulse_us(s_servo.current_angle_deg[ch]);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Update PWM */
|
||||||
|
update_pwm((ServoChannel)ch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool servo_is_sweeping(void)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < SERVO_COUNT; i++) {
|
||||||
|
if (s_servo.is_sweeping[i]) return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void servo_stop_sweep(ServoChannel channel)
|
||||||
|
{
|
||||||
|
if (channel >= SERVO_COUNT) return;
|
||||||
|
s_servo.is_sweeping[channel] = false;
|
||||||
|
}
|
||||||
345
test/test_servo.py
Normal file
345
test/test_servo.py
Normal file
@ -0,0 +1,345 @@
|
|||||||
|
"""
|
||||||
|
test_servo.py — Pan-tilt servo driver tests (Issue #206)
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- PWM frequency: 50 Hz (20 ms period)
|
||||||
|
- Pulse width: 500-2500 µs for 0-180°
|
||||||
|
- Angle conversion: linear mapping
|
||||||
|
- Smooth sweeping: animation timing and interpolation
|
||||||
|
- Multi-servo coordination (pan + tilt independently)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# ── Constants ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
SERVO_MIN_US = 500
|
||||||
|
SERVO_MAX_US = 2500
|
||||||
|
SERVO_CENTER_US = 1500
|
||||||
|
|
||||||
|
PWM_FREQ_HZ = 50
|
||||||
|
PERIOD_MS = 20
|
||||||
|
|
||||||
|
NUM_SERVOS = 2
|
||||||
|
SERVO_PAN = 0
|
||||||
|
SERVO_TILT = 1
|
||||||
|
|
||||||
|
|
||||||
|
# ── Servo Simulator ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class ServoSimulator:
|
||||||
|
def __init__(self):
|
||||||
|
self.current_angle_deg = [90, 90] # Both centered
|
||||||
|
self.pulse_us = [SERVO_CENTER_US, SERVO_CENTER_US]
|
||||||
|
self.is_sweeping = [False, False]
|
||||||
|
self.sweep_start_deg = [0, 0]
|
||||||
|
self.sweep_end_deg = [0, 0]
|
||||||
|
self.sweep_duration_ms = [0, 0]
|
||||||
|
self.sweep_start_ms = [None, None]
|
||||||
|
|
||||||
|
def angle_to_pulse(self, degrees):
|
||||||
|
"""Convert angle (0-180) to pulse width (500-2500 µs)."""
|
||||||
|
if degrees < 0:
|
||||||
|
degrees = 0
|
||||||
|
if degrees > 180:
|
||||||
|
degrees = 180
|
||||||
|
return SERVO_MIN_US + (degrees * (SERVO_MAX_US - SERVO_MIN_US)) // 180
|
||||||
|
|
||||||
|
def pulse_to_angle(self, pulse_us):
|
||||||
|
"""Convert pulse width to angle."""
|
||||||
|
if pulse_us < SERVO_MIN_US:
|
||||||
|
pulse_us = SERVO_MIN_US
|
||||||
|
if pulse_us > SERVO_MAX_US:
|
||||||
|
pulse_us = SERVO_MAX_US
|
||||||
|
return (pulse_us - SERVO_MIN_US) * 180 // (SERVO_MAX_US - SERVO_MIN_US)
|
||||||
|
|
||||||
|
def set_angle(self, channel, degrees):
|
||||||
|
"""Immediately set servo angle."""
|
||||||
|
self.current_angle_deg[channel] = min(180, max(0, degrees))
|
||||||
|
self.pulse_us[channel] = self.angle_to_pulse(self.current_angle_deg[channel])
|
||||||
|
self.is_sweeping[channel] = False
|
||||||
|
|
||||||
|
def get_angle(self, channel):
|
||||||
|
"""Get current servo angle."""
|
||||||
|
return self.current_angle_deg[channel]
|
||||||
|
|
||||||
|
def set_pulse_us(self, channel, pulse_us):
|
||||||
|
"""Set servo by pulse width."""
|
||||||
|
if pulse_us < SERVO_MIN_US:
|
||||||
|
pulse_us = SERVO_MIN_US
|
||||||
|
if pulse_us > SERVO_MAX_US:
|
||||||
|
pulse_us = SERVO_MAX_US
|
||||||
|
self.pulse_us[channel] = pulse_us
|
||||||
|
self.current_angle_deg[channel] = self.pulse_to_angle(pulse_us)
|
||||||
|
self.is_sweeping[channel] = False
|
||||||
|
|
||||||
|
def sweep(self, channel, start_deg, end_deg, duration_ms):
|
||||||
|
"""Start smooth sweep."""
|
||||||
|
self.sweep_start_deg[channel] = start_deg
|
||||||
|
self.sweep_end_deg[channel] = end_deg
|
||||||
|
self.sweep_duration_ms[channel] = duration_ms
|
||||||
|
self.sweep_start_ms[channel] = None
|
||||||
|
self.is_sweeping[channel] = True
|
||||||
|
|
||||||
|
def tick(self, now_ms):
|
||||||
|
"""Update sweep animations."""
|
||||||
|
for ch in range(NUM_SERVOS):
|
||||||
|
if not self.is_sweeping[ch]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Initialize start time on first call
|
||||||
|
if self.sweep_start_ms[ch] is None:
|
||||||
|
self.sweep_start_ms[ch] = now_ms
|
||||||
|
|
||||||
|
elapsed = now_ms - self.sweep_start_ms[ch]
|
||||||
|
duration = self.sweep_duration_ms[ch]
|
||||||
|
|
||||||
|
if elapsed >= duration:
|
||||||
|
# Sweep complete
|
||||||
|
self.is_sweeping[ch] = False
|
||||||
|
self.current_angle_deg[ch] = self.sweep_end_deg[ch]
|
||||||
|
self.pulse_us[ch] = self.angle_to_pulse(self.sweep_end_deg[ch])
|
||||||
|
else:
|
||||||
|
# Linear interpolation
|
||||||
|
start = self.sweep_start_deg[ch]
|
||||||
|
end = self.sweep_end_deg[ch]
|
||||||
|
delta = end - start
|
||||||
|
angle = start + (delta * elapsed) // duration
|
||||||
|
self.current_angle_deg[ch] = angle
|
||||||
|
self.pulse_us[ch] = self.angle_to_pulse(angle)
|
||||||
|
|
||||||
|
def is_sweeping_any(self):
|
||||||
|
"""Check if any servo is sweeping."""
|
||||||
|
return any(self.is_sweeping)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_initialization():
|
||||||
|
"""Servos should initialize centered at 90°."""
|
||||||
|
sim = ServoSimulator()
|
||||||
|
assert sim.get_angle(SERVO_PAN) == 90
|
||||||
|
assert sim.get_angle(SERVO_TILT) == 90
|
||||||
|
assert sim.pulse_us[SERVO_PAN] == SERVO_CENTER_US
|
||||||
|
assert sim.pulse_us[SERVO_TILT] == SERVO_CENTER_US
|
||||||
|
|
||||||
|
|
||||||
|
def test_angle_to_pulse_conversion():
|
||||||
|
"""Angle to pulse conversion should be linear."""
|
||||||
|
sim = ServoSimulator()
|
||||||
|
|
||||||
|
assert sim.angle_to_pulse(0) == SERVO_MIN_US # 500 µs
|
||||||
|
assert sim.angle_to_pulse(90) == SERVO_CENTER_US # 1500 µs
|
||||||
|
assert sim.angle_to_pulse(180) == SERVO_MAX_US # 2500 µs
|
||||||
|
|
||||||
|
# Intermediate angles
|
||||||
|
assert sim.angle_to_pulse(45) == 1000 # 0.5 way: 500 + 500 = 1000
|
||||||
|
assert sim.angle_to_pulse(135) == 2000 # 0.75 way: 500 + 1500 = 2000
|
||||||
|
|
||||||
|
|
||||||
|
def test_pulse_to_angle_conversion():
|
||||||
|
"""Pulse to angle conversion should invert angle_to_pulse."""
|
||||||
|
sim = ServoSimulator()
|
||||||
|
|
||||||
|
assert sim.pulse_to_angle(SERVO_MIN_US) == 0
|
||||||
|
assert sim.pulse_to_angle(SERVO_CENTER_US) == 90
|
||||||
|
assert sim.pulse_to_angle(SERVO_MAX_US) == 180
|
||||||
|
|
||||||
|
# Intermediate pulses
|
||||||
|
assert sim.pulse_to_angle(1000) == 45
|
||||||
|
assert sim.pulse_to_angle(2000) == 135
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_angle_pan():
|
||||||
|
"""Pan servo should update angle immediately."""
|
||||||
|
sim = ServoSimulator()
|
||||||
|
|
||||||
|
sim.set_angle(SERVO_PAN, 0)
|
||||||
|
assert sim.get_angle(SERVO_PAN) == 0
|
||||||
|
assert sim.pulse_us[SERVO_PAN] == SERVO_MIN_US
|
||||||
|
|
||||||
|
sim.set_angle(SERVO_PAN, 90)
|
||||||
|
assert sim.get_angle(SERVO_PAN) == 90
|
||||||
|
assert sim.pulse_us[SERVO_PAN] == SERVO_CENTER_US
|
||||||
|
|
||||||
|
sim.set_angle(SERVO_PAN, 180)
|
||||||
|
assert sim.get_angle(SERVO_PAN) == 180
|
||||||
|
assert sim.pulse_us[SERVO_PAN] == SERVO_MAX_US
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_angle_tilt():
|
||||||
|
"""Tilt servo should work independently."""
|
||||||
|
sim = ServoSimulator()
|
||||||
|
|
||||||
|
sim.set_angle(SERVO_TILT, 45)
|
||||||
|
assert sim.get_angle(SERVO_TILT) == 45
|
||||||
|
assert sim.get_angle(SERVO_PAN) == 90 # Pan unchanged
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_pulse_us():
|
||||||
|
"""Pulse width setter should update angle correctly."""
|
||||||
|
sim = ServoSimulator()
|
||||||
|
|
||||||
|
sim.set_pulse_us(SERVO_PAN, SERVO_MIN_US)
|
||||||
|
assert sim.get_angle(SERVO_PAN) == 0
|
||||||
|
|
||||||
|
sim.set_pulse_us(SERVO_PAN, SERVO_CENTER_US)
|
||||||
|
assert sim.get_angle(SERVO_PAN) == 90
|
||||||
|
|
||||||
|
sim.set_pulse_us(SERVO_PAN, SERVO_MAX_US)
|
||||||
|
assert sim.get_angle(SERVO_PAN) == 180
|
||||||
|
|
||||||
|
|
||||||
|
def test_sweep_timing():
|
||||||
|
"""Sweep should complete in specified duration."""
|
||||||
|
sim = ServoSimulator()
|
||||||
|
|
||||||
|
# Pan from 0° to 180° over 2 seconds
|
||||||
|
sim.sweep(SERVO_PAN, 0, 180, 2000)
|
||||||
|
|
||||||
|
# Initial tick
|
||||||
|
sim.tick(0)
|
||||||
|
assert sim.get_angle(SERVO_PAN) == 0
|
||||||
|
|
||||||
|
# Halfway through sweep (t=1000ms)
|
||||||
|
sim.tick(1000)
|
||||||
|
assert sim.get_angle(SERVO_PAN) == 90 # Linear interpolation
|
||||||
|
|
||||||
|
# End of sweep (t=2000ms)
|
||||||
|
sim.tick(2000)
|
||||||
|
assert sim.get_angle(SERVO_PAN) == 180
|
||||||
|
assert not sim.is_sweeping[SERVO_PAN]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sweep_interpolation():
|
||||||
|
"""Sweep should interpolate smoothly."""
|
||||||
|
sim = ServoSimulator()
|
||||||
|
|
||||||
|
# Sweep from 0° to 180° in 1000ms
|
||||||
|
sim.sweep(SERVO_PAN, 0, 180, 1000)
|
||||||
|
|
||||||
|
angles = []
|
||||||
|
for t in range(0, 1001, 100):
|
||||||
|
sim.tick(t)
|
||||||
|
angles.append(sim.get_angle(SERVO_PAN))
|
||||||
|
|
||||||
|
# Expected: [0, 18, 36, 54, 72, 90, 108, 126, 144, 162, 180]
|
||||||
|
expected = [i * 18 for i in range(11)]
|
||||||
|
assert angles == expected, f"Got {angles}, expected {expected}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_reverse_sweep():
|
||||||
|
"""Sweep from higher angle to lower angle."""
|
||||||
|
sim = ServoSimulator()
|
||||||
|
|
||||||
|
sim.sweep(SERVO_TILT, 180, 0, 1000)
|
||||||
|
|
||||||
|
sim.tick(0)
|
||||||
|
assert sim.get_angle(SERVO_TILT) == 180
|
||||||
|
|
||||||
|
sim.tick(500)
|
||||||
|
assert sim.get_angle(SERVO_TILT) == 90
|
||||||
|
|
||||||
|
sim.tick(1000)
|
||||||
|
assert sim.get_angle(SERVO_TILT) == 0
|
||||||
|
assert not sim.is_sweeping[SERVO_TILT]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sweep_stops_on_immediate_set():
|
||||||
|
"""Setting angle immediately should stop sweep."""
|
||||||
|
sim = ServoSimulator()
|
||||||
|
|
||||||
|
sim.sweep(SERVO_PAN, 0, 180, 2000)
|
||||||
|
sim.tick(500)
|
||||||
|
|
||||||
|
# Stop sweep by setting angle
|
||||||
|
sim.set_angle(SERVO_PAN, 45)
|
||||||
|
assert not sim.is_sweeping[SERVO_PAN]
|
||||||
|
assert sim.get_angle(SERVO_PAN) == 45
|
||||||
|
|
||||||
|
|
||||||
|
def test_independent_servos():
|
||||||
|
"""Pan and tilt servos should sweep independently."""
|
||||||
|
sim = ServoSimulator()
|
||||||
|
|
||||||
|
sim.sweep(SERVO_PAN, 0, 180, 1000)
|
||||||
|
sim.sweep(SERVO_TILT, 180, 0, 2000)
|
||||||
|
|
||||||
|
# Initialize sweep timing
|
||||||
|
sim.tick(0)
|
||||||
|
|
||||||
|
# After 1 second
|
||||||
|
sim.tick(1000)
|
||||||
|
assert sim.get_angle(SERVO_PAN) == 180
|
||||||
|
assert not sim.is_sweeping[SERVO_PAN]
|
||||||
|
assert sim.get_angle(SERVO_TILT) == 90 # Halfway through
|
||||||
|
assert sim.is_sweeping[SERVO_TILT]
|
||||||
|
|
||||||
|
# After 2 seconds
|
||||||
|
sim.tick(2000)
|
||||||
|
assert not sim.is_sweeping[SERVO_PAN]
|
||||||
|
assert sim.get_angle(SERVO_TILT) == 0
|
||||||
|
assert not sim.is_sweeping[SERVO_TILT]
|
||||||
|
assert not sim.is_sweeping_any()
|
||||||
|
|
||||||
|
|
||||||
|
def test_fast_sweep():
|
||||||
|
"""Very short sweep should work."""
|
||||||
|
sim = ServoSimulator()
|
||||||
|
|
||||||
|
sim.sweep(SERVO_PAN, 45, 135, 100) # 90° in 100ms
|
||||||
|
|
||||||
|
sim.tick(0)
|
||||||
|
assert sim.get_angle(SERVO_PAN) == 45
|
||||||
|
|
||||||
|
sim.tick(50)
|
||||||
|
assert sim.get_angle(SERVO_PAN) == 90
|
||||||
|
|
||||||
|
sim.tick(100)
|
||||||
|
assert sim.get_angle(SERVO_PAN) == 135
|
||||||
|
assert not sim.is_sweeping[SERVO_PAN]
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_sweeps():
|
||||||
|
"""Multiple sequential sweeps should work."""
|
||||||
|
sim = ServoSimulator()
|
||||||
|
|
||||||
|
# First sweep (0° to 90° in 500ms)
|
||||||
|
sim.sweep(SERVO_PAN, 0, 90, 500)
|
||||||
|
sim.tick(0)
|
||||||
|
sim.tick(500)
|
||||||
|
assert sim.get_angle(SERVO_PAN) == 90
|
||||||
|
assert not sim.is_sweeping[SERVO_PAN]
|
||||||
|
|
||||||
|
# Second sweep (90° to 0° in 500ms, starting at t=500)
|
||||||
|
sim.sweep(SERVO_PAN, 90, 0, 500)
|
||||||
|
sim.tick(500) # Initialize second sweep
|
||||||
|
sim.tick(1000) # After 500ms of second sweep
|
||||||
|
assert sim.get_angle(SERVO_PAN) == 0
|
||||||
|
assert not sim.is_sweeping[SERVO_PAN]
|
||||||
|
|
||||||
|
|
||||||
|
def test_boundary_angles():
|
||||||
|
"""Angles > 180° should clamp to 180°."""
|
||||||
|
sim = ServoSimulator()
|
||||||
|
|
||||||
|
sim.set_angle(SERVO_PAN, 200)
|
||||||
|
assert sim.get_angle(SERVO_PAN) == 180
|
||||||
|
|
||||||
|
sim.set_angle(SERVO_PAN, -10)
|
||||||
|
assert sim.get_angle(SERVO_PAN) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_pulse_clamping():
|
||||||
|
"""Pulse widths outside 500-2500 µs should clamp."""
|
||||||
|
sim = ServoSimulator()
|
||||||
|
|
||||||
|
sim.set_pulse_us(SERVO_PAN, 100) # Too low
|
||||||
|
assert sim.pulse_us[SERVO_PAN] == SERVO_MIN_US
|
||||||
|
|
||||||
|
sim.set_pulse_us(SERVO_PAN, 3000) # Too high
|
||||||
|
assert sim.pulse_us[SERVO_PAN] == SERVO_MAX_US
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main([__file__, '-v'])
|
||||||
Loading…
x
Reference in New Issue
Block a user