#!/usr/bin/env python3 """ build_engine.py — Convert BiSeNetV2/DDRNet PyTorch model → ONNX → TensorRT FP16 engine. Run ONCE on the Jetson Orin Nano Super. The .engine file is hardware-specific (cannot be transferred between machines). Usage ───── # Build BiSeNetV2 engine (default, fastest): python3 build_engine.py # Build DDRNet-23-slim engine (higher mIoU, slightly slower): python3 build_engine.py --model ddrnet # Custom output paths: python3 build_engine.py --onnx /tmp/model.onnx --engine /tmp/model.engine # Re-export ONNX from local weights (skip download): python3 build_engine.py --weights /path/to/bisenetv2.pth Prerequisites ───────────── pip install torch torchvision onnx onnxruntime-gpu # TensorRT is pre-installed with JetPack 6 at /usr/lib/python3/dist-packages/tensorrt Outputs ─────── /mnt/nvme/saltybot/models/bisenetv2_cityscapes_512x256.onnx (FP32) /mnt/nvme/saltybot/models/bisenetv2_cityscapes_512x256.engine (TRT FP16) Model sources ───────────── BiSeNetV2 pretrained on Cityscapes: https://github.com/CoinCheung/BiSeNet (MIT License) Checkpoint: cp/model_final_v2_city.pth (~25MB) DDRNet-23-slim pretrained on Cityscapes: https://github.com/ydhongHIT/DDRNet (Apache License) Checkpoint: DDRNet23s_imagenet.pth + Cityscapes fine-tune Performance on Orin Nano Super (1024-core Ampere, 20 TOPS) ────────────────────────────────────────────────────────── BiSeNetV2 FP32 ONNX: ~12ms (~83fps theoretical) BiSeNetV2 FP16 TRT: ~5ms (~200fps theoretical, real ~50fps w/ overhead) DDRNet-23 FP16 TRT: ~8ms (~125fps theoretical, real ~40fps) Target: >15fps including ROS2 overhead at 512×256 """ import argparse import os import sys import time # ── Default paths ───────────────────────────────────────────────────────────── DEFAULT_MODEL_DIR = "/mnt/nvme/saltybot/models" INPUT_W, INPUT_H = 512, 256 BATCH_SIZE = 1 MODEL_CONFIGS = { "bisenetv2": { "onnx_name": "bisenetv2_cityscapes_512x256.onnx", "engine_name": "bisenetv2_cityscapes_512x256.engine", "repo_url": "https://github.com/CoinCheung/BiSeNet.git", "weights_url": "https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v2_city.pth", "num_classes": 19, "description": "BiSeNetV2 Cityscapes — 72.6 mIoU, ~50fps on Orin", }, "ddrnet": { "onnx_name": "ddrnet23s_cityscapes_512x256.onnx", "engine_name": "ddrnet23s_cityscapes_512x256.engine", "repo_url": "https://github.com/ydhongHIT/DDRNet.git", "weights_url": None, # must supply manually "num_classes": 19, "description": "DDRNet-23-slim Cityscapes — 79.5 mIoU, ~40fps on Orin", }, } def parse_args(): p = argparse.ArgumentParser(description="Build TensorRT segmentation engine") p.add_argument("--model", default="bisenetv2", choices=list(MODEL_CONFIGS.keys()), help="Model architecture") p.add_argument("--weights", default=None, help="Path to .pth weights file (downloads if not set)") p.add_argument("--onnx", default=None, help="Output ONNX path") p.add_argument("--engine", default=None, help="Output TRT engine path") p.add_argument("--fp32", action="store_true", help="Build FP32 engine instead of FP16 (slower, more accurate)") p.add_argument("--workspace-gb", type=float, default=2.0, help="TRT builder workspace in GB (default 2.0)") p.add_argument("--skip-onnx", action="store_true", help="Skip ONNX export (use existing .onnx file)") return p.parse_args() def ensure_dir(path: str): os.makedirs(path, exist_ok=True) def download_weights(url: str, dest: str): """Download model weights with progress display.""" import urllib.request print(f" Downloading weights from {url}") print(f" → {dest}") def progress(count, block, total): pct = min(count * block / total * 100, 100) bar = "#" * int(pct / 2) sys.stdout.write(f"\r [{bar:<50}] {pct:.1f}%") sys.stdout.flush() urllib.request.urlretrieve(url, dest, reporthook=progress) print() def export_bisenetv2_onnx(weights_path: str, onnx_path: str, num_classes: int = 19): """ Export BiSeNetV2 from CoinCheung/BiSeNet to ONNX. Expects BiSeNet repo to be cloned in /tmp/BiSeNet or PYTHONPATH set. """ import torch # Try to import BiSeNetV2 — clone repo if needed try: sys.path.insert(0, "/tmp/BiSeNet") from lib.models.bisenetv2 import BiSeNetV2 except ImportError: print(" Cloning BiSeNet repository to /tmp/BiSeNet ...") os.system("git clone --depth 1 https://github.com/CoinCheung/BiSeNet.git /tmp/BiSeNet") sys.path.insert(0, "/tmp/BiSeNet") from lib.models.bisenetv2 import BiSeNetV2 print(" Loading BiSeNetV2 weights ...") net = BiSeNetV2(num_classes) state = torch.load(weights_path, map_location="cpu") # Handle both direct state_dict and checkpoint wrapper state_dict = state.get("state_dict", state.get("model", state)) # Strip module. prefix from DataParallel checkpoints state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} net.load_state_dict(state_dict, strict=False) net.eval() print(f" Exporting ONNX → {onnx_path}") dummy = torch.randn(BATCH_SIZE, 3, INPUT_H, INPUT_W) import torch.onnx # BiSeNetV2 inference returns (logits, *aux) — we only want logits class _Wrapper(torch.nn.Module): def __init__(self, net): super().__init__() self.net = net def forward(self, x): out = self.net(x) return out[0] if isinstance(out, (list, tuple)) else out wrapped = _Wrapper(net) torch.onnx.export( wrapped, dummy, onnx_path, opset_version=12, input_names=["image"], output_names=["logits"], dynamic_axes=None, # fixed batch=1 for TRT do_constant_folding=True, ) # Verify import onnx model = onnx.load(onnx_path) onnx.checker.check_model(model) print(f" ONNX export verified ({os.path.getsize(onnx_path) / 1e6:.1f} MB)") def export_ddrnet_onnx(weights_path: str, onnx_path: str, num_classes: int = 19): """Export DDRNet-23-slim to ONNX.""" import torch try: sys.path.insert(0, "/tmp/DDRNet/tools/../lib") from models.ddrnet_23_slim import get_seg_model except ImportError: print(" Cloning DDRNet repository to /tmp/DDRNet ...") os.system("git clone --depth 1 https://github.com/ydhongHIT/DDRNet.git /tmp/DDRNet") sys.path.insert(0, "/tmp/DDRNet/lib") from models.ddrnet_23_slim import get_seg_model print(" Loading DDRNet weights ...") net = get_seg_model(num_classes=num_classes) net.load_state_dict( {k.replace("module.", ""): v for k, v in torch.load(weights_path, map_location="cpu").items()}, strict=False ) net.eval() print(f" Exporting ONNX → {onnx_path}") dummy = torch.randn(BATCH_SIZE, 3, INPUT_H, INPUT_W) import torch.onnx torch.onnx.export( net, dummy, onnx_path, opset_version=12, input_names=["image"], output_names=["logits"], do_constant_folding=True, ) import onnx onnx.checker.check_model(onnx.load(onnx_path)) print(f" ONNX export verified ({os.path.getsize(onnx_path) / 1e6:.1f} MB)") def build_trt_engine(onnx_path: str, engine_path: str, fp16: bool = True, workspace_gb: float = 2.0): """ Build a TensorRT engine from an ONNX model. Uses explicit batch dimension (not implicit batch). Enables FP16 mode by default (2× speedup on Orin Ampere vs FP32). """ try: import tensorrt as trt except ImportError: print("ERROR: TensorRT not found. Install JetPack 6 or:") print(" pip install tensorrt (x86 only)") sys.exit(1) logger = trt.Logger(trt.Logger.INFO) print(f"\n Building TRT engine ({'FP16' if fp16 else 'FP32'}) from: {onnx_path}") print(f" Output: {engine_path}") print(f" Workspace: {workspace_gb:.1f} GB") print(" This may take 5–15 minutes on first build (layer calibration)...") t0 = time.time() builder = trt.Builder(logger) network = builder.create_network( 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) ) parser = trt.OnnxParser(network, logger) config = builder.create_builder_config() # Workspace config.set_memory_pool_limit( trt.MemoryPoolType.WORKSPACE, int(workspace_gb * 1024**3) ) # FP16 precision if fp16 and builder.platform_has_fast_fp16: config.set_flag(trt.BuilderFlag.FP16) print(" FP16 mode enabled") elif fp16: print(" WARNING: FP16 not available on this platform, using FP32") # Parse ONNX with open(onnx_path, "rb") as f: if not parser.parse(f.read()): for i in range(parser.num_errors): print(f" ONNX parse error: {parser.get_error(i)}") sys.exit(1) print(f" Network: {network.num_inputs} input(s), {network.num_outputs} output(s)") inp = network.get_input(0) print(f" Input shape: {inp.shape} dtype: {inp.dtype}") # Build serialized_engine = builder.build_serialized_network(network, config) if serialized_engine is None: print("ERROR: TRT engine build failed") sys.exit(1) with open(engine_path, "wb") as f: f.write(serialized_engine) elapsed = time.time() - t0 size_mb = os.path.getsize(engine_path) / 1e6 print(f"\n Engine built in {elapsed:.1f}s ({size_mb:.1f} MB)") print(f" Saved: {engine_path}") def validate_engine(engine_path: str): """Run a dummy inference to confirm the engine works and measure latency.""" try: import tensorrt as trt import pycuda.driver as cuda import pycuda.autoinit # noqa import numpy as np except ImportError: print(" Skipping validation (pycuda not available)") return print("\n Validating engine with dummy input...") trt_logger = trt.Logger(trt.Logger.WARNING) with open(engine_path, "rb") as f: engine = trt.Runtime(trt_logger).deserialize_cuda_engine(f.read()) ctx = engine.create_execution_context() stream = cuda.Stream() bindings = [] host_in = host_out = dev_in = dev_out = None for i in range(engine.num_bindings): shape = engine.get_binding_shape(i) dtype = trt.nptype(engine.get_binding_dtype(i)) h_buf = cuda.pagelocked_empty(int(np.prod(shape)), dtype) d_buf = cuda.mem_alloc(h_buf.nbytes) bindings.append(int(d_buf)) if engine.binding_is_input(i): host_in, dev_in = h_buf, d_buf host_in[:] = np.random.randn(*shape).astype(dtype).ravel() else: host_out, dev_out = h_buf, d_buf # Warm up for _ in range(5): cuda.memcpy_htod_async(dev_in, host_in, stream) ctx.execute_async_v2(bindings, stream.handle) cuda.memcpy_dtoh_async(host_out, dev_out, stream) stream.synchronize() # Benchmark N = 20 t0 = time.time() for _ in range(N): cuda.memcpy_htod_async(dev_in, host_in, stream) ctx.execute_async_v2(bindings, stream.handle) cuda.memcpy_dtoh_async(host_out, dev_out, stream) stream.synchronize() elapsed_ms = (time.time() - t0) * 1000 / N out_shape = engine.get_binding_shape(engine.num_bindings - 1) print(f" Output shape: {out_shape}") print(f" Inference latency: {elapsed_ms:.1f}ms ({1000/elapsed_ms:.1f} fps) [avg {N} runs]") if elapsed_ms < 1000 / 15: print(" PASS: target >15fps achieved") else: print(f" WARNING: latency {elapsed_ms:.1f}ms > target {1000/15:.1f}ms") def main(): args = parse_args() cfg = MODEL_CONFIGS[args.model] print(f"\nBuilding TRT engine for: {cfg['description']}") ensure_dir(DEFAULT_MODEL_DIR) onnx_path = args.onnx or os.path.join(DEFAULT_MODEL_DIR, cfg["onnx_name"]) engine_path = args.engine or os.path.join(DEFAULT_MODEL_DIR, cfg["engine_name"]) # ── Step 1: Download weights if needed ──────────────────────────────────── if not args.skip_onnx: weights_path = args.weights if weights_path is None: if cfg["weights_url"] is None: print(f"ERROR: No weights URL for {args.model}. Supply --weights /path/to.pth") sys.exit(1) weights_path = os.path.join(DEFAULT_MODEL_DIR, f"{args.model}_cityscapes.pth") if not os.path.exists(weights_path): print("\nStep 1: Downloading pretrained weights") download_weights(cfg["weights_url"], weights_path) else: print(f"\nStep 1: Using cached weights: {weights_path}") else: print(f"\nStep 1: Using provided weights: {weights_path}") # ── Step 2: Export ONNX ──────────────────────────────────────────────── print(f"\nStep 2: Exporting {args.model.upper()} to ONNX ({INPUT_W}×{INPUT_H})") if args.model == "bisenetv2": export_bisenetv2_onnx(weights_path, onnx_path, cfg["num_classes"]) else: export_ddrnet_onnx(weights_path, onnx_path, cfg["num_classes"]) else: print(f"\nStep 1+2: Skipping ONNX export, using: {onnx_path}") if not os.path.exists(onnx_path): print(f"ERROR: ONNX file not found: {onnx_path}") sys.exit(1) # ── Step 3: Build TRT engine ─────────────────────────────────────────────── print(f"\nStep 3: Building TensorRT engine") build_trt_engine( onnx_path, engine_path, fp16=not args.fp32, workspace_gb=args.workspace_gb, ) # ── Step 4: Validate ─────────────────────────────────────────────────────── print(f"\nStep 4: Validating engine") validate_engine(engine_path) print(f"\nDone. Engine: {engine_path}") print("Start the node:") print(f" ros2 launch saltybot_segmentation sidewalk_segmentation.launch.py \\") print(f" engine_path:={engine_path}") if __name__ == "__main__": main()