mirror of
https://github.com/esphome/esphome.git
synced 2025-06-15 23:06:58 +02:00
547 lines
18 KiB
Python
547 lines
18 KiB
Python
"""Common fixtures for integration tests."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections.abc import AsyncGenerator, Generator
|
|
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
import platform
|
|
import signal
|
|
import socket
|
|
import sys
|
|
import tempfile
|
|
from typing import TextIO
|
|
|
|
from aioesphomeapi import APIClient, APIConnectionError, ReconnectLogic
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
import esphome.config
|
|
from esphome.core import CORE
|
|
from esphome.platformio_api import get_idedata
|
|
|
|
from .const import (
|
|
API_CONNECTION_TIMEOUT,
|
|
DEFAULT_API_PORT,
|
|
LOCALHOST,
|
|
PORT_POLL_INTERVAL,
|
|
PORT_WAIT_TIMEOUT,
|
|
SIGINT_TIMEOUT,
|
|
SIGTERM_TIMEOUT,
|
|
)
|
|
from .types import (
|
|
APIClientConnectedFactory,
|
|
APIClientFactory,
|
|
CompileFunction,
|
|
ConfigWriter,
|
|
RunCompiledFunction,
|
|
)
|
|
|
|
# Skip all integration tests on Windows
|
|
if platform.system() == "Windows":
|
|
pytest.skip(
|
|
"Integration tests are not supported on Windows", allow_module_level=True
|
|
)
|
|
|
|
import pty # not available on Windows
|
|
|
|
|
|
@pytest.fixture(scope="module", autouse=True)
|
|
def enable_aioesphomeapi_debug_logging():
|
|
"""Enable debug logging for aioesphomeapi to help diagnose connection issues."""
|
|
# Get the aioesphomeapi logger
|
|
logger = logging.getLogger("aioesphomeapi")
|
|
# Save the original level
|
|
original_level = logger.level
|
|
# Set to DEBUG level
|
|
logger.setLevel(logging.DEBUG)
|
|
# Also ensure we have a handler that outputs to console
|
|
if not logger.handlers:
|
|
handler = logging.StreamHandler()
|
|
handler.setLevel(logging.DEBUG)
|
|
formatter = logging.Formatter(
|
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
)
|
|
handler.setFormatter(formatter)
|
|
logger.addHandler(handler)
|
|
yield
|
|
# Restore original level
|
|
logger.setLevel(original_level)
|
|
|
|
|
|
@pytest.fixture
|
|
def integration_test_dir() -> Generator[Path]:
|
|
"""Create a temporary directory for integration tests."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
yield Path(tmpdir)
|
|
|
|
|
|
@pytest.fixture
|
|
def reserved_tcp_port() -> Generator[tuple[int, socket.socket]]:
|
|
"""Reserve an unused TCP port by holding the socket open."""
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
s.bind(("", 0))
|
|
port = s.getsockname()[1]
|
|
try:
|
|
yield port, s
|
|
finally:
|
|
s.close()
|
|
|
|
|
|
@pytest.fixture
|
|
def unused_tcp_port(reserved_tcp_port: tuple[int, socket.socket]) -> int:
|
|
"""Get the reserved TCP port number."""
|
|
return reserved_tcp_port[0]
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def yaml_config(request: pytest.FixtureRequest, unused_tcp_port: int) -> str:
|
|
"""Load YAML configuration based on test name."""
|
|
# Get the test function name
|
|
test_name: str = request.node.name
|
|
# Extract the base test name (remove test_ prefix and any parametrization)
|
|
base_name = test_name.replace("test_", "").partition("[")[0]
|
|
|
|
# Load the fixture file
|
|
fixture_path = Path(__file__).parent / "fixtures" / f"{base_name}.yaml"
|
|
if not fixture_path.exists():
|
|
raise FileNotFoundError(f"Fixture file not found: {fixture_path}")
|
|
|
|
loop = asyncio.get_running_loop()
|
|
content = await loop.run_in_executor(None, fixture_path.read_text)
|
|
|
|
# Replace the port in the config if it contains api section
|
|
if "api:" in content:
|
|
# Add port configuration after api:
|
|
content = content.replace("api:", f"api:\n port: {unused_tcp_port}")
|
|
|
|
# Add debug build flags for integration tests to enable assertions
|
|
if "esphome:" in content:
|
|
# Check if platformio_options already exists
|
|
if "platformio_options:" not in content:
|
|
# Add platformio_options with debug flags after esphome:
|
|
content = content.replace(
|
|
"esphome:",
|
|
"esphome:\n"
|
|
" # Enable assertions for integration tests\n"
|
|
" platformio_options:\n"
|
|
" build_flags:\n"
|
|
' - "-DDEBUG" # Enable assert() statements\n'
|
|
' - "-g" # Add debug symbols',
|
|
)
|
|
|
|
return content
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def write_yaml_config(
|
|
integration_test_dir: Path, request: pytest.FixtureRequest
|
|
) -> AsyncGenerator[ConfigWriter]:
|
|
"""Write YAML configuration to a file."""
|
|
# Get the test name for default filename
|
|
test_name = request.node.name
|
|
base_name = test_name.replace("test_", "").split("[")[0]
|
|
|
|
async def _write_config(content: str, filename: str | None = None) -> Path:
|
|
if filename is None:
|
|
filename = f"{base_name}.yaml"
|
|
config_path = integration_test_dir / filename
|
|
loop = asyncio.get_running_loop()
|
|
await loop.run_in_executor(None, config_path.write_text, content)
|
|
return config_path
|
|
|
|
yield _write_config
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def compile_esphome(
|
|
integration_test_dir: Path,
|
|
) -> AsyncGenerator[CompileFunction]:
|
|
"""Compile an ESPHome configuration and return the binary path."""
|
|
|
|
async def _compile(config_path: Path) -> Path:
|
|
# Retry compilation up to 3 times if we get a segfault
|
|
max_retries = 3
|
|
for attempt in range(max_retries):
|
|
# Compile using subprocess, inheriting stdout/stderr to show progress
|
|
proc = await asyncio.create_subprocess_exec(
|
|
"esphome",
|
|
"compile",
|
|
str(config_path),
|
|
cwd=integration_test_dir,
|
|
stdout=None, # Inherit stdout
|
|
stderr=None, # Inherit stderr
|
|
stdin=asyncio.subprocess.DEVNULL,
|
|
# Start in a new process group to isolate signal handling
|
|
start_new_session=True,
|
|
)
|
|
await proc.wait()
|
|
|
|
if proc.returncode == 0:
|
|
# Success!
|
|
break
|
|
elif proc.returncode == -11 and attempt < max_retries - 1:
|
|
# Segfault (-11 = SIGSEGV), retry
|
|
print(
|
|
f"Compilation segfaulted (attempt {attempt + 1}/{max_retries}), retrying..."
|
|
)
|
|
await asyncio.sleep(1) # Brief pause before retry
|
|
continue
|
|
else:
|
|
# Other error or final retry
|
|
raise RuntimeError(
|
|
f"Failed to compile {config_path}, return code: {proc.returncode}. "
|
|
f"Run with 'pytest -s' to see compilation output."
|
|
)
|
|
|
|
# Load the config to get idedata (blocking call, must use executor)
|
|
loop = asyncio.get_running_loop()
|
|
|
|
def _read_config_and_get_binary():
|
|
CORE.config_path = str(config_path)
|
|
config = esphome.config.read_config(
|
|
{"command": "compile", "config": str(config_path)}
|
|
)
|
|
if config is None:
|
|
raise RuntimeError(f"Failed to read config from {config_path}")
|
|
|
|
# Get the compiled binary path
|
|
idedata = get_idedata(config)
|
|
return Path(idedata.firmware_elf_path)
|
|
|
|
binary_path = await loop.run_in_executor(None, _read_config_and_get_binary)
|
|
|
|
if not binary_path.exists():
|
|
raise RuntimeError(f"Compiled binary not found at {binary_path}")
|
|
|
|
return binary_path
|
|
|
|
yield _compile
|
|
|
|
|
|
@asynccontextmanager
|
|
async def create_api_client(
|
|
address: str = LOCALHOST,
|
|
port: int = DEFAULT_API_PORT,
|
|
password: str = "",
|
|
noise_psk: str | None = None,
|
|
client_info: str = "integration-test",
|
|
) -> AsyncGenerator[APIClient]:
|
|
"""Create an API client context manager."""
|
|
client = APIClient(
|
|
address=address,
|
|
port=port,
|
|
password=password,
|
|
noise_psk=noise_psk,
|
|
client_info=client_info,
|
|
)
|
|
try:
|
|
yield client
|
|
finally:
|
|
await client.disconnect()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def api_client_factory(
|
|
unused_tcp_port: int,
|
|
) -> AsyncGenerator[APIClientFactory]:
|
|
"""Factory for creating API client context managers."""
|
|
|
|
def _create_client(
|
|
address: str = LOCALHOST,
|
|
port: int | None = None,
|
|
password: str = "",
|
|
noise_psk: str | None = None,
|
|
client_info: str = "integration-test",
|
|
) -> AbstractAsyncContextManager[APIClient]:
|
|
return create_api_client(
|
|
address=address,
|
|
port=port if port is not None else unused_tcp_port,
|
|
password=password,
|
|
noise_psk=noise_psk,
|
|
client_info=client_info,
|
|
)
|
|
|
|
yield _create_client
|
|
|
|
|
|
@asynccontextmanager
|
|
async def wait_and_connect_api_client(
|
|
address: str = LOCALHOST,
|
|
port: int = DEFAULT_API_PORT,
|
|
password: str = "",
|
|
noise_psk: str | None = None,
|
|
client_info: str = "integration-test",
|
|
timeout: float = API_CONNECTION_TIMEOUT,
|
|
) -> AsyncGenerator[APIClient]:
|
|
"""Wait for API to be available and connect."""
|
|
client = APIClient(
|
|
address=address,
|
|
port=port,
|
|
password=password,
|
|
noise_psk=noise_psk,
|
|
client_info=client_info,
|
|
)
|
|
|
|
# Create a future to signal when connected
|
|
loop = asyncio.get_running_loop()
|
|
connected_future: asyncio.Future[None] = loop.create_future()
|
|
|
|
async def on_connect() -> None:
|
|
"""Called when successfully connected."""
|
|
if not connected_future.done():
|
|
connected_future.set_result(None)
|
|
|
|
async def on_disconnect(expected_disconnect: bool) -> None:
|
|
"""Called when disconnected."""
|
|
if not connected_future.done() and not expected_disconnect:
|
|
connected_future.set_exception(
|
|
APIConnectionError("Disconnected before fully connected")
|
|
)
|
|
|
|
async def on_connect_error(err: Exception) -> None:
|
|
"""Called when connection fails."""
|
|
if not connected_future.done():
|
|
connected_future.set_exception(err)
|
|
|
|
# Create and start the reconnect logic
|
|
reconnect_logic = ReconnectLogic(
|
|
client=client,
|
|
on_connect=on_connect,
|
|
on_disconnect=on_disconnect,
|
|
zeroconf_instance=None, # Not using zeroconf for integration tests
|
|
name=f"{address}:{port}",
|
|
on_connect_error=on_connect_error,
|
|
)
|
|
|
|
try:
|
|
# Start the connection
|
|
await reconnect_logic.start()
|
|
|
|
# Wait for connection with timeout
|
|
try:
|
|
await asyncio.wait_for(connected_future, timeout=timeout)
|
|
except asyncio.TimeoutError:
|
|
raise TimeoutError(f"Failed to connect to API after {timeout} seconds")
|
|
|
|
yield client
|
|
finally:
|
|
# Stop reconnect logic and disconnect
|
|
await reconnect_logic.stop()
|
|
await client.disconnect()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def api_client_connected(
|
|
unused_tcp_port: int,
|
|
) -> AsyncGenerator[APIClientConnectedFactory]:
|
|
"""Factory for creating connected API client context managers."""
|
|
|
|
def _connect_client(
|
|
address: str = LOCALHOST,
|
|
port: int | None = None,
|
|
password: str = "",
|
|
noise_psk: str | None = None,
|
|
client_info: str = "integration-test",
|
|
timeout: float = API_CONNECTION_TIMEOUT,
|
|
) -> AbstractAsyncContextManager[APIClient]:
|
|
return wait_and_connect_api_client(
|
|
address=address,
|
|
port=port if port is not None else unused_tcp_port,
|
|
password=password,
|
|
noise_psk=noise_psk,
|
|
client_info=client_info,
|
|
timeout=timeout,
|
|
)
|
|
|
|
yield _connect_client
|
|
|
|
|
|
async def _read_stream_lines(
|
|
stream: asyncio.StreamReader, lines: list[str], output_stream: TextIO
|
|
) -> None:
|
|
"""Read lines from a stream, append to list, and echo to output stream."""
|
|
while line := await stream.readline():
|
|
decoded_line = line.decode("utf-8", errors="replace")
|
|
lines.append(decoded_line.rstrip())
|
|
# Echo to stdout/stderr in real-time
|
|
print(decoded_line.rstrip(), file=output_stream, flush=True)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def run_binary_and_wait_for_port(
|
|
binary_path: Path,
|
|
host: str,
|
|
port: int,
|
|
timeout: float = PORT_WAIT_TIMEOUT,
|
|
) -> AsyncGenerator[None]:
|
|
"""Run a binary, wait for it to open a port, and clean up on exit."""
|
|
# Create a pseudo-terminal to make the binary think it's running interactively
|
|
# This is needed because the ESPHome host logger checks isatty()
|
|
controller_fd, device_fd = pty.openpty()
|
|
|
|
# Run the compiled binary with PTY
|
|
process = await asyncio.create_subprocess_exec(
|
|
str(binary_path),
|
|
stdout=device_fd,
|
|
stderr=device_fd,
|
|
stdin=asyncio.subprocess.DEVNULL,
|
|
# Start in a new process group to isolate signal handling
|
|
start_new_session=True,
|
|
pass_fds=(device_fd,),
|
|
)
|
|
|
|
# Close the device end in the parent process
|
|
os.close(device_fd)
|
|
|
|
# Convert controller_fd to async streams for reading
|
|
loop = asyncio.get_running_loop()
|
|
controller_reader = asyncio.StreamReader()
|
|
controller_protocol = asyncio.StreamReaderProtocol(controller_reader)
|
|
controller_transport, _ = await loop.connect_read_pipe(
|
|
lambda: controller_protocol, os.fdopen(controller_fd, "rb", 0)
|
|
)
|
|
output_reader = controller_reader
|
|
|
|
if process.returncode is not None:
|
|
raise RuntimeError(
|
|
f"Process died immediately with return code {process.returncode}. "
|
|
"Ensure the binary is valid and can run successfully."
|
|
)
|
|
|
|
# Wait for the API server to start listening
|
|
loop = asyncio.get_running_loop()
|
|
start_time = loop.time()
|
|
|
|
# Start collecting output
|
|
stdout_lines: list[str] = []
|
|
output_tasks: list[asyncio.Task] = []
|
|
|
|
try:
|
|
# Read from output stream
|
|
output_tasks = [
|
|
asyncio.create_task(
|
|
_read_stream_lines(output_reader, stdout_lines, sys.stdout)
|
|
)
|
|
]
|
|
|
|
# Small yield to ensure the process has a chance to start
|
|
await asyncio.sleep(0)
|
|
|
|
while loop.time() - start_time < timeout:
|
|
try:
|
|
# Try to connect to the port
|
|
_, writer = await asyncio.open_connection(host, port)
|
|
writer.close()
|
|
await writer.wait_closed()
|
|
# Port is open, yield control
|
|
yield
|
|
return
|
|
except (ConnectionRefusedError, OSError):
|
|
# Check if process died
|
|
if process.returncode is not None:
|
|
break
|
|
# Port not open yet, wait a bit and try again
|
|
await asyncio.sleep(PORT_POLL_INTERVAL)
|
|
|
|
# Timeout or process died - build error message
|
|
error_msg = f"Port {port} on {host} did not open within {timeout} seconds"
|
|
|
|
if process.returncode is not None:
|
|
error_msg += f"\nProcess exited with code: {process.returncode}"
|
|
|
|
# Include any output collected so far
|
|
if stdout_lines:
|
|
error_msg += "\n\n--- Process Output ---\n"
|
|
error_msg += "\n".join(stdout_lines[-100:]) # Last 100 lines
|
|
|
|
raise TimeoutError(error_msg)
|
|
|
|
finally:
|
|
# Cancel output collection tasks
|
|
for task in output_tasks:
|
|
task.cancel()
|
|
# Wait for tasks to complete and check for exceptions
|
|
results = await asyncio.gather(*output_tasks, return_exceptions=True)
|
|
for i, result in enumerate(results):
|
|
if isinstance(result, Exception) and not isinstance(
|
|
result, asyncio.CancelledError
|
|
):
|
|
print(
|
|
f"Error reading from PTY: {result}",
|
|
file=sys.stderr,
|
|
)
|
|
|
|
# Close the PTY transport (Unix only)
|
|
if controller_transport is not None:
|
|
controller_transport.close()
|
|
|
|
# Cleanup: terminate the process gracefully
|
|
if process.returncode is None:
|
|
# Send SIGINT (Ctrl+C) for graceful shutdown
|
|
process.send_signal(signal.SIGINT)
|
|
try:
|
|
await asyncio.wait_for(process.wait(), timeout=SIGINT_TIMEOUT)
|
|
except asyncio.TimeoutError:
|
|
# If SIGINT didn't work, try SIGTERM
|
|
process.terminate()
|
|
try:
|
|
await asyncio.wait_for(process.wait(), timeout=SIGTERM_TIMEOUT)
|
|
except asyncio.TimeoutError:
|
|
# Last resort: SIGKILL
|
|
process.kill()
|
|
await process.wait()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def run_compiled_context(
|
|
yaml_content: str,
|
|
filename: str | None,
|
|
write_yaml_config: ConfigWriter,
|
|
compile_esphome: CompileFunction,
|
|
port: int,
|
|
port_socket: socket.socket | None = None,
|
|
) -> AsyncGenerator[None]:
|
|
"""Context manager to write, compile and run an ESPHome configuration."""
|
|
# Write the YAML config
|
|
config_path = await write_yaml_config(yaml_content, filename)
|
|
|
|
# Compile the configuration and get binary path
|
|
binary_path = await compile_esphome(config_path)
|
|
|
|
# Close the port socket right before running to release the port
|
|
if port_socket is not None:
|
|
port_socket.close()
|
|
|
|
# Run the binary and wait for the API server to start
|
|
async with run_binary_and_wait_for_port(binary_path, LOCALHOST, port):
|
|
yield
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def run_compiled(
|
|
write_yaml_config: ConfigWriter,
|
|
compile_esphome: CompileFunction,
|
|
reserved_tcp_port: tuple[int, socket.socket],
|
|
) -> AsyncGenerator[RunCompiledFunction]:
|
|
"""Write, compile and run an ESPHome configuration."""
|
|
port, port_socket = reserved_tcp_port
|
|
|
|
def _run_compiled(
|
|
yaml_content: str, filename: str | None = None
|
|
) -> AbstractAsyncContextManager[asyncio.subprocess.Process]:
|
|
return run_compiled_context(
|
|
yaml_content,
|
|
filename,
|
|
write_yaml_config,
|
|
compile_esphome,
|
|
port,
|
|
port_socket,
|
|
)
|
|
|
|
yield _run_compiled
|