140 lines
4.2 KiB
Python
140 lines
4.2 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import queue
|
|
import threading
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, Optional
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class MqttConfig:
|
|
server: str # host or host:port
|
|
topic: str
|
|
username: Optional[str] = None
|
|
password: Optional[str] = None
|
|
client_id: Optional[str] = None
|
|
keepalive: int = 60
|
|
|
|
|
|
def _parse_server(server: str) -> tuple[str, int]:
|
|
server = server.strip()
|
|
if not server:
|
|
raise ValueError("MQTT server must not be empty")
|
|
|
|
# Accept host or host:port.
|
|
if ":" in server:
|
|
host, port_s = server.rsplit(":", 1)
|
|
host = host.strip()
|
|
if not host:
|
|
raise ValueError("MQTT server host must not be empty")
|
|
try:
|
|
port = int(port_s)
|
|
except ValueError as exc:
|
|
raise ValueError("MQTT server port must be an integer") from exc
|
|
return host, port
|
|
|
|
return server, 1883
|
|
|
|
|
|
class MqttJsonSubscriber:
|
|
"""Subscribes to a topic and yields decoded JSON payloads."""
|
|
|
|
def __init__(self, config: MqttConfig):
|
|
self._config = config
|
|
self._queue: queue.Queue[Dict[str, Any]] = queue.Queue(maxsize=1000)
|
|
self._client = None
|
|
|
|
def __enter__(self) -> "MqttJsonSubscriber":
|
|
self.open()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb) -> None:
|
|
self.close()
|
|
|
|
def open(self) -> None:
|
|
if self._client is not None:
|
|
return
|
|
|
|
try:
|
|
import paho.mqtt.client as mqtt
|
|
except Exception as exc: # pragma: no cover
|
|
raise RuntimeError(
|
|
"MQTT requires paho-mqtt. Install it (pip install paho-mqtt) or run with --source usb."
|
|
) from exc
|
|
|
|
host, port = _parse_server(self._config.server)
|
|
|
|
client = mqtt.Client(client_id=self._config.client_id)
|
|
if self._config.username is not None:
|
|
client.username_pw_set(self._config.username, self._config.password)
|
|
|
|
connected = threading.Event()
|
|
connect_rc: dict[str, int] = {"rc": -1}
|
|
|
|
def on_connect(_client, _userdata, _flags, rc):
|
|
connect_rc["rc"] = int(rc)
|
|
connected.set()
|
|
if rc != 0:
|
|
return
|
|
_client.subscribe(self._config.topic)
|
|
|
|
def on_message(_client, _userdata, msg):
|
|
try:
|
|
payload = msg.payload.decode("utf-8", errors="strict")
|
|
data = json.loads(payload)
|
|
if isinstance(data, dict):
|
|
try:
|
|
self._queue.put_nowait(data)
|
|
except queue.Full:
|
|
# Drop oldest to keep latest values moving.
|
|
try:
|
|
_ = self._queue.get_nowait()
|
|
except queue.Empty:
|
|
pass
|
|
try:
|
|
self._queue.put_nowait(data)
|
|
except queue.Full:
|
|
pass
|
|
except Exception:
|
|
# Ignore invalid JSON messages.
|
|
return
|
|
|
|
client.on_connect = on_connect
|
|
client.on_message = on_message
|
|
|
|
client.connect(host, port, keepalive=self._config.keepalive)
|
|
client.loop_start()
|
|
|
|
if not connected.wait(timeout=5.0):
|
|
client.loop_stop()
|
|
client.disconnect()
|
|
raise RuntimeError("MQTT connect timed out")
|
|
|
|
if connect_rc["rc"] != 0:
|
|
client.loop_stop()
|
|
client.disconnect()
|
|
raise RuntimeError(f"MQTT connect failed (rc={connect_rc['rc']})")
|
|
|
|
self._client = client
|
|
|
|
def close(self) -> None:
|
|
client = self._client
|
|
self._client = None
|
|
if client is None:
|
|
return
|
|
try:
|
|
client.loop_stop()
|
|
finally:
|
|
try:
|
|
client.disconnect()
|
|
except Exception:
|
|
pass
|
|
|
|
def get(self, *, timeout: Optional[float] = None) -> Optional[Dict[str, Any]]:
|
|
"""Get the next JSON message dict, or None on timeout."""
|
|
try:
|
|
return self._queue.get(timeout=timeout)
|
|
except queue.Empty:
|
|
return None
|