Add configurator to strict typing (#87279)

This commit is contained in:
epenet 2023-02-03 16:02:55 +01:00 committed by GitHub
parent 9306e0371e
commit 923abdb02a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 48 additions and 24 deletions

View File

@ -84,6 +84,7 @@ homeassistant.components.camera.*
homeassistant.components.canary.* homeassistant.components.canary.*
homeassistant.components.clickatell.* homeassistant.components.clickatell.*
homeassistant.components.clicksend.* homeassistant.components.clicksend.*
homeassistant.components.configurator.*
homeassistant.components.cover.* homeassistant.components.cover.*
homeassistant.components.cpuspeed.* homeassistant.components.cpuspeed.*
homeassistant.components.crownstone.* homeassistant.components.crownstone.*

View File

@ -12,7 +12,7 @@ from collections.abc import Callable
from contextlib import suppress from contextlib import suppress
from datetime import datetime from datetime import datetime
import functools as ft import functools as ft
from typing import Any, cast from typing import Any
from homeassistant.const import ATTR_ENTITY_PICTURE, ATTR_FRIENDLY_NAME from homeassistant.const import ATTR_ENTITY_PICTURE, ATTR_FRIENDLY_NAME
from homeassistant.core import HomeAssistant, ServiceCall, callback as async_callback from homeassistant.core import HomeAssistant, ServiceCall, callback as async_callback
@ -80,7 +80,7 @@ def async_request_config(
if DATA_REQUESTS not in hass.data: if DATA_REQUESTS not in hass.data:
hass.data[DATA_REQUESTS] = {} hass.data[DATA_REQUESTS] = {}
hass.data[DATA_REQUESTS][request_id] = instance _get_requests(hass)[request_id] = instance
return request_id return request_id
@ -98,10 +98,10 @@ def request_config(hass: HomeAssistant, *args: Any, **kwargs: Any) -> str:
@bind_hass @bind_hass
@async_callback @async_callback
def async_notify_errors(hass, request_id, error): def async_notify_errors(hass: HomeAssistant, request_id: str, error: str) -> None:
"""Add errors to a config request.""" """Add errors to a config request."""
with suppress(KeyError): # If request_id does not exist with suppress(KeyError): # If request_id does not exist
hass.data[DATA_REQUESTS][request_id].async_notify_errors(request_id, error) _get_requests(hass)[request_id].async_notify_errors(request_id, error)
@bind_hass @bind_hass
@ -117,7 +117,7 @@ def notify_errors(hass: HomeAssistant, request_id: str, error: str) -> None:
def async_request_done(hass: HomeAssistant, request_id: str) -> None: def async_request_done(hass: HomeAssistant, request_id: str) -> None:
"""Mark a configuration request as done.""" """Mark a configuration request as done."""
with suppress(KeyError): # If request_id does not exist with suppress(KeyError): # If request_id does not exist
hass.data[DATA_REQUESTS].pop(request_id).async_request_done(request_id) _get_requests(hass).pop(request_id).async_request_done(request_id)
@bind_hass @bind_hass
@ -133,10 +133,15 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
def _get_requests(hass: HomeAssistant) -> dict[str, Configurator]:
"""Return typed configurator_requests data."""
return hass.data[DATA_REQUESTS] # type: ignore[no-any-return]
class Configurator: class Configurator:
"""The class to keep track of current configuration requests.""" """The class to keep track of current configuration requests."""
def __init__(self, hass): def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the configurator.""" """Initialize the configurator."""
self.hass = hass self.hass = hass
self._cur_id = 0 self._cur_id = 0
@ -190,14 +195,15 @@ class Configurator:
return request_id return request_id
@async_callback @async_callback
def async_notify_errors(self, request_id, error): def async_notify_errors(self, request_id: str, error: str) -> None:
"""Update the state with errors.""" """Update the state with errors."""
if not self._validate_request_id(request_id): if not self._validate_request_id(request_id):
return return
entity_id = self._requests[request_id][0] entity_id = self._requests[request_id][0]
state = self.hass.states.get(entity_id) if (state := self.hass.states.get(entity_id)) is None:
return
new_data = dict(state.attributes) new_data = dict(state.attributes)
new_data[ATTR_ERRORS] = error new_data[ATTR_ERRORS] = error
@ -205,7 +211,7 @@ class Configurator:
self.hass.states.async_set(entity_id, STATE_CONFIGURE, new_data) self.hass.states.async_set(entity_id, STATE_CONFIGURE, new_data)
@async_callback @async_callback
def async_request_done(self, request_id): def async_request_done(self, request_id: str) -> None:
"""Remove the configuration request.""" """Remove the configuration request."""
if not self._validate_request_id(request_id): if not self._validate_request_id(request_id):
return return
@ -219,7 +225,7 @@ class Configurator:
self.hass.states.async_set(entity_id, STATE_CONFIGURED) self.hass.states.async_set(entity_id, STATE_CONFIGURED)
@async_callback @async_callback
def deferred_remove(now: datetime): def deferred_remove(now: datetime) -> None:
"""Remove the request state.""" """Remove the request state."""
self.hass.states.async_remove(entity_id) self.hass.states.async_remove(entity_id)
@ -227,22 +233,24 @@ class Configurator:
async def async_handle_service_call(self, call: ServiceCall) -> None: async def async_handle_service_call(self, call: ServiceCall) -> None:
"""Handle a configure service call.""" """Handle a configure service call."""
request_id = call.data.get(ATTR_CONFIGURE_ID) request_id: str | None = call.data.get(ATTR_CONFIGURE_ID)
if not self._validate_request_id(request_id): if not request_id or not self._validate_request_id(request_id):
return return
_, _, callback = self._requests[cast(str, request_id)] _, _, callback = self._requests[request_id]
# field validation goes here? # field validation goes here?
if callback: if callback and (
await self.hass.async_add_job(callback, call.data.get(ATTR_FIELDS, {})) job := self.hass.async_add_job(callback, call.data.get(ATTR_FIELDS, {}))
):
await job
def _generate_unique_id(self): def _generate_unique_id(self) -> str:
"""Generate a unique configurator ID.""" """Generate a unique configurator ID."""
self._cur_id += 1 self._cur_id += 1
return f"{id(self)}-{self._cur_id}" return f"{id(self)}-{self._cur_id}"
def _validate_request_id(self, request_id): def _validate_request_id(self, request_id: str) -> bool:
"""Validate that the request belongs to this instance.""" """Validate that the request belongs to this instance."""
return request_id in self._requests return request_id in self._requests

View File

@ -593,6 +593,16 @@ disallow_untyped_defs = true
warn_return_any = true warn_return_any = true
warn_unreachable = true warn_unreachable = true
[mypy-homeassistant.components.configurator.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.cover.*] [mypy-homeassistant.components.cover.*]
check_untyped_defs = true check_untyped_defs = true
disallow_incomplete_defs = true disallow_incomplete_defs = true

View File

@ -4,12 +4,13 @@ from datetime import timedelta
import homeassistant.components.configurator as configurator import homeassistant.components.configurator as configurator
from homeassistant.const import ATTR_FRIENDLY_NAME from homeassistant.const import ATTR_FRIENDLY_NAME
from homeassistant.core import HomeAssistant
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from tests.common import async_fire_time_changed from tests.common import async_fire_time_changed
async def test_request_least_info(hass): async def test_request_least_info(hass: HomeAssistant) -> None:
"""Test request config with least amount of data.""" """Test request config with least amount of data."""
request_id = configurator.async_request_config(hass, "Test Request", lambda _: None) request_id = configurator.async_request_config(hass, "Test Request", lambda _: None)
@ -27,7 +28,7 @@ async def test_request_least_info(hass):
assert state.attributes.get(configurator.ATTR_CONFIGURE_ID) == request_id assert state.attributes.get(configurator.ATTR_CONFIGURE_ID) == request_id
async def test_request_all_info(hass): async def test_request_all_info(hass: HomeAssistant) -> None:
"""Test request config with all possible info.""" """Test request config with all possible info."""
exp_attr = { exp_attr = {
ATTR_FRIENDLY_NAME: "Test Request", ATTR_FRIENDLY_NAME: "Test Request",
@ -61,7 +62,7 @@ async def test_request_all_info(hass):
assert state.attributes == exp_attr assert state.attributes == exp_attr
async def test_callback_called_on_configure(hass): async def test_callback_called_on_configure(hass: HomeAssistant) -> None:
"""Test if our callback gets called when configure service called.""" """Test if our callback gets called when configure service called."""
calls = [] calls = []
request_id = configurator.async_request_config( request_id = configurator.async_request_config(
@ -78,7 +79,7 @@ async def test_callback_called_on_configure(hass):
assert len(calls) == 1, "Callback not called" assert len(calls) == 1, "Callback not called"
async def test_state_change_on_notify_errors(hass): async def test_state_change_on_notify_errors(hass: HomeAssistant) -> None:
"""Test state change on notify errors.""" """Test state change on notify errors."""
request_id = configurator.async_request_config(hass, "Test Request", lambda _: None) request_id = configurator.async_request_config(hass, "Test Request", lambda _: None)
error = "Oh no bad bad bad" error = "Oh no bad bad bad"
@ -90,12 +91,14 @@ async def test_state_change_on_notify_errors(hass):
assert state.attributes.get(configurator.ATTR_ERRORS) == error assert state.attributes.get(configurator.ATTR_ERRORS) == error
async def test_notify_errors_fail_silently_on_bad_request_id(hass): async def test_notify_errors_fail_silently_on_bad_request_id(
hass: HomeAssistant,
) -> None:
"""Test if notify errors fails silently with a bad request id.""" """Test if notify errors fails silently with a bad request id."""
configurator.async_notify_errors(hass, 2015, "Try this error") configurator.async_notify_errors(hass, 2015, "Try this error")
async def test_request_done_works(hass): async def test_request_done_works(hass: HomeAssistant) -> None:
"""Test if calling request done works.""" """Test if calling request done works."""
request_id = configurator.async_request_config(hass, "Test Request", lambda _: None) request_id = configurator.async_request_config(hass, "Test Request", lambda _: None)
configurator.async_request_done(hass, request_id) configurator.async_request_done(hass, request_id)
@ -105,6 +108,8 @@ async def test_request_done_works(hass):
assert len(hass.states.async_all()) == 0 assert len(hass.states.async_all()) == 0
async def test_request_done_fail_silently_on_bad_request_id(hass): async def test_request_done_fail_silently_on_bad_request_id(
hass: HomeAssistant,
) -> None:
"""Test that request_done fails silently with a bad request id.""" """Test that request_done fails silently with a bad request id."""
configurator.async_request_done(hass, 2016) configurator.async_request_done(hass, 2016)