home-assistant-core/homeassistant/components/websocket_api/http.py

262 lines
8.6 KiB
Python

"""View to accept incoming websocket connection."""
from __future__ import annotations
import asyncio
from collections.abc import Callable
from contextlib import suppress
import datetime as dt
import logging
from typing import Any, Final
from aiohttp import WSMsgType, web
import async_timeout
from homeassistant.components.http import HomeAssistantView
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.helpers.event import async_call_later
from .auth import AuthPhase, auth_required_message
from .const import (
CANCELLATION_ERRORS,
DATA_CONNECTIONS,
MAX_PENDING_MSG,
PENDING_MSG_PEAK,
PENDING_MSG_PEAK_TIME,
SIGNAL_WEBSOCKET_CONNECTED,
SIGNAL_WEBSOCKET_DISCONNECTED,
URL,
)
from .error import Disconnect
from .messages import message_to_json
_WS_LOGGER: Final = logging.getLogger(f"{__name__}.connection")
class WebsocketAPIView(HomeAssistantView):
"""View to serve a websockets endpoint."""
name: str = "websocketapi"
url: str = URL
requires_auth: bool = False
async def get(self, request: web.Request) -> web.WebSocketResponse:
"""Handle an incoming websocket connection."""
return await WebSocketHandler(request.app["hass"], request).async_handle()
class WebSocketAdapter(logging.LoggerAdapter):
"""Add connection id to websocket messages."""
def process(self, msg: str, kwargs: Any) -> tuple[str, Any]:
"""Add connid to websocket log messages."""
return f'[{self.extra["connid"]}] {msg}', kwargs
class WebSocketHandler:
"""Handle an active websocket client connection."""
def __init__(self, hass: HomeAssistant, request: web.Request) -> None:
"""Initialize an active connection."""
self.hass = hass
self.request = request
self.wsock = web.WebSocketResponse(heartbeat=55)
self._to_write: asyncio.Queue = asyncio.Queue(maxsize=MAX_PENDING_MSG)
self._handle_task: asyncio.Task | None = None
self._writer_task: asyncio.Task | None = None
self._logger = WebSocketAdapter(_WS_LOGGER, {"connid": id(self)})
self._peak_checker_unsub: Callable[[], None] | None = None
async def _writer(self) -> None:
"""Write outgoing messages."""
# Exceptions if Socket disconnected or cancelled by connection handler
with suppress(RuntimeError, ConnectionResetError, *CANCELLATION_ERRORS):
while not self.wsock.closed:
if (message := await self._to_write.get()) is None:
break
self._logger.debug("Sending %s", message)
await self.wsock.send_str(message)
# Clean up the peaker checker when we shut down the writer
if self._peak_checker_unsub is not None:
self._peak_checker_unsub()
self._peak_checker_unsub = None
@callback
def _send_message(self, message: str | dict[str, Any]) -> None:
"""Send a message to the client.
Closes connection if the client is not reading the messages.
Async friendly.
"""
if not isinstance(message, str):
message = message_to_json(message)
try:
self._to_write.put_nowait(message)
except asyncio.QueueFull:
self._logger.error(
"Client exceeded max pending messages [2]: %s", MAX_PENDING_MSG
)
self._cancel()
if self._to_write.qsize() < PENDING_MSG_PEAK:
if self._peak_checker_unsub:
self._peak_checker_unsub()
self._peak_checker_unsub = None
return
if self._peak_checker_unsub is None:
self._peak_checker_unsub = async_call_later(
self.hass, PENDING_MSG_PEAK_TIME, self._check_write_peak
)
@callback
def _check_write_peak(self, _utc_time: dt.datetime) -> None:
"""Check that we are no longer above the write peak."""
self._peak_checker_unsub = None
if self._to_write.qsize() < PENDING_MSG_PEAK:
return
self._logger.error(
"Client unable to keep up with pending messages. Stayed over %s for %s seconds",
PENDING_MSG_PEAK,
PENDING_MSG_PEAK_TIME,
)
self._cancel()
@callback
def _cancel(self) -> None:
"""Cancel the connection."""
if self._handle_task is not None:
self._handle_task.cancel()
if self._writer_task is not None:
self._writer_task.cancel()
async def async_handle(self) -> web.WebSocketResponse:
"""Handle a websocket response."""
request = self.request
wsock = self.wsock
try:
async with async_timeout.timeout(10):
await wsock.prepare(request)
except asyncio.TimeoutError:
self._logger.warning("Timeout preparing request from %s", request.remote)
return wsock
self._logger.debug("Connected from %s", request.remote)
self._handle_task = asyncio.current_task()
@callback
def handle_hass_stop(event: Event) -> None:
"""Cancel this connection."""
self._cancel()
unsub_stop = self.hass.bus.async_listen(
EVENT_HOMEASSISTANT_STOP, handle_hass_stop
)
# As the webserver is now started before the start
# event we do not want to block for websocket responses
self._writer_task = asyncio.create_task(self._writer())
auth = AuthPhase(
self._logger, self.hass, self._send_message, self._cancel, request
)
connection = None
disconnect_warn = None
try:
self._send_message(auth_required_message())
# Auth Phase
try:
async with async_timeout.timeout(10):
msg = await wsock.receive()
except asyncio.TimeoutError as err:
disconnect_warn = "Did not receive auth message within 10 seconds"
raise Disconnect from err
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING):
raise Disconnect
if msg.type != WSMsgType.TEXT:
disconnect_warn = "Received non-Text message."
raise Disconnect
try:
msg_data = msg.json()
except ValueError as err:
disconnect_warn = "Received invalid JSON."
raise Disconnect from err
self._logger.debug("Received %s", msg_data)
connection = await auth.async_handle(msg_data)
self.hass.data[DATA_CONNECTIONS] = (
self.hass.data.get(DATA_CONNECTIONS, 0) + 1
)
self.hass.helpers.dispatcher.async_dispatcher_send(
SIGNAL_WEBSOCKET_CONNECTED
)
# Command phase
while not wsock.closed:
msg = await wsock.receive()
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING):
break
if msg.type != WSMsgType.TEXT:
disconnect_warn = "Received non-Text message."
break
try:
msg_data = msg.json()
except ValueError:
disconnect_warn = "Received invalid JSON."
break
self._logger.debug("Received %s", msg_data)
connection.async_handle(msg_data)
except asyncio.CancelledError:
self._logger.info("Connection closed by client")
except Disconnect:
pass
except Exception: # pylint: disable=broad-except
self._logger.exception("Unexpected error inside websocket API")
finally:
unsub_stop()
if connection is not None:
connection.async_handle_close()
try:
self._to_write.put_nowait(None)
# Make sure all error messages are written before closing
await self._writer_task
await wsock.close()
except asyncio.QueueFull: # can be raised by put_nowait
self._writer_task.cancel()
finally:
if disconnect_warn is None:
self._logger.debug("Disconnected")
else:
self._logger.warning("Disconnected: %s", disconnect_warn)
if connection is not None:
self.hass.data[DATA_CONNECTIONS] -= 1
self.hass.helpers.dispatcher.async_dispatcher_send(
SIGNAL_WEBSOCKET_DISCONNECTED
)
return wsock