Compare commits

...

7 Commits

Author SHA1 Message Date
Otto winter
88632b22e2 Complete 2021-09-06 12:44:53 +02:00
Otto winter
44041d2526 Updates 2021-08-23 20:23:39 +02:00
Otto winter
7cfc36cb70 Add noise API transport support 2021-08-16 20:18:01 +02:00
Otto winter
08dd72e716 Echo component and noise 2021-08-12 13:45:51 +02:00
Otto winter
7b7e5f7db5 Fixes 2021-08-10 19:40:38 +02:00
Otto winter
c9b170eab9 Merge remote-tracking branch 'origin/dev' into socket-refactor-and-ssl 2021-08-10 19:06:46 +02:00
Otto winter
40dd9c5dce Socket refactor and SSL 2021-08-09 20:54:50 +02:00
30 changed files with 3281 additions and 246 deletions

View File

@@ -192,6 +192,11 @@ class APIClient(threading.Thread):
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.settimeout(10.0)
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
import ssl
context = ssl.SSLContext()
self._socket = context.wrap_socket(self._socket)
try:
self._socket.connect((ip, self._port))
except OSError as err:

View File

@@ -1,3 +1,5 @@
import base64
import esphome.codegen as cg
import esphome.config_validation as cv
from esphome import automation
@@ -6,6 +8,7 @@ from esphome.const import (
CONF_DATA,
CONF_DATA_TEMPLATE,
CONF_ID,
CONF_KEY,
CONF_PASSWORD,
CONF_PORT,
CONF_REBOOT_TIMEOUT,
@@ -19,7 +22,7 @@ from esphome.const import (
from esphome.core import coroutine_with_priority
DEPENDENCIES = ["network"]
AUTO_LOAD = ["async_tcp"]
AUTO_LOAD = ["socket"]
CODEOWNERS = ["@OttoWinter"]
api_ns = cg.esphome_ns.namespace("api")
@@ -41,6 +44,22 @@ SERVICE_ARG_NATIVE_TYPES = {
"float[]": cg.std_vector.template(float),
"string[]": cg.std_vector.template(cg.std_string),
}
CONF_ENCRYPTION = "encryption"
def validate_encryption_key(value):
value = cv.string_strict(value)
try:
decoded = base64.b64decode(value, validate=True)
except ValueError as err:
raise cv.Invalid("Invalid key format, please check it's using base64") from err
if len(decoded) != 32:
raise cv.Invalid("Encryption key must be base64 and 32 bytes long")
# Return original data for roundtrip conversion
return value
CONFIG_SCHEMA = cv.Schema(
{
@@ -63,6 +82,11 @@ CONFIG_SCHEMA = cv.Schema(
),
}
),
cv.Optional(CONF_ENCRYPTION): cv.Schema(
{
cv.Required(CONF_KEY): validate_encryption_key,
}
),
}
).extend(cv.COMPONENT_SCHEMA)
@@ -92,6 +116,14 @@ async def to_code(config):
cg.add(var.register_user_service(trigger))
await automation.build_automation(trigger, func_args, conf)
if CONF_ENCRYPTION in config:
conf = config[CONF_ENCRYPTION]
decoded = base64.b64decode(conf[CONF_KEY])
cg.add(var.set_noise_psk(list(decoded)))
cg.add_define("USE_API_NOISE")
else:
cg.add_define("USE_API_PLAINTEXT")
cg.add_define("USE_API")
cg.add_global(api_ns.using)

View File

@@ -2,6 +2,7 @@
#include "esphome/core/log.h"
#include "esphome/core/util.h"
#include "esphome/core/version.h"
#include <errno.h>
#ifdef USE_DEEP_SLEEP
#include "esphome/components/deep_sleep/deep_sleep_component.h"
@@ -18,74 +19,36 @@ namespace api {
static const char *const TAG = "api.connection";
APIConnection::APIConnection(AsyncClient *client, APIServer *parent)
: client_(client), parent_(parent), initial_state_iterator_(parent, this), list_entities_iterator_(parent, this) {
this->client_->onError([](void *s, AsyncClient *c, int8_t error) { ((APIConnection *) s)->on_error_(error); }, this);
this->client_->onDisconnect([](void *s, AsyncClient *c) { ((APIConnection *) s)->on_disconnect_(); }, this);
this->client_->onTimeout([](void *s, AsyncClient *c, uint32_t time) { ((APIConnection *) s)->on_timeout_(time); },
this);
this->client_->onData([](void *s, AsyncClient *c, void *buf,
size_t len) { ((APIConnection *) s)->on_data_(reinterpret_cast<uint8_t *>(buf), len); },
this);
APIConnection::APIConnection(std::unique_ptr<socket::Socket> sock, APIServer *parent)
: parent_(parent),
initial_state_iterator_(parent, this),
list_entities_iterator_(parent, this) {
this->proto_write_buffer_.reserve(64);
this->send_buffer_.reserve(64);
this->recv_buffer_.reserve(32);
this->client_info_ = this->client_->remoteIP().toString().c_str();
#ifdef USE_API_NOISE
helper_ = std::unique_ptr<APIFrameHelper>{new APINoiseFrameHelper(std::move(sock), parent->get_noise_ctx())};
#elif defined(USE_API_PLAINTEXT)
helper_ = std::unique_ptr<APIFrameHelper>{new APIPlaintextFrameHelper(std::move(sock))};
#else
#error "No api frame helper enabled"
#endif
}
void APIConnection::start() {
this->last_traffic_ = millis();
}
APIConnection::~APIConnection() { delete this->client_; }
void APIConnection::on_error_(int8_t error) { this->remove_ = true; }
void APIConnection::on_disconnect_() { this->remove_ = true; }
void APIConnection::on_timeout_(uint32_t time) { this->on_fatal_error(); }
void APIConnection::on_data_(uint8_t *buf, size_t len) {
if (len == 0 || buf == nullptr)
APIError err = helper_->init();
if (err != APIError::OK) {
ESP_LOGW(TAG, "Helper init failed: %d errno=%d", (int) err, errno);
remove_ = true;
return;
this->recv_buffer_.insert(this->recv_buffer_.end(), buf, buf + len);
}
void APIConnection::parse_recv_buffer_() {
if (this->recv_buffer_.empty() || this->remove_)
return;
while (!this->recv_buffer_.empty()) {
if (this->recv_buffer_[0] != 0x00) {
ESP_LOGW(TAG, "Invalid preamble from %s", this->client_info_.c_str());
this->on_fatal_error();
return;
}
uint32_t i = 1;
const uint32_t size = this->recv_buffer_.size();
uint32_t consumed;
auto msg_size_varint = ProtoVarInt::parse(&this->recv_buffer_[i], size - i, &consumed);
if (!msg_size_varint.has_value())
// not enough data there yet
return;
i += consumed;
uint32_t msg_size = msg_size_varint->as_uint32();
auto msg_type_varint = ProtoVarInt::parse(&this->recv_buffer_[i], size - i, &consumed);
if (!msg_type_varint.has_value())
// not enough data there yet
return;
i += consumed;
uint32_t msg_type = msg_type_varint->as_uint32();
if (size - i < msg_size)
// message body not fully received
return;
uint8_t *msg = &this->recv_buffer_[i];
this->read_message(msg_size, msg_type, msg);
if (this->remove_)
return;
// pop front
uint32_t total = i + msg_size;
this->recv_buffer_.erase(this->recv_buffer_.begin(), this->recv_buffer_.begin() + total);
this->last_traffic_ = millis();
}
client_info_ = helper_->getpeername();
helper_->set_log_info(client_info_);
}
void APIConnection::disconnect_client() {
this->client_->close();
void APIConnection::force_disconnect_client() {
this->helper_->close();
this->remove_ = true;
}
@@ -93,61 +56,78 @@ void APIConnection::loop() {
if (this->remove_)
return;
if (this->next_close_) {
this->disconnect_client();
return;
}
if (!network_is_connected()) {
// when network is disconnected force disconnect immediately
// don't wait for timeout
this->on_fatal_error();
return;
}
if (this->client_->disconnected()) {
// failsafe for disconnect logic
this->on_disconnect_();
if (this->next_close_) {
this->helper_->close();
this->remove_ = true;
return;
}
this->parse_recv_buffer_();
APIError err = helper_->loop();
if (err != APIError::OK) {
on_fatal_error();
ESP_LOGW(TAG, "%s: Socket operation failed: %d", client_info_.c_str(), (int) err);
return;
}
ReadPacketBuffer buffer;
err = helper_->read_packet(&buffer);
if (err == APIError::WOULD_BLOCK) {
// pass
} else if (err != APIError::OK) {
on_fatal_error();
ESP_LOGW(TAG, "%s: Reading failed: %d", client_info_.c_str(), (int) err);
return;
} else {
this->last_traffic_ = millis();
// read a packet
this->read_message(
buffer.data_len,
buffer.type,
&buffer.container[buffer.data_offset]
);
if (this->remove_)
return;
}
this->list_entities_iterator_.advance();
this->initial_state_iterator_.advance();
const uint32_t keepalive = 60000;
const uint32_t now = millis();
if (this->sent_ping_) {
// Disconnect if not responded within 2.5*keepalive
if (millis() - this->last_traffic_ > (keepalive * 5) / 2) {
if (now - this->last_traffic_ > (keepalive * 5) / 2) {
this->force_disconnect_client();
ESP_LOGW(TAG, "'%s' didn't respond to ping request in time. Disconnecting...", this->client_info_.c_str());
this->disconnect_client();
}
} else if (millis() - this->last_traffic_ > keepalive) {
} else if (now - this->last_traffic_ > keepalive) {
this->sent_ping_ = true;
this->send_ping_request(PingRequest());
}
#ifdef USE_ESP32_CAMERA
if (this->image_reader_.available()) {
uint32_t space = this->client_->space();
// reserve 15 bytes for metadata, and at least 64 bytes of data
if (space >= 15 + 64) {
uint32_t to_send = std::min(space - 15, this->image_reader_.available());
auto buffer = this->create_buffer();
// fixed32 key = 1;
buffer.encode_fixed32(1, esp32_camera::global_esp32_camera->get_object_id_hash());
// bytes data = 2;
buffer.encode_bytes(2, this->image_reader_.peek_data_buffer(), to_send);
// bool done = 3;
bool done = this->image_reader_.available() == to_send;
buffer.encode_bool(3, done);
bool success = this->send_buffer(buffer, 44);
if (this->image_reader_.available() && this->helper_->can_write_without_blocking()) {
uint32_t to_send = std::min((size_t) 1024, this->image_reader_.available());
auto buffer = this->create_buffer();
// fixed32 key = 1;
buffer.encode_fixed32(1, esp32_camera::global_esp32_camera->get_object_id_hash());
// bytes data = 2;
buffer.encode_bytes(2, this->image_reader_.peek_data_buffer(), to_send);
// bool done = 3;
bool done = this->image_reader_.available() == to_send;
buffer.encode_bool(3, done);
bool success = this->send_buffer(buffer, 44);
if (success) {
this->image_reader_.consume_data(to_send);
}
if (success && done) {
this->image_reader_.return_image();
}
if (success) {
this->image_reader_.consume_data(to_send);
}
if (success && done) {
this->image_reader_.return_image();
}
}
#endif
@@ -718,8 +698,8 @@ bool APIConnection::send_log_message(int level, const char *tag, const char *lin
}
HelloResponse APIConnection::hello(const HelloRequest &msg) {
this->client_info_ = msg.client_info + " (" + this->client_->remoteIP().toString().c_str();
this->client_info_ += ")";
this->client_info_ = msg.client_info + " (" + this->helper_->getpeername() + ")";
this->helper_->set_log_info(client_info_);
ESP_LOGV(TAG, "Hello from client: '%s'", this->client_info_.c_str());
HelloResponse resp;
@@ -797,44 +777,35 @@ void APIConnection::subscribe_home_assistant_states(const SubscribeHomeAssistant
bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) {
if (this->remove_)
return false;
if (!this->helper_->can_write_without_blocking())
return false;
std::vector<uint8_t> header;
header.push_back(0x00);
ProtoVarInt(buffer.get_buffer()->size()).encode(header);
ProtoVarInt(message_type).encode(header);
size_t needed_space = buffer.get_buffer()->size() + header.size();
if (needed_space > this->client_->space()) {
delay(0);
if (needed_space > this->client_->space()) {
// SubscribeLogsResponse
if (message_type != 29) {
ESP_LOGV(TAG, "Cannot send message because of TCP buffer space");
}
delay(0);
return false;
}
APIError err = this->helper_->write_packet(
message_type,
buffer.get_buffer()->data(),
buffer.get_buffer()->size()
);
if (err == APIError::WOULD_BLOCK)
return false;
if (err != APIError::OK) {
on_fatal_error();
ESP_LOGW(TAG, "%s: Packet write failed %d errno=%d", client_info_.c_str(), (int) err, errno);
return false;
}
this->client_->add(reinterpret_cast<char *>(header.data()), header.size(),
ASYNC_WRITE_FLAG_COPY | ASYNC_WRITE_FLAG_MORE);
this->client_->add(reinterpret_cast<char *>(buffer.get_buffer()->data()), buffer.get_buffer()->size(),
ASYNC_WRITE_FLAG_COPY);
bool ret = this->client_->send();
return ret;
this->last_traffic_ = millis();
return true;
}
void APIConnection::on_unauthenticated_access() {
ESP_LOGD(TAG, "'%s' tried to access without authentication.", this->client_info_.c_str());
this->on_fatal_error();
ESP_LOGD(TAG, "'%s' tried to access without authentication.", this->client_info_.c_str());
}
void APIConnection::on_no_setup_connection() {
ESP_LOGD(TAG, "'%s' tried to access without full connection.", this->client_info_.c_str());
this->on_fatal_error();
ESP_LOGD(TAG, "'%s' tried to access without full connection.", this->client_info_.c_str());
}
void APIConnection::on_fatal_error() {
ESP_LOGV(TAG, "Error: Disconnecting %s", this->client_info_.c_str());
this->client_->close();
this->helper_->close();
this->remove_ = true;
}

View File

@@ -5,16 +5,18 @@
#include "api_pb2.h"
#include "api_pb2_service.h"
#include "api_server.h"
#include "api_frame_helper.h"
namespace esphome {
namespace api {
class APIConnection : public APIServerConnection {
public:
APIConnection(AsyncClient *client, APIServer *parent);
virtual ~APIConnection();
APIConnection(std::unique_ptr<socket::Socket> socket, APIServer *parent);
virtual ~APIConnection() = default;
void disconnect_client();
void start();
void force_disconnect_client();
void loop();
bool send_list_info_done() {
@@ -87,8 +89,8 @@ class APIConnection : public APIServerConnection {
#endif
void on_disconnect_response(const DisconnectResponse &value) override {
// we initiated disconnect_client
this->next_close_ = true;
this->helper_->close();
this->remove_ = true;
}
void on_ping_response(const PingResponse &value) override {
// we initiated ping
@@ -102,6 +104,8 @@ class APIConnection : public APIServerConnection {
ConnectResponse connect(const ConnectRequest &msg) override;
DisconnectResponse disconnect(const DisconnectRequest &msg) override {
// remote initiated disconnect_client
// don't close yet, we still need to send the disconnect response
// close will happen on next loop
this->next_close_ = true;
DisconnectResponse resp;
return resp;
@@ -135,19 +139,16 @@ class APIConnection : public APIServerConnection {
void on_unauthenticated_access() override;
void on_no_setup_connection() override;
ProtoWriteBuffer create_buffer() override {
this->send_buffer_.clear();
return {&this->send_buffer_};
// FIXME: ensure no recursive writes can happen
this->proto_write_buffer_.clear();
return {&this->proto_write_buffer_};
}
bool send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) override;
protected:
friend APIServer;
void on_error_(int8_t error);
void on_disconnect_();
void on_timeout_(uint32_t time);
void on_data_(uint8_t *buf, size_t len);
void parse_recv_buffer_();
bool send_(const void *buf, size_t len, bool force);
enum class ConnectionState {
WAITING_FOR_HELLO,
@@ -157,8 +158,10 @@ class APIConnection : public APIServerConnection {
bool remove_{false};
std::vector<uint8_t> send_buffer_;
std::vector<uint8_t> recv_buffer_;
// Buffer used to encode proto messages
// Re-use to prevent allocations
std::vector<uint8_t> proto_write_buffer_;
std::unique_ptr<APIFrameHelper> helper_;
std::string client_info_;
#ifdef USE_ESP32_CAMERA
@@ -170,9 +173,7 @@ class APIConnection : public APIServerConnection {
uint32_t last_traffic_;
bool sent_ping_{false};
bool service_call_subscription_{false};
bool current_nodelay_{false};
bool next_close_{false};
AsyncClient *client_;
bool next_close_ = false;
APIServer *parent_;
InitialStateIterator initial_state_iterator_;
ListEntitiesIterator list_entities_iterator_;

View File

@@ -0,0 +1,907 @@
#include "api_frame_helper.h"
#include "esphome/core/log.h"
#include "esphome/core/helpers.h"
#include "proto.h"
namespace esphome {
namespace api {
static const char *const TAG = "api.socket";
/// Is the given return value (from read/write syscalls) a wouldblock error?
bool is_would_block(ssize_t ret) {
if (ret == -1) {
return errno == EWOULDBLOCK || errno == EAGAIN;
}
return ret == 0;
}
#define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, info_.c_str(), ##__VA_ARGS__)
#ifdef USE_API_NOISE
static const char *const PROLOGUE_INIT = "NoiseAPIInit";
/// Convert a noise error code to a readable error
std::string noise_err_to_str(int err) {
if (err == NOISE_ERROR_NO_MEMORY)
return "NO_MEMORY";
if (err == NOISE_ERROR_UNKNOWN_ID)
return "UNKNOWN_ID";
if (err == NOISE_ERROR_UNKNOWN_NAME)
return "UNKNOWN_NAME";
if (err == NOISE_ERROR_MAC_FAILURE)
return "MAC_FAILURE";
if (err == NOISE_ERROR_NOT_APPLICABLE)
return "NOT_APPLICABLE";
if (err == NOISE_ERROR_SYSTEM)
return "SYSTEM";
if (err == NOISE_ERROR_REMOTE_KEY_REQUIRED)
return "REMOTE_KEY_REQUIRED";
if (err == NOISE_ERROR_LOCAL_KEY_REQUIRED)
return "LOCAL_KEY_REQUIRED";
if (err == NOISE_ERROR_PSK_REQUIRED)
return "PSK_REQUIRED";
if (err == NOISE_ERROR_INVALID_LENGTH)
return "INVALID_LENGTH";
if (err == NOISE_ERROR_INVALID_PARAM)
return "INVALID_PARAM";
if (err == NOISE_ERROR_INVALID_STATE)
return "INVALID_STATE";
if (err == NOISE_ERROR_INVALID_NONCE)
return "INVALID_NONCE";
if (err == NOISE_ERROR_INVALID_PRIVATE_KEY)
return "INVALID_PRIVATE_KEY";
if (err == NOISE_ERROR_INVALID_PUBLIC_KEY)
return "INVALID_PUBLIC_KEY";
if (err == NOISE_ERROR_INVALID_FORMAT)
return "INVALID_FORMAT";
if (err == NOISE_ERROR_INVALID_SIGNATURE)
return "INVALID_SIGNATURE";
return to_string(err);
}
/// Initialize the frame helper, returns OK if successful.
APIError APINoiseFrameHelper::init() {
if (state_ != State::INITIALIZE || socket_ == nullptr) {
HELPER_LOG("Bad state for init %d", (int) state_);
return APIError::BAD_STATE;
}
int err = socket_->setblocking(false);
if (err != 0) {
state_ = State::FAILED;
HELPER_LOG("Setting nonblocking failed with errno %d", errno);
return APIError::TCP_NONBLOCKING_FAILED;
}
int enable = 1;
err = socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int));
if (err != 0) {
state_ = State::FAILED;
HELPER_LOG("Setting nodelay failed with errno %d", errno);
return APIError::TCP_NODELAY_FAILED;
}
// init prologue
prologue_.insert(prologue_.end(), PROLOGUE_INIT, PROLOGUE_INIT + strlen(PROLOGUE_INIT));
state_ = State::CLIENT_HELLO;
return APIError::OK;
}
/// Run through handshake messages (if in that phase)
APIError APINoiseFrameHelper::loop() {
APIError err = state_action_();
if (err == APIError::WOULD_BLOCK)
return APIError::OK;
if (err != APIError::OK)
return err;
if (!tx_buf_.empty()) {
err = try_send_tx_buf_();
if (err != APIError::OK) {
return err;
}
}
return APIError::OK;
}
/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter
*
* @param frame: The struct to hold the frame information in.
* msg_start: points to the start of the payload - this pointer is only valid until the next
* try_receive_raw_ call
*
* @return 0 if a full packet is in rx_buf_
* @return -1 if error, check errno.
*
* errno EWOULDBLOCK: Packet could not be read without blocking. Try again later.
* errno ENOMEM: Not enough memory for reading packet.
* errno API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame.
* errno API_ERROR_HANDSHAKE_PACKET_LEN: Packet too big for this phase.
*/
APIError APINoiseFrameHelper::try_read_frame_(ParsedFrame *frame) {
int err;
APIError aerr;
if (frame == nullptr) {
HELPER_LOG("Bad argument for try_read_frame_");
return APIError::BAD_ARG;
}
// read header
if (rx_header_buf_len_ < 3) {
// no header information yet
size_t to_read = 3 - rx_header_buf_len_;
ssize_t received = socket_->read(&rx_header_buf_[rx_header_buf_len_], to_read);
if (is_would_block(received)) {
return APIError::WOULD_BLOCK;
} else if (received == -1) {
state_ = State::FAILED;
HELPER_LOG("Socket read failed with errno %d", errno);
return APIError::SOCKET_READ_FAILED;
}
rx_header_buf_len_ += received;
if (received != to_read) {
// not a full read
return APIError::WOULD_BLOCK;
}
// header reading done
}
// read body
uint8_t indicator = rx_header_buf_[0];
if (indicator != 0x01) {
state_ = State::FAILED;
HELPER_LOG("Bad indicator byte %u", indicator);
return APIError::BAD_INDICATOR;
}
uint16_t msg_size = (((uint16_t) rx_header_buf_[1]) << 8) | rx_header_buf_[2];
if (state_ != State::DATA && msg_size > 128) {
// for handshake message only permit up to 128 byte
state_ = State::FAILED;
HELPER_LOG("Bad packet len for handshake: %d", msg_size);
return APIError::BAD_HANDSHAKE_PACKET_LEN;
}
// reserve space for body
if (rx_buf_.size() != msg_size) {
rx_buf_.resize(msg_size);
}
if (rx_buf_len_ < msg_size) {
// more data to read
size_t to_read = msg_size - rx_buf_len_;
ssize_t received = socket_->read(&rx_buf_[rx_buf_len_], to_read);
if (is_would_block(received)) {
return APIError::WOULD_BLOCK;
} else if (received == -1) {
state_ = State::FAILED;
HELPER_LOG("Socket read failed with errno %d", errno);
return APIError::SOCKET_READ_FAILED;
}
rx_buf_len_ += received;
if (received != to_read) {
// not all read
return APIError::WOULD_BLOCK;
}
}
// uncomment for even more debugging
// ESP_LOGVV(TAG, "Received frame: %s", hexencode(rx_buf_).c_str());
frame->msg = std::move(rx_buf_);
// consume msg
rx_buf_ = {};
rx_buf_len_ = 0;
rx_header_buf_len_ = 0;
return APIError::OK;
}
/** To be called from read/write methods.
*
* This method runs through the internal handshake methods, if in that state.
*
* If the handshake is still active when this method returns and a read/write can't take place at
* the moment, returns WOULD_BLOCK.
* If an error occured, returns that error. Only returns OK if the transport is ready for data
* traffic.
*/
APIError APINoiseFrameHelper::state_action_() {
int err;
APIError aerr;
if (state_ == State::INITIALIZE) {
HELPER_LOG("Bad state for method: %d", (int) state_);
return APIError::BAD_STATE;
}
if (state_ == State::CLIENT_HELLO) {
// waiting for client hello
ParsedFrame frame;
aerr = try_read_frame_(&frame);
if (aerr != APIError::OK)
return aerr;
// ignore contents, may be used in future for flags
prologue_.push_back((uint8_t) (frame.msg.size() >> 8));
prologue_.push_back((uint8_t) frame.msg.size());
prologue_.insert(prologue_.end(), frame.msg.begin(), frame.msg.end());
state_ = State::SERVER_HELLO;
}
if (state_ == State::SERVER_HELLO) {
// send server hello
uint8_t msg[1];
msg[0] = 0x01; // chosen proto
aerr = write_frame_(msg, 1);
if (aerr != APIError::OK)
return aerr;
// start handshake
aerr = init_handshake_();
if (aerr != APIError::OK)
return aerr;
state_ = State::HANDSHAKE;
}
if (state_ == State::HANDSHAKE) {
int action = noise_handshakestate_get_action(handshake_);
if (action == NOISE_ACTION_READ_MESSAGE) {
// waiting for handshake msg
ParsedFrame frame;
aerr = try_read_frame_(&frame);
if (aerr == APIError::BAD_INDICATOR) {
send_explicit_handshake_reject_("Bad indicator byte");
return aerr;
}
if (frame.msg.size() < 1 || frame.msg[0] != 0x00) {
aerr = APIError::BAD_HANDSHAKE_PACKET_LEN;
}
if (aerr == APIError::BAD_HANDSHAKE_PACKET_LEN) {
send_explicit_handshake_reject_("Bad handshake packet len");
return aerr;
}
if (aerr != APIError::OK)
return aerr;
NoiseBuffer mbuf;
noise_buffer_init(mbuf);
noise_buffer_set_input(mbuf, frame.msg.data() + 1, frame.msg.size() - 1);
err = noise_handshakestate_read_message(handshake_, &mbuf, nullptr);
if (err != 0) {
// TODO: explicit rejection
state_ = State::FAILED;
HELPER_LOG("noise_handshakestate_read_message failed: %s", noise_err_to_str(err).c_str());
if (err == NOISE_ERROR_MAC_FAILURE) {
send_explicit_handshake_reject_("Handshake MAC failure");
} else {
send_explicit_handshake_reject_("Handshake error");
}
return APIError::HANDSHAKESTATE_READ_FAILED;
}
aerr = check_handshake_finished_();
if (aerr != APIError::OK)
return aerr;
} else if (action == NOISE_ACTION_WRITE_MESSAGE) {
uint8_t buffer[65];
NoiseBuffer mbuf;
noise_buffer_init(mbuf);
noise_buffer_set_output(mbuf, buffer + 1, sizeof(buffer) - 1);
err = noise_handshakestate_write_message(handshake_, &mbuf, nullptr);
if (err != 0) {
state_ = State::FAILED;
HELPER_LOG("noise_handshakestate_write_message failed: %s", noise_err_to_str(err).c_str());
return APIError::HANDSHAKESTATE_WRITE_FAILED;
}
buffer[0] = 0x00; // success
aerr = write_frame_(buffer, mbuf.size + 1);
if (aerr != APIError::OK)
return aerr;
aerr = check_handshake_finished_();
if (aerr != APIError::OK)
return aerr;
} else {
// bad state for action
state_ = State::FAILED;
HELPER_LOG("Bad action for handshake: %d", action);
return APIError::HANDSHAKESTATE_BAD_STATE;
}
}
if (state_ == State::CLOSED || state_ == State::FAILED) {
return APIError::BAD_STATE;
}
return APIError::OK;
}
void APINoiseFrameHelper::send_explicit_handshake_reject_(const std::string &reason) {
std::vector<uint8_t> data;
data.reserve(reason.size() + 1);
data[0] = 0x01; // failure
for (size_t i = 0; i < reason.size(); i++) {
data[i+1] = (uint8_t) reason[i];
}
write_frame_(data.data(), data.size());
}
APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) {
int err;
APIError aerr;
aerr = state_action_();
if (aerr != APIError::OK) {
return aerr;
}
if (state_ != State::DATA) {
return APIError::WOULD_BLOCK;
}
ParsedFrame frame;
aerr = try_read_frame_(&frame);
if (aerr != APIError::OK)
return aerr;
NoiseBuffer mbuf;
noise_buffer_init(mbuf);
noise_buffer_set_inout(mbuf, frame.msg.data(), frame.msg.size(), frame.msg.size());
err = noise_cipherstate_decrypt(recv_cipher_, &mbuf);
if (err != 0) {
state_ = State::FAILED;
HELPER_LOG("noise_cipherstate_decrypt failed: %s", noise_err_to_str(err).c_str());
return APIError::CIPHERSTATE_DECRYPT_FAILED;
}
size_t msg_size = mbuf.size;
uint8_t *msg_data = frame.msg.data();
if (msg_size < 4) {
state_ = State::FAILED;
HELPER_LOG("Bad data packet: size %d too short", msg_size);
return APIError::BAD_DATA_PACKET;
}
// uint16_t type;
// uint16_t data_len;
// uint8_t *data;
// uint8_t *padding; zero or more bytes to fill up the rest of the packet
uint16_t type = (((uint16_t) msg_data[0]) << 8) | msg_data[1];
uint16_t data_len = (((uint16_t) msg_data[2]) << 8) | msg_data[3];
if (data_len > msg_size - 4) {
state_ = State::FAILED;
HELPER_LOG("Bad data packet: data_len %u greater than msg_size %u", data_len, msg_size);
return APIError::BAD_DATA_PACKET;
}
buffer->container = std::move(frame.msg);
buffer->data_offset = 4;
buffer->data_len = data_len;
buffer->type = type;
return APIError::OK;
}
bool APINoiseFrameHelper::can_write_without_blocking() {
return state_ == State::DATA && tx_buf_.empty();
}
APIError APINoiseFrameHelper::write_packet(uint16_t type, const uint8_t *payload, size_t payload_len) {
int err;
APIError aerr;
aerr = state_action_();
if (aerr != APIError::OK) {
return aerr;
}
if (state_ != State::DATA) {
return APIError::WOULD_BLOCK;
}
size_t padding = 0;
size_t msg_len = 4 + payload_len + padding;
size_t frame_len = 3 + msg_len + noise_cipherstate_get_mac_length(send_cipher_);
auto tmpbuf = std::unique_ptr<uint8_t[]>{new (std::nothrow) uint8_t[frame_len]};
if (tmpbuf == nullptr) {
HELPER_LOG("Could not allocate for writing packet");
return APIError::OUT_OF_MEMORY;
}
tmpbuf[0] = 0x01; // indicator
// tmpbuf[1], tmpbuf[2] to be set later
const uint8_t msg_offset = 3;
const uint8_t payload_offset = msg_offset + 4;
tmpbuf[msg_offset + 0] = (uint8_t) (type >> 8); // type
tmpbuf[msg_offset + 1] = (uint8_t) type;
tmpbuf[msg_offset + 2] = (uint8_t) (payload_len >> 8); // data_len
tmpbuf[msg_offset + 3] = (uint8_t) payload_len;
// copy data
std::copy(payload, payload + payload_len, &tmpbuf[payload_offset]);
// fill padding with zeros
std::fill(&tmpbuf[payload_offset + payload_len], &tmpbuf[frame_len], 0);
NoiseBuffer mbuf;
noise_buffer_init(mbuf);
noise_buffer_set_inout(mbuf, &tmpbuf[msg_offset], msg_len, frame_len - msg_offset);
err = noise_cipherstate_encrypt(send_cipher_, &mbuf);
if (err != 0) {
state_ = State::FAILED;
HELPER_LOG("noise_cipherstate_encrypt failed: %s", noise_err_to_str(err).c_str());
return APIError::CIPHERSTATE_ENCRYPT_FAILED;
}
size_t total_len = 3 + mbuf.size;
tmpbuf[1] = (uint8_t) (mbuf.size >> 8);
tmpbuf[2] = (uint8_t) mbuf.size;
// write raw to not have two packets sent if NAGLE disabled
aerr = write_raw_(&tmpbuf[0], total_len);
if (aerr != APIError::OK) {
return aerr;
}
return APIError::OK;
}
APIError APINoiseFrameHelper::try_send_tx_buf_() {
// try send from tx_buf
while (state_ != State::CLOSED && !tx_buf_.empty()) {
ssize_t sent = socket_->write(tx_buf_.data(), tx_buf_.size());
if (sent == -1) {
if (errno == EWOULDBLOCK || errno == EAGAIN)
break;
state_ = State::FAILED;
HELPER_LOG("Socket write failed with errno %d", errno);
return APIError::SOCKET_WRITE_FAILED;
} else if (sent == 0) {
break;
}
// TODO: inefficient if multiple packets in txbuf
// replace with deque of buffers
tx_buf_.erase(tx_buf_.begin(), tx_buf_.begin() + sent);
}
return APIError::OK;
}
/** Write the data to the socket, or buffer it a write would block
*
* @param data The data to write
* @param len The length of data
*/
APIError APINoiseFrameHelper::write_raw_(const uint8_t *data, size_t len) {
if (len == 0)
return APIError::OK;
int err;
APIError aerr;
// uncomment for even more debugging
// ESP_LOGVV(TAG, "Sending raw: %s", hexencode(data, len).c_str());
if (!tx_buf_.empty()) {
// try to empty tx_buf_ first
aerr = try_send_tx_buf_();
if (aerr != APIError::OK && aerr != APIError::WOULD_BLOCK)
return aerr;
}
if (!tx_buf_.empty()) {
// tx buf not empty, can't write now because then stream would be inconsistent
tx_buf_.insert(tx_buf_.end(), data, data + len);
return APIError::OK;
}
ssize_t sent = socket_->write(data, len);
if (is_would_block(sent)) {
// operation would block, add buffer to tx_buf
tx_buf_.insert(tx_buf_.end(), data, data + len);
return APIError::OK;
} else if (sent == -1) {
// an error occured
state_ = State::FAILED;
HELPER_LOG("Socket write failed with errno %d", errno);
return APIError::SOCKET_WRITE_FAILED;
} else if (sent != len) {
// partially sent, add end to tx_buf
tx_buf_.insert(tx_buf_.end(), data + sent, data + len);
return APIError::OK;
}
// fully sent
return APIError::OK;
}
APIError APINoiseFrameHelper::write_frame_(const uint8_t *data, size_t len) {
APIError aerr;
uint8_t header[3];
header[0] = 0x01; // indicator
header[1] = (uint8_t) (len >> 8);
header[2] = (uint8_t) len;
aerr = write_raw_(header, 3);
if (aerr != APIError::OK)
return aerr;
aerr = write_raw_(data, len);
return aerr;
}
/** Initiate the data structures for the handshake.
*
* @return 0 on success, -1 on error (check errno)
*/
APIError APINoiseFrameHelper::init_handshake_() {
int err;
memset(&nid_, 0, sizeof(nid_));
// const char *proto = "Noise_NNpsk0_25519_ChaChaPoly_SHA256";
// err = noise_protocol_name_to_id(&nid_, proto, strlen(proto));
nid_.pattern_id = NOISE_PATTERN_NN;
nid_.cipher_id = NOISE_CIPHER_CHACHAPOLY;
nid_.dh_id = NOISE_DH_CURVE25519;
nid_.prefix_id = NOISE_PREFIX_STANDARD;
nid_.hybrid_id = NOISE_DH_NONE;
nid_.hash_id = NOISE_HASH_SHA256;
nid_.modifier_ids[0] = NOISE_MODIFIER_PSK0;
err = noise_handshakestate_new_by_id(&handshake_, &nid_, NOISE_ROLE_RESPONDER);
if (err != 0) {
state_ = State::FAILED;
HELPER_LOG("noise_handshakestate_new_by_id failed: %s", noise_err_to_str(err).c_str());
return APIError::HANDSHAKESTATE_SETUP_FAILED;
}
const auto &psk = ctx_->get_psk();
err = noise_handshakestate_set_pre_shared_key(handshake_, psk.data(), psk.size());
if (err != 0) {
state_ = State::FAILED;
HELPER_LOG("noise_handshakestate_set_pre_shared_key failed: %s", noise_err_to_str(err).c_str());
return APIError::HANDSHAKESTATE_SETUP_FAILED;
}
err = noise_handshakestate_set_prologue(handshake_, prologue_.data(), prologue_.size());
if (err != 0) {
state_ = State::FAILED;
HELPER_LOG("noise_handshakestate_set_prologue failed: %s", noise_err_to_str(err).c_str());
return APIError::HANDSHAKESTATE_SETUP_FAILED;
}
// set_prologue copies it into handshakestate, so we can get rid of it now
prologue_ = {};
err = noise_handshakestate_start(handshake_);
if (err != 0) {
state_ = State::FAILED;
HELPER_LOG("noise_handshakestate_start failed: %s", noise_err_to_str(err).c_str());
return APIError::HANDSHAKESTATE_SETUP_FAILED;
}
return APIError::OK;
}
APIError APINoiseFrameHelper::check_handshake_finished_() {
assert(state_ == State::HANDSHAKE);
int action = noise_handshakestate_get_action(handshake_);
if (action == NOISE_ACTION_READ_MESSAGE || action == NOISE_ACTION_WRITE_MESSAGE)
return APIError::OK;
if (action != NOISE_ACTION_SPLIT) {
state_ = State::FAILED;
HELPER_LOG("Bad action for handshake: %d", action);
return APIError::HANDSHAKESTATE_BAD_STATE;
}
int err = noise_handshakestate_split(handshake_, &send_cipher_, &recv_cipher_);
if (err != 0) {
state_ = State::FAILED;
HELPER_LOG("noise_handshakestate_split failed: %s", noise_err_to_str(err).c_str());
return APIError::HANDSHAKESTATE_SPLIT_FAILED;
}
HELPER_LOG("Handshake complete!");
noise_handshakestate_free(handshake_);
handshake_ = nullptr;
state_ = State::DATA;
return APIError::OK;
}
APINoiseFrameHelper::~APINoiseFrameHelper() {
if (handshake_ != nullptr) {
noise_handshakestate_free(handshake_);
handshake_ = nullptr;
}
if (send_cipher_ != nullptr) {
noise_cipherstate_free(send_cipher_);
send_cipher_ = nullptr;
}
if (recv_cipher_ != nullptr) {
noise_cipherstate_free(recv_cipher_);
recv_cipher_ = nullptr;
}
}
APIError APINoiseFrameHelper::close() {
state_ = State::CLOSED;
int err = socket_->close();
if (err == -1)
return APIError::CLOSE_FAILED;
return APIError::OK;
}
APIError APINoiseFrameHelper::shutdown(int how) {
int err = socket_->shutdown(how);
if (err == -1)
return APIError::SHUTDOWN_FAILED;
if (how == SHUT_RDWR) {
state_ = State::CLOSED;
}
return APIError::OK;
}
extern "C" {
// declare how noise generates random bytes (here with a good HWRNG based on the RF system)
void noise_rand_bytes(void *output, size_t len) {
esphome::fill_random(reinterpret_cast<uint8_t *>(output), len);
}
}
#endif // USE_API_NOISE
#ifdef USE_API_PLAINTEXT
/// Initialize the frame helper, returns OK if successful.
APIError APIPlaintextFrameHelper::init() {
if (state_ != State::INITIALIZE || socket_ == nullptr) {
HELPER_LOG("Bad state for init %d", (int) state_);
return APIError::BAD_STATE;
}
int err = socket_->setblocking(false);
if (err != 0) {
state_ = State::FAILED;
HELPER_LOG("Setting nonblocking failed with errno %d", errno);
return APIError::TCP_NONBLOCKING_FAILED;
}
int enable = 1;
err = socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int));
if (err != 0) {
state_ = State::FAILED;
HELPER_LOG("Setting nodelay failed with errno %d", errno);
return APIError::TCP_NODELAY_FAILED;
}
state_ = State::DATA;
return APIError::OK;
}
/// Not used for plaintext
APIError APIPlaintextFrameHelper::loop() {
if (state_ != State::DATA) {
return APIError::BAD_STATE;
}
// try send pending TX data
if (!tx_buf_.empty()) {
APIError err = try_send_tx_buf_();
if (err != APIError::OK) {
return err;
}
}
return APIError::OK;
}
/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter
*
* @param frame: The struct to hold the frame information in.
* msg: store the parsed frame in that struct
*
* @return See APIError
*
* error API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame.
*/
APIError APIPlaintextFrameHelper::try_read_frame_(ParsedFrame *frame) {
int err;
APIError aerr;
if (frame == nullptr) {
HELPER_LOG("Bad argument for try_read_frame_");
return APIError::BAD_ARG;
}
// read header
while (!rx_header_parsed_) {
uint8_t data;
ssize_t received = socket_->read(&data, 1);
if (is_would_block(received)) {
return APIError::WOULD_BLOCK;
} else if (received == -1) {
state_ = State::FAILED;
HELPER_LOG("Socket read failed with errno %d", errno);
return APIError::SOCKET_READ_FAILED;
}
rx_header_buf_.push_back(data);
// try parse header
if (rx_header_buf_[0] != 0x00) {
state_ = State::FAILED;
HELPER_LOG("Bad indicator byte %u", rx_header_buf_[0]);
return APIError::BAD_INDICATOR;
}
size_t i = 1;
size_t consumed = 0;
auto msg_size_varint = ProtoVarInt::parse(&rx_header_buf_[i], rx_header_buf_.size() - i, &consumed);
if (!msg_size_varint.has_value()) {
// not enough data there yet
continue;
}
i += consumed;
rx_header_parsed_len_ = msg_size_varint->as_uint32();
auto msg_type_varint = ProtoVarInt::parse(&rx_header_buf_[i], rx_header_buf_.size() - i, &consumed);
if (!msg_type_varint.has_value()) {
// not enough data there yet
continue;
}
rx_header_parsed_type_ = msg_type_varint->as_uint32();
rx_header_parsed_ = true;
}
// header reading done
// reserve space for body
if (rx_buf_.size() != rx_header_parsed_len_) {
rx_buf_.resize(rx_header_parsed_len_);
}
if (rx_buf_len_ < rx_header_parsed_len_) {
// more data to read
size_t to_read = rx_header_parsed_len_ - rx_buf_len_;
ssize_t received = socket_->read(&rx_buf_[rx_buf_len_], to_read);
if (is_would_block(received)) {
return APIError::WOULD_BLOCK;
} else if (received == -1) {
state_ = State::FAILED;
HELPER_LOG("Socket read failed with errno %d", errno);
return APIError::SOCKET_READ_FAILED;
}
rx_buf_len_ += received;
if (received != to_read) {
// not all read
return APIError::WOULD_BLOCK;
}
}
// uncomment for even more debugging
// ESP_LOGVV(TAG, "Received frame: %s", hexencode(rx_buf_).c_str());
frame->msg = std::move(rx_buf_);
// consume msg
rx_buf_ = {};
rx_buf_len_ = 0;
rx_header_buf_.clear();
rx_header_parsed_ = false;
return APIError::OK;
}
APIError APIPlaintextFrameHelper::read_packet(ReadPacketBuffer *buffer) {
int err;
APIError aerr;
if (state_ != State::DATA) {
return APIError::WOULD_BLOCK;
}
ParsedFrame frame;
aerr = try_read_frame_(&frame);
if (aerr != APIError::OK)
return aerr;
buffer->container = std::move(frame.msg);
buffer->data_offset = 0;
buffer->data_len = rx_header_parsed_len_;
buffer->type = rx_header_parsed_type_;
return APIError::OK;
}
bool APIPlaintextFrameHelper::can_write_without_blocking() {
return state_ == State::DATA && tx_buf_.empty();
}
APIError APIPlaintextFrameHelper::write_packet(uint16_t type, const uint8_t *payload, size_t payload_len) {
int err;
APIError aerr;
if (state_ != State::DATA) {
return APIError::BAD_STATE;
}
std::vector<uint8_t> header;
header.push_back(0x00);
ProtoVarInt(payload_len).encode(header);
ProtoVarInt(type).encode(header);
aerr = write_raw_(&header[0], header.size());
if (aerr != APIError::OK) {
return aerr;
}
aerr = write_raw_(payload, payload_len);
if (aerr != APIError::OK) {
return aerr;
}
return APIError::OK;
}
APIError APIPlaintextFrameHelper::try_send_tx_buf_() {
// try send from tx_buf
while (state_ != State::CLOSED && !tx_buf_.empty()) {
ssize_t sent = socket_->write(tx_buf_.data(), tx_buf_.size());
if (sent == -1) {
if (errno == EWOULDBLOCK || errno == EAGAIN)
break;
state_ = State::FAILED;
HELPER_LOG("Socket write failed with errno %d", errno);
return APIError::SOCKET_WRITE_FAILED;
} else if (sent == 0) {
break;
}
// TODO: inefficient if multiple packets in txbuf
// replace with deque of buffers
tx_buf_.erase(tx_buf_.begin(), tx_buf_.begin() + sent);
}
return APIError::OK;
}
/** Write the data to the socket, or buffer it a write would block
*
* @param data The data to write
* @param len The length of data
*/
APIError APIPlaintextFrameHelper::write_raw_(const uint8_t *data, size_t len) {
if (len == 0)
return APIError::OK;
int err;
APIError aerr;
// uncomment for even more debugging
// ESP_LOGVV(TAG, "Sending raw: %s", hexencode(data, len).c_str());
if (!tx_buf_.empty()) {
// try to empty tx_buf_ first
aerr = try_send_tx_buf_();
if (aerr != APIError::OK && aerr != APIError::WOULD_BLOCK)
return aerr;
}
if (!tx_buf_.empty()) {
// tx buf not empty, can't write now because then stream would be inconsistent
tx_buf_.insert(tx_buf_.end(), data, data + len);
return APIError::OK;
}
ssize_t sent = socket_->write(data, len);
if (is_would_block(sent)) {
// operation would block, add buffer to tx_buf
tx_buf_.insert(tx_buf_.end(), data, data + len);
return APIError::OK;
} else if (sent == -1) {
// an error occured
state_ = State::FAILED;
HELPER_LOG("Socket write failed with errno %d", errno);
return APIError::SOCKET_WRITE_FAILED;
} else if (sent != len) {
// partially sent, add end to tx_buf
tx_buf_.insert(tx_buf_.end(), data + sent, data + len);
return APIError::OK;
}
// fully sent
return APIError::OK;
}
APIError APIPlaintextFrameHelper::write_frame_(const uint8_t *data, size_t len) {
APIError aerr;
uint8_t header[3];
header[0] = 0x01; // indicator
header[1] = (uint8_t) (len >> 8);
header[2] = (uint8_t) len;
aerr = write_raw_(header, 3);
if (aerr != APIError::OK)
return aerr;
aerr = write_raw_(data, len);
return aerr;
}
APIError APIPlaintextFrameHelper::close() {
state_ = State::CLOSED;
int err = socket_->close();
if (err == -1)
return APIError::CLOSE_FAILED;
return APIError::OK;
}
APIError APIPlaintextFrameHelper::shutdown(int how) {
int err = socket_->shutdown(how);
if (err == -1)
return APIError::SHUTDOWN_FAILED;
if (how == SHUT_RDWR) {
state_ = State::CLOSED;
}
return APIError::OK;
}
#endif // USE_API_PLAINTEXT
} // namespace api
} // namespace esphome

View File

@@ -0,0 +1,186 @@
#pragma once
#include <cstdint>
#include <vector>
#include <deque>
#include "esphome/core/defines.h"
#ifdef USE_API_NOISE
#include "noise/protocol.h"
#endif
#include "esphome/components/socket/socket.h"
#include "api_noise_context.h"
namespace esphome {
namespace api {
struct ReadPacketBuffer {
std::vector<uint8_t> container;
uint16_t type;
size_t data_offset;
size_t data_len;
};
struct PacketBuffer {
const std::vector<uint8_t> container;
uint16_t type;
uint8_t data_offset;
uint8_t data_len;
};
enum class APIError : int {
OK = 0,
WOULD_BLOCK = 1001,
BAD_HANDSHAKE_PACKET_LEN = 1002,
BAD_INDICATOR = 1003,
BAD_DATA_PACKET = 1004,
TCP_NODELAY_FAILED = 1005,
TCP_NONBLOCKING_FAILED = 1006,
CLOSE_FAILED = 1007,
SHUTDOWN_FAILED = 1008,
BAD_STATE = 1009,
BAD_ARG = 1010,
SOCKET_READ_FAILED = 1011,
SOCKET_WRITE_FAILED = 1012,
HANDSHAKESTATE_READ_FAILED = 1013,
HANDSHAKESTATE_WRITE_FAILED = 1014,
HANDSHAKESTATE_BAD_STATE = 1015,
CIPHERSTATE_DECRYPT_FAILED = 1016,
CIPHERSTATE_ENCRYPT_FAILED = 1017,
OUT_OF_MEMORY = 1018,
HANDSHAKESTATE_SETUP_FAILED = 1019,
HANDSHAKESTATE_SPLIT_FAILED = 1020,
};
class APIFrameHelper {
public:
virtual APIError init() = 0;
virtual APIError loop() = 0;
virtual APIError read_packet(ReadPacketBuffer *buffer) = 0;
virtual bool can_write_without_blocking() = 0;
virtual APIError write_packet(uint16_t type, const uint8_t *data, size_t len) = 0;
virtual std::string getpeername() = 0;
virtual APIError close() = 0;
virtual APIError shutdown(int how) = 0;
// Give this helper a name for logging
virtual void set_log_info(std::string info) = 0;
};
#ifdef USE_API_NOISE
class APINoiseFrameHelper : public APIFrameHelper {
public:
APINoiseFrameHelper(std::unique_ptr<socket::Socket> socket, std::shared_ptr<APINoiseContext> ctx) : socket_(std::move(socket)), ctx_(ctx) {}
~APINoiseFrameHelper();
APIError init() override;
APIError loop() override;
APIError read_packet(ReadPacketBuffer *buffer) override;
bool can_write_without_blocking() override;
APIError write_packet(uint16_t type, const uint8_t *data, size_t len) override;
std::string getpeername() override{
return socket_->getpeername();
}
APIError close() override;
APIError shutdown(int how) override;
// Give this helper a name for logging
void set_log_info(std::string info) override {
info_ = std::move(info);
}
protected:
struct ParsedFrame {
std::vector<uint8_t> msg;
};
APIError state_action_();
APIError try_read_frame_(ParsedFrame *frame);
APIError try_send_tx_buf_();
APIError write_frame_(const uint8_t *data, size_t len);
APIError write_raw_(const uint8_t *data, size_t len);
APIError init_handshake_();
APIError check_handshake_finished_();
void send_explicit_handshake_reject_(const std::string &reason);
std::unique_ptr<socket::Socket> socket_;
std::string info_;
uint8_t rx_header_buf_[3];
size_t rx_header_buf_len_ = 0;
std::vector<uint8_t> rx_buf_;
size_t rx_buf_len_ = 0;
std::vector<uint8_t> tx_buf_;
std::vector<uint8_t> prologue_;
std::shared_ptr<APINoiseContext> ctx_;
NoiseHandshakeState *handshake_ = nullptr;
NoiseCipherState *send_cipher_ = nullptr;
NoiseCipherState *recv_cipher_ = nullptr;
NoiseProtocolId nid_;
enum class State {
INITIALIZE = 1,
CLIENT_HELLO = 2,
SERVER_HELLO = 3,
HANDSHAKE = 4,
DATA = 5,
CLOSED = 6,
FAILED = 7,
} state_ = State::INITIALIZE;
};
#endif // USE_API_NOISE
#ifdef USE_API_PLAINTEXT
class APIPlaintextFrameHelper : public APIFrameHelper {
public:
APIPlaintextFrameHelper(std::unique_ptr<socket::Socket> socket) : socket_(std::move(socket)) {}
~APIPlaintextFrameHelper() = default;
APIError init() override;
APIError loop() override;
APIError read_packet(ReadPacketBuffer *buffer) override;
bool can_write_without_blocking() override;
APIError write_packet(uint16_t type, const uint8_t *data, size_t len) override;
std::string getpeername() override {
return socket_->getpeername();
}
APIError close() override;
APIError shutdown(int how) override;
// Give this helper a name for logging
void set_log_info(std::string info) override {
info_ = std::move(info);
}
protected:
struct ParsedFrame {
std::vector<uint8_t> msg;
};
APIError try_read_frame_(ParsedFrame *frame);
APIError try_send_tx_buf_();
APIError write_frame_(const uint8_t *data, size_t len);
APIError write_raw_(const uint8_t *data, size_t len);
std::unique_ptr<socket::Socket> socket_;
std::string info_;
std::vector<uint8_t> rx_header_buf_;
bool rx_header_parsed_ = false;
uint32_t rx_header_parsed_type_ = 0;
uint32_t rx_header_parsed_len_ = 0;
std::vector<uint8_t> rx_buf_;
size_t rx_buf_len_ = 0;
std::vector<uint8_t> tx_buf_;
enum class State {
INITIALIZE = 1,
DATA = 2,
CLOSED = 3,
FAILED = 4,
} state_ = State::INITIALIZE;
};
#endif
} // namespace api
} // namespace esphome

View File

@@ -0,0 +1,27 @@
#pragma once
#include <cstdint>
#include <array>
#include "esphome/core/defines.h"
namespace esphome {
namespace api {
#ifdef USE_API_NOISE
using psk_t = std::array<uint8_t, 32>;
class APINoiseContext {
public:
void set_psk(psk_t psk) {
psk_ = std::move(psk);
}
const psk_t &get_psk() const {
return psk_;
}
protected:
psk_t psk_;
};
#endif // USE_API_NOISE
} // namespace api
} // namespace esphome

View File

@@ -5,6 +5,8 @@
#include "esphome/core/util.h"
#include "esphome/core/defines.h"
#include "esphome/core/version.h"
#include <errno.h>
//#include <arpa/inet.h>
#ifdef USE_LOGGER
#include "esphome/components/logger/logger.h"
@@ -21,20 +23,54 @@ static const char *const TAG = "api";
void APIServer::setup() {
ESP_LOGCONFIG(TAG, "Setting up Home Assistant API server...");
this->setup_controller();
this->server_ = AsyncServer(this->port_);
this->server_.setNoDelay(false);
this->server_.begin();
this->server_.onClient(
[](void *s, AsyncClient *client) {
if (client == nullptr)
return;
socket_ = socket::socket(AF_INET, SOCK_STREAM, 0);
if (socket_ == nullptr) {
ESP_LOGW(TAG, "Could not create socket.");
this->mark_failed();
return;
}
int enable = 1;
int err = socket_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to set reuseaddr: errno %d", err);
// we can still continue
}
err = socket_->setblocking(false);
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to set nonblocking mode: errno %d", err);
this->mark_failed();
return;
}
/*struct sockaddr_storage dest_addr;
memset(&dest_addr, 0, sizeof(dest_addr));
struct sockaddr_in *dest_addr_ip4 = (struct sockaddr_in *) &dest_addr;
dest_addr_ip4->sin_addr.s_addr = htonl(INADDR_ANY);
dest_addr_ip4->sin_family = AF_INET;
dest_addr_ip4->sin_port = htons(this->port_);
err = socket_->bind((struct sockaddr *) &dest_addr, sizeof(dest_addr));*/
struct sockaddr_in server;
memset(&server, 0, sizeof(server));
server.sin_family = AF_INET;
server.sin_addr.s_addr = INADDR_ANY;
server.sin_port = htons(this->port_);
err = socket_->bind((struct sockaddr *) &server, sizeof(server));
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to bind: errno %d", errno);
this->mark_failed();
return;
}
err = socket_->listen(4);
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to listen: errno %d", errno);
this->mark_failed();
return;
}
// can't print here because in lwIP thread
// ESP_LOGD(TAG, "New client connected from %s", client->remoteIP().toString().c_str());
auto *a_this = (APIServer *) s;
a_this->clients_.push_back(new APIConnection(client, a_this));
},
this);
#ifdef USE_LOGGER
if (logger::global_logger != nullptr) {
logger::global_logger->add_on_log_callback([this](int level, const char *tag, const char *message) {
@@ -59,6 +95,20 @@ void APIServer::setup() {
#endif
}
void APIServer::loop() {
// Accept new clients
while (true) {
struct sockaddr_storage source_addr;
socklen_t addr_len = sizeof(source_addr);
auto sock = socket_->accept((struct sockaddr *) &source_addr, &addr_len);
if (!sock)
break;
ESP_LOGD(TAG, "Accepted %s", sock->getpeername().c_str());
auto *conn = new APIConnection(std::move(sock), this);
clients_.push_back(conn);
conn->start();
}
// Partition clients into remove and active
auto new_end =
std::partition(this->clients_.begin(), this->clients_.end(), [](APIConnection *conn) { return !conn->remove_; });

View File

@@ -4,19 +4,14 @@
#include "esphome/core/controller.h"
#include "esphome/core/defines.h"
#include "esphome/core/log.h"
#include "esphome/components/socket/socket.h"
#include "api_pb2.h"
#include "api_pb2_service.h"
#include "util.h"
#include "list_entities.h"
#include "subscribe_state.h"
#include "user_services.h"
#ifdef ARDUINO_ARCH_ESP32
#include <AsyncTCP.h>
#endif
#ifdef ARDUINO_ARCH_ESP8266
#include <ESPAsyncTCP.h>
#endif
#include "api_noise_context.h"
namespace esphome {
namespace api {
@@ -35,6 +30,16 @@ class APIServer : public Component, public Controller {
void set_port(uint16_t port);
void set_password(const std::string &password);
void set_reboot_timeout(uint32_t reboot_timeout);
#ifdef USE_API_NOISE
void set_noise_psk(psk_t psk) {
noise_ctx_->set_psk(std::move(psk));
}
std::shared_ptr<APINoiseContext> get_noise_ctx() {
return noise_ctx_;
}
#endif // USE_API_NOISE
void handle_disconnect(APIConnection *conn);
#ifdef USE_BINARY_SENSOR
void on_binary_sensor_update(binary_sensor::BinarySensor *obj, bool state) override;
@@ -86,7 +91,7 @@ class APIServer : public Component, public Controller {
const std::vector<UserServiceDescriptor *> &get_user_services() const { return this->user_services_; }
protected:
AsyncServer server_{0};
std::unique_ptr<socket::Socket> socket_ = nullptr;
uint16_t port_{6053};
uint32_t reboot_timeout_{300000};
uint32_t last_connected_{0};
@@ -94,6 +99,10 @@ class APIServer : public Component, public Controller {
std::string password_;
std::vector<HomeAssistantStateSubscription> state_subs_;
std::vector<UserServiceDescriptor *> user_services_;
#ifdef USE_API_NOISE
std::shared_ptr<APINoiseContext> noise_ctx_ = std::make_shared<APINoiseContext>();
#endif // USE_API_NOISE
};
extern APIServer *global_api_server; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)

View File

@@ -0,0 +1,38 @@
from esphome.const import CONF_ID
import esphome.config_validation as cv
import esphome.codegen as cg
AUTO_LOAD = ["socket"]
echo_ns = cg.esphome_ns.namespace("echo")
EchoServer = echo_ns.class_("EchoServer", cg.Component)
EchoNoiseServer = echo_ns.class_("EchoNoiseServer", cg.Component)
CONFIG_SCHEMA = cv.Schema(
{
cv.Optional("ssl"): cv.COMPONENT_SCHEMA.extend(
{
cv.GenerateID(): cv.declare_id(EchoServer),
}
),
cv.Optional("noise"): cv.COMPONENT_SCHEMA.extend(
{
cv.GenerateID(): cv.declare_id(EchoNoiseServer),
}
),
}
)
async def to_code(config):
conf = config.get("ssl")
if conf is not None:
var = cg.new_Pvariable(conf[CONF_ID])
await cg.register_component(var, conf)
cg.add_define("USE_ECHO_SSL")
conf = config.get("noise")
if conf is not None:
var = cg.new_Pvariable(conf[CONF_ID])
await cg.register_component(var, conf)
cg.add_define("USE_ECHO_NOISE")
cg.add_library("esphome/noise-c", "0.1.0")

View File

@@ -0,0 +1,500 @@
#include "echo.h"
#include "esphome/core/log.h"
#include "esphome/core/helpers.h"
#include "esphome/core/esphal.h"
namespace esphome {
namespace echo {
static const char *const TAG = "echo";
#ifdef USE_ECHO_SSL
void EchoServer::setup() {
ESP_LOGCONFIG(TAG, "Setting up echo server...");
socket_ = socket::socket(AF_INET, SOCK_STREAM, 0);
if (socket_ == nullptr) {
ESP_LOGW(TAG, "Could not create socket.");
this->mark_failed();
return;
}
int enable = 1;
int err = socket_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to set reuseaddr: errno %d", err);
// we can still continue
}
err = socket_->setblocking(false);
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to set nonblocking mode: errno %d", err);
this->mark_failed();
return;
}
struct sockaddr_in server;
memset(&server, 0, sizeof(server));
server.sin_family = AF_INET;
server.sin_addr.s_addr = INADDR_ANY;
server.sin_port = htons(6055);
err = socket_->bind((struct sockaddr *) &server, sizeof(server));
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to bind: errno %d", errno);
this->mark_failed();
return;
}
err = socket_->listen(4);
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to listen: errno %d", errno);
this->mark_failed();
return;
}
ssl_ = ssl::create_context();
if (!ssl_) {
ESP_LOGW(TAG, "Failed to create SSL context: errno %d", errno);
this->mark_failed();
return;
}
ssl_->set_server_certificate(R"(-----BEGIN CERTIFICATE-----
MIIB3zCCAYWgAwIBAgIUZvjl3kvRTeMLlFUXR8Vbhw0UXnMwCgYIKoZIzj0EAwIw
RTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGElu
dGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMTA4MTAxODA5NDNaFw0yNDA1MDYx
ODA5NDNaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYD
VQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwWTATBgcqhkjOPQIBBggqhkjO
PQMBBwNCAAQrW9JCZlDynrY1DphZGpDxV1jkfoVSTiBWQLWuixolu7aJuR3+o+BJ
ZrvPCdNHpEOyx7r1DV23SWSp1eIZR43co1MwUTAdBgNVHQ4EFgQUPmWned9/9QAq
TCnb3I8dou8plDkwHwYDVR0jBBgwFoAUPmWned9/9QAqTCnb3I8dou8plDkwDwYD
VR0TAQH/BAUwAwEB/zAKBggqhkjOPQQDAgNIADBFAiBi375FEb+w297p0J/12lgp
iA9ppA4/QwtZdzioULmwVAIhALhGbVdbSAaLI+bwoICROHnuttY6mxJmDK8158Xe
s2U4
-----END CERTIFICATE-----
)");
ssl_->set_private_key(R"(-----BEGIN EC PRIVATE KEY-----
MHcCAQEEICqgqSPPEMmoWbwLpLm1lv4FQ48TsLOmXRbdceKs4DQ/oAoGCCqGSM49
AwEHoUQDQgAEK1vSQmZQ8p62NQ6YWRqQ8VdY5H6FUk4gVkC1rosaJbu2ibkd/qPg
SWa7zwnTR6RDsse69Q1dt0lkqdXiGUeN3A==
-----END EC PRIVATE KEY-----
)");
err = ssl_->init();
if (err != 0) {
ESP_LOGW(TAG, "Failed to initialize SSL context: errno %d", errno);
this->mark_failed();
return;
}
}
void EchoServer::loop() {
// Accept new clients
while (true) {
struct sockaddr_storage source_addr;
socklen_t addr_len = sizeof(source_addr);
auto sock = socket_->accept((struct sockaddr *) &source_addr, &addr_len);
if (!sock)
break;
ESP_LOGD(TAG, "Accepted %s", sock->getpeername().c_str());
// wrap socket
auto sock2 = ssl_->wrap_socket(std::move(sock));
if (!sock2) {
ESP_LOGW(TAG, "Failed to wrap socket with SSL: errno %d", errno);
continue;
}
auto cli = std::unique_ptr<EchoClient>{new EchoClient(std::move(sock2))};
cli->start();
clients_.push_back(std::move(cli));
}
auto new_end = std::partition(this->clients_.begin(), this->clients_.end(),
[](const std::unique_ptr<EchoClient> &cli) { return !cli->remove_; });
this->clients_.erase(new_end, this->clients_.end());
for (auto &client : this->clients_) {
client->loop();
}
}
void EchoClient::start() {
ESP_LOGD(TAG, "Starting socket");
int err = socket_->setblocking(false);
if (err != 0) {
on_error_();
ESP_LOGW(TAG, "Socket could not enable non-blocking, errno: %d", errno);
return;
}
int enable = 1;
err = socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int));
if (err != 0) {
on_error_();
ESP_LOGW(TAG, "Socket could not enable tcp nodelay, errno: %d", errno);
return;
}
rx_buffer_.resize(64);
}
void EchoClient::loop() {
uint32_t start = millis();
int err = socket_->loop();
if (err != 0) {
on_error_();
ESP_LOGW(TAG, "Socket loop failed: errno %d", errno);
return;
}
while (!this->remove_) {
size_t capacity = this->rx_buffer_.size();
ssize_t received = socket_->read(rx_buffer_.data(), capacity);
if (received == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// read would block
break;
}
if (errno == ECONNRESET) {
// connection reset
this->on_error_();
ESP_LOGW(TAG, "Client disconnected");
return;
}
this->on_error_();
ESP_LOGW(TAG, "Error reading from socket: errno %d", errno);
return;
} else if (received == 0) {
break;
}
ESP_LOGD(TAG, "received %s", hexencode(rx_buffer_.data(), received).c_str());
tx_buffer_.insert(tx_buffer_.end(), rx_buffer_.begin(), rx_buffer_.begin() + received);
if (received != capacity)
// done with reading
break;
}
while (!this->remove_ && !tx_buffer_.empty()) {
int err = socket_->write(tx_buffer_.data(), tx_buffer_.size());
if (err == -1) {
if (errno == EWOULDBLOCK || errno == EAGAIN) {
break;
}
if (errno == ECONNRESET) {
this->on_error_();
ESP_LOGW(TAG, "Client disconnected");
return;
}
on_error_();
ESP_LOGW(TAG, "Socket write failed: errno %d", errno);
return;
} else if (err == 0) {
break;
}
tx_buffer_.erase(tx_buffer_.begin(), tx_buffer_.begin() + err);
}
uint32_t end = millis();
if (end - start > 10) {
ESP_LOGD(TAG, "loop took %u ms", end - start);
}
}
void EchoClient::on_error_() {
this->socket_->close();
this->remove_ = true;
}
#endif
#ifdef USE_ECHO_NOISE
void EchoNoiseServer::setup() {
ESP_LOGCONFIG(TAG, "Setting up echo server...");
socket_ = socket::socket(AF_INET, SOCK_STREAM, 0);
if (socket_ == nullptr) {
ESP_LOGW(TAG, "Could not create socket.");
this->mark_failed();
return;
}
int enable = 1;
int err = socket_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to set reuseaddr: errno %d", err);
// we can still continue
}
err = socket_->setblocking(false);
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to set nonblocking mode: errno %d", err);
this->mark_failed();
return;
}
struct sockaddr_in server;
memset(&server, 0, sizeof(server));
server.sin_family = AF_INET;
server.sin_addr.s_addr = INADDR_ANY;
server.sin_port = htons(6056);
err = socket_->bind((struct sockaddr *) &server, sizeof(server));
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to bind: errno %d", errno);
this->mark_failed();
return;
}
err = socket_->listen(4);
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to listen: errno %d", errno);
this->mark_failed();
return;
}
}
void EchoNoiseServer::loop() {
// Accept new clients
while (true) {
struct sockaddr_storage source_addr;
socklen_t addr_len = sizeof(source_addr);
auto sock = socket_->accept((struct sockaddr *) &source_addr, &addr_len);
if (!sock)
break;
ESP_LOGD(TAG, "Accepted %s", sock->getpeername().c_str());
auto cli = std::unique_ptr<EchoNoiseClient>{new EchoNoiseClient(std::move(sock))};
cli->start();
clients_.push_back(std::move(cli));
}
auto new_end = std::partition(this->clients_.begin(), this->clients_.end(),
[](const std::unique_ptr<EchoNoiseClient> &cli) { return !cli->remove_; });
this->clients_.erase(new_end, this->clients_.end());
for (auto &client : this->clients_) {
client->loop();
}
}
void EchoNoiseClient::start() {
ESP_LOGD(TAG, "Starting socket");
int err = socket_->setblocking(false);
if (err != 0) {
on_error_();
ESP_LOGW(TAG, "Socket could not enable non-blocking, errno: %d", errno);
return;
}
int enable = 1;
err = socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int));
if (err != 0) {
on_error_();
ESP_LOGW(TAG, "Socket could not enable tcp nodelay, errno: %d", errno);
return;
}
memset(&nid_, 0, sizeof(nid_));
const char *proto = "Noise_NNpsk0_25519_ChaChaPoly_SHA256";
err = noise_protocol_name_to_id(&nid_, proto, strlen(proto));
if (err != 0) {
on_error_();
ESP_LOGW(TAG, "noise_protocol_name_to_id failed: %d", err);
return;
}
/*nid_.pattern_id = NOISE_PATTERN_NN;
nid_.cipher_id = NOISE_CIPHER_CHACHAPOLY;
nid_.dh_id = NOISE_DH_CURVE25519;
nid_.prefix_id = NOISE_PREFIX_STANDARD;
nid_.hybrid_id = NOISE_DH_NONE;
nid_.hash_id = NOISE_HASH_SHA256; // NOISE_HASH_BLAKE2s
nid_.modifier_ids[0] = NOISE_MODIFIER_PSK0;*/
err = noise_handshakestate_new_by_id(&handshake_, &nid_, NOISE_ROLE_RESPONDER);
if (err != 0) {
on_error_();
ESP_LOGW(TAG, "noise_handshakestate_new_by_id failed: %d", err);
return;
}
// initialize_handshake
{
const uint8_t psk[] = {0xC1, 0xD5, 0xE0, 0x72, 0xE7, 0x77, 0x58, 0x02, 0x45, 0xCB, 0x3A,
0x81, 0x04, 0x1B, 0x2D, 0x90, 0x3A, 0x0F, 0x0E, 0xC7, 0x9C, 0xFC,
0xB4, 0x2A, 0x50, 0xC0, 0xE6, 0x35, 0xA1, 0x54, 0x18, 0x12};
static_assert(sizeof(psk) == 32, "error");
// noise_handshakestate_set_prologue(handshake, prologue, strlen(prologue));
err = noise_handshakestate_set_pre_shared_key(handshake_, psk, 32);
if (err != 0) {
on_error_();
ESP_LOGW(TAG, "noise_handshakestate_set_pre_shared_key failed: %d", err);
return;
}
}
// Start the handshake
err = noise_handshakestate_start(handshake_);
if (err != 0) {
on_error_();
ESP_LOGW(TAG, "noise_handshakestate_start failed: %d", err);
return;
}
rx_buffer_.resize(64);
msg_buffer_.resize(64);
do_handshake_ = true;
}
void EchoNoiseClient::loop() {
if (this->remove_)
return;
const uint32_t start = millis();
int err = socket_->loop();
if (err != 0) {
on_error_();
ESP_LOGW(TAG, "Socket loop failed: errno %d", errno);
return;
}
NoiseBuffer mbuf;
noise_buffer_init(mbuf);
if (!this->remove_ && do_handshake_) {
int action = noise_handshakestate_get_action(handshake_);
if (action == NOISE_ACTION_WRITE_MESSAGE) {
// Write the next handshake message with a zero-length payload
noise_buffer_set_output(mbuf, msg_buffer_.data(), msg_buffer_.size());
err = noise_handshakestate_write_message(handshake_, &mbuf, nullptr);
if (err == 0) {
tx_buffer_.push_back((uint8_t)(mbuf.size >> 8));
tx_buffer_.push_back((uint8_t) mbuf.size);
tx_buffer_.insert(tx_buffer_.end(), msg_buffer_.begin(), msg_buffer_.begin() + mbuf.size);
} else {
on_error_();
ESP_LOGW(TAG, "noise_handshakestate_write_message failed: %d", err);
return;
}
} else if (action == NOISE_ACTION_READ_MESSAGE) {
if (rx_size_ >= 2) {
uint16_t msg_size = ((uint16_t)(rx_buffer_[0]) << 8) | (rx_buffer_[1]);
if (rx_size_ >= msg_size + 2) {
ESP_LOGD(TAG, "Message: %s", hexencode(rx_buffer_.data() + 2, msg_size).c_str());
noise_buffer_set_input(mbuf, rx_buffer_.data() + 2, msg_size);
err = noise_handshakestate_read_message(handshake_, &mbuf, nullptr);
if (err != 0) {
on_error_();
ESP_LOGW(TAG, "noise_handshakestate_read_message failed: %d", err);
return;
}
rx_size_ -= msg_size + 2;
}
}
} else if (action == NOISE_ACTION_SPLIT) {
err = noise_handshakestate_split(handshake_, &send_cipher_, &recv_cipher_);
if (err != 0) {
on_error_();
ESP_LOGW(TAG, "noise_handshakestate_split failed: %d", err);
return;
}
noise_handshakestate_free(handshake_);
do_handshake_ = false;
ESP_LOGI(TAG, "handshake finished");
} else {
on_error_();
ESP_LOGW(TAG, "noise_handshakestate_get_action failed: %d", err);
return;
}
}
while (!this->remove_ && !this->do_handshake_) {
if (rx_size_ < 2)
break;
uint16_t msg_size = ((uint16_t)(rx_buffer_[0]) << 8) | (rx_buffer_[1]);
if (rx_size_ < msg_size + 2)
break;
noise_buffer_set_inout(mbuf, rx_buffer_.data() + 2, msg_size, rx_buffer_.size() - 2);
err = noise_cipherstate_decrypt(recv_cipher_, &mbuf);
if (err != 0) {
on_error_();
ESP_LOGW(TAG, "noise_cipherstate_decrypt failed: %d", err);
return;
}
rx_size_ -= msg_size + 2;
err = noise_cipherstate_encrypt(send_cipher_, &mbuf);
if (err != 0) {
on_error_();
ESP_LOGW(TAG, "noise_cipherstate_encrypt failed: %d", err);
return;
}
tx_buffer_.push_back((uint8_t)(mbuf.size >> 8));
tx_buffer_.push_back((uint8_t) mbuf.size);
tx_buffer_.insert(tx_buffer_.end(), rx_buffer_.begin() + 2, rx_buffer_.begin() + 2 + mbuf.size);
}
while (!this->remove_) {
size_t capacity = this->rx_buffer_.size();
size_t used = rx_size_;
size_t space = capacity - used;
if (space == 0) {
rx_buffer_.resize(capacity + 64);
continue;
}
ssize_t received = socket_->read(rx_buffer_.data() + used, space);
if (received == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// read would block
break;
}
if (errno == ECONNRESET) {
// connection reset
this->on_error_();
ESP_LOGW(TAG, "Client disconnected");
return;
}
this->on_error_();
ESP_LOGW(TAG, "Error reading from socket: errno %d", errno);
return;
} else if (received == 0) {
break;
}
ESP_LOGD(TAG, "received %s", hexencode(rx_buffer_.data(), received).c_str());
rx_size_ += received;
if (received != capacity)
// done with reading
break;
}
while (!this->remove_ && !tx_buffer_.empty()) {
int err = socket_->write(tx_buffer_.data(), tx_buffer_.size());
if (err == -1) {
if (errno == EWOULDBLOCK || errno == EAGAIN) {
break;
}
if (errno == ECONNRESET) {
this->on_error_();
ESP_LOGW(TAG, "Client disconnected");
return;
}
on_error_();
ESP_LOGW(TAG, "Socket write failed: errno %d", errno);
return;
} else if (err == 0) {
break;
}
tx_buffer_.erase(tx_buffer_.begin(), tx_buffer_.begin() + err);
}
const uint32_t end = millis();
if (end - start > 10) {
ESP_LOGD(TAG, "noise took %u ms", end - start);
}
}
void EchoNoiseClient::on_error_() {
this->socket_->close();
this->remove_ = true;
}
#endif
} // namespace echo
} // namespace esphome
extern "C" {
void noise_rand_bytes(void *output, size_t len) { esp_fill_random(output, len); }
}

View File

@@ -0,0 +1,88 @@
#pragma once
#include <vector>
#include "esphome/core/component.h"
#include "esphome/core/defines.h"
#include "esphome/components/socket/socket.h"
#ifdef USE_ECHO_SSL
#include "esphome/components/ssl/ssl_context.h"
#endif
#include "noise/protocol.h"
namespace esphome {
namespace echo {
#ifdef USE_ECHO_SSL
class EchoClient {
public:
EchoClient(std::unique_ptr<socket::Socket> socket) : socket_(std::move(socket)) {}
void start();
void loop();
protected:
friend class EchoServer;
void on_error_();
std::unique_ptr<socket::Socket> socket_ = nullptr;
bool remove_ = false;
std::vector<uint8_t> rx_buffer_;
std::vector<uint8_t> tx_buffer_;
};
class EchoServer : public Component {
public:
void setup() override;
void loop() override;
float get_setup_priority() const override { return setup_priority::AFTER_WIFI; }
protected:
friend class EchoClient;
std::unique_ptr<socket::Socket> socket_ = nullptr;
std::unique_ptr<ssl::SSLContext> ssl_ = nullptr;
std::vector<std::unique_ptr<EchoClient>> clients_;
};
#endif
#ifdef USE_ECHO_NOISE
class EchoNoiseClient {
public:
EchoNoiseClient(std::unique_ptr<socket::Socket> socket) : socket_(std::move(socket)) {}
void start();
void loop();
protected:
friend class EchoNoiseServer;
void on_error_();
std::unique_ptr<socket::Socket> socket_ = nullptr;
bool remove_ = false;
std::vector<uint8_t> rx_buffer_;
size_t rx_size_ = 0;
std::vector<uint8_t> tx_buffer_;
std::vector<uint8_t> msg_buffer_;
NoiseHandshakeState *handshake_ = nullptr;
NoiseCipherState *send_cipher_ = nullptr;
NoiseCipherState *recv_cipher_ = nullptr;
NoiseProtocolId nid_;
bool do_handshake_ = false;
};
class EchoNoiseServer : public Component {
public:
void setup() override;
void loop() override;
float get_setup_priority() const override { return setup_priority::AFTER_WIFI; }
protected:
friend class EchoNoiseClient;
std::unique_ptr<socket::Socket> socket_ = nullptr;
std::vector<std::unique_ptr<EchoNoiseClient>> clients_;
};
#endif
} // namespace echo
} // namespace esphome

View File

@@ -0,0 +1,74 @@
import socket
from noise.connection import NoiseConnection
proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
proto.set_as_initiator()
proto.set_psks(
b"\xC1\xD5\xE0\x72\xE7\x77\x58\x02\x45\xCB\x3A\x81\x04\x1B\x2D\x90"
b"\x3A\x0F\x0E\xC7\x9C\xFC\xB4\x2A\x50\xC0\xE6\x35\xA1\x54\x18\x12"
)
# sys.exit(1)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
print("[x] Connecting...")
sock.connect(("192.168.178.154", 6053))
print("[x] Connected!")
prologue = b"NoiseAPIInit"
def write(msg):
print(f"[x] Writing frame {msg.hex()}")
l = len(msg)
buf = bytes(
[
0x01,
(l >> 8) & 0xFF,
(l >> 0) & 0xFF,
]
)
buf += msg
print(f" -> {buf.hex()}")
sock.sendall(buf)
def recv():
buf = b""
while len(buf) < 3:
buf += sock.recv(3 - len(buf))
assert buf[0] == 0x01
l = (buf[1] << 8) | buf[2]
buf = buf[3:]
while len(buf) < l:
buf += sock.recv(l - len(buf))
print(f"[x] Received frame {buf.hex()}")
return buf
write(b"")
prologue += b"\x00\x00"
buf = recv()
print(f"Received msg {buf.hex()}")
proto.set_prologue(prologue)
proto.start_handshake()
do_write = True
while not proto.handshake_finished:
if do_write:
msg = proto.write_message()
write(msg)
else:
msg = recv()
proto.read_message(msg)
do_write = not do_write
print(f"[x] Handshake done!")
while True:
msg = input().encode()
buf = proto.encrypt(msg)
write(buf)
buf = recv()
msg2 = proto.decrypt(buf)
print(msg2)

View File

@@ -43,21 +43,24 @@ void Logger::write_header_(int level, const char *tag, int line) {
}
void HOT Logger::log_vprintf_(int level, const char *tag, int line, const char *format, va_list args) { // NOLINT
if (level > this->level_for(tag))
if (level > this->level_for(tag) || recursion_guard_)
return;
recursion_guard_ = true;
this->reset_buffer_();
this->write_header_(level, tag, line);
this->vprintf_to_buffer_(format, args);
this->write_footer_();
this->log_message_(level, tag);
recursion_guard_ = false;
}
#ifdef USE_STORE_LOG_STR_IN_FLASH
void Logger::log_vprintf_(int level, const char *tag, int line, const __FlashStringHelper *format,
va_list args) { // NOLINT
if (level > this->level_for(tag))
if (level > this->level_for(tag) || recursion_guard_)
return;
recursion_guard_ = true;
this->reset_buffer_();
// copy format string
const char *format_pgm_p = (PGM_P) format;
@@ -78,6 +81,7 @@ void Logger::log_vprintf_(int level, const char *tag, int line, const __FlashStr
this->vprintf_to_buffer_(this->tx_buffer_, args);
this->write_footer_();
this->log_message_(level, tag, offset);
recursion_guard_ = false;
}
#endif

View File

@@ -113,6 +113,8 @@ class Logger : public Component {
};
std::vector<LogLevelOverride> log_levels_;
CallbackManager<void(int, const char *, const char *)> log_callback_{};
/// Prevents recursive log calls, if true a log message is already being processed.
bool recursion_guard_ = false;
};
extern Logger *global_logger; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)

View File

@@ -15,6 +15,7 @@ from esphome.core import CORE, coroutine_with_priority
CODEOWNERS = ["@esphome/core"]
DEPENDENCIES = ["network"]
AUTO_LOAD = ["socket"]
CONF_ON_STATE_CHANGE = "on_state_change"
CONF_ON_BEGIN = "on_begin"

View File

@@ -3,7 +3,7 @@
#include "esphome/core/log.h"
#include "esphome/core/application.h"
#include "esphome/core/util.h"
#include <errno.h>
#include <cstdio>
#include <MD5Builder.h>
#ifdef ARDUINO_ARCH_ESP32
@@ -19,8 +19,44 @@ static const char *const TAG = "ota";
static const uint8_t OTA_VERSION_1_0 = 1;
void OTAComponent::setup() {
this->server_ = new WiFiServer(this->port_);
this->server_->begin();
server_ = socket::socket(AF_INET, SOCK_STREAM, 0);
if (server_ == nullptr) {
ESP_LOGW(TAG, "Could not create socket.");
this->mark_failed();
return;
}
int enable = 1;
int err = server_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to set reuseaddr: errno %d", err);
// we can still continue
}
err = server_->setblocking(false);
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to set nonblocking mode: errno %d", err);
this->mark_failed();
return;
}
struct sockaddr_in server;
memset(&server, 0, sizeof(server));
server.sin_family = AF_INET;
server.sin_addr.s_addr = INADDR_ANY;
server.sin_port = htons(this->port_);
err = server_->bind((struct sockaddr *) &server, sizeof(server));
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to bind: errno %d", errno);
this->mark_failed();
return;
}
err = server_->listen(4);
if (err != 0) {
ESP_LOGW(TAG, "Socket unable to listen: errno %d", errno);
this->mark_failed();
return;
}
this->dump_config();
}
@@ -59,23 +95,28 @@ void OTAComponent::handle_() {
uint8_t ota_features;
(void) ota_features;
if (!this->client_.connected()) {
this->client_ = this->server_->available();
if (client_ == nullptr) {
struct sockaddr_storage source_addr;
socklen_t addr_len = sizeof(source_addr);
client_ = server_->accept((struct sockaddr *) &source_addr, &addr_len);
}
if (client_ == nullptr)
return;
if (!this->client_.connected())
return;
int enable = 1;
int err = client_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int));
if (err != 0) {
ESP_LOGW(TAG, "Socket could not enable tcp nodelay, errno: %d", errno);
return;
}
// enable nodelay for outgoing data
this->client_.setNoDelay(true);
ESP_LOGD(TAG, "Starting OTA Update from %s...", this->client_.remoteIP().toString().c_str());
ESP_LOGD(TAG, "Starting OTA Update from %s...", this->client_->getpeername().c_str());
this->status_set_warning();
#ifdef USE_OTA_STATE_CALLBACK
this->state_callback_.call(OTA_STARTED, 0.0f, 0);
#endif
if (!this->wait_receive_(buf, 5)) {
if (!this->readall_(buf, 5)) {
ESP_LOGW(TAG, "Reading magic bytes failed!");
goto error;
}
@@ -88,11 +129,12 @@ void OTAComponent::handle_() {
}
// Send OK and version - 2 bytes
this->client_.write(OTA_RESPONSE_OK);
this->client_.write(OTA_VERSION_1_0);
buf[0] = OTA_RESPONSE_OK;
buf[1] = OTA_VERSION_1_0;
this->writeall_(buf, 2);
// Read features - 1 byte
if (!this->wait_receive_(buf, 1)) {
if (!this->readall_(buf, 1)) {
ESP_LOGW(TAG, "Reading features failed!");
goto error;
}
@@ -100,10 +142,12 @@ void OTAComponent::handle_() {
ESP_LOGV(TAG, "OTA features is 0x%02X", ota_features);
// Acknowledge header - 1 byte
this->client_.write(OTA_RESPONSE_HEADER_OK);
buf[0] = OTA_RESPONSE_HEADER_OK;
this->writeall_(buf, 1);
if (!this->password_.empty()) {
this->client_.write(OTA_RESPONSE_REQUEST_AUTH);
buf[0] = OTA_RESPONSE_REQUEST_AUTH;
this->writeall_(buf, 1);
MD5Builder md5_builder{};
md5_builder.begin();
sprintf(sbuf, "%08X", random_uint32());
@@ -113,7 +157,7 @@ void OTAComponent::handle_() {
ESP_LOGV(TAG, "Auth: Nonce is %s", sbuf);
// Send nonce, 32 bytes hex MD5
if (this->client_.write(reinterpret_cast<uint8_t *>(sbuf), 32) != 32) {
if (!this->writeall_(reinterpret_cast<uint8_t *>(sbuf), 32)) {
ESP_LOGW(TAG, "Auth: Writing nonce failed!");
goto error;
}
@@ -125,7 +169,7 @@ void OTAComponent::handle_() {
md5_builder.add(sbuf);
// Receive cnonce, 32 bytes hex MD5
if (!this->wait_receive_(buf, 32)) {
if (!this->readall_(buf, 32)) {
ESP_LOGW(TAG, "Auth: Reading cnonce failed!");
goto error;
}
@@ -140,7 +184,7 @@ void OTAComponent::handle_() {
ESP_LOGV(TAG, "Auth: Result is %s", sbuf);
// Receive result, 32 bytes hex MD5
if (!this->wait_receive_(buf + 64, 32)) {
if (!this->writeall_(buf + 64, 32)) {
ESP_LOGW(TAG, "Auth: Reading response failed!");
goto error;
}
@@ -159,10 +203,11 @@ void OTAComponent::handle_() {
}
// Acknowledge auth OK - 1 byte
this->client_.write(OTA_RESPONSE_AUTH_OK);
buf[0] = OTA_RESPONSE_AUTH_OK;
this->writeall_(buf, 1);
// Read size, 4 bytes MSB first
if (!this->wait_receive_(buf, 4)) {
if (!this->readall_(buf, 4)) {
ESP_LOGW(TAG, "Reading size failed!");
goto error;
}
@@ -211,10 +256,11 @@ void OTAComponent::handle_() {
update_started = true;
// Acknowledge prepare OK - 1 byte
this->client_.write(OTA_RESPONSE_UPDATE_PREPARE_OK);
buf[0] = OTA_RESPONSE_UPDATE_PREPARE_OK;
this->writeall_(buf, 1);
// Read binary MD5, 32 bytes
if (!this->wait_receive_(buf, 32)) {
if (!this->readall_(buf, 32)) {
ESP_LOGW(TAG, "Reading binary MD5 checksum failed!");
goto error;
}
@@ -223,17 +269,22 @@ void OTAComponent::handle_() {
Update.setMD5(sbuf);
// Acknowledge MD5 OK - 1 byte
this->client_.write(OTA_RESPONSE_BIN_MD5_OK);
buf[0] = OTA_RESPONSE_BIN_MD5_OK;
this->writeall_(buf, 1);
while (!Update.isFinished()) {
size_t available = this->wait_receive_(buf, 0);
if (!available) {
// TODO: timeout check
ssize_t read = this->client_->read(buf, sizeof(buf));
if (read == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK)
continue;
ESP_LOGW(TAG, "Error receiving data for update, errno: %d", errno);
goto error;
}
uint32_t written = Update.write(buf, available);
if (written != available) {
ESP_LOGW(TAG, "Error writing binary data to flash: %u != %u!", written, available); // NOLINT
uint32_t written = Update.write(buf, read);
if (written != read) {
ESP_LOGW(TAG, "Error writing binary data to flash: %u != %u!", written, read); // NOLINT
error_code = OTA_RESPONSE_ERROR_WRITING_FLASH;
goto error;
}
@@ -253,7 +304,8 @@ void OTAComponent::handle_() {
}
// Acknowledge receive OK - 1 byte
this->client_.write(OTA_RESPONSE_RECEIVE_OK);
buf[0] = OTA_RESPONSE_RECEIVE_OK;
this->writeall_(buf, 1);
if (!Update.end()) {
error_code = OTA_RESPONSE_ERROR_UPDATE_END;
@@ -261,16 +313,17 @@ void OTAComponent::handle_() {
}
// Acknowledge Update end OK - 1 byte
this->client_.write(OTA_RESPONSE_UPDATE_END_OK);
buf[0] = OTA_RESPONSE_UPDATE_END_OK;
this->writeall_(buf, 1);
// Read ACK
if (!this->wait_receive_(buf, 1, false) || buf[0] != OTA_RESPONSE_OK) {
if (!this->readall_(buf, 1) || buf[0] != OTA_RESPONSE_OK) {
ESP_LOGW(TAG, "Reading back acknowledgement failed!");
// do not go to error, this is not fatal
}
this->client_.flush();
this->client_.stop();
this->client_->close();
this->client_ = nullptr;
delay(10);
ESP_LOGI(TAG, "OTA update finished!");
this->status_clear_warning();
@@ -286,11 +339,10 @@ error:
Update.printError(ss);
ESP_LOGW(TAG, "Update end failed! Error: %s", ss.c_str());
}
if (this->client_.connected()) {
this->client_.write(static_cast<uint8_t>(error_code));
this->client_.flush();
}
this->client_.stop();
buf[0] = static_cast<uint8_t>(error_code);
this->writeall_(buf, 1);
this->client_->close();
this->client_ = nullptr;
#ifdef ARDUINO_ARCH_ESP32
if (update_started) {
@@ -314,52 +366,56 @@ error:
#endif
}
size_t OTAComponent::wait_receive_(uint8_t *buf, size_t bytes, bool check_disconnected) {
size_t available = 0;
bool OTAComponent::readall_(uint8_t *buf, size_t len) {
uint32_t start = millis();
do {
App.feed_wdt();
if (check_disconnected && !this->client_.connected()) {
ESP_LOGW(TAG, "Error client disconnected while receiving data!");
return 0;
}
int availi = this->client_.available();
if (availi < 0) {
ESP_LOGW(TAG, "Error reading data!");
return 0;
}
uint32_t at = 0;
while (len - at > 0) {
uint32_t now = millis();
if (availi == 0 && now - start > 10000) {
ESP_LOGW(TAG, "Timeout waiting for data!");
return 0;
if (now - start > 1000) {
ESP_LOGW(TAG, "Timed out reading %d bytes of data", len);
return false;
}
available = size_t(availi);
yield();
} while (bytes == 0 ? available == 0 : available < bytes);
if (bytes == 0)
bytes = std::min(available, size_t(1024));
bool success = false;
for (uint32_t i = 0; !success && i < 100; i++) {
int res = this->client_.read(buf, bytes);
if (res != int(bytes)) {
// ESP32 implementation has an issue where calling read can fail with EAGAIN (race condition)
// so just re-try it until it works (with generous timeout of 1s)
// because we check with available() first this should not cause us any trouble in all other cases
delay(10);
ssize_t read = this->client_->read(buf + at, len - at);
if (read == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
delay(1);
continue;
}
ESP_LOGW(TAG, "Failed to read %d bytes of data, errno: %d", len, errno);
return false;
} else {
success = true;
at += read;
}
delay(1);
}
if (!success) {
ESP_LOGW(TAG, "Reading %u bytes of binary data failed!", bytes); // NOLINT
return 0;
}
return true;
}
bool OTAComponent::writeall_(const uint8_t *buf, size_t len) {
uint32_t start = millis();
uint32_t at = 0;
while (len - at > 0) {
uint32_t now = millis();
if (now - start > 1000) {
ESP_LOGW(TAG, "Timed out writing %d bytes of data", len);
return false;
}
return bytes;
ssize_t written = this->client_->write(buf + at, len - at);
if (written == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
delay(1);
continue;
}
ESP_LOGW(TAG, "Failed to write %d bytes of data, errno: %d", len, errno);
return false;
} else {
at += written;
}
delay(1);
}
return true;
}
void OTAComponent::set_auth_password(const std::string &password) { this->password_ = password; }

View File

@@ -1,10 +1,9 @@
#pragma once
#include "esphome/components/socket/socket.h"
#include "esphome/core/component.h"
#include "esphome/core/preferences.h"
#include "esphome/core/helpers.h"
#include <WiFiServer.h>
#include <WiFiClient.h>
namespace esphome {
namespace ota {
@@ -74,14 +73,15 @@ class OTAComponent : public Component {
uint32_t read_rtc_();
void handle_();
size_t wait_receive_(uint8_t *buf, size_t bytes, bool check_disconnected = true);
bool readall_(uint8_t *buf, size_t len);
bool writeall_(const uint8_t *buf, size_t len);
std::string password_;
uint16_t port_;
WiFiServer *server_{nullptr};
WiFiClient client_{};
std::unique_ptr<socket::Socket> server_;
std::unique_ptr<socket::Socket> client_;
bool has_safe_mode_{false}; ///< stores whether safe mode can be enabled.
uint32_t safe_mode_start_time_; ///< stores when safe mode was enabled.

View File

@@ -0,0 +1,28 @@
import esphome.config_validation as cv
import esphome.codegen as cg
CODEOWNERS = ["@esphome/core"]
CONF_IMPLEMENTATION = "implementation"
IMPLEMENTATION_LWIP_TCP = "lwip_tcp"
IMPLEMENTATION_BSD_SOCKETS = "bsd_sockets"
CONFIG_SCHEMA = cv.Schema(
{
cv.SplitDefault(
CONF_IMPLEMENTATION,
esp8266=IMPLEMENTATION_LWIP_TCP,
esp32=IMPLEMENTATION_BSD_SOCKETS,
): cv.one_of(
IMPLEMENTATION_LWIP_TCP, IMPLEMENTATION_BSD_SOCKETS, lower=True, space="_"
),
}
)
async def to_code(config):
impl = config[CONF_IMPLEMENTATION]
if impl == IMPLEMENTATION_LWIP_TCP:
cg.add_define("USE_SOCKET_IMPL_LWIP_TCP")
elif impl == IMPLEMENTATION_BSD_SOCKETS:
cg.add_define("USE_SOCKET_IMPL_BSD_SOCKETS")

View File

@@ -0,0 +1,105 @@
#include "socket.h"
#include "esphome/core/defines.h"
#ifdef USE_SOCKET_IMPL_BSD_SOCKETS
#include <string.h>
namespace esphome {
namespace socket {
std::string format_sockaddr(const struct sockaddr_storage &storage) {
if (storage.ss_family == AF_INET) {
const struct sockaddr_in *addr = reinterpret_cast<const struct sockaddr_in *>(&storage);
char buf[INET_ADDRSTRLEN];
const char *ret = inet_ntop(AF_INET, &addr->sin_addr, buf, sizeof(buf));
if (ret == NULL)
return {};
return std::string{buf};
} else if (storage.ss_family == AF_INET6) {
const struct sockaddr_in6 *addr = reinterpret_cast<const struct sockaddr_in6 *>(&storage);
char buf[INET6_ADDRSTRLEN];
const char *ret = inet_ntop(AF_INET6, &addr->sin6_addr, buf, sizeof(buf));
if (ret == NULL)
return {};
return std::string{buf};
}
return {};
}
class BSDSocketImpl : public Socket {
public:
BSDSocketImpl(int fd) : Socket(), fd_(fd) {}
~BSDSocketImpl() override {
if (!closed_) {
close();
}
}
std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) override {
int fd = ::accept(fd_, addr, addrlen);
if (fd == -1)
return {};
return std::unique_ptr<BSDSocketImpl>{new BSDSocketImpl(fd)};
}
int bind(const struct sockaddr *addr, socklen_t addrlen) override { return ::bind(fd_, addr, addrlen); }
int close() override {
int ret = ::close(fd_);
closed_ = true;
return ret;
}
int shutdown(int how) override { return ::shutdown(fd_, how); }
int getpeername(struct sockaddr *addr, socklen_t *addrlen) override { return ::getpeername(fd_, addr, addrlen); }
std::string getpeername() override {
struct sockaddr_storage storage;
socklen_t len = sizeof(storage);
int err = this->getpeername((struct sockaddr *) &storage, &len);
if (err != 0)
return {};
return format_sockaddr(storage);
}
int getsockname(struct sockaddr *addr, socklen_t *addrlen) override { return ::getsockname(fd_, addr, addrlen); }
std::string getsockname() override {
struct sockaddr_storage storage;
socklen_t len = sizeof(storage);
int err = this->getsockname((struct sockaddr *) &storage, &len);
if (err != 0)
return {};
return format_sockaddr(storage);
}
int getsockopt(int level, int optname, void *optval, socklen_t *optlen) override {
return ::getsockopt(fd_, level, optname, optval, optlen);
}
int setsockopt(int level, int optname, const void *optval, socklen_t optlen) override {
return ::setsockopt(fd_, level, optname, optval, optlen);
}
int listen(int backlog) override { return ::listen(fd_, backlog); }
ssize_t read(void *buf, size_t len) override { return ::read(fd_, buf, len); }
ssize_t write(const void *buf, size_t len) override { return ::write(fd_, buf, len); }
int setblocking(bool blocking) override {
int fl = ::fcntl(fd_, F_GETFL, 0);
if (blocking) {
fl &= ~O_NONBLOCK;
} else {
fl |= O_NONBLOCK;
}
::fcntl(fd_, F_SETFL, fl);
return 0;
}
protected:
int fd_;
bool closed_ = false;
};
std::unique_ptr<Socket> socket(int domain, int type, int protocol) {
int ret = ::socket(domain, type, protocol);
if (ret == -1)
return nullptr;
return std::unique_ptr<Socket>{new BSDSocketImpl(ret)};
}
} // namespace socket
} // namespace esphome
#endif // USE_SOCKET_IMPL_BSD_SOCKETS

View File

@@ -0,0 +1,118 @@
#pragma once
#include "esphome/core/defines.h"
// Helper file to include all socket-related system headers (or use our own
// definitions where system ones don't exist)
#ifdef USE_SOCKET_IMPL_LWIP_TCP
#define LWIP_INTERNAL
#include <sys/types.h>
#include "lwip/inet.h"
#include <stdint.h>
#include <errno.h>
/* Address families. */
#define AF_UNSPEC 0
#define AF_INET 2
#define AF_INET6 10
#define PF_INET AF_INET
#define PF_INET6 AF_INET6
#define PF_UNSPEC AF_UNSPEC
#define IPPROTO_IP 0
#define IPPROTO_TCP 6
#define IPPROTO_IPV6 41
#define IPPROTO_ICMPV6 58
#define TCP_NODELAY 0x01
#define F_GETFL 3
#define F_SETFL 4
#define O_NONBLOCK 1
#define SHUT_RD 0
#define SHUT_WR 1
#define SHUT_RDWR 2
/* Socket protocol types (TCP/UDP/RAW) */
#define SOCK_STREAM 1
#define SOCK_DGRAM 2
#define SOCK_RAW 3
#define SO_REUSEADDR 0x0004 /* Allow local address reuse */
#define SO_KEEPALIVE 0x0008 /* keep connections alive */
#define SO_BROADCAST 0x0020 /* permit to send and to receive broadcast messages (see IP_SOF_BROADCAST option) */
#define SOL_SOCKET 0xfff /* options for socket level */
typedef uint8_t sa_family_t;
typedef uint16_t in_port_t;
struct sockaddr_in {
uint8_t sin_len;
sa_family_t sin_family;
in_port_t sin_port;
struct in_addr sin_addr;
#define SIN_ZERO_LEN 8
char sin_zero[SIN_ZERO_LEN];
};
struct sockaddr_in6 {
uint8_t sin6_len; /* length of this structure */
sa_family_t sin6_family; /* AF_INET6 */
in_port_t sin6_port; /* Transport layer port # */
uint32_t sin6_flowinfo; /* IPv6 flow information */
struct in6_addr sin6_addr; /* IPv6 address */
uint32_t sin6_scope_id; /* Set of interfaces for scope */
};
struct sockaddr {
uint8_t sa_len;
sa_family_t sa_family;
char sa_data[14];
};
struct sockaddr_storage {
uint8_t s2_len;
sa_family_t ss_family;
char s2_data1[2];
uint32_t s2_data2[3];
uint32_t s2_data3[3];
};
typedef uint32_t socklen_t;
#ifdef ARDUINO_ARCH_ESP8266
// arduino-esp8266 declares a global vars called INADDR_NONE/ANY which are invalid with the define
#ifdef INADDR_ANY
#undef INADDR_ANY
#endif
#ifdef INADDR_NONE
#undef INADDR_NONE
#endif
#endif
#endif // USE_SOCKET_IMPL_LWIP_TCP
#ifdef USE_SOCKET_IMPL_BSD_SOCKETS
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/ioctl.h>
#include <unistd.h>
#include <fcntl.h>
#include <stdint.h>
#ifdef ARDUINO_ARCH_ESP32
// arduino-esp32 declares a global var called INADDR_NONE which is replaced
// by the define
#ifdef INADDR_NONE
#undef INADDR_NONE
#endif
// not defined for ESP32
typedef uint32_t socklen_t;
#endif // ARDUINO_ARCH_ESP32
#endif // USE_SOCKET_IMPL_BSD_SOCKETS

View File

@@ -0,0 +1,470 @@
#include "socket.h"
#include "esphome/core/defines.h"
#ifdef USE_SOCKET_IMPL_LWIP_TCP
#include <queue>
#include <string.h>
#include "lwip/opt.h"
#include "lwip/ip.h"
#include "lwip/tcp.h"
#include "lwip/netif.h"
#include "errno.h"
#include "esphome/core/log.h"
namespace esphome {
namespace socket {
static const char *const TAG = "lwip";
class LWIPRawImpl : public Socket {
public:
LWIPRawImpl(struct tcp_pcb *pcb) : pcb_(pcb) {}
~LWIPRawImpl() override {
if (pcb_ != nullptr) {
tcp_abort(pcb_);
pcb_ = nullptr;
}
}
void init() {
tcp_arg(pcb_, this);
tcp_accept(pcb_, LWIPRawImpl::s_accept_fn);
tcp_recv(pcb_, LWIPRawImpl::s_recv_fn);
tcp_err(pcb_, LWIPRawImpl::s_err_fn);
}
std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) override {
if (pcb_ == nullptr) {
errno = EBADF;
return nullptr;
}
if (accepted_sockets_.empty()) {
errno = EWOULDBLOCK;
return nullptr;
}
std::unique_ptr<LWIPRawImpl> sock = std::move(accepted_sockets_.front());
accepted_sockets_.pop();
if (addr != nullptr) {
sock->getpeername(addr, addrlen);
}
sock->init();
return std::unique_ptr<Socket>(std::move(sock));
}
int bind(const struct sockaddr *name, socklen_t addrlen) override {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
if (name == nullptr) {
errno = EINVAL;
return 0;
}
ip_addr_t ip;
in_port_t port;
auto family = name->sa_family;
#if LWIP_IPV6
if (family == AF_INET) {
if (addrlen < sizeof(sockaddr_in6)) {
errno = EINVAL;
return -1;
}
auto *addr4 = reinterpret_cast<const sockaddr_in *>(name);
port = ntohs(addr4->sin_port);
ip.type = IPADDR_TYPE_V4;
ip.u_addr.ip4.addr = addr4->sin_addr.s_addr;
} else if (family == AF_INET6) {
if (addrlen < sizeof(sockaddr_in)) {
errno = EINVAL;
return -1;
}
auto *addr6 = reinterpret_cast<const sockaddr_in6 *>(name);
port = ntohs(addr6->sin6_port);
ip.type = IPADDR_TYPE_V6;
memcpy(&ip.u_addr.ip6.addr, &addr6->sin6_addr.un.u8_addr, 16);
} else {
errno = EINVAL;
return -1;
}
#else
if (family != AF_INET) {
errno = EINVAL;
return -1;
}
auto *addr4 = reinterpret_cast<const sockaddr_in *>(name);
port = ntohs(addr4->sin_port);
ip.addr = addr4->sin_addr.s_addr;
#endif
err_t err = tcp_bind(pcb_, &ip, port);
if (err == ERR_USE) {
errno = EADDRINUSE;
return -1;
}
if (err == ERR_VAL) {
errno = EINVAL;
return -1;
}
if (err != ERR_OK) {
errno = EIO;
return -1;
}
return 0;
}
int close() {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
err_t err = tcp_close(pcb_);
if (err != ERR_OK) {
tcp_abort(pcb_);
pcb_ = nullptr;
errno = err == ERR_MEM ? ENOMEM : EIO;
return -1;
}
pcb_ = nullptr;
return 0;
}
int shutdown(int how) override {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
bool shut_rx = false, shut_tx = false;
if (how == SHUT_RD) {
shut_rx = true;
} else if (how == SHUT_WR) {
shut_tx = true;
} else if (how == SHUT_RDWR) {
shut_rx = shut_tx = true;
} else {
errno = EINVAL;
return -1;
}
err_t err = tcp_shutdown(pcb_, shut_rx, shut_tx);
if (err != ERR_OK) {
errno = err == ERR_MEM ? ENOMEM : EIO;
return -1;
}
return 0;
}
int getpeername(struct sockaddr *name, socklen_t *addrlen) override {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
if (name == nullptr || addrlen == nullptr) {
errno = EINVAL;
return -1;
}
if (*addrlen < sizeof(struct sockaddr_in)) {
errno = EINVAL;
return -1;
}
struct sockaddr_in *addr = reinterpret_cast<struct sockaddr_in *>(name);
addr->sin_family = AF_INET;
*addrlen = addr->sin_len = sizeof(struct sockaddr_in);
addr->sin_port = pcb_->remote_port;
addr->sin_addr.s_addr = pcb_->remote_ip.addr;
return 0;
}
std::string getpeername() override {
if (pcb_ == nullptr) {
errno = EBADF;
return "";
}
char buffer[24];
uint32_t ip4 = pcb_->remote_ip.addr;
snprintf(buffer, sizeof(buffer), "%d.%d.%d.%d", (ip4 >> 24) & 0xFF, (ip4 >> 16) & 0xFF, (ip4 >> 8) & 0xFF, (ip4 >> 0) & 0xFF);
return std::string(buffer);
}
int getsockname(struct sockaddr *name, socklen_t *addrlen) override {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
if (name == nullptr || addrlen == nullptr) {
errno = EINVAL;
return -1;
}
if (*addrlen < sizeof(struct sockaddr_in)) {
errno = EINVAL;
return -1;
}
struct sockaddr_in *addr = reinterpret_cast<struct sockaddr_in *>(name);
addr->sin_family = AF_INET;
*addrlen = addr->sin_len = sizeof(struct sockaddr_in);
addr->sin_port = pcb_->local_port;
addr->sin_addr.s_addr = pcb_->local_ip.addr;
return 0;
}
std::string getsockname() override {
if (pcb_ == nullptr) {
errno = EBADF;
return "";
}
char buffer[24];
uint32_t ip4 = pcb_->local_ip.addr;
snprintf(buffer, sizeof(buffer), "%d.%d.%d.%d", (ip4 >> 24) & 0xFF, (ip4 >> 16) & 0xFF, (ip4 >> 8) & 0xFF, (ip4 >> 0) & 0xFF);
return std::string(buffer);
}
int getsockopt(int level, int optname, void *optval, socklen_t *optlen) override {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
if (level == SOL_SOCKET && optname == SO_REUSEADDR) {
if (optlen < 4) {
errno = EINVAL;
return -1;
}
// lwip doesn't seem to have this feature. Don't send an error
// to prevent warnings
*reinterpret_cast<int *>(optval) = 1;
*optlen = 4;
return 0;
}
if (level == IPPROTO_TCP && optname == TCP_NODELAY) {
if (optlen < 4) {
errno = EINVAL;
return -1;
}
*reinterpret_cast<int *>(optval) = tcp_nagle_disabled(pcb_);
*optlen = 4;
return 0;
}
errno = EINVAL;
return -1;
}
int setsockopt(int level, int optname, const void *optval, socklen_t optlen) override {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
if (level == SOL_SOCKET && optname == SO_REUSEADDR) {
if (optlen != 4) {
errno = EINVAL;
return -1;
}
// lwip doesn't seem to have this feature. Don't send an error
// to prevent warnings
return 0;
}
if (level == IPPROTO_TCP && optname == TCP_NODELAY) {
if (optlen != 4) {
errno = EINVAL;
return -1;
}
int val = *reinterpret_cast<const int *>(optval);
if (val != 0) {
tcp_nagle_disable(pcb_);
} else {
tcp_nagle_enable(pcb_);
}
return 0;
}
errno = EINVAL;
return -1;
}
int listen(int backlog) override {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
struct tcp_pcb *listen_pcb = tcp_listen_with_backlog(pcb_, backlog);
if (listen_pcb == nullptr) {
tcp_abort(pcb_);
pcb_ = nullptr;
errno = EOPNOTSUPP;
return -1;
}
// tcp_listen reallocates the pcb, replace ours
pcb_ = listen_pcb;
// set callbacks on new pcb
tcp_arg(pcb_, this);
tcp_accept(pcb_, LWIPRawImpl::s_accept_fn);
return 0;
}
ssize_t read(void *buf, size_t len) override {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
if (rx_closed_ && rx_buf_ == nullptr) {
errno = ECONNRESET;
return -1;
}
if (len == 0) {
return 0;
}
if (rx_buf_ == nullptr) {
errno = EWOULDBLOCK;
return -1;
}
size_t read = 0;
uint8_t *buf8 = reinterpret_cast<uint8_t *>(buf);
while (len) {
size_t pb_len = rx_buf_->len;
size_t pb_left = pb_len - rx_buf_offset_;
if (pb_left == 0)
break;
size_t copysize = std::min(len, pb_left);
memcpy(buf8, reinterpret_cast<uint8_t *>(rx_buf_->payload) + rx_buf_offset_, copysize);
if (pb_left == copysize) {
// full pb copied, free it
if (rx_buf_->next == nullptr) {
// last buffer in chain
pbuf_free(rx_buf_);
rx_buf_ = nullptr;
rx_buf_offset_ = 0;
} else {
auto *old_buf = rx_buf_;
rx_buf_ = rx_buf_->next;
pbuf_ref(rx_buf_);
pbuf_free(old_buf);
rx_buf_offset_ = 0;
}
} else {
rx_buf_offset_ += copysize;
}
tcp_recved(pcb_, copysize);
buf8 += copysize;
len -= copysize;
read += copysize;
}
return read;
}
ssize_t write(const void *buf, size_t len) {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
if (len == 0)
return 0;
if (buf == nullptr) {
errno = EINVAL;
return 0;
}
auto space = tcp_sndbuf(pcb_);
if (space == 0) {
errno = EWOULDBLOCK;
return -1;
}
size_t to_send = std::min((size_t) space, len);
err_t err = tcp_write(pcb_, buf, to_send, TCP_WRITE_FLAG_COPY);
if (err == ERR_MEM) {
errno = EWOULDBLOCK;
return -1;
}
if (err != ERR_OK) {
errno = EIO;
return -1;
}
err = tcp_output(pcb_);
if (err != ERR_OK) {
errno = EIO;
return -1;
}
return to_send;
}
int setblocking(bool blocking) {
if (pcb_ == nullptr) {
errno = EBADF;
return -1;
}
if (blocking) {
// blocking operation not supported
errno = EINVAL;
return -1;
}
return 0;
}
err_t accept_fn(struct tcp_pcb *newpcb, err_t err) {
if (err != ERR_OK || newpcb == 0) {
// "An error code if there has been an error accepting. Only return ERR_ABRT if you have
// called tcp_abort from within the callback function!"
// https://www.nongnu.org/lwip/2_1_x/tcp_8h.html#a00517abce6856d6c82f0efebdafb734d
// nothing to do here, we just don't push it to the queue
return ERR_OK;
}
accepted_sockets_.emplace(new LWIPRawImpl(newpcb));
return ERR_OK;
}
void err_fn(err_t err) {
// "If a connection is aborted because of an error, the application is alerted of this event by
// the err callback."
// pcb is already freed when this callback is called
// ERR_RST: connection was reset by remote host
// ERR_ABRT: aborted through tcp_abort or TCP timer
pcb_ = nullptr;
}
err_t recv_fn(struct pbuf *pb, err_t err) {
if (err != 0) {
// "An error code if there has been an error receiving Only return ERR_ABRT if you have
// called tcp_abort from within the callback function!"
rx_closed_ = true;
return ERR_OK;
}
if (pb == nullptr) {
rx_closed_ = true;
return ERR_OK;
}
if (rx_buf_ == nullptr) {
// no need to copy because lwIP gave control of it to us
rx_buf_ = pb;
rx_buf_offset_ = 0;
} else {
pbuf_cat(rx_buf_, pb);
}
return ERR_OK;
}
static err_t s_accept_fn(void *arg, struct tcp_pcb *newpcb, err_t err) {
LWIPRawImpl *arg_this = reinterpret_cast<LWIPRawImpl *>(arg);
return arg_this->accept_fn(newpcb, err);
}
static void s_err_fn(void *arg, err_t err) {
LWIPRawImpl *arg_this = reinterpret_cast<LWIPRawImpl *>(arg);
return arg_this->err_fn(err);
}
static err_t s_recv_fn(void *arg, struct tcp_pcb *pcb, struct pbuf *pb, err_t err) {
LWIPRawImpl *arg_this = reinterpret_cast<LWIPRawImpl *>(arg);
return arg_this->recv_fn(pb, err);
}
protected:
struct tcp_pcb *pcb_;
std::queue<std::unique_ptr<LWIPRawImpl>> accepted_sockets_;
bool rx_closed_ = false;
pbuf *rx_buf_ = nullptr;
size_t rx_buf_offset_ = 0;
};
std::unique_ptr<Socket> socket(int domain, int type, int protocol) {
auto *pcb = tcp_new();
if (pcb == nullptr)
return nullptr;
auto *sock = new LWIPRawImpl(pcb);
sock->init();
return std::unique_ptr<Socket>{sock};
}
} // namespace socket
} // namespace esphome
#endif // USE_SOCKET_IMPL_LWIP_TCP

View File

@@ -0,0 +1,42 @@
#pragma once
#include <string>
#include <memory>
#include "headers.h"
#include "esphome/core/optional.h"
namespace esphome {
namespace socket {
class Socket {
public:
Socket() = default;
virtual ~Socket() = default;
Socket(const Socket &) = delete;
Socket &operator=(const Socket &) = delete;
virtual std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) = 0;
virtual int bind(const struct sockaddr *addr, socklen_t addrlen) = 0;
virtual int close() = 0;
// not supported yet:
// virtual int connect(const std::string &address) = 0;
// virtual int connect(const struct sockaddr *addr, socklen_t addrlen) = 0;
virtual int shutdown(int how) = 0;
virtual int getpeername(struct sockaddr *addr, socklen_t *addrlen) = 0;
virtual std::string getpeername() = 0;
virtual int getsockname(struct sockaddr *addr, socklen_t *addrlen) = 0;
virtual std::string getsockname() = 0;
virtual int getsockopt(int level, int optname, void *optval, socklen_t *optlen) = 0;
virtual int setsockopt(int level, int optname, const void *optval, socklen_t optlen) = 0;
virtual int listen(int backlog) = 0;
virtual ssize_t read(void *buf, size_t len) = 0;
virtual ssize_t write(const void *buf, size_t len) = 0;
virtual int setblocking(bool blocking) = 0;
virtual int loop() { return 0; };
};
std::unique_ptr<Socket> socket(int domain, int type, int protocol);
} // namespace socket
} // namespace esphome

View File

@@ -0,0 +1,9 @@
import esphome.config_validation as cv
AUTO_LOAD = ["socket"]
CONFIG_SCHEMA = cv.Schema({})
async def to_code(config):
pass

View File

@@ -0,0 +1,272 @@
#include "ssl_context.h"
#include <string.h>
#include "mbedtls/platform.h"
#include "mbedtls/net_sockets.h"
#include "mbedtls/esp_debug.h"
#include "mbedtls/ssl.h"
#include "mbedtls/entropy.h"
#include "mbedtls/ctr_drbg.h"
#include "mbedtls/error.h"
#include "mbedtls/certs.h"
#ifdef INADDR_NONE
#undef INADDR_NONE
#endif
#include "esphome/core/log.h"
namespace esphome {
namespace ssl {
static const char *const TAG = "ssl.mbedtls";
static int entropy_hw_random_source(void *data, uint8_t *output, size_t len, size_t *olen) {
esp_fill_random(output, len);
*olen = len;
return 0;
}
struct MbedTLSBioCtx {
socket::Socket *sock;
static int send(void *raw, const uint8_t *buf, size_t len) {
auto *ctx = reinterpret_cast<MbedTLSBioCtx *>(raw);
ssize_t ret = ctx->sock->write(buf, len);
if (ret != -1)
return ret;
if (errno == EWOULDBLOCK || errno == EAGAIN)
return MBEDTLS_ERR_SSL_WANT_WRITE;
if (errno == EPIPE || errno == ECONNRESET)
return MBEDTLS_ERR_NET_CONN_RESET;
if (errno == EINTR)
return MBEDTLS_ERR_SSL_WANT_WRITE;
return MBEDTLS_ERR_NET_SEND_FAILED;
}
static int recv(void *raw, uint8_t *buf, size_t len) {
auto *ctx = reinterpret_cast<MbedTLSBioCtx *>(raw);
ssize_t ret = ctx->sock->read(buf, len);
if (ret != -1)
return ret;
if (errno == EWOULDBLOCK || errno == EAGAIN)
return MBEDTLS_ERR_SSL_WANT_WRITE;
if (errno == EPIPE || errno == ECONNRESET)
return MBEDTLS_ERR_NET_CONN_RESET;
if (errno == EINTR)
return MBEDTLS_ERR_SSL_WANT_WRITE;
return MBEDTLS_ERR_NET_SEND_FAILED;
}
};
void test();
class MbedTLSWrappedSocket : public socket::Socket {
public:
MbedTLSWrappedSocket(std::unique_ptr<socket::Socket> sock) : socket::Socket(), sock_(std::move(sock)) {}
~MbedTLSWrappedSocket() override {
mbedtls_ssl_free(&ssl_);
sock_ = nullptr;
}
void init(const mbedtls_ssl_config *conf) {
// TODO: reuse ssl contexts?
mbedtls_ssl_init(&ssl_);
int err = mbedtls_ssl_setup(&ssl_, conf);
if (err != 0) {
ESP_LOGW(TAG, "mbedtls_ssl_setup failed: %d", err);
return;
}
// sock pointer does not fit in void*
// instead store it in a heap-allocated var
auto *ctx = new MbedTLSBioCtx;
// unsafe, but should be fine because we free before sock is reset
ctx->sock = sock_.get();
mbedtls_ssl_set_bio(&ssl_, ctx, MbedTLSBioCtx::send, MbedTLSBioCtx::recv, nullptr);
do_handshake_ = true;
}
std::unique_ptr<Socket> accept(struct sockaddr *addr, socklen_t *addrlen) override {
// only for server sockets
errno = EBADF;
return {};
}
int bind(const struct sockaddr *addr, socklen_t addrlen) override {
errno = EBADF;
return -1;
}
int close() override {
do_handshake_ = false;
return sock_->close();
}
int connect(const std::string &address) override { return sock_->connect(address); }
int connect(const struct sockaddr *addr, socklen_t addrlen) override { return sock_->connect(addr, addrlen); }
int shutdown(int how) override {
do_handshake_ = false;
int ret = mbedtls_ssl_close_notify(&ssl_);
if (ret != 0)
return this->mbedtls_to_errno_(ret);
return this->sock_->shutdown(how);
}
int getpeername(struct sockaddr *addr, socklen_t *addrlen) override { return sock_->getpeername(addr, addrlen); }
std::string getpeername() override { return sock_->getpeername(); }
int getsockname(struct sockaddr *addr, socklen_t *addrlen) override { return sock_->getsockname(addr, addrlen); }
std::string getsockname() override { return sock_->getsockname(); }
int getsockopt(int level, int optname, void *optval, socklen_t *optlen) override {
return sock_->getsockopt(level, optname, optval, optlen);
}
int setsockopt(int level, int optname, const void *optval, socklen_t optlen) override {
return sock_->setsockopt(level, optname, optval, optlen);
}
int listen(int backlog) override {
errno = EBADF;
return -1;
}
ssize_t read(void *buf, size_t len) override {
// mbedtls will automatically perform handshake here if necessary
loop();
if (do_handshake_) {
errno = EWOULDBLOCK;
return -1;
}
int ret = mbedtls_ssl_read(&ssl_, reinterpret_cast<uint8_t *>(buf), len);
return this->mbedtls_to_errno_(ret);
}
ssize_t write(const void *buf, size_t len) override {
loop();
if (do_handshake_) {
errno = EWOULDBLOCK;
return -1;
}
int ret = mbedtls_ssl_write(&ssl_, reinterpret_cast<const uint8_t *>(buf), len);
return this->mbedtls_to_errno_(ret);
}
int setblocking(bool blocking) override {
// TODO: handle blocking modes
return sock_->setblocking(blocking);
}
int loop() override {
if (do_handshake_) {
int err = mbedtls_ssl_handshake_step(&ssl_);
if (err == 0) {
do_handshake_ = false;
} else if (err == MBEDTLS_ERR_SSL_WANT_WRITE || err == MBEDTLS_ERR_SSL_WANT_READ) {
} else {
do_handshake_ = false;
return -1;
}
}
return 0;
}
protected:
int mbedtls_to_errno_(int ret) {
if (ret >= 0) {
return ret;
} else if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
errno = EWOULDBLOCK;
return -1;
} else if (ret == MBEDTLS_ERR_NET_CONN_RESET) {
errno = ECONNRESET;
return -1;
} else {
if (errno == 0)
errno = EIO;
ESP_LOGW(TAG, "mbedtls failed with %d", ret);
return -1;
}
}
std::unique_ptr<socket::Socket> sock_;
mbedtls_ssl_context ssl_;
bool do_handshake_ = false;
};
void debug_callback(void *arg, int level, const char *filename, int line_number, const char *msg) {
ESP_LOGV(TAG, "mbedtls %d [%s:%d]: %s", level, filename, line_number, msg);
}
class MbedTLSContext : public SSLContext {
public:
MbedTLSContext() = default;
~MbedTLSContext() override {
mbedtls_pk_free(&pkey_);
mbedtls_entropy_free(&entropy_);
mbedtls_ctr_drbg_free(&ctr_drbg_);
mbedtls_x509_crt_free(&srv_cert_);
mbedtls_ssl_config_free(&conf_);
}
void set_server_certificate(const char *cert) override { this->srv_cert_str_ = cert; }
void set_private_key(const char *private_key) override { this->privkey_str_ = private_key; }
int init() override {
mbedtls_x509_crt_init(&srv_cert_);
mbedtls_ctr_drbg_init(&ctr_drbg_);
mbedtls_entropy_init(&entropy_);
mbedtls_pk_init(&pkey_);
mbedtls_ssl_config_init(&conf_);
// TODO check what this does
int err = mbedtls_entropy_add_source(&entropy_, entropy_hw_random_source, NULL, 134, MBEDTLS_ENTROPY_SOURCE_STRONG);
if (err != 0) {
ESP_LOGW(TAG, "mbedtls_entropy_add_source failed: %d", err);
return 1;
}
err = mbedtls_ctr_drbg_seed(&ctr_drbg_, mbedtls_entropy_func, &entropy_, NULL, 0);
if (err != 0) {
ESP_LOGW(TAG, "mbedtls_ctr_drbg_seed failed: %d", err);
return 1;
}
err = mbedtls_x509_crt_parse(&srv_cert_, reinterpret_cast<const uint8_t *>(srv_cert_str_),
// "including the terminating NULL byte"
strlen(srv_cert_str_) + 1);
if (err != 0) {
ESP_LOGW(TAG, "mbedtls_x509_crt_parse failed: %d", err);
return 1;
}
err = mbedtls_pk_parse_key(&pkey_, reinterpret_cast<const uint8_t *>(privkey_str_),
// "including the terminating NULL byte"
strlen(privkey_str_) + 1, nullptr, 0);
if (err != 0) {
ESP_LOGW(TAG, "mbedtls_pk_parse_key failed: %d", err);
return 1;
}
err = mbedtls_ssl_config_defaults(&conf_, MBEDTLS_SSL_IS_SERVER, MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT);
if (err != 0) {
ESP_LOGW(TAG, "mbedtls_ssl_config_defaults failed: %d", err);
return 1;
}
mbedtls_ssl_conf_rng(&conf_, mbedtls_ctr_drbg_random, &ctr_drbg_);
err = mbedtls_ssl_conf_own_cert(&conf_, &srv_cert_, &pkey_);
if (err != 0) {
ESP_LOGW(TAG, "mbedtls_ssl_conf_own_cert failed: %d", err);
return 1;
}
mbedtls_ssl_conf_dbg(&conf_, debug_callback, nullptr);
return 0;
}
std::unique_ptr<socket::Socket> wrap_socket(std::unique_ptr<socket::Socket> sock) override {
auto *wrapped = new MbedTLSWrappedSocket(std::move(sock));
wrapped->init(&conf_);
return std::unique_ptr<socket::Socket>{wrapped};
}
protected:
const char *srv_cert_str_ = nullptr;
const char *privkey_str_ = nullptr;
mbedtls_entropy_context entropy_;
mbedtls_ctr_drbg_context ctr_drbg_;
mbedtls_x509_crt srv_cert_;
mbedtls_pk_context pkey_;
mbedtls_ssl_config conf_;
};
std::unique_ptr<SSLContext> create_context() { return std::unique_ptr<SSLContext>{new MbedTLSContext()}; }
} // namespace ssl
} // namespace esphome

View File

@@ -0,0 +1,24 @@
#pragma once
#include <memory>
#include "esphome/components/socket/socket.h"
namespace esphome {
namespace ssl {
class SSLContext {
public:
SSLContext() = default;
virtual ~SSLContext() = default;
SSLContext(const SSLContext &) = delete;
SSLContext &operator=(const SSLContext &) = delete;
virtual int init() = 0;
virtual void set_server_certificate(const char *cert) = 0;
virtual void set_private_key(const char *private_key) = 0;
virtual std::unique_ptr<socket::Socket> wrap_socket(std::unique_ptr<socket::Socket> sock) = 0;
};
std::unique_ptr<SSLContext> create_context();
} // namespace ssl
} // namespace esphome

View File

@@ -29,3 +29,7 @@
#define USE_CAPTIVE_PORTAL
#define ESPHOME_BOARD "dummy_board"
#define USE_MDNS
#define USE_SOCKET_IMPL_LWIP_TCP
#define USE_SOCKET_IMPL_BSD_SOCKETS
#define USE_API_NOISE
#define USE_API_PLAINTEXT

View File

@@ -55,6 +55,15 @@ double random_double() { return random_uint32() / double(UINT32_MAX); }
float random_float() { return float(random_double()); }
void fill_random(uint8_t *data, size_t len) {
#ifdef ARDUINO_ARCH_ESP32
esp_fill_random(data, len);
#else
int err = os_get_random(data, len);
assert(err == 0);
#endif
}
static uint32_t fast_random_seed = 0; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
void fast_random_set_seed(uint32_t seed) { fast_random_seed = seed; }

View File

@@ -109,6 +109,8 @@ double random_double();
/// Returns a random float between 0 and 1. Essentially just casts random_double() to a float.
float random_float();
void fill_random(uint8_t *data, size_t len);
void fast_random_set_seed(uint32_t seed);
uint32_t fast_random_32();
uint16_t fast_random_16();

View File

@@ -36,6 +36,7 @@ lib_deps =
6306@1.0.3 ; HM3301
glmnet/Dsmr@0.3 ; used by dsmr
rweather/Crypto@0.2.0 ; used by dsmr
esphome/noise-c@0.1.0
build_flags =
-DESPHOME_LOG_LEVEL=ESPHOME_LOG_LEVEL_VERY_VERBOSE