"""Async connection to a thermostat over mTLS on port 7878."""
from __future__ import annotations
import asyncio
import contextlib
import hashlib
import logging
import os
import ssl
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar
if TYPE_CHECKING:
from collections.abc import Callable
from typing import Self
import orjson
from .certs import CERT_SETS, CertSet, create_ssl_context
from .const import (
BACKOFF_FACTOR,
CONNECT_TIMEOUT,
DEFAULT_PORT,
HEARTBEAT_INTERVAL,
INITIAL_STATE_TIMEOUT,
PAIRING_TIMEOUT,
RECONNECT_DELAY,
RECONNECT_MAX,
RESPONSE_TIMEOUT,
FanMode,
HoldType,
ZoneMode,
)
from .exceptions import (
AuthenticationError,
PairingError,
SteamloopConnectionError,
SteamloopError,
)
from .models import (
CoolingStatusUpdatedEvent,
EmergencyHeatUpdatedEvent,
ErrorResponse,
FanModeUpdatedEvent,
HeatingStatusUpdatedEvent,
IndoorRelativeHumidityUpdatedEvent,
IndoorTemperatureUpdatedEvent,
LoginResponse,
SetSecretKeyRequest,
SupportedZoneModesUpdatedEvent,
TemperatureSetpointUpdatedEvent,
ThermostatState,
Zone,
ZoneAddedEvent,
ZoneModeUpdatedEvent,
ZoneNameUpdatedEvent,
)
_LOGGER = logging.getLogger(__name__)
def _encode_message(msg: dict[str, Any]) -> bytes:
r"""Encode a message for sending.
Wire format: compact JSON + " " + \x00.
The thermostat uses null-byte delimiters to find message boundaries.
"""
return orjson.dumps(msg) + b" \x00"
def _pairing_path(ip: str, directory: Path | None = None) -> Path:
"""Return the pairing file path for a thermostat IP."""
md5 = hashlib.md5(ip.encode()).hexdigest() # noqa: S324
base = directory or Path.cwd()
return base / f"pairing_{md5}.json"
[docs]
async def load_pairing(ip: str, directory: Path | None = None) -> dict[str, str] | None:
"""
Load saved pairing data for a thermostat IP.
Args:
ip: Thermostat IP address.
directory: Directory to load from. Defaults to current directory.
Returns:
Pairing dict with secret_key, device_type, device_id, or None.
"""
path = _pairing_path(ip, directory)
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, _load_pairing_sync, path)
def _load_pairing_sync(path: Path) -> dict[str, str] | None:
"""Synchronous pairing file read."""
try:
return orjson.loads(path.read_bytes())
except FileNotFoundError:
return None
[docs]
async def save_pairing(
ip: str, login_info: dict[str, str], directory: Path | None = None
) -> None:
"""
Save pairing data for a thermostat IP.
Args:
ip: Thermostat IP address.
login_info: Dict with secret_key, device_type, device_id.
directory: Directory to save to. Defaults to current directory.
"""
path = _pairing_path(ip, directory)
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, _save_pairing_sync, path, login_info)
_LOGGER.info("Pairing saved to %s", path)
def _save_pairing_sync(path: Path, login_info: dict[str, str]) -> None:
"""Synchronous pairing file write (atomic via rename)."""
tmp_path = path.with_suffix(".tmp")
fd = os.open(tmp_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
try:
os.write(fd, orjson.dumps(login_info, option=orjson.OPT_INDENT_2))
finally:
os.close(fd)
tmp_path.replace(path)
[docs]
class ThermostatProtocol(asyncio.Protocol):
"""
Low-level protocol handler for thermostat communication.
Handles framing (null-byte delimited JSON) and delegates parsed
messages to the owning ThermostatConnection.
"""
def __init__(self, connection: ThermostatConnection) -> None:
self._connection = connection
self._transport: asyncio.Transport | None = None
self._buf = bytearray()
[docs]
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Called when the TLS connection is established."""
self._transport = transport # type: ignore[assignment]
[docs]
def data_received(self, data: bytes) -> None:
"""Called when data is received from the thermostat."""
self._buf.extend(data)
self._process_buffer()
[docs]
def connection_lost(self, exc: Exception | None) -> None:
"""Called when the connection is lost."""
self._connection._on_connection_lost(exc) # noqa: SLF001
[docs]
def send(self, msg: dict[str, Any]) -> None:
"""Send a message to the thermostat (sync — no drain needed)."""
if self._transport is None:
raise SteamloopConnectionError("Not connected")
encoded = _encode_message(msg)
_LOGGER.debug("[>] TX %d bytes: %r", len(encoded), encoded)
self._transport.write(encoded)
[docs]
def send_request(self, command: str, data: dict[str, str]) -> None:
"""Send a Request-wrapped command to the thermostat."""
self.send({"Request": {command: data}})
[docs]
def close(self) -> None:
"""Close the transport."""
if self._transport is not None:
self._transport.close()
self._transport = None
self._buf.clear()
def _process_buffer(self) -> None:
"""
Extract complete JSON messages from the buffer and dispatch them.
Messages on the wire are terminated by a null byte (0x00).
The JSON is extracted by finding the first '{' and last '}'
within each null-delimited segment. Incomplete data (no null
terminator yet) stays in the buffer for the next data_received().
"""
while b"\x00" in self._buf:
idx = self._buf.index(b"\x00")
segment = self._buf[:idx]
del self._buf[: idx + 1]
if not segment:
continue
text = segment.decode("utf-8", errors="replace")
start = text.find("{")
end = text.rfind("}")
if start >= 0 and end > start:
try:
msg = orjson.loads(text[start : end + 1])
except orjson.JSONDecodeError:
_LOGGER.warning(
"Failed to parse JSON: %s",
text[start : end + 1][:200],
)
else:
self._connection._on_message(msg) # noqa: SLF001
[docs]
class ThermostatConnection:
"""
Async connection to a thermostat over mTLS.
After calling connect() and login(), call start_background_tasks()
to begin sending heartbeats. Events are dispatched automatically
via the protocol's data_received callback.
If the connection drops, it will automatically reconnect with
exponential backoff. Call disconnect() to stop everything.
"""
def __init__(
self,
ip: str,
port: int = DEFAULT_PORT,
*,
cert_set: CertSet | None = None,
secret_key: str,
device_type: str = "automation",
device_id: str = "module",
) -> None:
self._ip = ip
self._port = port
self._cert_set = cert_set
self._secret_key = secret_key
self._device_type = device_type
self._device_id = device_id
self._protocol: ThermostatProtocol | None = None
self._transport: asyncio.Transport | None = None
self._run_task: asyncio.Task[None] | None = None
self._connection_lost_event = asyncio.Event()
self._message_queue: asyncio.Queue[dict[str, Any]] | None = None
self.state = ThermostatState()
self._event_callbacks: list[Callable[[dict[str, Any]], None]] = []
self._connected = False
@property
def connected(self) -> bool:
"""Return True if the connection is active."""
return self._connected
@property
def secret_key(self) -> str:
"""Return the secret key used for authentication."""
return self._secret_key
[docs]
def add_event_callback(
self, callback: Callable[[dict[str, Any]], None]
) -> Callable[[], None]:
"""Register an event callback. Returns a callable to unregister it."""
self._event_callbacks.append(callback)
def _remove() -> None:
with contextlib.suppress(ValueError):
self._event_callbacks.remove(callback)
return _remove
# --- Connection lifecycle ---
[docs]
async def connect(self) -> None:
"""
Establish the TLS connection to the thermostat.
If no cert_set was specified, tries each cert set in order until
one succeeds. If already connected, the existing connection is
closed first.
Raises:
SteamloopConnectionError: If the connection fails.
"""
if self._connected:
self._close_transport()
if self._cert_set is not None:
await self._connect_with_cert_set(self._cert_set)
return
last_exc: Exception | None = None
for cert_set in CERT_SETS:
try:
await self._connect_with_cert_set(cert_set)
except SteamloopConnectionError as exc:
_LOGGER.warning("Failed with %s certs: %s", cert_set.name, exc)
last_exc = exc
else:
self._cert_set = cert_set
return
raise SteamloopConnectionError(
f"Could not connect with any certificate set: {last_exc}"
)
async def _connect_with_cert_set(self, cert_set: CertSet) -> None:
"""Connect using a specific certificate set."""
loop = asyncio.get_running_loop()
ssl_ctx = await loop.run_in_executor(None, create_ssl_context, cert_set)
_LOGGER.info(
"Connecting to %s:%s using %s certificates...",
self._ip,
self._port,
cert_set.name,
)
try:
self._transport, self._protocol = await asyncio.wait_for(
loop.create_connection(
lambda: ThermostatProtocol(self),
self._ip,
self._port,
ssl=ssl_ctx,
),
timeout=CONNECT_TIMEOUT,
)
except ssl.SSLCertVerificationError as exc:
raise SteamloopConnectionError(
f"TLS cert verification failed: {exc!r}"
) from exc
except ssl.SSLError as exc:
raise SteamloopConnectionError(
f"TLS handshake failed: {exc!r} "
f"(errno={exc.errno}, "
f"reason={getattr(exc, 'reason', 'unknown')})"
) from exc
except TimeoutError as exc:
raise SteamloopConnectionError(
f"Connection timed out after {CONNECT_TIMEOUT}s"
) from exc
except OSError as exc:
raise SteamloopConnectionError(f"TCP connect failed: {exc!r}") from exc
self._connected = True
self._connection_lost_event.clear()
_LOGGER.info("TLS connected to %s:%s", self._ip, self._port)
def _close_transport(self) -> None:
"""Close the underlying transport."""
self._connected = False
if self._protocol is not None:
self._protocol.close()
self._protocol = None
self._transport = None
# --- Sending ---
[docs]
def send(self, msg: dict[str, Any]) -> None:
"""
Send a message to the thermostat.
Raises:
SteamloopConnectionError: If not connected.
"""
if self._protocol is None:
raise SteamloopConnectionError("Not connected")
self._protocol.send(msg)
[docs]
def send_request(self, command: str, data: dict[str, str]) -> None:
"""Send a Request-wrapped command to the thermostat."""
if self._protocol is None:
raise SteamloopConnectionError("Not connected")
self._protocol.send_request(command, data)
# --- Message handling ---
def _on_message(self, msg: dict[str, Any]) -> None:
"""Called by the protocol for every complete message."""
self._dispatch(msg)
if self._message_queue is not None:
self._message_queue.put_nowait(msg)
def _on_connection_lost(self, exc: Exception | None) -> None:
"""Called by the protocol when the connection drops."""
self._connected = False
if exc:
_LOGGER.warning("Connection lost: %s", exc)
else:
_LOGGER.warning("Connection closed by thermostat")
self._connection_lost_event.set()
def _dispatch(self, msg: dict[str, Any]) -> None:
"""Update internal state from a message and notify callbacks."""
if "Event" in msg:
self._process_event(msg["Event"])
for cb in self._event_callbacks:
try:
cb(msg)
except Exception:
_LOGGER.exception("Error in event callback")
def _get_zone(self, zone_id: str) -> Zone:
"""Get or create a zone by ID."""
return self.state.zones.setdefault(zone_id, Zone(zone_id=zone_id))
def _process_event(self, event: dict[str, Any]) -> None:
"""
Process a single event and update thermostat state.
Each Event dict contains exactly one key (the event type).
"""
for event_type, data in event.items():
handler = self._EVENT_HANDLERS.get(event_type)
if handler is not None:
try:
handler(self, data)
except (KeyError, ValueError, TypeError) as exc:
_LOGGER.warning("Error handling %s event: %s", event_type, exc)
def _handle_zone_added(self, data: ZoneAddedEvent) -> None:
zid = data["zone_id"]
if zid not in self.state.zones:
self.state.zones[zid] = Zone(zone_id=zid)
def _handle_zone_name_updated(self, data: ZoneNameUpdatedEvent) -> None:
self._get_zone(data["zone_id"]).name = data["zone_name"]
def _handle_indoor_temperature_updated(
self, data: IndoorTemperatureUpdatedEvent
) -> None:
self._get_zone(data["zone_id"]).indoor_temperature = data["indoor_temperature"]
def _handle_temperature_setpoint_updated(
self, data: TemperatureSetpointUpdatedEvent
) -> None:
zone = self._get_zone(data["zone_id"])
zone.heat_setpoint = data.get("heat_setpoint", zone.heat_setpoint)
zone.cool_setpoint = data.get("cool_setpoint", zone.cool_setpoint)
zone.deadband = data.get("deadband", zone.deadband)
hold_str = data.get("hold_type")
if hold_str is not None:
zone.hold_type = HoldType(int(hold_str))
def _handle_zone_mode_updated(self, data: ZoneModeUpdatedEvent) -> None:
self._get_zone(data["zone_id"]).mode = ZoneMode(int(data["zone_mode"]))
def _handle_supported_zone_modes_updated(
self, data: SupportedZoneModesUpdatedEvent
) -> None:
modes: list[ZoneMode] = []
for raw in data["modes"].split(","):
stripped = raw.strip()
if stripped:
with contextlib.suppress(ValueError):
modes.append(ZoneMode(int(stripped)))
self.state.supported_modes = modes
def _handle_fan_mode_updated(self, data: FanModeUpdatedEvent) -> None:
self.state.fan_mode = FanMode(int(data["fan_mode"]))
def _handle_emergency_heat_updated(self, data: EmergencyHeatUpdatedEvent) -> None:
self.state.emergency_heat = data["emergency_heat"]
def _handle_indoor_relative_humidity_updated(
self, data: IndoorRelativeHumidityUpdatedEvent
) -> None:
self.state.relative_humidity = data["relative_humidity"]
def _handle_cooling_status_updated(self, data: CoolingStatusUpdatedEvent) -> None:
self.state.cooling_active = data["cooling_active"]
def _handle_heating_status_updated(self, data: HeatingStatusUpdatedEvent) -> None:
self.state.heating_active = data["heating_active"]
_EVENT_HANDLERS: ClassVar[dict[str, Callable[..., None]]] = {
"ZoneAdded": _handle_zone_added,
"ZoneNameUpdated": _handle_zone_name_updated,
"IndoorTemperatureUpdated": _handle_indoor_temperature_updated,
"TemperatureSetpointUpdated": _handle_temperature_setpoint_updated,
"ZoneModeUpdated": _handle_zone_mode_updated,
"SupportedZoneModesUpdated": _handle_supported_zone_modes_updated,
"FanModeUpdated": _handle_fan_mode_updated,
"EmergencyHeatUpdated": _handle_emergency_heat_updated,
"IndoorRelativeHumidityUpdated": _handle_indoor_relative_humidity_updated,
"CoolingStatusUpdated": _handle_cooling_status_updated,
"HeatingStatusUpdated": _handle_heating_status_updated,
}
# --- Login / Pairing ---
[docs]
async def login(self) -> LoginResponse:
"""
Authenticate with the thermostat.
After receiving the login response, waits for the initial burst
of state events (zone discovery, temperatures, etc.) to arrive
before returning. This ensures ``state`` is fully populated.
Returns:
LoginResponse on success.
Raises:
AuthenticationError: If authentication fails.
SteamloopConnectionError: If the connection is lost.
"""
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
self._message_queue = queue
try:
self.send_request(
"Login",
{
"device_id": self._device_id,
"device_type": self._device_type,
"secret_key": self._secret_key,
},
)
loop = asyncio.get_running_loop()
deadline = loop.time() + RESPONSE_TIMEOUT
login_resp: LoginResponse | None = None
while (remaining := deadline - loop.time()) > 0:
try:
msg = await asyncio.wait_for(queue.get(), timeout=remaining)
except TimeoutError:
break
if "Response" not in msg:
continue
resp = msg["Response"]
if "LoginResponse" in resp:
login_resp = resp["LoginResponse"]
if login_resp.get("status") == "1":
_LOGGER.info("Authenticated successfully")
break
raise AuthenticationError(f"Authentication failed: {login_resp}")
if "Error" in resp:
err: ErrorResponse = resp["Error"]
raise AuthenticationError(
f"Error {err.get('error_type')}: {err.get('description')}"
)
if login_resp is None:
raise AuthenticationError("No login response received")
# Drain the initial state burst — the thermostat sends zone
# discovery events right after the login response. Keep
# reading until the stream goes quiet so that callers see a
# fully-populated ``state`` when login() returns.
while True:
try:
await asyncio.wait_for(queue.get(), timeout=INITIAL_STATE_TIMEOUT)
except TimeoutError:
break
return login_resp
finally:
self._message_queue = None
[docs]
async def pair(self) -> SetSecretKeyRequest:
"""
Pair with the thermostat.
The thermostat must be in pairing mode. Sends a login request with
an empty secret key and waits for the thermostat to send a
SetSecretKey request containing the new secret key.
Returns:
SetSecretKeyRequest with the new secret_key.
Raises:
PairingError: If pairing fails or times out.
SteamloopConnectionError: If the connection is lost.
"""
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
self._message_queue = queue
try:
self.send_request(
"Login",
{
"device_id": self._device_id,
"device_type": self._device_type,
"secret_key": "",
},
)
_LOGGER.info(
"Waiting for pairing response (put thermostat in pairing mode)..."
)
loop = asyncio.get_running_loop()
deadline = loop.time() + PAIRING_TIMEOUT
while (remaining := deadline - loop.time()) > 0:
try:
msg = await asyncio.wait_for(
queue.get(), timeout=min(remaining, 5.0)
)
except TimeoutError:
continue
if "Request" in msg and "SetSecretKey" in msg.get("Request", {}):
ssk: SetSecretKeyRequest = msg["Request"]["SetSecretKey"]
secret_key = ssk["secret_key"]
self._secret_key = secret_key
_LOGGER.debug("Received secret key")
self.send(
{
"Response": {
"SecretKeyUpdated": {
"secret_key": secret_key,
}
}
}
)
return ssk
if "Response" in msg:
resp = msg["Response"]
if "LoginResponse" in resp:
status = resp["LoginResponse"].get("status")
if status == "1":
# Login accepted — the thermostat typically
# sends SetSecretKey next, so keep waiting.
_LOGGER.info("Login accepted, waiting for secret key...")
continue
raise PairingError(
f"Thermostat rejected pairing (status={status})"
)
if "Error" in resp:
err = resp["Error"]
raise PairingError(
f"Pairing error {err.get('error_type')}: "
f"{err.get('description')}"
)
finally:
self._message_queue = None
raise PairingError(f"Pairing timeout — no response in {PAIRING_TIMEOUT}s")
# --- Background tasks ---
async def _heartbeat_loop(self) -> None:
"""Send periodic heartbeats to keep the connection alive."""
while self._connected:
await asyncio.sleep(HEARTBEAT_INTERVAL)
if self._connected and self._protocol is not None:
self._protocol.send({"Heartbeat": {}})
async def _run_loop(self) -> None:
"""
Main background loop: send heartbeats, auto-reconnect on failure.
Events are dispatched by the protocol's data_received callback.
This loop just manages heartbeats and reconnection.
"""
delay = RECONNECT_DELAY
try:
while True:
heartbeat = asyncio.create_task(self._heartbeat_loop())
try:
await self._connection_lost_event.wait()
finally:
self._connected = False
heartbeat.cancel()
with contextlib.suppress(asyncio.CancelledError):
await heartbeat
self._close_transport()
# Reconnect with exponential backoff
while True:
_LOGGER.info("Reconnecting in %.0fs...", delay)
await asyncio.sleep(delay)
try:
await self.connect()
self.state = ThermostatState()
await self.login()
_LOGGER.info("Reconnected successfully")
delay = RECONNECT_DELAY
break
except (SteamloopError, OSError) as exc:
_LOGGER.warning("Reconnect failed: %s", exc)
self._close_transport()
delay = min(delay * BACKOFF_FACTOR, RECONNECT_MAX)
except asyncio.CancelledError:
return
[docs]
def start_background_tasks(self) -> None:
"""Start the background loop with heartbeats and auto-reconnect."""
self._run_task = asyncio.create_task(self._run_loop())
[docs]
async def disconnect(self) -> None:
"""Disconnect from the thermostat and stop all background tasks."""
if self._run_task:
self._run_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._run_task
self._run_task = None
self._close_transport()
_LOGGER.info("Disconnected")
async def __aenter__(self) -> Self:
"""Connect, login, and start background tasks."""
await self.connect()
try:
await self.login()
except BaseException:
self._close_transport()
raise
self.start_background_tasks()
return self
async def __aexit__(self, *args: object) -> None:
"""Disconnect and stop all background tasks."""
await self.disconnect()
# --- Command helpers ---
[docs]
def set_temperature_setpoint(
self,
zone_id: str,
*,
heat_setpoint: str | None = None,
cool_setpoint: str | None = None,
hold_type: HoldType = HoldType.MANUAL,
) -> None:
"""
Set temperature setpoints for a zone.
If a setpoint isn't provided, the current state value is used.
Deadband is always taken from current state. If the resulting
setpoints would violate the deadband, the opposite setpoint is
automatically adjusted to maintain the minimum gap:
- If only heat_setpoint is provided, cool is raised if needed.
- If only cool_setpoint is provided, heat is lowered if needed.
- If both are provided, cool is raised to maintain the gap.
"""
zone = self.state.zones.get(zone_id)
db = float(zone.deadband) if zone and zone.deadband else 3.0
heat_requested = heat_setpoint is not None
cool_requested = cool_setpoint is not None
if heat_setpoint is None:
heat_setpoint = zone.heat_setpoint if zone else "55"
if cool_setpoint is None:
cool_setpoint = zone.cool_setpoint if zone else "75"
heat_f = float(heat_setpoint)
cool_f = float(cool_setpoint)
if cool_f - heat_f < db:
if cool_requested and not heat_requested:
# Only cool was requested — lower heat
heat_f = cool_f - db
heat_setpoint = str(int(heat_f))
else:
# Only heat was requested, or both — raise cool
cool_f = heat_f + db
cool_setpoint = str(int(cool_f))
self.send_request(
"UpdateTemperatureSetpoint",
{
"zone_id": zone_id,
"heat_setpoint": heat_setpoint,
"cool_setpoint": cool_setpoint,
"deadband": str(int(db)),
"hold_type": str(int(hold_type)),
},
)
[docs]
def set_fan_mode(self, mode: FanMode | int) -> None:
"""Set the fan operating mode."""
self.send_request("UpdateFanMode", {"fan_mode": str(int(mode))})
[docs]
def set_zone_mode(self, zone_id: str, mode: ZoneMode | int) -> None:
"""Set the HVAC mode for a zone."""
self.send_request(
"UpdateZoneMode",
{"zone_id": zone_id, "zone_mode": str(int(mode))},
)
[docs]
def set_emergency_heat(self, enabled: bool) -> None:
"""Enable or disable emergency heat."""
self.send_request(
"UpdateEmergencyHeat",
{"emergency_heat": "1" if enabled else "2"},
)
[docs]
def heartbeat(self) -> None:
"""Send a heartbeat to keep the connection alive."""
self.send({"Heartbeat": {}})