feat: ROS2 gimbal control node (Issue #548) #558

Merged
sl-jetson merged 1 commits from sl-jetson/issue-548-gimbal-ros2 into main 2026-03-14 11:39:59 -04:00
10 changed files with 1185 additions and 0 deletions

View File

@ -0,0 +1,29 @@
gimbal_node:
ros__parameters:
# Serial port connecting to STM32 over JLINK protocol
serial_port: "/dev/ttyTHS1"
baud_rate: 921600
# Soft angle limits (degrees, ± from center)
pan_limit_deg: 150.0
tilt_limit_deg: 45.0
# Home position (degrees from center)
home_pan_deg: 0.0
home_tilt_deg: 0.0
# Motion profile
max_speed_deg_s: 90.0 # maximum pan/tilt speed (°/s)
accel_deg_s2: 180.0 # trapezoidal acceleration (°/s²)
# Update rates
update_rate_hz: 20.0 # motion tick + JLINK TX rate
state_publish_hz: 10.0 # /saltybot/gimbal/state publish rate
# Serial reconnect
reconnect_delay_s: 2.0
# D435i camera parameters (for look_at 3D→pan/tilt projection)
camera_focal_length_px: 600.0
image_width_px: 848
image_height_px: 480

View File

@ -0,0 +1,50 @@
"""gimbal.launch.py — Launch saltybot_gimbal node (Issue #548)."""
import os
from ament_index_python.packages import get_package_share_directory
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
def generate_launch_description() -> LaunchDescription:
pkg = get_package_share_directory("saltybot_gimbal")
serial_port_arg = DeclareLaunchArgument(
"serial_port",
default_value="/dev/ttyTHS1",
description="JLINK serial port to STM32",
)
pan_limit_arg = DeclareLaunchArgument(
"pan_limit_deg",
default_value="150.0",
description="Pan soft limit ± degrees",
)
tilt_limit_arg = DeclareLaunchArgument(
"tilt_limit_deg",
default_value="45.0",
description="Tilt soft limit ± degrees",
)
gimbal_node = Node(
package="saltybot_gimbal",
executable="gimbal_node",
name="gimbal_node",
output="screen",
parameters=[
os.path.join(pkg, "config", "gimbal_params.yaml"),
{
"serial_port": LaunchConfiguration("serial_port"),
"pan_limit_deg": LaunchConfiguration("pan_limit_deg"),
"tilt_limit_deg": LaunchConfiguration("tilt_limit_deg"),
},
],
)
return LaunchDescription([
serial_port_arg,
pan_limit_arg,
tilt_limit_arg,
gimbal_node,
])

View File

@ -0,0 +1,21 @@
<?xml version="1.0"?>
<package format="3">
<name>saltybot_gimbal</name>
<version>1.0.0</version>
<description>
ROS2 gimbal control node: pan/tilt camera head via JLINK serial to STM32.
Smooth trapezoidal motion profiles, configurable limits, look_at 3D projection.
Issue #548.
</description>
<maintainer email="sl-jetson@saltylab.local">sl-jetson</maintainer>
<license>Apache-2.0</license>
<buildtool_depend>ament_python</buildtool_depend>
<depend>rclpy</depend>
<depend>geometry_msgs</depend>
<depend>std_msgs</depend>
<depend>std_srvs</depend>
<test_depend>pytest</test_depend>
</package>

View File

@ -0,0 +1,497 @@
#!/usr/bin/env python3
"""gimbal_node.py — ROS2 gimbal control node for SaltyBot pan/tilt camera head (Issue #548).
Controls pan/tilt gimbal via JLINK binary protocol over serial to STM32.
Implements smooth trapezoidal motion profiles with configurable axis limits.
Subscribed topics:
/saltybot/gimbal/cmd (geometry_msgs/Vector3) x=pan_deg, y=tilt_deg, z=speed_deg_s
/saltybot/gimbal/look_at_target (geometry_msgs/PointStamped) 3D look-at target (camera frame)
Published topics:
/saltybot/gimbal/state (std_msgs/String) JSON: pan_deg, tilt_deg, moving, fault,
/saltybot/gimbal/cmd_echo (geometry_msgs/Vector3) current target cmd (monitoring)
Services:
/saltybot/gimbal/home (std_srvs/Trigger) command gimbal to home position
/saltybot/gimbal/look_at (std_srvs/Trigger) execute look-at to last look_at_target
Parameters (config/gimbal_params.yaml):
serial_port /dev/ttyTHS1 JLINK serial device
baud_rate 921600
pan_limit_deg 150.0 soft limit ± degrees
tilt_limit_deg 45.0 soft limit ± degrees
home_pan_deg 0.0 home pan position
home_tilt_deg 0.0 home tilt position
max_speed_deg_s 90.0 default motion speed
accel_deg_s2 180.0 trapezoidal acceleration
update_rate_hz 20.0 motion profile update rate
state_publish_hz 10.0 /saltybot/gimbal/state publish rate
reconnect_delay_s 2.0 serial reconnect interval
camera_focal_length_px 600.0 D435i focal length (for look_at projection)
image_width_px 848
image_height_px 480
"""
from __future__ import annotations
import json
import math
import threading
import time
from typing import Optional
import rclpy
from rclpy.node import Node
from rclpy.qos import HistoryPolicy, QoSProfile, ReliabilityPolicy
import serial
from geometry_msgs.msg import PointStamped, Vector3
from std_msgs.msg import String
from std_srvs.srv import Trigger
from .jlink_gimbal import (
encode_gimbal_pos,
encode_gimbal_home,
decode_gimbal_state,
TLM_GIMBAL_STATE,
GimbalStateFrame,
STX, ETX,
_crc16_ccitt,
)
def _clamp(v: float, lo: float, hi: float) -> float:
return max(lo, min(hi, v))
class MotionAxis:
"""Single-axis trapezoidal velocity profile.
Tracks current position and velocity; each tick() call advances the
axis toward the target using bounded acceleration and velocity.
"""
def __init__(self, initial_pos: float, max_speed: float, accel: float) -> None:
self.pos = initial_pos
self.vel = 0.0
self.target = initial_pos
self.max_speed = max_speed
self.accel = accel
def set_target(self, target: float, speed: Optional[float] = None) -> None:
self.target = target
if speed is not None:
self.max_speed = max(1.0, speed)
def tick(self, dt: float) -> float:
"""Advance one timestep. Returns new position."""
error = self.target - self.pos
if abs(error) < 0.01:
self.pos = self.target
self.vel = 0.0
return self.pos
# Deceleration distance at current speed: v²/(2a)
decel_dist = (self.vel ** 2) / (2.0 * self.accel) if self.accel > 0 else 0.0
direction = 1.0 if error > 0 else -1.0
if abs(error) <= decel_dist + 0.01:
# Deceleration phase
desired_vel = direction * math.sqrt(max(0.0, 2.0 * self.accel * abs(error)))
else:
# Acceleration/cruise phase
desired_vel = direction * self.max_speed
# Apply acceleration limit
dv = desired_vel - self.vel
dv = _clamp(dv, -self.accel * dt, self.accel * dt)
self.vel += dv
self.vel = _clamp(self.vel, -self.max_speed, self.max_speed)
self.pos += self.vel * dt
# Don't overshoot
if direction > 0 and self.pos > self.target:
self.pos = self.target
self.vel = 0.0
elif direction < 0 and self.pos < self.target:
self.pos = self.target
self.vel = 0.0
return self.pos
@property
def is_moving(self) -> bool:
return abs(self.target - self.pos) > 0.05 or abs(self.vel) > 0.05
class JLinkGimbalSerial:
"""Serial connection manager for JLINK gimbal frames.
Handles connect/disconnect/reconnect and provides send() and
read_state() methods. Thread-safe via internal lock.
"""
_PARSER_WAIT_STX = 0
_PARSER_WAIT_CMD = 1
_PARSER_WAIT_LEN = 2
_PARSER_PAYLOAD = 3
_PARSER_CRC_HI = 4
_PARSER_CRC_LO = 5
_PARSER_WAIT_ETX = 6
def __init__(self, port: str, baud: int) -> None:
self._port = port
self._baud = baud
self._ser: Optional[serial.Serial] = None
self._lock = threading.Lock()
self._parser_state = self._PARSER_WAIT_STX
self._parser_cmd = 0
self._parser_len = 0
self._parser_payload = bytearray()
self._parser_crc_rcvd = 0
def connect(self) -> bool:
try:
self._ser = serial.Serial(
port=self._port,
baudrate=self._baud,
timeout=0.05,
bytesize=8,
parity=serial.PARITY_NONE,
stopbits=1,
)
self._reset_parser()
return True
except Exception as e:
print(f"[gimbal] serial connect failed ({self._port}): {e}")
self._ser = None
return False
def disconnect(self) -> None:
with self._lock:
if self._ser:
self._ser.close()
self._ser = None
def is_connected(self) -> bool:
return self._ser is not None and self._ser.is_open
def send(self, frame: bytes) -> bool:
with self._lock:
if not self.is_connected():
return False
try:
self._ser.write(frame)
return True
except Exception as e:
print(f"[gimbal] serial write error: {e}")
self._ser = None
return False
def read_pending(self) -> Optional[GimbalStateFrame]:
"""Drain RX buffer; return GimbalStateFrame if a complete telemetry
frame was received, or None if no complete frame is available."""
with self._lock:
if not self.is_connected():
return None
try:
raw = self._ser.read(self._ser.in_waiting or 1)
except Exception:
self._ser = None
return None
result = None
for b in raw:
frame = self._feed(b)
if frame is not None:
result = frame # keep last complete frame
return result
def _reset_parser(self) -> None:
self._parser_state = self._PARSER_WAIT_STX
self._parser_payload = bytearray()
def _feed(self, byte: int) -> Optional[GimbalStateFrame]:
s = self._parser_state
if s == self._PARSER_WAIT_STX:
if byte == STX:
self._parser_state = self._PARSER_WAIT_CMD
return None
if s == self._PARSER_WAIT_CMD:
self._parser_cmd = byte
self._parser_state = self._PARSER_WAIT_LEN
return None
if s == self._PARSER_WAIT_LEN:
self._parser_len = byte
self._parser_payload = bytearray()
if byte > 64:
self._reset_parser()
return None
self._parser_state = self._PARSER_PAYLOAD if byte > 0 else self._PARSER_CRC_HI
return None
if s == self._PARSER_PAYLOAD:
self._parser_payload.append(byte)
if len(self._parser_payload) == self._parser_len:
self._parser_state = self._PARSER_CRC_HI
return None
if s == self._PARSER_CRC_HI:
self._parser_crc_rcvd = byte << 8
self._parser_state = self._PARSER_CRC_LO
return None
if s == self._PARSER_CRC_LO:
self._parser_crc_rcvd |= byte
self._parser_state = self._PARSER_WAIT_ETX
return None
if s == self._PARSER_WAIT_ETX:
self._reset_parser()
if byte != ETX:
return None
crc_data = bytes([self._parser_cmd, self._parser_len]) + self._parser_payload
if _crc16_ccitt(crc_data) != self._parser_crc_rcvd:
return None
if self._parser_cmd == TLM_GIMBAL_STATE:
return decode_gimbal_state(bytes(self._parser_payload))
return None
self._reset_parser()
return None
class GimbalNode(Node):
"""ROS2 gimbal control node: smooth motion profiles + JLINK serial bridge."""
def __init__(self) -> None:
super().__init__("gimbal_node")
# ── Parameters ─────────────────────────────────────────────────────
self.declare_parameter("serial_port", "/dev/ttyTHS1")
self.declare_parameter("baud_rate", 921600)
self.declare_parameter("pan_limit_deg", 150.0)
self.declare_parameter("tilt_limit_deg", 45.0)
self.declare_parameter("home_pan_deg", 0.0)
self.declare_parameter("home_tilt_deg", 0.0)
self.declare_parameter("max_speed_deg_s", 90.0)
self.declare_parameter("accel_deg_s2", 180.0)
self.declare_parameter("update_rate_hz", 20.0)
self.declare_parameter("state_publish_hz", 10.0)
self.declare_parameter("reconnect_delay_s", 2.0)
self.declare_parameter("camera_focal_length_px", 600.0)
self.declare_parameter("image_width_px", 848)
self.declare_parameter("image_height_px", 480)
self._port = self.get_parameter("serial_port").value
self._baud = self.get_parameter("baud_rate").value
self._pan_limit = self.get_parameter("pan_limit_deg").value
self._tilt_limit = self.get_parameter("tilt_limit_deg").value
self._home_pan = self.get_parameter("home_pan_deg").value
self._home_tilt = self.get_parameter("home_tilt_deg").value
self._max_speed = self.get_parameter("max_speed_deg_s").value
self._accel = self.get_parameter("accel_deg_s2").value
update_hz = self.get_parameter("update_rate_hz").value
state_hz = self.get_parameter("state_publish_hz").value
self._reconnect_del = self.get_parameter("reconnect_delay_s").value
self._focal_px = self.get_parameter("camera_focal_length_px").value
self._img_w = self.get_parameter("image_width_px").value
self._img_h = self.get_parameter("image_height_px").value
# ── State ──────────────────────────────────────────────────────────
self._pan_axis = MotionAxis(self._home_pan, self._max_speed, self._accel)
self._tilt_axis = MotionAxis(self._home_tilt, self._max_speed, self._accel)
self._hw_state: Optional[GimbalStateFrame] = None
self._look_at_target: Optional[PointStamped] = None
self._last_cmd_time = time.monotonic()
self._reconnect_ts = 0.0
self._state_lock = threading.Lock()
# ── Serial ─────────────────────────────────────────────────────────
self._serial = JLinkGimbalSerial(self._port, self._baud)
if not self._serial.connect():
self.get_logger().warn(
f"[gimbal] serial not available on {self._port} — will retry"
)
# ── Subscribers ────────────────────────────────────────────────────
qos = QoSProfile(
history=HistoryPolicy.KEEP_LAST,
depth=5,
reliability=ReliabilityPolicy.BEST_EFFORT,
)
self.create_subscription(Vector3, "/saltybot/gimbal/cmd",
self._on_cmd, qos)
self.create_subscription(PointStamped, "/saltybot/gimbal/look_at_target",
self._on_look_at_target, 10)
# ── Publishers ─────────────────────────────────────────────────────
self._pub_state = self.create_publisher(String, "/saltybot/gimbal/state", 10)
self._pub_cmd_echo = self.create_publisher(Vector3, "/saltybot/gimbal/cmd_echo", 10)
# ── Services ───────────────────────────────────────────────────────
self.create_service(Trigger, "/saltybot/gimbal/home", self._svc_home)
self.create_service(Trigger, "/saltybot/gimbal/look_at", self._svc_look_at)
# ── Timers ─────────────────────────────────────────────────────────
self.create_timer(1.0 / update_hz, self._motion_tick)
self.create_timer(1.0 / state_hz, self._publish_state)
self.get_logger().info(
f"[gimbal] node ready — port={self._port} "
f"pan=±{self._pan_limit}° tilt=±{self._tilt_limit}° "
f"speed={self._max_speed}°/s accel={self._accel}°/s²"
)
# ── Subscriber callbacks ───────────────────────────────────────────────
def _on_cmd(self, msg: Vector3) -> None:
"""Handle /saltybot/gimbal/cmd (x=pan, y=tilt, z=speed_deg_s)."""
pan = _clamp(msg.x, -self._pan_limit, self._pan_limit)
tilt = _clamp(msg.y, -self._tilt_limit, self._tilt_limit)
speed = msg.z if msg.z > 0.0 else self._max_speed
speed = _clamp(speed, 1.0, self._max_speed)
with self._state_lock:
self._pan_axis.set_target(pan, speed)
self._tilt_axis.set_target(tilt, speed)
self._last_cmd_time = time.monotonic()
def _on_look_at_target(self, msg: PointStamped) -> None:
"""Cache look-at 3D target for use by /look_at service."""
self._look_at_target = msg
# ── Service callbacks ──────────────────────────────────────────────────
def _svc_home(self, _req: Trigger.Request, resp: Trigger.Response) -> Trigger.Response:
"""Return gimbal to home position."""
with self._state_lock:
self._pan_axis.set_target(self._home_pan, self._max_speed)
self._tilt_axis.set_target(self._home_tilt, self._max_speed)
self._serial.send(encode_gimbal_home())
resp.success = True
resp.message = f"Homing to pan={self._home_pan}° tilt={self._home_tilt}°"
self.get_logger().info("[gimbal] home command")
return resp
def _svc_look_at(self, _req: Trigger.Request, resp: Trigger.Response) -> Trigger.Response:
"""Convert last look_at_target point to pan/tilt and command motion."""
tgt = self._look_at_target
if tgt is None:
resp.success = False
resp.message = "No look_at_target received on /saltybot/gimbal/look_at_target"
return resp
# Project 3D point (camera frame) → pan/tilt angles.
# Assumes point is in camera_depth_optical_frame (Z=forward, X=right, Y=down).
x, y, z = tgt.point.x, tgt.point.y, tgt.point.z
if z <= 0.0:
resp.success = False
resp.message = f"Invalid look_at_target: z={z:.3f} (must be > 0)"
return resp
pan_deg = math.degrees(math.atan2(x, z))
tilt_deg = -math.degrees(math.atan2(y, z)) # negative: down in image = tilt up
pan_deg = _clamp(pan_deg, -self._pan_limit, self._pan_limit)
tilt_deg = _clamp(tilt_deg, -self._tilt_limit, self._tilt_limit)
with self._state_lock:
self._pan_axis.set_target(pan_deg, self._max_speed)
self._tilt_axis.set_target(tilt_deg, self._max_speed)
resp.success = True
resp.message = (
f"Looking at ({x:.2f}, {y:.2f}, {z:.2f}) → "
f"pan={pan_deg:.1f}° tilt={tilt_deg:.1f}°"
)
self.get_logger().info(f"[gimbal] look_at: {resp.message}")
return resp
# ── Motion tick ────────────────────────────────────────────────────────
def _motion_tick(self) -> None:
"""Advance motion profiles and send JLINK frame at update_rate_hz."""
# Reconnect if needed
if not self._serial.is_connected():
now = time.monotonic()
if now - self._reconnect_ts >= self._reconnect_del:
self._reconnect_ts = now
if self._serial.connect():
self.get_logger().info(f"[gimbal] serial reconnected on {self._port}")
dt = 1.0 / self.get_parameter("update_rate_hz").value
with self._state_lock:
pan = self._pan_axis.tick(dt)
tilt = self._tilt_axis.tick(dt)
moving = self._pan_axis.is_moving or self._tilt_axis.is_moving
speed = max(self._pan_axis.max_speed, self._tilt_axis.max_speed)
# Send JLINK frame
frame = encode_gimbal_pos(pan, tilt, speed)
self._serial.send(frame)
# Read telemetry
hw = self._serial.read_pending()
if hw is not None:
with self._state_lock:
self._hw_state = hw
# Echo command
echo = Vector3(x=pan, y=tilt, z=speed)
self._pub_cmd_echo.publish(echo)
# ── State publisher ────────────────────────────────────────────────────
def _publish_state(self) -> None:
"""Publish gimbal state as JSON string."""
with self._state_lock:
pan = self._pan_axis.pos
tilt = self._tilt_axis.pos
pan_t = self._pan_axis.target
tilt_t = self._tilt_axis.target
moving = self._pan_axis.is_moving or self._tilt_axis.is_moving
hw = self._hw_state
state: dict = {
"pan_deg": round(pan, 2),
"tilt_deg": round(tilt, 2),
"pan_target_deg": round(pan_t, 2),
"tilt_target_deg": round(tilt_t, 2),
"moving": moving,
"serial_ok": self._serial.is_connected(),
}
if hw is not None:
state["hw_pan_deg"] = round(hw.pan_deg, 1)
state["hw_tilt_deg"] = round(hw.tilt_deg, 1)
state["hw_pan_speed_raw"] = hw.pan_speed_raw
state["hw_tilt_speed_raw"] = hw.tilt_speed_raw
state["hw_torque_en"] = hw.torque_en
state["hw_rx_err_pct"] = hw.rx_err_pct
self._pub_state.publish(String(data=json.dumps(state)))
def destroy_node(self) -> None:
self._serial.disconnect()
super().destroy_node()
def main(args=None) -> None:
rclpy.init(args=args)
node = GimbalNode()
try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,141 @@
"""jlink_gimbal.py — JLINK binary frame codec for gimbal commands (Issue #548).
Matches the JLINK protocol defined in include/jlink.h (Issue #547 STM32 side).
Command type (Jetson STM32):
0x0B GIMBAL_POS int16 pan_x10 + int16 tilt_x10 + uint16 speed (6 bytes)
pan_x10 = pan_deg * 10 (±1500 for ±150°)
tilt_x10 = tilt_deg * 10 (±450 for ±45°)
speed = servo speed register 04095 (0 = max)
Telemetry type (STM32 Jetson):
0x84 GIMBAL_STATE int16 pan_x10 + int16 tilt_x10 +
uint16 pan_speed_raw + uint16 tilt_speed_raw +
uint8 torque_en + uint8 rx_err_pct (10 bytes)
Frame format (shared with stm32_protocol.py):
[STX=0x02][CMD][LEN][PAYLOAD...][CRC16_hi][CRC16_lo][ETX=0x03]
CRC16-CCITT: poly=0x1021, init=0xFFFF, covers CMD+LEN+PAYLOAD bytes.
"""
from __future__ import annotations
import struct
from dataclasses import dataclass
from typing import Optional
# ── Frame constants ────────────────────────────────────────────────────────────
STX = 0x02
ETX = 0x03
# ── Command / telemetry type codes ─────────────────────────────────────────────
CMD_GIMBAL_POS = 0x0B # Jetson → STM32: set pan/tilt target
TLM_GIMBAL_STATE = 0x84 # STM32 → Jetson: measured state
# Speed register: 0 = maximum servo speed; 4095 = slowest non-zero speed.
# Map deg/s to this register: speed_reg = max(0, 4095 - int(deg_s * 4095 / 360))
_MAX_SPEED_DEGS = 360.0 # degrees/sec at speed_reg=0
# ── Parsed telemetry ───────────────────────────────────────────────────────────
@dataclass
class GimbalStateFrame:
pan_deg: float # measured pan position (degrees, + = right)
tilt_deg: float # measured tilt position (degrees, + = up)
pan_speed_raw: int # pan servo present-speed register value
tilt_speed_raw: int # tilt servo present-speed register value
torque_en: bool # True = servo torque is on
rx_err_pct: int # SCS bus error rate 0100 %
# ── CRC16-CCITT ────────────────────────────────────────────────────────────────
def _crc16_ccitt(data: bytes, init: int = 0xFFFF) -> int:
"""CRC16-CCITT: poly=0x1021, init=0xFFFF, no final XOR."""
crc = init
for byte in data:
crc ^= byte << 8
for _ in range(8):
crc = ((crc << 1) ^ 0x1021) if (crc & 0x8000) else (crc << 1)
crc &= 0xFFFF
return crc
# ── Frame builder ──────────────────────────────────────────────────────────────
def _build_frame(cmd: int, payload: bytes) -> bytes:
length = len(payload)
header = bytes([cmd, length])
crc = _crc16_ccitt(header + payload)
return bytes([STX, cmd, length]) + payload + struct.pack(">H", crc) + bytes([ETX])
# ── Speed conversion ───────────────────────────────────────────────────────────
def degs_to_speed_reg(deg_s: float) -> int:
"""Convert degrees/second to servo speed register value.
speed_reg=0 is maximum speed; speed_reg=4095 is the slowest non-zero.
Linear mapping: 0 deg/s 0 (max), 360 deg/s 0, clamped.
"""
if deg_s <= 0.0:
return 0 # max speed
ratio = min(1.0, deg_s / _MAX_SPEED_DEGS)
return max(0, min(4095, int((1.0 - ratio) * 4095)))
def speed_reg_to_degs(reg: int) -> float:
"""Convert speed register to approximate degrees/second."""
if reg == 0:
return _MAX_SPEED_DEGS
return (1.0 - reg / 4095.0) * _MAX_SPEED_DEGS
# ── Encoder ────────────────────────────────────────────────────────────────────
def encode_gimbal_pos(pan_deg: float, tilt_deg: float, speed_deg_s: float) -> bytes:
"""Build GIMBAL_POS frame.
Args:
pan_deg: Pan angle (degrees, clamped to ±327°).
tilt_deg: Tilt angle (degrees, clamped to ±327°).
speed_deg_s: Motion speed (degrees/second; 0 = max).
Returns:
Complete JLINK binary frame (10 bytes total).
"""
pan_x10 = max(-32767, min(32767, int(round(pan_deg * 10))))
tilt_x10 = max(-32767, min(32767, int(round(tilt_deg * 10))))
speed = degs_to_speed_reg(speed_deg_s)
return _build_frame(CMD_GIMBAL_POS, struct.pack(">hhH", pan_x10, tilt_x10, speed))
def encode_gimbal_home() -> bytes:
"""Build GIMBAL_POS frame targeting (0°, 0°) at default speed."""
return encode_gimbal_pos(0.0, 0.0, 0.0) # speed=0 → max speed for homing
# ── Decoder ────────────────────────────────────────────────────────────────────
def decode_gimbal_state(payload: bytes) -> Optional[GimbalStateFrame]:
"""Decode GIMBAL_STATE telemetry payload (10 bytes).
Returns:
GimbalStateFrame on success, None if payload is too short.
"""
if len(payload) < 10:
return None
pan_x10, tilt_x10, pan_spd, tilt_spd, torque_en, rx_err_pct = (
struct.unpack_from(">hhHHBB", payload)
)
return GimbalStateFrame(
pan_deg=pan_x10 / 10.0,
tilt_deg=tilt_x10 / 10.0,
pan_speed_raw=pan_spd,
tilt_speed_raw=tilt_spd,
torque_en=bool(torque_en),
rx_err_pct=int(rx_err_pct),
)

View File

@ -0,0 +1,2 @@
[develop]
script_dir=$base/lib/saltybot_gimbal/scripts

View File

@ -0,0 +1,23 @@
from setuptools import setup
package_name = 'saltybot_gimbal'
setup(
name=package_name,
version='1.0.0',
packages=[package_name],
data_files=[
('share/ament_index/resource_index/packages', ['resource/' + package_name]),
('share/' + package_name, ['package.xml']),
('share/' + package_name + '/launch', ['launch/gimbal.launch.py']),
('share/' + package_name + '/config', ['config/gimbal_params.yaml']),
],
install_requires=['setuptools'],
zip_safe=True,
author='Salty Lab',
entry_points={
'console_scripts': [
'gimbal_node = saltybot_gimbal.gimbal_node:main',
],
},
)

View File

@ -0,0 +1,422 @@
"""test_gimbal.py — Unit tests for saltybot_gimbal (Issue #548).
Runs without hardware: all serial I/O is mocked.
"""
import math
import struct
import sys
import types
import unittest
from unittest.mock import MagicMock
# ── Stub ROS2 / serial so tests run offline ────────────────────────────────────
def _make_rclpy_stub():
rclpy = types.ModuleType("rclpy")
rclpy.init = lambda **_: None
rclpy.spin = lambda _: None
rclpy.shutdown = lambda: None
node_mod = types.ModuleType("rclpy.node")
class _Node:
def __init__(self, *a, **kw): pass
def declare_parameter(self, *a, **kw): pass
def get_parameter(self, name):
defaults = {
"serial_port": "/dev/null", "baud_rate": 921600,
"pan_limit_deg": 150.0, "tilt_limit_deg": 45.0,
"home_pan_deg": 0.0, "home_tilt_deg": 0.0,
"max_speed_deg_s": 90.0, "accel_deg_s2": 180.0,
"update_rate_hz": 20.0, "state_publish_hz": 10.0,
"reconnect_delay_s": 2.0,
"camera_focal_length_px": 600.0,
"image_width_px": 848, "image_height_px": 480,
}
m = MagicMock()
m.value = defaults.get(name, 0)
return m
def get_logger(self): return MagicMock()
def create_subscription(self, *a, **kw): pass
def create_publisher(self, *a, **kw): return MagicMock()
def create_service(self, *a, **kw): pass
def create_timer(self, *a, **kw): pass
def destroy_node(self): pass
node_mod.Node = _Node
rclpy.node = node_mod
qos_mod = types.ModuleType("rclpy.qos")
qos_mod.QoSProfile = MagicMock(return_value=None)
qos_mod.HistoryPolicy = MagicMock()
qos_mod.ReliabilityPolicy = MagicMock()
rclpy.qos = qos_mod
return rclpy
def _install_stubs():
rclpy_stub = _make_rclpy_stub()
sys.modules.setdefault("rclpy", rclpy_stub)
sys.modules.setdefault("rclpy.node", rclpy_stub.node)
sys.modules.setdefault("rclpy.qos", rclpy_stub.qos)
geo = types.ModuleType("geometry_msgs")
geo_msg = types.ModuleType("geometry_msgs.msg")
class _Vector3:
def __init__(self, x=0.0, y=0.0, z=0.0): self.x=x; self.y=y; self.z=z
class _PointStamped:
def __init__(self): self.point = _Vector3()
geo_msg.Vector3 = _Vector3
geo_msg.PointStamped = _PointStamped
geo.msg = geo_msg
sys.modules.setdefault("geometry_msgs", geo)
sys.modules.setdefault("geometry_msgs.msg", geo_msg)
std_msgs = types.ModuleType("std_msgs")
std_msgs_msg = types.ModuleType("std_msgs.msg")
class _String:
def __init__(self, data=""): self.data = data
std_msgs_msg.String = _String
std_msgs.msg = std_msgs_msg
sys.modules.setdefault("std_msgs", std_msgs)
sys.modules.setdefault("std_msgs.msg", std_msgs_msg)
std_srvs = types.ModuleType("std_srvs")
std_srvs_srv = types.ModuleType("std_srvs.srv")
class _Trigger:
class Request: pass
class Response:
success = False
message = ""
std_srvs_srv.Trigger = _Trigger
std_srvs.srv = std_srvs_srv
sys.modules.setdefault("std_srvs", std_srvs)
sys.modules.setdefault("std_srvs.srv", std_srvs_srv)
sys.modules.setdefault("serial", MagicMock())
_install_stubs()
from saltybot_gimbal.jlink_gimbal import ( # noqa: E402
encode_gimbal_pos, encode_gimbal_home, decode_gimbal_state,
GimbalStateFrame, CMD_GIMBAL_POS, TLM_GIMBAL_STATE,
STX, ETX, _crc16_ccitt, _build_frame,
degs_to_speed_reg, speed_reg_to_degs,
)
from saltybot_gimbal.gimbal_node import MotionAxis, _clamp # noqa: E402
# ── Helper ─────────────────────────────────────────────────────────────────────
def _parse_frame(raw: bytes):
"""Parse a JLINK frame; return (cmd, payload) or raise on error."""
assert raw[0] == STX, f"STX expected, got 0x{raw[0]:02x}"
cmd = raw[1]
length = raw[2]
payload = raw[3:3 + length]
crc_hi = raw[3 + length]
crc_lo = raw[4 + length]
crc_rcvd = (crc_hi << 8) | crc_lo
assert raw[-1] == ETX, f"ETX expected, got 0x{raw[-1]:02x}"
crc_data = bytes([cmd, length]) + payload
crc_calc = _crc16_ccitt(crc_data)
assert crc_rcvd == crc_calc, (
f"CRC mismatch: got {crc_rcvd:#06x}, expected {crc_calc:#06x}"
)
return cmd, payload
# ── CRC16 Tests ────────────────────────────────────────────────────────────────
class TestCRC16(unittest.TestCase):
def test_empty_data(self):
self.assertEqual(_crc16_ccitt(b""), 0xFFFF)
def test_known_value_123456789(self):
# Standard CCITT-FFFF check value for "123456789"
self.assertEqual(_crc16_ccitt(b"123456789"), 0x29B1)
def test_single_byte_is_int(self):
v = _crc16_ccitt(b"\x0B")
self.assertIsInstance(v, int)
self.assertLessEqual(v, 0xFFFF)
def test_different_data_different_crc(self):
self.assertNotEqual(_crc16_ccitt(b"\x00\x01"), _crc16_ccitt(b"\x00\x02"))
# ── Speed Conversion Tests ─────────────────────────────────────────────────────
class TestSpeedConversion(unittest.TestCase):
def test_zero_speed_is_max(self):
self.assertEqual(degs_to_speed_reg(0.0), 0)
def test_negative_speed_is_max(self):
self.assertEqual(degs_to_speed_reg(-10.0), 0)
def test_360_degs_maps_to_zero(self):
# 360°/s → ratio=1 → reg=0 (max)
self.assertEqual(degs_to_speed_reg(360.0), 0)
def test_mid_speed(self):
reg = degs_to_speed_reg(90.0)
self.assertGreater(reg, 0)
self.assertLess(reg, 4095)
def test_round_trip_approx(self):
for deg_s in (30.0, 60.0, 120.0, 200.0):
reg = degs_to_speed_reg(deg_s)
back = speed_reg_to_degs(reg)
self.assertAlmostEqual(back, deg_s, delta=5.0)
def test_reg_zero_is_max_speed(self):
self.assertAlmostEqual(speed_reg_to_degs(0), 360.0)
# ── Encode Tests ───────────────────────────────────────────────────────────────
class TestEncodeGimbalPos(unittest.TestCase):
def test_zero_pose(self):
raw = encode_gimbal_pos(0.0, 0.0, 90.0)
cmd, payload = _parse_frame(raw)
self.assertEqual(cmd, CMD_GIMBAL_POS) # 0x0B
pan_x10, tilt_x10, speed = struct.unpack(">hhH", payload)
self.assertEqual(pan_x10, 0)
self.assertEqual(tilt_x10, 0)
def test_positive_pan_tilt(self):
raw = encode_gimbal_pos(30.0, 15.0, 45.0)
_, payload = _parse_frame(raw)
pan_x10, tilt_x10, _ = struct.unpack(">hhH", payload)
self.assertEqual(pan_x10, 300) # 30.0 * 10
self.assertEqual(tilt_x10, 150) # 15.0 * 10
def test_negative_angles(self):
raw = encode_gimbal_pos(-90.0, -20.0, 60.0)
_, payload = _parse_frame(raw)
pan_x10, tilt_x10, _ = struct.unpack(">hhH", payload)
self.assertEqual(pan_x10, -900)
self.assertEqual(tilt_x10, -200)
def test_pan_limit_150_deg(self):
raw = encode_gimbal_pos(150.0, 0.0, 30.0)
_, payload = _parse_frame(raw)
pan_x10, _, _ = struct.unpack(">hhH", payload)
self.assertEqual(pan_x10, 1500)
def test_tilt_limit_45_deg(self):
raw = encode_gimbal_pos(0.0, 45.0, 30.0)
_, payload = _parse_frame(raw)
_, tilt_x10, _ = struct.unpack(">hhH", payload)
self.assertEqual(tilt_x10, 450)
def test_extreme_clamped(self):
raw = encode_gimbal_pos(9999.0, -9999.0, 0.0)
_, payload = _parse_frame(raw)
pan_x10, tilt_x10, _ = struct.unpack(">hhH", payload)
self.assertEqual(pan_x10, 32767)
self.assertEqual(tilt_x10, -32767)
def test_frame_byte_layout(self):
raw = encode_gimbal_pos(10.0, -5.0, 30.0)
self.assertEqual(raw[0], STX)
self.assertEqual(raw[-1], ETX)
self.assertEqual(raw[1], CMD_GIMBAL_POS) # 0x0B
self.assertEqual(raw[2], 6) # 6-byte payload
def test_crc_valid(self):
# _parse_frame raises if CRC is bad
_parse_frame(encode_gimbal_pos(45.0, -10.0, 120.0))
class TestEncodeGimbalHome(unittest.TestCase):
def test_homes_to_zero(self):
raw = encode_gimbal_home()
cmd, payload = _parse_frame(raw)
self.assertEqual(cmd, CMD_GIMBAL_POS)
pan_x10, tilt_x10, _ = struct.unpack(">hhH", payload)
self.assertEqual(pan_x10, 0)
self.assertEqual(tilt_x10, 0)
def test_crc_valid(self):
_parse_frame(encode_gimbal_home())
# ── Decode Tests ───────────────────────────────────────────────────────────────
class TestDecodeGimbalState(unittest.TestCase):
def _make_payload(self, pan_x10, tilt_x10, pan_spd, tilt_spd,
torque_en, rx_err_pct) -> bytes:
return struct.pack(">hhHHBB",
pan_x10, tilt_x10,
pan_spd, tilt_spd,
torque_en, rx_err_pct)
def test_zero_state(self):
state = decode_gimbal_state(self._make_payload(0, 0, 0, 0, 0, 0))
self.assertIsNotNone(state)
self.assertAlmostEqual(state.pan_deg, 0.0)
self.assertAlmostEqual(state.tilt_deg, 0.0)
self.assertFalse(state.torque_en)
self.assertEqual(state.rx_err_pct, 0)
def test_angles_decode(self):
state = decode_gimbal_state(self._make_payload(500, -250, 100, 200, 1, 5))
self.assertAlmostEqual(state.pan_deg, 50.0)
self.assertAlmostEqual(state.tilt_deg, -25.0)
def test_torque_en(self):
state = decode_gimbal_state(self._make_payload(0, 0, 0, 0, 1, 0))
self.assertTrue(state.torque_en)
def test_torque_off(self):
state = decode_gimbal_state(self._make_payload(0, 0, 0, 0, 0, 0))
self.assertFalse(state.torque_en)
def test_speed_raw(self):
state = decode_gimbal_state(self._make_payload(0, 0, 1234, 5678, 0, 0))
self.assertEqual(state.pan_speed_raw, 1234)
self.assertEqual(state.tilt_speed_raw, 5678)
def test_rx_err_pct(self):
state = decode_gimbal_state(self._make_payload(0, 0, 0, 0, 0, 42))
self.assertEqual(state.rx_err_pct, 42)
def test_short_payload_returns_none(self):
self.assertIsNone(decode_gimbal_state(b"\x00\x01\x02"))
def test_empty_payload_returns_none(self):
self.assertIsNone(decode_gimbal_state(b""))
def test_exactly_10_bytes_ok(self):
payload = self._make_payload(100, -50, 512, 512, 1, 3)
state = decode_gimbal_state(payload)
self.assertIsNotNone(state)
# ── MotionAxis Tests ───────────────────────────────────────────────────────────
class TestMotionAxis(unittest.TestCase):
def test_starts_at_initial_pos(self):
ax = MotionAxis(10.0, 90.0, 180.0)
self.assertAlmostEqual(ax.pos, 10.0)
def test_not_moving_at_rest(self):
ax = MotionAxis(0.0, 90.0, 180.0)
self.assertFalse(ax.is_moving)
def test_moves_toward_target(self):
ax = MotionAxis(0.0, 90.0, 180.0)
ax.set_target(90.0)
for _ in range(5):
ax.tick(0.05)
self.assertGreater(ax.pos, 0.0)
def test_reaches_target(self):
ax = MotionAxis(0.0, 90.0, 180.0)
ax.set_target(10.0)
for _ in range(200):
ax.tick(0.05)
self.assertAlmostEqual(ax.pos, 10.0, places=1)
self.assertFalse(ax.is_moving)
def test_negative_target(self):
ax = MotionAxis(0.0, 90.0, 180.0)
ax.set_target(-20.0)
for _ in range(200):
ax.tick(0.05)
self.assertAlmostEqual(ax.pos, -20.0, places=1)
def test_faster_speed_reaches_sooner(self):
ax_fast = MotionAxis(0.0, 180.0, 360.0)
ax_slow = MotionAxis(0.0, 30.0, 360.0)
ax_fast.set_target(30.0)
ax_slow.set_target(30.0)
for _ in range(5):
ax_fast.tick(0.05)
ax_slow.tick(0.05)
self.assertGreater(ax_fast.pos, ax_slow.pos)
def test_is_moving_while_en_route(self):
ax = MotionAxis(0.0, 90.0, 180.0)
ax.set_target(60.0)
ax.tick(0.05)
self.assertTrue(ax.is_moving)
def test_no_overshoot(self):
ax = MotionAxis(0.0, 90.0, 180.0)
ax.set_target(5.0)
for _ in range(500):
ax.tick(0.05)
self.assertLessEqual(ax.pos, 5.01)
def test_set_target_updates_speed(self):
ax = MotionAxis(0.0, 90.0, 180.0)
ax.set_target(10.0, speed=45.0)
self.assertAlmostEqual(ax.max_speed, 45.0)
# ── Clamp Tests ────────────────────────────────────────────────────────────────
class TestClamp(unittest.TestCase):
def test_within_range(self):
self.assertEqual(_clamp(5.0, 0.0, 10.0), 5.0)
def test_below_min(self):
self.assertEqual(_clamp(-5.0, 0.0, 10.0), 0.0)
def test_above_max(self):
self.assertEqual(_clamp(15.0, 0.0, 10.0), 10.0)
def test_at_lower_boundary(self):
self.assertEqual(_clamp(0.0, 0.0, 10.0), 0.0)
def test_at_upper_boundary(self):
self.assertEqual(_clamp(10.0, 0.0, 10.0), 10.0)
# ── Look-at Projection Tests ───────────────────────────────────────────────────
class TestLookAtProjection(unittest.TestCase):
"""Test the look-at 3D → pan/tilt math (mirrors gimbal_node logic)."""
def _project(self, x, y, z):
pan_deg = math.degrees(math.atan2(x, z))
tilt_deg = -math.degrees(math.atan2(y, z))
return pan_deg, tilt_deg
def test_straight_ahead(self):
pan, tilt = self._project(0.0, 0.0, 1.0)
self.assertAlmostEqual(pan, 0.0, places=5)
self.assertAlmostEqual(tilt, 0.0, places=5)
def test_right_target(self):
pan, tilt = self._project(1.0, 0.0, 1.0)
self.assertAlmostEqual(pan, 45.0, places=4)
self.assertAlmostEqual(tilt, 0.0, places=4)
def test_up_target(self):
# camera frame: negative y = above centre → tilt up (positive)
pan, tilt = self._project(0.0, -1.0, 1.0)
self.assertAlmostEqual(pan, 0.0, places=4)
self.assertAlmostEqual(tilt, 45.0, places=4)
def test_left_down_target(self):
pan, tilt = self._project(-1.0, 1.0, 2.0)
self.assertLess(pan, 0.0) # left → negative pan
self.assertLess(tilt, 0.0) # below → negative tilt
def test_90_degree_right(self):
pan, _ = self._project(1.0, 0.0, 0.0)
self.assertAlmostEqual(pan, 90.0, places=4)
if __name__ == "__main__":
unittest.main()